From dab9d7c6fcd649770953628193e61c087eb33bf8 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Tue, 11 Jun 2024 14:09:31 -0700 Subject: [PATCH] sgemm_tcore: Fix kernel launch for smaller TBs than cluster threads E.g. bm32bn32bk32wm16wn8 --- tests/regression/sgemm_tcore/kernel.cpp | 137 ++++++++++++++++-------- 1 file changed, 92 insertions(+), 45 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 5f95eed4..db9114d1 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -10,28 +10,6 @@ #define NUM_LANES 8 -#if SMEM_SIZE != 0x4000 -#error Currently only supports 16K spad -#endif -#define SMEM_ADDR_Q0 ((float * const) 0xff000000) -#define SMEM_ADDR_Q1 ((float * const) 0xff001000) -#define SMEM_ADDR_Q2 ((float * const) 0xff002000) -#define SMEM_ADDR_Q3 ((float * const) 0xff003000) -#define SPAD_ADDR_Q0 0x0 -#define SPAD_ADDR_Q1 0x80 -#define SPAD_ADDR_Q2 0x100 -#define SPAD_ADDR_Q3 0x180 - -// number of loop around the inner 0..TCK..BK loop to simulate perfect-DRAM -// scenario -#define BK_LOOP 1 -#define TRANSPOSE_AS 1 -// GMEM_COALESCED sets bank conflict-free accesses for -// 1: GMEM loads of A matrix -// 0: SMEM stores of A matrix -#define GMEM_COALESCED_A 1 -#define GEMMINI_DMA 1 - // Constraints on parameters: // * Memory: // (BM + BN) * BK * sizeof(float) <= sharedmem size. @@ -56,6 +34,27 @@ #define WNITER (WN / TCN) #define ELEM_PER_THREAD (WMITER * WNITER * (TCM * TCN) / NUM_LANES) +// number of loop around the inner 0..TCK..BK loop to simulate perfect-DRAM +// scenario +#define BK_LOOP 1 +#define TRANSPOSE_AS 1 +// GMEM_COALESCED sets bank conflict-free accesses for +// 1: GMEM loads of A matrix +// 0: SMEM stores of A matrix +#define GMEM_COALESCED_A 1 +#define GEMMINI_DMA 0 +#if SMEM_SIZE != 0x4000 +#error Currently only supports 16K spad +#endif +#define SMEM_ADDR_Q0 ((float * const) 0xff000000) +#define SMEM_ADDR_Q1 ((float * const) 0xff001000) +#define SMEM_ADDR_Q2 ((float * const) 0xff002000) +#define SMEM_ADDR_Q3 ((float * const) 0xff003000) +#define SPAD_ADDR_Q0 0x0 +#define SPAD_ADDR_Q1 0x80 +#define SPAD_ADDR_Q2 0x100 +#define SPAD_ADDR_Q3 0x180 + // FIXME: NUM_THREADS and NUM_WARPS hardcoded #if ((BM * BN / ELEM_PER_THREAD) > (CORES_PER_CLUSTER * 8 * 8)) #error "threadblock size too big for cluster" @@ -612,12 +611,45 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, initialize_C(0); initialize_C(1); - // NOTE: this *should* be signed integer to trigger arithmetic - // right-shift - int32_t k_index = 0; + if constexpr (GEMMINI_DMA) { + // pipeline initiation + if (tid_in_threadblock == 0) { + // configure dma gmem address to load from + // FIXME: block_k is wrong + ROCC_INSTRUCTION_RS1_RS2( + XCUSTOM_ACC, + (uint64_t)(A + block_m * BM * dim_k + /*block_k:*/0 * BK), + (uint64_t)(B + /*block_k:*/0 * BK * dim_n + block_n * BN), + k_LOOP_WS_CONFIG_ADDRS_AB) + // GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB + GEMMINI_CISC_CMD_R((dim_n << 16) | (dim_k << 8) | 8); + gemmini_fence(); + + // GEMMINI_CISC_CMD_I(12); + // gemmini_fence(); + + // sp_tiled_matmul_full_spad_ws includes CONFIG_BOUNDS + // FIXME: block_k is 0 for two times + sp_tiled_matmul_full_spad_ws( +#if 1 + SPAD_ADDR_Q0, SPAD_ADDR_Q1, +#else + (/*block_k:*/ 0 & 1) ? SPAD_ADDR_Q2 : SPAD_ADDR_Q0, + (/*block_k:*/ 0 & 1) ? SPAD_ADDR_Q3 : SPAD_ADDR_Q1, +#endif + /*spad_D=*/0, /*spad_C=*/SPAD_ADDR_Q3, + /*I=*/BM / DIM, /*J=*/BN / DIM, /*K=*/BK / DIM, /*pad_I=*/0, + /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips) + gemmini_fence(); + } + + threadblock_barrier(0 /*threadblock_id_in_cluster*/, threadblock_dim_y); + } + #pragma GCC unroll 1 for (uint32_t block_k = 0; (block_k * BK) < (dim_k); block_k++) { - k_index++; // producer code: GMEM->SMEM memory movement // --------------------------------------------------------------------- @@ -635,9 +667,14 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, k_LOOP_WS_CONFIG_ADDRS_AB) // GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB GEMMINI_CISC_CMD_R((dim_n << 16) | (dim_k << 8) | 8); - // gemmini_fence(); - GEMMINI_CISC_CMD_I(13); + + // TODO: this is probably slow + // if (block_k & 1) { + // GEMMINI_CISC_CMD_I(12); + // } else { // block_k == 0 is here + // GEMMINI_CISC_CMD_I(13); + // } // configure loop iteration bounds // FIXME: shouldn't be necessary @@ -730,22 +767,27 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // @perf: All threads are running these compute whose result is mostly same // across the threadblock - // const uint32_t threads_per_threadblock = (BM * BN) / (ELEM_PER_THREAD); #ifdef RADIANCE - const uint32_t threads_per_threadblock = - CORES_PER_CLUSTER * vx_num_threads() * vx_num_warps(); - const uint32_t threadblocks_per_core = CORES_PER_CLUSTER * vx_num_threads() * - vx_num_warps() / - threads_per_threadblock; + constexpr uint32_t cores_per_cluster = CORES_PER_CLUSTER; #else - const uint32_t threads_per_threadblock = vx_num_threads() * vx_num_warps(); - const uint32_t threadblocks_per_core = - vx_num_threads() * vx_num_warps() / threads_per_threadblock; + constexpr uint32_t cores_per_cluster = 1; #endif - const uint32_t threadblock_dim_x = vx_num_threads(); - const uint32_t threadblock_dim_y = vx_num_warps() / threadblocks_per_core; + + uint32_t threads_per_threadblock = (BM * BN) / (ELEM_PER_THREAD); + const uint32_t hw_threads_per_cluster = + cores_per_cluster * vx_num_threads() * vx_num_warps(); + // cap maximum threadblock size to # of HW threads in cluster, to prevent + // multiple "wave" invocations which slows down the kernel + if (threads_per_threadblock > hw_threads_per_cluster) { + threads_per_threadblock = hw_threads_per_cluster; + } + const uint32_t threadblocks_per_cluster = + hw_threads_per_cluster / threads_per_threadblock; + + const uint32_t threadblock_dim_y = vx_num_warps() / threadblocks_per_cluster; const int threadblock_id = task_id / threads_per_threadblock; - const int threadblock_id_in_cluster = threadblock_id % threadblocks_per_core; + const int threadblock_id_in_cluster = + threadblock_id % threadblocks_per_cluster; const int tid_in_threadblock = task_id % threads_per_threadblock; const uint32_t dim_m = arg->dim_m; @@ -761,18 +803,23 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const int warp_id = vx_warp_id(); thread_block_gemm(arg, tid_in_threadblock, threads_per_threadblock, - threadblock_dim_y, /*threadblock_id_x, - threadblock_id_y,*/ /*threadblock_id_in_cluster, */ + threadblock_dim_y, + /*threadblock_id_x, + threadblock_id_y,*/ /*threadblock_id_in_cluster, */ sharedmem_per_threadblock); } int main() { kernel_arg_t *arg = (kernel_arg_t *)KERNEL_ARG_DEV_MEM_ADDR; - const uint32_t threads_per_cluster = + const uint32_t problem_size = (arg->dim_m * arg->dim_n) / (ELEM_PER_THREAD); + const uint32_t hw_threads_per_cluster = CORES_PER_CLUSTER * vx_num_threads() * vx_num_warps(); - // const uint32_t grid_size = arg->dim_m * arg->dim_n / ELEM_PER_THREAD; - const uint32_t grid_size = threads_per_cluster; + // prevent launching more threads than the necessary problem size + // TODO: this does not take into account multiple clusters + const uint32_t grid_size = (problem_size > hw_threads_per_cluster) + ? hw_threads_per_cluster + : problem_size; #ifdef RADIANCE vx_spawn_tasks_cluster(grid_size, (vx_spawn_tasks_cb)kernel_body, arg);