From b44b202a21d669cbfda2ee382460f9c6468967f0 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sun, 18 Aug 2024 16:21:22 -0700 Subject: [PATCH] sgemm_impl: Rename to wmma --- tests/regression/flash_attention/kernel.cpp | 43 ++++++++++----------- tests/regression/sgemm_tcore/sgemm_impl.hpp | 20 +++++----- 2 files changed, 31 insertions(+), 32 deletions(-) diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index bd238bb6..d627e413 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -78,7 +78,6 @@ inline void thread_block_online_softmax( volatile float *gmem_tmp0 = reinterpret_cast(0xd0000000UL); volatile float *gmem_tmp1 = reinterpret_cast(0xe0000000UL); - volatile float *gmem_tmp2 = reinterpret_cast(0xf0000000UL); float *smem_rowmax_prev = smem_rowmax; float *smem_rowmax_new = smem_rowmax + B_ROW; @@ -310,35 +309,33 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { (smem_QK_size + smem_V_size + smem_O_size) * threadblock_id_in_cluster); - uint8_t *smem_S = smem_per_threadblock; - uint8_t *smem_P = smem_S; // in-place update from S to P - uint8_t *smem_V = smem_per_threadblock + sizeof(float) * smem_QK_size; - uint8_t *smem_O = - smem_per_threadblock + sizeof(float) * (smem_QK_size + smem_V_size); + float *smem_S = reinterpret_cast(smem_per_threadblock); + float *smem_P = smem_S; // in-place update from S to P + float *smem_V = + reinterpret_cast(smem_per_threadblock) + smem_QK_size; + float *smem_O = reinterpret_cast(smem_per_threadblock) + + smem_QK_size + smem_V_size; // allocate rowmax/rowsum storage at the end of the sharedmem address space - constexpr uint32_t smem_rowmax_size = sizeof(float) * B_ROW * 3 /* mi, mi~, minew */; - constexpr uint32_t smem_rowsum_size = sizeof(float) * B_ROW; - uint8_t *smem_rowmax = - reinterpret_cast(SMEM_ADDR_END) - smem_rowmax_size; - uint8_t *smem_rowsum = smem_rowmax - smem_rowsum_size; + constexpr uint32_t smem_rowmax_size = B_ROW * 3 /* mi, mi~, minew */; + constexpr uint32_t smem_rowsum_size = B_ROW; + float *smem_rowmax = + reinterpret_cast(SMEM_ADDR_END) - smem_rowmax_size; + float *smem_rowsum = smem_rowmax - smem_rowsum_size; // sharedmem "scratchpad" area to put temporary data, e.g. for tree reduction // in rowsum // NOTE: out-of bounds is not checked constexpr uint32_t smem_scratchpad_size = - sizeof(float) * B_ROW * NUM_THREADS * 2 /*arbitrary slack*/; - uint8_t *smem_scratchpad = - smem_rowmax - smem_scratchpad_size; + B_ROW * NUM_THREADS * 2 /*arbitrary slack*/; + float *smem_scratchpad = smem_rowmax - smem_scratchpad_size; const uint32_t warps_per_threadblock_per_core = NUM_WARPS / threads_per_threadblock; // initialize rowmax/rowsum values in sharedmem thread_block_init_sharedmem(tid_in_threadblock, threads_per_threadblock, - (float *)smem_O, - (float *)smem_rowmax, - (float *)smem_rowsum); + smem_O, smem_rowmax, smem_rowsum); #define SKIP_GEMM #ifndef SKIP_GEMM @@ -359,16 +356,18 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { #endif thread_block_online_softmax( - tile_S, (float *)smem_O, (float *)smem_P, tid_in_threadblock, - threads_per_threadblock, threadblock_id_in_cluster, - (float *)smem_scratchpad, (float *)smem_rowmax, (float *)smem_rowsum); + tile_S, smem_O, smem_P, tid_in_threadblock, threads_per_threadblock, + threadblock_id_in_cluster, smem_scratchpad, smem_rowmax, smem_rowsum); // FIXME unnecessary? threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core); - thread_block_gemm_single_tile(smem_P, smem_V, tid_in_threadblock, - threads_per_threadblock); + float *gmem_tmp2 = reinterpret_cast(0xf0000000UL); + + thread_block_gemm_single_tile( + smem_P, smem_V, gmem_tmp2 /*smem_O*/, tid_in_threadblock, + threads_per_threadblock); } int main() { diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index 6fb18e7e..0785c5bd 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -205,15 +205,15 @@ inline void vx_wmma(const int dest_reg) { // `local_k` is assumed to be multiple of TCK template -inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k, +inline void wmma_load_a(volatile const T *smem_A, const int local_k, const int warp_row, const int wm_iter, const int thread_in_warp) { - asm volatile ("vx_wmma_load_a_start_%=:" :: ); + asm volatile ("wmma_load_a_start_%=:" :: ); const int tid = thread_in_warp; const int tg = tid / 4; - // @perf: this is duplicately computed in vx_wmma_load_a and vx_wmma_load_b + // @perf: this is duplicately computed in wmma_load_a and wmma_load_b int row = 0; int col = 0; map_operand(tid, row, col); @@ -273,15 +273,15 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k, asm volatile("flw f7, %0(%1)" :: "i"(smem_AS_cols * 7 * sizeof(float)), "r"(smem_addr)); } - asm volatile ("vx_wmma_load_a_finish_%=:" :: ); + asm volatile ("wmma_load_a_finish_%=:" :: ); } // `local_k` is assumed to be multiple of TCK template -inline void vx_wmma_load_b(const volatile T *smem_B, const int local_k, +inline void wmma_load_b(const volatile T *smem_B, const int local_k, const int warp_col, const int wn_iter, const int thread_in_warp) { - asm volatile ("vx_wmma_load_b_start_%=:" :: ); + asm volatile ("wmma_load_b_start_%=:" :: ); const int tid = thread_in_warp; const int tg = tid / 4; @@ -290,7 +290,7 @@ inline void vx_wmma_load_b(const volatile T *smem_B, const int local_k, int col = 0; map_operand(tid, row, col); - // see comment in vx_wmma_load_a + // see comment in wmma_load_a constexpr int packed_factor = (std::is_same_v ? 2 : 1); constexpr int BK_adjusted = BN / packed_factor; constexpr int BN_adjusted = BN / packed_factor; @@ -316,7 +316,7 @@ inline void vx_wmma_load_b(const volatile T *smem_B, const int local_k, asm volatile("flw f14, %0(%1)" :: "i"(smem_B_cols * 6 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f15, %0(%1)" :: "i"(smem_B_cols * 7 * sizeof(float)), "r"(smem_addr)); - asm volatile ("vx_wmma_load_b_finish_%=:" :: ); + asm volatile ("wmma_load_b_finish_%=:" :: ); } inline void initialize_C(const int dest_reg) { @@ -659,11 +659,11 @@ thread_block_gemm_single_tile(const T *local_a, const T *local_b, T *local_c, #pragma GCC unroll 2 for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { // SMEM -> RF - vx_wmma_load_b(local_b, local_k, warp_col, wn_iter, tid_in_warp); + wmma_load_b(local_b, local_k, warp_col, wn_iter, tid_in_warp); #pragma GCC unroll 2 for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { // SMEM -> RF - vx_wmma_load_a(local_a, local_k, warp_row, wm_iter, tid_in_warp); + wmma_load_a(local_a, local_k, warp_row, wm_iter, tid_in_warp); // perform mma vx_wmma(wm_iter); }