sgemm_impl: Rename to wmma

This commit is contained in:
Hansung Kim
2024-08-18 16:21:22 -07:00
parent b978bf8757
commit b44b202a21
2 changed files with 31 additions and 32 deletions

View File

@@ -78,7 +78,6 @@ inline void thread_block_online_softmax(
volatile float *gmem_tmp0 = reinterpret_cast<volatile float *>(0xd0000000UL);
volatile float *gmem_tmp1 = reinterpret_cast<volatile float *>(0xe0000000UL);
volatile float *gmem_tmp2 = reinterpret_cast<volatile float *>(0xf0000000UL);
float *smem_rowmax_prev = smem_rowmax;
float *smem_rowmax_new = smem_rowmax + B_ROW;
@@ -310,35 +309,33 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
(smem_QK_size + smem_V_size + smem_O_size) *
threadblock_id_in_cluster);
uint8_t *smem_S = smem_per_threadblock;
uint8_t *smem_P = smem_S; // in-place update from S to P
uint8_t *smem_V = smem_per_threadblock + sizeof(float) * smem_QK_size;
uint8_t *smem_O =
smem_per_threadblock + sizeof(float) * (smem_QK_size + smem_V_size);
float *smem_S = reinterpret_cast<float *>(smem_per_threadblock);
float *smem_P = smem_S; // in-place update from S to P
float *smem_V =
reinterpret_cast<float *>(smem_per_threadblock) + smem_QK_size;
float *smem_O = reinterpret_cast<float *>(smem_per_threadblock) +
smem_QK_size + smem_V_size;
// allocate rowmax/rowsum storage at the end of the sharedmem address space
constexpr uint32_t smem_rowmax_size = sizeof(float) * B_ROW * 3 /* mi, mi~, minew */;
constexpr uint32_t smem_rowsum_size = sizeof(float) * B_ROW;
uint8_t *smem_rowmax =
reinterpret_cast<uint8_t *>(SMEM_ADDR_END) - smem_rowmax_size;
uint8_t *smem_rowsum = smem_rowmax - smem_rowsum_size;
constexpr uint32_t smem_rowmax_size = B_ROW * 3 /* mi, mi~, minew */;
constexpr uint32_t smem_rowsum_size = B_ROW;
float *smem_rowmax =
reinterpret_cast<float *>(SMEM_ADDR_END) - smem_rowmax_size;
float *smem_rowsum = smem_rowmax - smem_rowsum_size;
// sharedmem "scratchpad" area to put temporary data, e.g. for tree reduction
// in rowsum
// NOTE: out-of bounds is not checked
constexpr uint32_t smem_scratchpad_size =
sizeof(float) * B_ROW * NUM_THREADS * 2 /*arbitrary slack*/;
uint8_t *smem_scratchpad =
smem_rowmax - smem_scratchpad_size;
B_ROW * NUM_THREADS * 2 /*arbitrary slack*/;
float *smem_scratchpad = smem_rowmax - smem_scratchpad_size;
const uint32_t warps_per_threadblock_per_core =
NUM_WARPS / threads_per_threadblock;
// initialize rowmax/rowsum values in sharedmem
thread_block_init_sharedmem(tid_in_threadblock, threads_per_threadblock,
(float *)smem_O,
(float *)smem_rowmax,
(float *)smem_rowsum);
smem_O, smem_rowmax, smem_rowsum);
#define SKIP_GEMM
#ifndef SKIP_GEMM
@@ -359,16 +356,18 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
#endif
thread_block_online_softmax(
tile_S, (float *)smem_O, (float *)smem_P, tid_in_threadblock,
threads_per_threadblock, threadblock_id_in_cluster,
(float *)smem_scratchpad, (float *)smem_rowmax, (float *)smem_rowsum);
tile_S, smem_O, smem_P, tid_in_threadblock, threads_per_threadblock,
threadblock_id_in_cluster, smem_scratchpad, smem_rowmax, smem_rowsum);
// FIXME unnecessary?
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
thread_block_gemm_single_tile(smem_P, smem_V, tid_in_threadblock,
threads_per_threadblock);
float *gmem_tmp2 = reinterpret_cast<float *>(0xf0000000UL);
thread_block_gemm_single_tile<float, /*write_to_smem=*/true>(
smem_P, smem_V, gmem_tmp2 /*smem_O*/, tid_in_threadblock,
threads_per_threadblock);
}
int main() {