From ae9e7072807870115a2470f31fa8f730e2fb1956 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Tue, 18 Jun 2024 17:59:46 -0700 Subject: [PATCH] sgemm_{gemmini_dma,tcore}: Separate activate_block --- .../sgemm_gemmini_dma/kernel.activation.cpp | 296 ++++++++++------- .../sgemm_tcore/kernel.activation.cpp | 298 ++++++++++++------ 2 files changed, 379 insertions(+), 215 deletions(-) diff --git a/tests/regression/sgemm_gemmini_dma/kernel.activation.cpp b/tests/regression/sgemm_gemmini_dma/kernel.activation.cpp index 566fa2cb..704bb273 100644 --- a/tests/regression/sgemm_gemmini_dma/kernel.activation.cpp +++ b/tests/regression/sgemm_gemmini_dma/kernel.activation.cpp @@ -57,6 +57,154 @@ inline void threadblock_barrier(unsigned int barrier_id, unsigned int count) { vx_barrier(barrier_id, count); } +inline void activate_block(const uint32_t dim_n, const float *const C, + const uint32_t tile_i, const uint32_t tile_j, + const uint32_t warp_row, const uint32_t warp_col, + const uint32_t tid_in_threadblock) { + // activation code currently assumes that the column-width of a warp + // tile exactly matches SIMD width + static_assert(WN == NUM_THREADS); + + const uint32_t col_in_warptile = tid_in_threadblock % WN; + // const uint32_t row_in_warptile = elem_i; // FIXME: doesn't work with WN != + // NUM_THREADS + const uint32_t row_in_warptile = 0; + const uint32_t C_row = (tile_i * TILE_M) + (warp_row * WM) + row_in_warptile; + const uint32_t C_col = (tile_j * TILE_N) + (warp_col * WN) + col_in_warptile; + const float *const global_C = C + dim_n * C_row + C_col; + const float *global_C_curr = global_C; + + // read in elements from GMEM to RF + // each thread works on ELEM_PER_THREAD elements, which can be larger than 1 + static_assert(ELEM_PER_THREAD == 16, "currently assumes ELEM_PER_THREAD == 16"); + + constexpr uint32_t asm_unrolled = 8; // working with f0~f7 at a time + + for (int i = 0; i < ELEM_PER_THREAD; i += asm_unrolled) { + asm volatile("mv t6, %0" ::"r"(global_C_curr)); + asm volatile("flw f0, (t6)"); + asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); + asm volatile("flw f1, (t6)"); + asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); + asm volatile("flw f2, (t6)"); + asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); + asm volatile("flw f3, (t6)"); + asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); + asm volatile("flw f4, (t6)"); + asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); + asm volatile("flw f5, (t6)"); + asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); + asm volatile("flw f6, (t6)"); + asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); + asm volatile("flw f7, (t6)"); + asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); + + // do elem-wise e^x + // each register has 3 temporary registers: + // f0 has f8, f9, f10 + // f1 has f11, f12, f13 + asm volatile("fcvt.s.w f9, %0" ::"r"(1)); + asm volatile("fadd.s f8, f9, f0"); // acc = 1 + x + asm volatile("fcvt.s.w f9, %0" ::"r"(2)); + asm volatile("fdiv.s f10, f0, f9"); // x / 2 + asm volatile("fmadd.s f8, f10, f0, f8"); // acc += (x / 2) * x + asm volatile("fcvt.s.w f9, %0" ::"r"(3)); + asm volatile("fmul.s f10, f10, f0"); // (x * x) / 2 + asm volatile("fdiv.s f10, f10, f9"); // (x * x) / (2 * 3) + asm volatile("fmadd.s f0, f10, f0, f8"); // acc += (x * x) / (2 * 3) * x + + asm volatile("fcvt.s.w f12, %0" ::"r"(1)); + asm volatile("fadd.s f11, f12, f1"); + asm volatile("fcvt.s.w f12, %0" ::"r"(2)); + asm volatile("fdiv.s f13, f1, f12"); + asm volatile("fmadd.s f11, f13, f1, f11"); + asm volatile("fcvt.s.w f12, %0" ::"r"(3)); + asm volatile("fmul.s f13, f13, f1"); + asm volatile("fdiv.s f13, f13, f12"); + asm volatile("fmadd.s f1, f13, f1, f11"); + + asm volatile("fcvt.s.w f15, %0" ::"r"(1)); + asm volatile("fadd.s f14, f15, f2"); + asm volatile("fcvt.s.w f15, %0" ::"r"(2)); + asm volatile("fdiv.s f16, f2, f15"); + asm volatile("fmadd.s f14, f16, f2, f14"); + asm volatile("fcvt.s.w f15, %0" ::"r"(3)); + asm volatile("fmul.s f16, f16, f2"); + asm volatile("fdiv.s f16, f16, f15"); + asm volatile("fmadd.s f2, f16, f2, f14"); + + asm volatile("fcvt.s.w f18, %0" ::"r"(1)); + asm volatile("fadd.s f17, f18, f3"); + asm volatile("fcvt.s.w f18, %0" ::"r"(2)); + asm volatile("fdiv.s f19, f3, f18"); + asm volatile("fmadd.s f17, f19, f3, f17"); + asm volatile("fcvt.s.w f18, %0" ::"r"(3)); + asm volatile("fmul.s f19, f19, f3"); + asm volatile("fdiv.s f19, f19, f18"); + asm volatile("fmadd.s f3, f19, f3, f17"); + + asm volatile("fcvt.s.w f21, %0" ::"r"(1)); + asm volatile("fadd.s f20, f21, f4"); + asm volatile("fcvt.s.w f21, %0" ::"r"(2)); + asm volatile("fdiv.s f22, f4, f21"); + asm volatile("fmadd.s f20, f22, f4, f20"); + asm volatile("fcvt.s.w f21, %0" ::"r"(3)); + asm volatile("fmul.s f22, f22, f4"); + asm volatile("fdiv.s f22, f22, f21"); + asm volatile("fmadd.s f4, f22, f4, f20"); + + asm volatile("fcvt.s.w f24, %0" ::"r"(1)); + asm volatile("fadd.s f23, f24, f5"); + asm volatile("fcvt.s.w f24, %0" ::"r"(2)); + asm volatile("fdiv.s f25, f5, f24"); + asm volatile("fmadd.s f23, f25, f5, f23"); + asm volatile("fcvt.s.w f24, %0" ::"r"(3)); + asm volatile("fmul.s f25, f25, f5"); + asm volatile("fdiv.s f25, f25, f24"); + asm volatile("fmadd.s f5, f25, f5, f23"); + + asm volatile("fcvt.s.w f27, %0" ::"r"(1)); + asm volatile("fadd.s f26, f27, f6"); + asm volatile("fcvt.s.w f27, %0" ::"r"(2)); + asm volatile("fdiv.s f28, f6, f27"); + asm volatile("fmadd.s f26, f28, f6, f26"); + asm volatile("fcvt.s.w f27, %0" ::"r"(3)); + asm volatile("fmul.s f28, f28, f6"); + asm volatile("fdiv.s f28, f28, f27"); + asm volatile("fmadd.s f6, f28, f6, f26"); + + asm volatile("fcvt.s.w f30, %0" ::"r"(1)); + asm volatile("fadd.s f29, f30, f7"); + asm volatile("fcvt.s.w f30, %0" ::"r"(2)); + asm volatile("fdiv.s f31, f7, f30"); + asm volatile("fmadd.s f29, f31, f7, f29"); + asm volatile("fcvt.s.w f30, %0" ::"r"(3)); + asm volatile("fmul.s f31, f31, f7"); + asm volatile("fdiv.s f31, f31, f30"); + asm volatile("fmadd.s f7, f31, f7, f29"); + + // move back from RF to gmem + asm volatile("mv t6, %0" ::"r"(global_C_curr)); + asm volatile("fsw f0, (t6)"); + asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); + asm volatile("fsw f1, (t6)"); + asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); + asm volatile("fsw f2, (t6)"); + asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); + asm volatile("fsw f3, (t6)"); + asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); + asm volatile("fsw f4, (t6)"); + asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); + asm volatile("fsw f5, (t6)"); + asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); + asm volatile("fsw f6, (t6)"); + asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); + asm volatile("fsw f7, (t6)"); + asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); + asm volatile("mv %0, t6" :"=r"(global_C_curr)); + } +} + void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg, const uint32_t threadblock_id, const uint32_t tid_in_threadblock) { @@ -87,10 +235,6 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg, const uint32_t local_c_row = tid_in_threadblock / TILE_N; const uint32_t local_c_col = tid_in_threadblock % TILE_N; - const uint32_t warp_id_in_threadblock = tid_in_threadblock / NUM_THREADS; - const uint32_t warp_row = warp_id_in_threadblock / (TILE_N / WN); - const uint32_t warp_col = warp_id_in_threadblock % (TILE_N / WN); - const uint32_t num_tile_rows_per_tb = num_tiles_m / NUM_CLUSTERS; if (HW_TID() == 0) { @@ -101,10 +245,12 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg, // gemmini_extended_config_st(stride_C * sizeof_C, act & 3, scale); } - for (uint32_t tile_i = num_tile_rows_per_tb * threadblock_id; - tile_i < num_tile_rows_per_tb * (threadblock_id + 1); - tile_i += 1) { - for (int tile_j = 0; tile_j < num_tiles_n; tile_j += 1) { + uint32_t tile_i = 0; + uint32_t tile_j = 0; + for (tile_i = num_tile_rows_per_tb * threadblock_id; + tile_i < num_tile_rows_per_tb * (threadblock_id + 1); + tile_i += 1) { + for (tile_j = 0; tile_j < num_tiles_n; tile_j += 1) { for (int tile_k = 0; tile_k < num_tiles_k; tile_k += 1) { if (HW_TID() == 0) { ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, @@ -134,111 +280,20 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg, // while Gemmini is computing, software-pipeline with activation on the // previous (M,N) tile - if (true || (tid_in_threadblock >= NUM_THREADS)) { - // activation code currently assumes that the column-width of a warp - // tile exactly matches SIMD width - static_assert(WN == NUM_THREADS); + if ((tid_in_threadblock >= NUM_THREADS /*excludes warp 0*/)) { + const uint32_t warp_id_in_threadblock = tid_in_threadblock / NUM_THREADS; + const uint32_t warp_row = warp_id_in_threadblock / (TILE_N / WN); + const uint32_t warp_col = warp_id_in_threadblock % (TILE_N / WN); + activate_block(dim_n, C, tile_i, tile_j, warp_row, warp_col, + tid_in_threadblock); - float elem[ELEM_PER_THREAD]; - const uint32_t col_in_warptile = tid_in_threadblock % WN; - // const uint32_t row_in_warptile = elem_i; // FIXME: doesn't work with WN != NUM_THREADS - const uint32_t row_in_warptile = 0; - const uint32_t C_row = (tile_i * TILE_M) + (warp_row * WM) + row_in_warptile; - const uint32_t C_col = (tile_j * TILE_N) + (warp_col * WN) + col_in_warptile; - const float *const global_C = C + dim_n * C_row + C_col; - - // read in elements from GMEM to RF - asm volatile("mv t6, %0" :: "r"(global_C)); - asm volatile("flw f0, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("flw f1, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("flw f2, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("flw f3, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("flw f4, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("flw f5, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("flw f6, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("flw f7, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("flw f8, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("flw f9, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("flw f10, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("flw f11, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("flw f12, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("flw f13, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("flw f14, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("flw f15, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - - asm volatile("fcvt.s.w f16, %0, rtz" :: "r"(2)); - - // do elem-wise compute in RF -#pragma GCC unroll 4 - for (uint32_t count = 0; count < 128; count++) { - asm volatile("fmul.s f0, f0, f16"); - asm volatile("fmul.s f1, f1, f16"); - asm volatile("fmul.s f2, f2, f16"); - asm volatile("fmul.s f3, f3, f16"); - asm volatile("fmul.s f4, f4, f16"); - asm volatile("fmul.s f5, f5, f16"); - asm volatile("fmul.s f6, f6, f16"); - asm volatile("fmul.s f7, f7, f16"); - asm volatile("fmul.s f8, f8, f16"); - asm volatile("fmul.s f9, f9, f16"); - asm volatile("fmul.s f10, f10, f16"); - asm volatile("fmul.s f11, f11, f16"); - asm volatile("fmul.s f12, f12, f16"); - asm volatile("fmul.s f13, f13, f16"); - asm volatile("fmul.s f14, f14, f16"); - asm volatile("fmul.s f15, f15, f16"); + // for warp 1, do warp 0's worth of work as well + if (vx_warp_id() == 1) { + const uint32_t warp_row = (warp_id_in_threadblock - 1) / (TILE_N / WN); + const uint32_t warp_col = (warp_id_in_threadblock - 1) % (TILE_N / WN); + activate_block(dim_n, C, tile_i, tile_j, warp_row, warp_col, + tid_in_threadblock); } - - // move back from RF to gmem - asm volatile("mv t6, %0" :: "r"(global_C)); - asm volatile("fsw f0, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("fsw f1, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("fsw f2, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("fsw f3, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("fsw f4, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("fsw f5, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("fsw f6, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("fsw f7, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("fsw f8, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("fsw f9, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("fsw f10, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("fsw f11, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("fsw f12, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("fsw f13, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("fsw f14, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("fsw f15, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); } if (HW_TID() == 0) { @@ -268,6 +323,25 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg, } } } + + // last (M,N) block activation + if ((tid_in_threadblock >= NUM_THREADS /*excludes warp 0*/)) { + const uint32_t warp_id_in_threadblock = tid_in_threadblock / NUM_THREADS; + const uint32_t warp_row = warp_id_in_threadblock / (TILE_N / WN); + const uint32_t warp_col = warp_id_in_threadblock % (TILE_N / WN); + activate_block(dim_n, C, tile_i, tile_j, warp_row, warp_col, + tid_in_threadblock); + + // for warp 1, do warp 0's worth of work as well + if (vx_warp_id() == 1) { + const uint32_t warp_row = (warp_id_in_threadblock - 1) / (TILE_N / WN); + const uint32_t warp_col = (warp_id_in_threadblock - 1) % (TILE_N / WN); + activate_block(dim_n, C, tile_i, tile_j, warp_row, warp_col, + tid_in_threadblock); + } + } + + // last thread block complete if (threadblock_id == NUM_CLUSTERS - 1) { threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS); diff --git a/tests/regression/sgemm_tcore/kernel.activation.cpp b/tests/regression/sgemm_tcore/kernel.activation.cpp index e4c39ca2..db089ee0 100644 --- a/tests/regression/sgemm_tcore/kernel.activation.cpp +++ b/tests/regression/sgemm_tcore/kernel.activation.cpp @@ -33,6 +33,8 @@ #endif #define WARP_SPECIALIZED 1 +#define SWISH(beta, x) ((x) / (1 + exp(-(beta) * (x)))) + static_assert( !WARP_SPECIALIZED || GEMMINI_DMA, "warp specialization is currently only supported with GEMMINI_DMA"); @@ -251,6 +253,192 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, } } +inline void activate_block(const uint32_t dim_n, const float *const C, + const uint32_t tile_i, const uint32_t tile_j, + const uint32_t warp_row, const uint32_t warp_col, + const uint32_t tid_in_threadblock) { + // activation code currently assumes that the column-width of a warp + // tile exactly matches SIMD width + static_assert(WN == NUM_THREADS); + + const uint32_t col_in_warptile = tid_in_threadblock % WN; + // const uint32_t row_in_warptile = elem_i; // FIXME: doesn't work with WN != + // NUM_THREADS + const uint32_t row_in_warptile = 0; + const uint32_t C_row = (tile_i * BM) + (warp_row * WM) + row_in_warptile; + const uint32_t C_col = (tile_j * BN) + (warp_col * WN) + col_in_warptile; + const float *const global_C = C + dim_n * C_row + C_col; + const float *global_C_curr = global_C; + + // ELEM_PER_THREAD macro does not take into account warp-specialization + constexpr uint32_t elem_per_thread = + ELEM_PER_THREAD * (WARP_SPECIALIZED ? 2 : 1); + constexpr uint32_t asm_unrolled = 8; // working with f0~f7 at a time + // each thread works on ELEM_PER_THREAD elements, which can be larger than 1 + static_assert((elem_per_thread % asm_unrolled) == 0, + "unmet manual unroll condition for elem_per_thread"); + + for (int i = 0; i < elem_per_thread; i += asm_unrolled) { + // read in elements from GMEM to RF + asm volatile("mv t6, %0" ::"r"(global_C_curr)); + asm volatile("flw f0, (t6)"); + asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); + asm volatile("flw f1, (t6)"); + asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); + asm volatile("flw f2, (t6)"); + asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); + asm volatile("flw f3, (t6)"); + asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); + asm volatile("flw f4, (t6)"); + asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); + asm volatile("flw f5, (t6)"); + asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); + asm volatile("flw f6, (t6)"); + asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); + asm volatile("flw f7, (t6)"); + asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); + + if constexpr (true) { + register float x0 asm("f0"); + register float x1 asm("f1"); + register float x2 asm("f2"); + register float x3 asm("f3"); + register float x4 asm("f4"); + register float x5 asm("f5"); + register float x6 asm("f6"); + register float x7 asm("f7"); + asm volatile("fmv.s %0, f0" :"=f"(x0)); + x0 = SWISH(1.0f, x0); + asm volatile("fmv.s f0, %0" ::"f"(x0)); + asm volatile("fmv.s %0, f1" :"=f"(x1)); + x1 = SWISH(1.0f, x1); + asm volatile("fmv.s f1, %0" ::"f"(x1)); + asm volatile("fmv.s %0, f1" :"=f"(x2)); + x2 = SWISH(1.0f, x2); + asm volatile("fmv.s f1, %0" ::"f"(x2)); + asm volatile("fmv.s %0, f1" :"=f"(x3)); + x3 = SWISH(1.0f, x3); + asm volatile("fmv.s f1, %0" ::"f"(x3)); + asm volatile("fmv.s %0, f1" :"=f"(x4)); + x4 = SWISH(1.0f, x4); + asm volatile("fmv.s f1, %0" ::"f"(x4)); + asm volatile("fmv.s %0, f1" :"=f"(x5)); + x5 = SWISH(1.0f, x5); + asm volatile("fmv.s f1, %0" ::"f"(x5)); + asm volatile("fmv.s %0, f1" :"=f"(x6)); + x6 = SWISH(1.0f, x6); + asm volatile("fmv.s f1, %0" ::"f"(x6)); + asm volatile("fmv.s %0, f1" :"=f"(x7)); + x7 = SWISH(1.0f, x7); + asm volatile("fmv.s f1, %0" ::"f"(x7)); + } else { + // do elem-wise e^x + // each register has 3 temporary registers: + // f0 has f8, f9, f10 + // f1 has f11, f12, f13 + asm volatile("fcvt.s.w f9, %0" ::"r"(1)); + asm volatile("fadd.s f8, f9, f0"); // acc = 1 + x + asm volatile("fcvt.s.w f9, %0" ::"r"(2)); + asm volatile("fdiv.s f10, f0, f9"); // x / 2 + asm volatile("fmadd.s f8, f10, f0, f8"); // acc += (x / 2) * x + asm volatile("fcvt.s.w f9, %0" ::"r"(3)); + asm volatile("fmul.s f10, f10, f0"); // (x * x) / 2 + asm volatile("fdiv.s f10, f10, f9"); // (x * x) / (2 * 3) + asm volatile("fmadd.s f0, f10, f0, f8"); // acc += (x * x) / (2 * 3) * x + + asm volatile("fcvt.s.w f12, %0" ::"r"(1)); + asm volatile("fadd.s f11, f12, f1"); + asm volatile("fcvt.s.w f12, %0" ::"r"(2)); + asm volatile("fdiv.s f13, f1, f12"); + asm volatile("fmadd.s f11, f13, f1, f11"); + asm volatile("fcvt.s.w f12, %0" ::"r"(3)); + asm volatile("fmul.s f13, f13, f1"); + asm volatile("fdiv.s f13, f13, f12"); + asm volatile("fmadd.s f1, f13, f1, f11"); + + asm volatile("fcvt.s.w f15, %0" ::"r"(1)); + asm volatile("fadd.s f14, f15, f2"); + asm volatile("fcvt.s.w f15, %0" ::"r"(2)); + asm volatile("fdiv.s f16, f2, f15"); + asm volatile("fmadd.s f14, f16, f2, f14"); + asm volatile("fcvt.s.w f15, %0" ::"r"(3)); + asm volatile("fmul.s f16, f16, f2"); + asm volatile("fdiv.s f16, f16, f15"); + asm volatile("fmadd.s f2, f16, f2, f14"); + + asm volatile("fcvt.s.w f18, %0" ::"r"(1)); + asm volatile("fadd.s f17, f18, f3"); + asm volatile("fcvt.s.w f18, %0" ::"r"(2)); + asm volatile("fdiv.s f19, f3, f18"); + asm volatile("fmadd.s f17, f19, f3, f17"); + asm volatile("fcvt.s.w f18, %0" ::"r"(3)); + asm volatile("fmul.s f19, f19, f3"); + asm volatile("fdiv.s f19, f19, f18"); + asm volatile("fmadd.s f3, f19, f3, f17"); + + asm volatile("fcvt.s.w f21, %0" ::"r"(1)); + asm volatile("fadd.s f20, f21, f4"); + asm volatile("fcvt.s.w f21, %0" ::"r"(2)); + asm volatile("fdiv.s f22, f4, f21"); + asm volatile("fmadd.s f20, f22, f4, f20"); + asm volatile("fcvt.s.w f21, %0" ::"r"(3)); + asm volatile("fmul.s f22, f22, f4"); + asm volatile("fdiv.s f22, f22, f21"); + asm volatile("fmadd.s f4, f22, f4, f20"); + + asm volatile("fcvt.s.w f24, %0" ::"r"(1)); + asm volatile("fadd.s f23, f24, f5"); + asm volatile("fcvt.s.w f24, %0" ::"r"(2)); + asm volatile("fdiv.s f25, f5, f24"); + asm volatile("fmadd.s f23, f25, f5, f23"); + asm volatile("fcvt.s.w f24, %0" ::"r"(3)); + asm volatile("fmul.s f25, f25, f5"); + asm volatile("fdiv.s f25, f25, f24"); + asm volatile("fmadd.s f5, f25, f5, f23"); + + asm volatile("fcvt.s.w f27, %0" ::"r"(1)); + asm volatile("fadd.s f26, f27, f6"); + asm volatile("fcvt.s.w f27, %0" ::"r"(2)); + asm volatile("fdiv.s f28, f6, f27"); + asm volatile("fmadd.s f26, f28, f6, f26"); + asm volatile("fcvt.s.w f27, %0" ::"r"(3)); + asm volatile("fmul.s f28, f28, f6"); + asm volatile("fdiv.s f28, f28, f27"); + asm volatile("fmadd.s f6, f28, f6, f26"); + + asm volatile("fcvt.s.w f30, %0" ::"r"(1)); + asm volatile("fadd.s f29, f30, f7"); + asm volatile("fcvt.s.w f30, %0" ::"r"(2)); + asm volatile("fdiv.s f31, f7, f30"); + asm volatile("fmadd.s f29, f31, f7, f29"); + asm volatile("fcvt.s.w f30, %0" ::"r"(3)); + asm volatile("fmul.s f31, f31, f7"); + asm volatile("fdiv.s f31, f31, f30"); + asm volatile("fmadd.s f7, f31, f7, f29"); + } + + // move back from RF to gmem + asm volatile("mv t6, %0" ::"r"(global_C_curr)); + asm volatile("fsw f0, (t6)"); + asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); + asm volatile("fsw f1, (t6)"); + asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); + asm volatile("fsw f2, (t6)"); + asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); + asm volatile("fsw f3, (t6)"); + asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); + asm volatile("fsw f4, (t6)"); + asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); + asm volatile("fsw f5, (t6)"); + asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); + asm volatile("fsw f6, (t6)"); + asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); + asm volatile("fsw f7, (t6)"); + asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); + asm volatile("mv %0, t6" :"=r"(global_C_curr)); + } +} + inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock, @@ -515,110 +703,12 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // warp specialization: activation on the previous (M,N) tile in the 2nd // warpgroup if (WARP_SPECIALIZED && warpgroup_id == 1) { - // activation code currently assumes that the column-width of a warp - // tile exactly matches SIMD width - static_assert(WN == NUM_THREADS); - - float elem[ELEM_PER_THREAD]; - const uint32_t col_in_warptile = tid_in_warpgroup % WN; - // const uint32_t row_in_warptile = elem_i; // FIXME: doesn't work with WN != NUM_THREADS - const uint32_t row_in_warptile = 0; - const uint32_t C_row = (block_m * BM) + (warp_row * WM) + row_in_warptile; - const uint32_t C_col = (block_n * BN) + (warp_col * WN) + col_in_warptile; - const float *const global_C = C + dim_n * C_row + C_col; - - // read in elements from GMEM to RF - asm volatile("mv t6, %0" :: "r"(global_C)); - asm volatile("flw f0, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("flw f1, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("flw f2, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("flw f3, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("flw f4, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("flw f5, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("flw f6, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("flw f7, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("flw f8, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("flw f9, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("flw f10, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("flw f11, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("flw f12, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("flw f13, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("flw f14, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("flw f15, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - - asm volatile("fcvt.s.w f16, %0, rtz" :: "r"(2)); - - // do elem-wise compute in RF -#pragma GCC unroll 4 - for (uint32_t count = 0; count < 32; count++) { - asm volatile("fmul.s f0, f0, f16"); - asm volatile("fmul.s f1, f1, f16"); - asm volatile("fmul.s f2, f2, f16"); - asm volatile("fmul.s f3, f3, f16"); - asm volatile("fmul.s f4, f4, f16"); - asm volatile("fmul.s f5, f5, f16"); - asm volatile("fmul.s f6, f6, f16"); - asm volatile("fmul.s f7, f7, f16"); - asm volatile("fmul.s f8, f8, f16"); - asm volatile("fmul.s f9, f9, f16"); - asm volatile("fmul.s f10, f10, f16"); - asm volatile("fmul.s f11, f11, f16"); - asm volatile("fmul.s f12, f12, f16"); - asm volatile("fmul.s f13, f13, f16"); - asm volatile("fmul.s f14, f14, f16"); - asm volatile("fmul.s f15, f15, f16"); - } - - // move back from RF to gmem - asm volatile("mv t6, %0" :: "r"(global_C)); - asm volatile("fsw f0, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("fsw f1, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("fsw f2, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("fsw f3, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("fsw f4, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("fsw f5, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("fsw f6, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("fsw f7, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("fsw f8, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("fsw f9, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("fsw f10, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("fsw f11, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("fsw f12, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("fsw f13, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("fsw f14, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); - asm volatile("fsw f15, (t6)"); - asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); + const uint32_t warp_id_in_threadblock = + tid_in_threadblock / NUM_THREADS; + const uint32_t warp_row = warp_id_in_threadblock / (BN / WN); + const uint32_t warp_col = warp_id_in_threadblock % (BN / WN); + activate_block(dim_n, C, block_m, block_n, warp_row, warp_col, + tid_in_threadblock); } // global barrier that synchronizes both warpgroups at every M-N