From b892c22f003a6e2e3233ebea41804a60fcd8cf14 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Thu, 16 May 2024 23:31:52 -0700 Subject: [PATCH] sgemm_tcore: Reflect WMITER/WNITER in threadblock size --- tests/regression/sgemm_tcore/kernel.cpp | 37 ++++++++++++++++++------- tests/regression/sgemm_tcore/main.cpp | 6 ++-- 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 2e95f047..ad741156 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -20,9 +20,9 @@ // (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER // * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields // BM <= BK*TM*TN -#define BM 32 -#define BN 32 -#define BK 32 +#define BM 16 +#define BN 16 +#define BK 8 #define TCM 8 #define TCN 8 #define TCK 8 @@ -33,12 +33,13 @@ #define TM 1 #define TN ((TCM * TCN) / NUM_LANES / TM) // #define TN 1 +#define ELEM_PER_THREAD (WMITER * WNITER * TM * TN) #define USE_TENSOR_CORE 1 #define TC_SINGLE_WARP 0 // number of loop around the inner 0..TCK..BK loop to simulate perfect-DRAM // scenario -#define BK_LOOP 8 +#define BK_LOOP 16 #define TRANSPOSE_AS 1 inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) { @@ -281,6 +282,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // clear out C initialize_C(); +#pragma GCC unroll 1 for (uint32_t k = 0; k < dim_k; k += BK) { // Data move from GMEM to SMEM // @@ -291,7 +293,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, if constexpr (!TRANSPOSE_AS) { const uint32_t global_a_row = BM * threadblock_id_y + local_a_row; // number of rows a full TB can read at a time - constexpr uint32_t row_stride_a = (BM * BN) / BK / (TM * TN); + constexpr uint32_t row_stride_a = (BM * BN) / ELEM_PER_THREAD / BK; #pragma GCC unroll 1 for (uint32_t load_offset = 0; load_offset < BM; load_offset += row_stride_a) { const uint32_t global_a_offset = @@ -303,7 +305,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, } } else { const uint32_t global_a_row = BM * threadblock_id_y + local_as_col; - constexpr uint32_t row_stride_a = (BM * BN) / BM / (TM * TN); + constexpr uint32_t row_stride_a = (BM * BN) / ELEM_PER_THREAD / BM; #pragma GCC unroll 1 for (uint32_t load_offset = 0; load_offset < BK; load_offset += row_stride_a) { const uint32_t global_a_offset = @@ -313,7 +315,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, } } - constexpr uint32_t row_stride_b = (BM * BN) / BN / (TM * TN); + constexpr uint32_t row_stride_b = (BM * BN) / ELEM_PER_THREAD / BN; const uint32_t global_b_col = BN * threadblock_id_x + local_b_col; #pragma GCC unroll 1 for (uint32_t load_offset = 0; load_offset < BK; load_offset += row_stride_b) { @@ -329,8 +331,8 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, #if USE_TENSOR_CORE // #pragma GCC unroll 1 for (int i = 0; i < BK_LOOP; i++) { -#pragma GCC unroll 1 // @perf: this loop spills to stack a lot because of all the flws in vx_wmma_load +#pragma GCC unroll 1 for (uint32_t local_k = 0; local_k < BK; local_k += TCK) { // perform wmma // vx_wmma_load(local_a, local_b, warp_x, warp_y, tid_in_warp); @@ -338,11 +340,24 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // does one stall the other? // FIXME: this is wrong!! need separate accumulation register for // WM/WN_ITERS +#pragma GCC unroll 1 for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { +#pragma GCC unroll 1 for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { #if TC_SINGLE_WARP if (warp_in_threadblock == 0) { #endif + // if ((threadblock_id_in_cluster % 2) == 0) { + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // } // SMEM -> RF vx_wmma_load(local_a, local_b, local_k, warp_col, warp_row, wn_iter, wm_iter, tid_in_warp); @@ -394,7 +409,9 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, } #if USE_TENSOR_CORE +#pragma GCC unroll 1 for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { +#pragma GCC unroll 1 for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { #if TC_SINGLE_WARP if (warp_in_threadblock == 0) { @@ -428,7 +445,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // @perf: All threads are running these compute whose result is mostly same // across the threadblock - const uint32_t threads_per_threadblock = (BM * BN) / (TM * TN); + const uint32_t threads_per_threadblock = (BM * BN) / (ELEM_PER_THREAD); #ifdef RADIANCE const uint32_t threadblocks_per_core = CORES_PER_CLUSTER * vx_num_threads() * vx_num_warps() / threads_per_threadblock; @@ -460,7 +477,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { int main() { kernel_arg_t *arg = (kernel_arg_t *)KERNEL_ARG_DEV_MEM_ADDR; - const uint32_t grid_size = arg->dim_m * arg->dim_n / (TM * TN); + const uint32_t grid_size = arg->dim_m * arg->dim_n / ELEM_PER_THREAD; #ifdef RADIANCE vx_spawn_tasks_cluster(grid_size, (vx_spawn_tasks_cb)kernel_body, arg); #else diff --git a/tests/regression/sgemm_tcore/main.cpp b/tests/regression/sgemm_tcore/main.cpp index 5294f27b..e34b7066 100644 --- a/tests/regression/sgemm_tcore/main.cpp +++ b/tests/regression/sgemm_tcore/main.cpp @@ -147,9 +147,9 @@ int main(int argc, char *argv[]) { RT_CHECK(vx_dev_open(&device)); // FIXME: hardcoded - uint32_t dim_m = 32; - uint32_t dim_n = 32; - uint32_t dim_k = 32; + uint32_t dim_m = 64; + uint32_t dim_n = 64; + uint32_t dim_k = 64; generate_source_matrix(dim_m, dim_n, dim_k); generate_reference_matmul(dim_m, dim_n, dim_k);