sgemm_impl: 128x64 tile; fix unrolled asm, comment out actual gemm

This commit is contained in:
Hansung Kim
2024-09-05 16:22:19 -07:00
parent 137df9bee2
commit a832fa7b84

View File

@@ -29,7 +29,7 @@ using float_type = float16_t;
// (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER
// * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields
// BM <= BK*TM*TN
#define BM 64
#define BM 128
#define BN 64
#if (FP_SIZE == 32)
#define BK 64
@@ -72,7 +72,7 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER ==
#define TRANSPOSE_AT_PRODUCE 0
#define TRANSPOSE_AT_CONSUME 0
#define GEMMINI_DMA 1
#define GEMMINI_DMA 0
#define GEMMINI_DMA_MN_MAJOR 1
#if SMEM_SIZE == 0x4000
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
@@ -299,14 +299,23 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
(WM * warp_row + TCM * wm_iter) + row]);
// f8-f15 stores a single row of A
// threads read from different columns; no bank conflicts
// asm volatile("flw f0, %0(%1)" :: "i"(smem_AS_cols * 0 * sizeof(float)), "r"(smem_addr));
// asm volatile("flw f1, %0(%1)" :: "i"(smem_AS_cols * 1 * sizeof(float)), "r"(smem_addr));
// asm volatile("flw f2, %0(%1)" :: "i"(smem_AS_cols * 2 * sizeof(float)), "r"(smem_addr));
// asm volatile("flw f3, %0(%1)" :: "i"(smem_AS_cols * 3 * sizeof(float)), "r"(smem_addr));
// asm volatile("flw f4, %0(%1)" :: "i"(smem_AS_cols * 4 * sizeof(float)), "r"(smem_addr));
// asm volatile("flw f5, %0(%1)" :: "i"(smem_AS_cols * 5 * sizeof(float)), "r"(smem_addr));
// asm volatile("flw f6, %0(%1)" :: "i"(smem_AS_cols * 6 * sizeof(float)), "r"(smem_addr));
// asm volatile("flw f7, %0(%1)" :: "i"(smem_AS_cols * 7 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f0, %0(%1)" :: "i"(smem_AS_cols * 0 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f1, %0(%1)" :: "i"(smem_AS_cols * 1 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f2, %0(%1)" :: "i"(smem_AS_cols * 2 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f3, %0(%1)" :: "i"(smem_AS_cols * 3 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f4, %0(%1)" :: "i"(smem_AS_cols * 4 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f5, %0(%1)" :: "i"(smem_AS_cols * 5 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f6, %0(%1)" :: "i"(smem_AS_cols * 6 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f7, %0(%1)" :: "i"(smem_AS_cols * 7 * sizeof(float)), "r"(smem_addr));
smem_addr += smem_AS_cols * 4 * sizeof(float);
asm volatile("flw f4, %0(%1)" :: "i"(smem_AS_cols * 0 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f5, %0(%1)" :: "i"(smem_AS_cols * 1 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f6, %0(%1)" :: "i"(smem_AS_cols * 2 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f7, %0(%1)" :: "i"(smem_AS_cols * 3 * sizeof(float)), "r"(smem_addr));
} else {
static_assert(layout ==
MemLayout::K_major /* fake cond that is always false */,
@@ -638,34 +647,67 @@ load_tile_to_smem(const uint32_t dim_major, const uint32_t mn_index,
// need to branch because address offset constant in the inline assembly
// cannot be larger than a certain limit
if constexpr (!transposed_write) {
// asm volatile("fsw ft0, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
// sizeof(float)),
// "r"(local));
// asm volatile("fsw ft1, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 *
// sizeof(float)),
// "r"(local));
// local += smem_dim_col * row_stride * 2;
// asm volatile("fsw ft2, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
// sizeof(float)),
// "r"(local));
// asm volatile("fsw ft3, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 *
// sizeof(float)),
// "r"(local));
// local += smem_dim_col * row_stride * 2;
// asm volatile("fsw ft4, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
// sizeof(float)),
// "r"(local));
// asm volatile("fsw ft5, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 *
// sizeof(float)),
// "r"(local));
// local += smem_dim_col * row_stride * 2;
// asm volatile("fsw ft6, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
// sizeof(float)),
// "r"(local));
// asm volatile("fsw ft7, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 *
// sizeof(float)),
// "r"(local));
// local += smem_dim_col * row_stride * 2;
asm volatile("fsw ft0, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
sizeof(float)),
"r"(local));
asm volatile("fsw ft1, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 *
local += smem_dim_col * row_stride;
asm volatile("fsw ft1, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
sizeof(float)),
"r"(local));
local += smem_dim_col * row_stride * 2;
local += smem_dim_col * row_stride;
asm volatile("fsw ft2, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
sizeof(float)),
"r"(local));
asm volatile("fsw ft3, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 *
local += smem_dim_col * row_stride;
asm volatile("fsw ft3, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
sizeof(float)),
"r"(local));
local += smem_dim_col * row_stride * 2;
local += smem_dim_col * row_stride;
asm volatile("fsw ft4, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
sizeof(float)),
"r"(local));
asm volatile("fsw ft5, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 *
local += smem_dim_col * row_stride;
asm volatile("fsw ft5, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
sizeof(float)),
"r"(local));
local += smem_dim_col * row_stride * 2;
local += smem_dim_col * row_stride;
asm volatile("fsw ft6, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
sizeof(float)),
"r"(local));
asm volatile("fsw ft7, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 *
local += smem_dim_col * row_stride;
asm volatile("fsw ft7, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
sizeof(float)),
"r"(local));
local += smem_dim_col * row_stride * 2;
local += smem_dim_col * row_stride;
} else {
// currently, tensor core hardware only supports MN-major SMEM tile
// layout for correct results
@@ -996,6 +1038,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
warps_per_threadblock_per_core);
#endif
#if 0
// consumer code: SMEM->RF and compute
// ----------------------------------------------------------------------
// @perf: this loop spills to stack a lot because of all the flws in
@@ -1044,6 +1087,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
#endif
}
if constexpr (write_to_gmem) {