sgemm_impl: 128x64 tile; fix unrolled asm, comment out actual gemm
This commit is contained in:
@@ -29,7 +29,7 @@ using float_type = float16_t;
|
||||
// (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER
|
||||
// * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields
|
||||
// BM <= BK*TM*TN
|
||||
#define BM 64
|
||||
#define BM 128
|
||||
#define BN 64
|
||||
#if (FP_SIZE == 32)
|
||||
#define BK 64
|
||||
@@ -72,7 +72,7 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER ==
|
||||
#define TRANSPOSE_AT_PRODUCE 0
|
||||
#define TRANSPOSE_AT_CONSUME 0
|
||||
|
||||
#define GEMMINI_DMA 1
|
||||
#define GEMMINI_DMA 0
|
||||
#define GEMMINI_DMA_MN_MAJOR 1
|
||||
#if SMEM_SIZE == 0x4000
|
||||
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
|
||||
@@ -299,14 +299,23 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
||||
(WM * warp_row + TCM * wm_iter) + row]);
|
||||
// f8-f15 stores a single row of A
|
||||
// threads read from different columns; no bank conflicts
|
||||
// asm volatile("flw f0, %0(%1)" :: "i"(smem_AS_cols * 0 * sizeof(float)), "r"(smem_addr));
|
||||
// asm volatile("flw f1, %0(%1)" :: "i"(smem_AS_cols * 1 * sizeof(float)), "r"(smem_addr));
|
||||
// asm volatile("flw f2, %0(%1)" :: "i"(smem_AS_cols * 2 * sizeof(float)), "r"(smem_addr));
|
||||
// asm volatile("flw f3, %0(%1)" :: "i"(smem_AS_cols * 3 * sizeof(float)), "r"(smem_addr));
|
||||
// asm volatile("flw f4, %0(%1)" :: "i"(smem_AS_cols * 4 * sizeof(float)), "r"(smem_addr));
|
||||
// asm volatile("flw f5, %0(%1)" :: "i"(smem_AS_cols * 5 * sizeof(float)), "r"(smem_addr));
|
||||
// asm volatile("flw f6, %0(%1)" :: "i"(smem_AS_cols * 6 * sizeof(float)), "r"(smem_addr));
|
||||
// asm volatile("flw f7, %0(%1)" :: "i"(smem_AS_cols * 7 * sizeof(float)), "r"(smem_addr));
|
||||
asm volatile("flw f0, %0(%1)" :: "i"(smem_AS_cols * 0 * sizeof(float)), "r"(smem_addr));
|
||||
asm volatile("flw f1, %0(%1)" :: "i"(smem_AS_cols * 1 * sizeof(float)), "r"(smem_addr));
|
||||
asm volatile("flw f2, %0(%1)" :: "i"(smem_AS_cols * 2 * sizeof(float)), "r"(smem_addr));
|
||||
asm volatile("flw f3, %0(%1)" :: "i"(smem_AS_cols * 3 * sizeof(float)), "r"(smem_addr));
|
||||
asm volatile("flw f4, %0(%1)" :: "i"(smem_AS_cols * 4 * sizeof(float)), "r"(smem_addr));
|
||||
asm volatile("flw f5, %0(%1)" :: "i"(smem_AS_cols * 5 * sizeof(float)), "r"(smem_addr));
|
||||
asm volatile("flw f6, %0(%1)" :: "i"(smem_AS_cols * 6 * sizeof(float)), "r"(smem_addr));
|
||||
asm volatile("flw f7, %0(%1)" :: "i"(smem_AS_cols * 7 * sizeof(float)), "r"(smem_addr));
|
||||
smem_addr += smem_AS_cols * 4 * sizeof(float);
|
||||
asm volatile("flw f4, %0(%1)" :: "i"(smem_AS_cols * 0 * sizeof(float)), "r"(smem_addr));
|
||||
asm volatile("flw f5, %0(%1)" :: "i"(smem_AS_cols * 1 * sizeof(float)), "r"(smem_addr));
|
||||
asm volatile("flw f6, %0(%1)" :: "i"(smem_AS_cols * 2 * sizeof(float)), "r"(smem_addr));
|
||||
asm volatile("flw f7, %0(%1)" :: "i"(smem_AS_cols * 3 * sizeof(float)), "r"(smem_addr));
|
||||
} else {
|
||||
static_assert(layout ==
|
||||
MemLayout::K_major /* fake cond that is always false */,
|
||||
@@ -638,34 +647,67 @@ load_tile_to_smem(const uint32_t dim_major, const uint32_t mn_index,
|
||||
// need to branch because address offset constant in the inline assembly
|
||||
// cannot be larger than a certain limit
|
||||
if constexpr (!transposed_write) {
|
||||
// asm volatile("fsw ft0, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
||||
// sizeof(float)),
|
||||
// "r"(local));
|
||||
// asm volatile("fsw ft1, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 *
|
||||
// sizeof(float)),
|
||||
// "r"(local));
|
||||
// local += smem_dim_col * row_stride * 2;
|
||||
// asm volatile("fsw ft2, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
||||
// sizeof(float)),
|
||||
// "r"(local));
|
||||
// asm volatile("fsw ft3, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 *
|
||||
// sizeof(float)),
|
||||
// "r"(local));
|
||||
// local += smem_dim_col * row_stride * 2;
|
||||
// asm volatile("fsw ft4, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
||||
// sizeof(float)),
|
||||
// "r"(local));
|
||||
// asm volatile("fsw ft5, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 *
|
||||
// sizeof(float)),
|
||||
// "r"(local));
|
||||
// local += smem_dim_col * row_stride * 2;
|
||||
// asm volatile("fsw ft6, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
||||
// sizeof(float)),
|
||||
// "r"(local));
|
||||
// asm volatile("fsw ft7, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 *
|
||||
// sizeof(float)),
|
||||
// "r"(local));
|
||||
// local += smem_dim_col * row_stride * 2;
|
||||
|
||||
asm volatile("fsw ft0, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
||||
sizeof(float)),
|
||||
"r"(local));
|
||||
asm volatile("fsw ft1, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 *
|
||||
local += smem_dim_col * row_stride;
|
||||
asm volatile("fsw ft1, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
||||
sizeof(float)),
|
||||
"r"(local));
|
||||
local += smem_dim_col * row_stride * 2;
|
||||
local += smem_dim_col * row_stride;
|
||||
asm volatile("fsw ft2, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
||||
sizeof(float)),
|
||||
"r"(local));
|
||||
asm volatile("fsw ft3, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 *
|
||||
local += smem_dim_col * row_stride;
|
||||
asm volatile("fsw ft3, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
||||
sizeof(float)),
|
||||
"r"(local));
|
||||
local += smem_dim_col * row_stride * 2;
|
||||
local += smem_dim_col * row_stride;
|
||||
asm volatile("fsw ft4, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
||||
sizeof(float)),
|
||||
"r"(local));
|
||||
asm volatile("fsw ft5, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 *
|
||||
local += smem_dim_col * row_stride;
|
||||
asm volatile("fsw ft5, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
||||
sizeof(float)),
|
||||
"r"(local));
|
||||
local += smem_dim_col * row_stride * 2;
|
||||
local += smem_dim_col * row_stride;
|
||||
asm volatile("fsw ft6, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
||||
sizeof(float)),
|
||||
"r"(local));
|
||||
asm volatile("fsw ft7, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 *
|
||||
local += smem_dim_col * row_stride;
|
||||
asm volatile("fsw ft7, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
||||
sizeof(float)),
|
||||
"r"(local));
|
||||
local += smem_dim_col * row_stride * 2;
|
||||
local += smem_dim_col * row_stride;
|
||||
} else {
|
||||
// currently, tensor core hardware only supports MN-major SMEM tile
|
||||
// layout for correct results
|
||||
@@ -996,6 +1038,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
||||
warps_per_threadblock_per_core);
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
// consumer code: SMEM->RF and compute
|
||||
// ----------------------------------------------------------------------
|
||||
// @perf: this loop spills to stack a lot because of all the flws in
|
||||
@@ -1044,6 +1087,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
#endif
|
||||
}
|
||||
|
||||
if constexpr (write_to_gmem) {
|
||||
|
||||
Reference in New Issue
Block a user