From 062403066ef6054679ca96154523fc9a8366dbcd Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Thu, 6 Jun 2024 15:22:01 -0700 Subject: [PATCH] sgemm_tcore: Bring M/N-loop inside the kernel Instead of spawning multiple threadblocks which comes with stack access overhead, have 1 threadblock work on the entire M/N-space thru a loop. Grid size is fixed to the hardware parallelism. TODO currently only works with 1 cluster in the system. --- tests/regression/sgemm_tcore/kernel.cpp | 346 ++++++++++++------------ 1 file changed, 175 insertions(+), 171 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 4838e9d8..11187644 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -9,7 +9,6 @@ #define NUM_LANES 8 #define USE_TENSOR_CORE 1 -#define TC_SINGLE_WARP 0 // number of loop around the inner 0..TCK..BK loop to simulate perfect-DRAM // scenario #define BK_LOOP 1 @@ -267,7 +266,7 @@ inline void initialize_C(const int dest_reg) { inline void write_results(const int thread_in_warp, const int warp_col, const int warp_row, const int wn_iter, - const int wm_iter, const int dim_m, const int dim_n, + const int wm_iter, const int dim_n, float *C, const int threadblock_id_x, const int threadblock_id_y) { int tid = thread_in_warp; @@ -333,12 +332,12 @@ inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count) // vx_barrier(0, count); } -inline void -global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, - const float *A, const float *B, volatile float *local_a, - volatile float *local_b, const uint32_t tid_in_threadblock, - const uint32_t threadblock_id_x, - const uint32_t threadblock_id_y) { +inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, + const uint32_t k, const float *A, const float *B, + volatile float *local_a, volatile float *local_b, + const uint32_t tid_in_threadblock, + const uint32_t threadblock_id_x, + const uint32_t threadblock_id_y) { const uint32_t local_a_row = tid_in_threadblock / BK; const uint32_t local_a_col = tid_in_threadblock % BK; const uint32_t local_as_row = tid_in_threadblock / BM; @@ -546,8 +545,8 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, const uint32_t threads_per_threadblock, const uint32_t threadblock_dim_x, const uint32_t threadblock_dim_y, - const uint32_t threadblock_id_x, - const uint32_t threadblock_id_y, + /*const uint32_t threadblock_id_x, + const uint32_t threadblock_id_y,*/ const uint32_t threadblock_id_in_cluster, float *sharedmem_per_threadblock) { const float *A = (const float *)arg->addr_a; @@ -593,198 +592,198 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, volatile float *local_a_buf = local_b + local_b_elems; volatile float *local_b_buf = local_a_buf + local_a_elems; - // clear out C - initialize_C(0); - initialize_C(1); - - if constexpr (DOUBLE_BUFFER) { - // initiate software pipeline - if (warpgroup_id == 0) { - global_dmem_load(dim_n, dim_k, 0 /*k*/, A, B, local_a, local_b, - tid_in_warpgroup, threadblock_id_x, threadblock_id_y); - } - - threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); - } - if (warpgroup_id == 0) { - // TODO: bring initiation pipeline here - // NOTE: this *should* be signed integer to trigger arithmetic right-shift - int32_t k_index = 0; #pragma GCC unroll 1 - for (uint32_t k = 0; k < dim_k - BK; k += BK) { - volatile float *local_a_produce; - volatile float *local_b_produce; - if constexpr (DOUBLE_BUFFER) { - const uint32_t mask_odd = (k_index & 1) << 31 >> 31; - const uint32_t mask_even = ((k_index & 1) ^ 1) << 31 >> 31; - // local_a_produce = (k_index % 2) ? local_a : local_a_buf; - // local_b_produce = (k_index % 2) ? local_b : local_b_buf; - local_a_produce = reinterpret_cast( - (mask_odd & reinterpret_cast(local_a)) | - (mask_even & reinterpret_cast(local_a_buf))); - local_b_produce = reinterpret_cast( - (mask_odd & reinterpret_cast(local_b)) | - (mask_even & reinterpret_cast(local_b_buf))); - } else { - local_a_produce = local_a; - local_b_produce = local_b; + for (uint32_t block_m = 0; (block_m * BM) < dim_m; block_m++) { +#pragma GCC unroll 1 + for (uint32_t block_n = 0; (block_n * BN) < dim_n; block_n++) { + if constexpr (DOUBLE_BUFFER) { + // initiate software pipeline + global_dmem_load(dim_n, dim_k, 0 /*k*/, A, B, local_a, local_b, + tid_in_warpgroup, block_n, block_m); + + threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); + } + + // NOTE: this *should* be signed integer to trigger arithmetic + // right-shift + int32_t k_index = 0; +#pragma GCC unroll 1 + for (uint32_t k = 0; k < (dim_k) - BK; k += BK) { + volatile float *local_a_produce; + volatile float *local_b_produce; + if constexpr (DOUBLE_BUFFER) { + const uint32_t mask_odd = (k_index & 1) << 31 >> 31; + const uint32_t mask_even = ((k_index & 1) ^ 1) << 31 >> 31; + // local_a_produce = (k_index % 2) ? local_a : local_a_buf; + // local_b_produce = (k_index % 2) ? local_b : local_b_buf; + local_a_produce = reinterpret_cast( + (mask_odd & reinterpret_cast(local_a)) | + (mask_even & reinterpret_cast(local_a_buf))); + local_b_produce = reinterpret_cast( + (mask_odd & reinterpret_cast(local_b)) | + (mask_even & reinterpret_cast(local_b_buf))); + } else { + local_a_produce = local_a; + local_b_produce = local_b; + } + k_index++; + + global_dmem_load(dim_n, dim_k, k + BK /*runahead*/, A, B, + local_a_produce, local_b_produce, tid_in_warpgroup, + block_n, block_m); + + threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); + } + + // sync with final consumer stage in the k-loop + threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); } - k_index++; - - global_dmem_load(dim_n, dim_k, k + BK /*runahead*/, A, B, local_a_produce, - local_b_produce, tid_in_warpgroup, threadblock_id_x, - threadblock_id_y); - - threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); } - - threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); } else { - // NOTE: this *should* be signed integer to trigger arithmetic right-shift - int32_t k_index = 0; #pragma GCC unroll 1 - for (uint32_t k = 0; k < dim_k; k += BK) { - volatile float *local_a_consume; - volatile float *local_b_consume; - if constexpr (DOUBLE_BUFFER) { - // local_a_consume = (k_index % 2) ? local_a_buf : local_a; - // local_b_consume = (k_index % 2) ? local_b_buf : local_b; - // FIXME: swap multiply with bitshifts - const uint32_t mask_odd = (k_index & 1) << 31 >> 31; - const uint32_t mask_even = ((k_index & 1) ^ 1) << 31 >> 31; - local_a_consume = reinterpret_cast( - (mask_odd & reinterpret_cast(local_a_buf)) | - (mask_even & reinterpret_cast(local_a))); - local_b_consume = reinterpret_cast( - (mask_odd & reinterpret_cast(local_b_buf)) | - (mask_even & reinterpret_cast(local_b))); - } else { - local_a_consume = local_a; - local_b_consume = local_b; - } - k_index++; + for (uint32_t block_m = 0; (block_m * BM) < dim_m; block_m++) { +#pragma GCC unroll 1 + for (uint32_t block_n = 0; (block_n * BN) < dim_n; block_n++) { + // clear out C + initialize_C(0); + initialize_C(1); + + // sync with initial producer stage in the k-loop + threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); + + // NOTE: this *should* be signed integer to trigger arithmetic + // right-shift + int32_t k_index = 0; +#pragma GCC unroll 1 + for (uint32_t k = 0; k < (dim_k); k += BK) { + volatile float *local_a_consume; + volatile float *local_b_consume; + if constexpr (DOUBLE_BUFFER) { + // local_a_consume = (k_index % 2) ? local_a_buf : local_a; + // local_b_consume = (k_index % 2) ? local_b_buf : local_b; + // FIXME: swap multiply with bitshifts + const uint32_t mask_odd = (k_index & 1) << 31 >> 31; + const uint32_t mask_even = ((k_index & 1) ^ 1) << 31 >> 31; + local_a_consume = reinterpret_cast( + (mask_odd & reinterpret_cast(local_a_buf)) | + (mask_even & reinterpret_cast(local_a))); + local_b_consume = reinterpret_cast( + (mask_odd & reinterpret_cast(local_b_buf)) | + (mask_even & reinterpret_cast(local_b))); + } else { + local_a_consume = local_a; + local_b_consume = local_b; + } + k_index++; #if USE_TENSOR_CORE - // @perf: this loop spills to stack a lot because of all the flws in - // vx_wmma_load + // @perf: this loop spills to stack a lot because of all the flws in + // vx_wmma_load #pragma GCC unroll 1 - for (int i = 0; i < BK_LOOP; i++) { -#pragma GCC unroll 4 - for (uint32_t local_k = 0; local_k < BK; local_k += TCK) { - // perform wmma - // vx_wmma_load(local_a_consume, local_b_consume, warp_x, warp_y, - // tid_in_warp); - // FIXME: this is wrong!! need separate accumulation register for - // WM/WN_ITERS -#pragma GCC unroll 2 - for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { - vx_wmma_load_b(local_b_consume, local_k, warp_col, wn_iter, - tid_in_warp); - // vx_wmma_load_b(local_b_consume, 0, 0, 0, tid_in_warp); -#pragma GCC unroll 2 - for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { -#if TC_SINGLE_WARP - if (warp_in_warpgroup == 0) { -#endif - // if ((threadblock_id_in_cluster % 2) == 0) { - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // } - // SMEM -> RF - vx_wmma_load_a(local_a_consume, local_k, warp_row, wm_iter, + for (int i = 0; i < BK_LOOP; i++) { +#pragma GCC unroll 1 + for (uint32_t local_k = 0; local_k < BK; local_k += TCK) { + // perform wmma + // vx_wmma_load(local_a_consume, local_b_consume, warp_x, warp_y, + // tid_in_warp); + // FIXME: this is wrong!! need separate accumulation register for + // WM/WN_ITERS +#pragma GCC unroll 1 + for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { + vx_wmma_load_b(local_b_consume, local_k, warp_col, wn_iter, tid_in_warp); - // vx_wmma_load_a(local_a_consume, 0, 0, 0, tid_in_warp); - // compute - vx_wmma(wm_iter); -#if TC_SINGLE_WARP + // vx_wmma_load_b(local_b_consume, 0, 0, 0, tid_in_warp); +#pragma GCC unroll 1 + for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { + // if ((threadblock_id_in_cluster % 2) == 0) { + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // } + // SMEM -> RF + vx_wmma_load_a(local_a_consume, local_k, warp_row, wm_iter, + tid_in_warp); + // vx_wmma_load_a(local_a_consume, 0, 0, 0, tid_in_warp); + // compute + vx_wmma(wm_iter); + } } -#endif } } - } - } - threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); + threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); #else - // Compute single tile*tile matmul + // Compute single tile*tile matmul #pragma GCC unroll 4 - for (uint32_t local_k = 0; local_k < BK; local_k++) { - // First, pump data from SMEM->RF + for (uint32_t local_k = 0; local_k < BK; local_k++) { + // First, pump data from SMEM->RF #pragma GCC unroll TM - for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { - reg_a[res_idx_m] = - local_a[BK * (TM * local_c_row + res_idx_m) + local_k]; - } -#pragma GCC unroll TN - for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { - reg_b[res_idx_n] = - local_b[BN * local_k + (TN * local_c_col + res_idx_n)]; - } - - // Next, compute multiple result elements (TM*TN) by reusing data in - // RF -#pragma GCC unroll TM - for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { + for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { + reg_a[res_idx_m] = + local_a[BK * (TM * local_c_row + res_idx_m) + local_k]; + } #pragma GCC unroll TN for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { - // NOTE use of local_b_row - reg_c[TN * res_idx_m + res_idx_n] += - reg_a[res_idx_m] * reg_b[res_idx_n]; - // reg_c[TN * res_idx_m + res_idx_n] += - // local_a[BK * (TM * local_c_row + res_idx_m) + local_k] * - // local_b[BN * local_k + (TN * local_c_col + res_idx_n)]; + reg_b[res_idx_n] = + local_b[BN * local_k + (TN * local_c_col + res_idx_n)]; + } + + // Next, compute multiple result elements (TM*TN) by reusing data in + // RF +#pragma GCC unroll TM + for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { +#pragma GCC unroll TN + for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { + // NOTE use of local_b_row + reg_c[TN * res_idx_m + res_idx_n] += + reg_a[res_idx_m] * reg_b[res_idx_n]; + // reg_c[TN * res_idx_m + res_idx_n] += + // local_a[BK * (TM * local_c_row + res_idx_m) + local_k] * + // local_b[BN * local_k + (TN * local_c_col + res_idx_n)]; + } } } - } - threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, - threadblock_dim_y); + threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, + threadblock_dim_y); #endif - } - } + } #if USE_TENSOR_CORE #pragma GCC unroll 1 - for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { + for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { #pragma GCC unroll 1 - for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { -#if TC_SINGLE_WARP - if (warp_in_warpgroup == 0) { -#endif - if (warpgroup_id == 1) { - write_results(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter, - dim_m, dim_n, C, threadblock_id_x, threadblock_id_y); - } -#if TC_SINGLE_WARP - } -#endif - } - } - + for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { + if (warpgroup_id == 1) { + write_results(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter, + dim_n, C, block_n, block_m); + } #else - - // Store result data from RF to GMEM + // Store result data from RF to GMEM #pragma GCC unroll TM - for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { + for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { #pragma GCC unroll TN - for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { - C[dim_n * (BM * threadblock_id_y + TM * local_c_row + res_idx_m) + - (BN * threadblock_id_x + TN * local_c_col + res_idx_n)] = - reg_c[TN * res_idx_m + res_idx_n]; + for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { + C[dim_n * (BM * threadblock_id_y + TM * local_c_row + res_idx_m) + + (BN * threadblock_id_x + TN * local_c_col + res_idx_n)] = + reg_c[TN * res_idx_m + res_idx_n]; + } + } +#endif + } + } + } } } -#endif - } void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { @@ -819,14 +818,19 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const int warp_id = vx_warp_id(); thread_block_gemm(arg, tid_in_threadblock, threads_per_threadblock, - threadblock_dim_x, threadblock_dim_y, threadblock_id_x, - threadblock_id_y, threadblock_id_in_cluster, + threadblock_dim_x, threadblock_dim_y, /*threadblock_id_x, + threadblock_id_y,*/ threadblock_id_in_cluster, sharedmem_per_threadblock); } int main() { kernel_arg_t *arg = (kernel_arg_t *)KERNEL_ARG_DEV_MEM_ADDR; - const uint32_t grid_size = arg->dim_m * arg->dim_n / ELEM_PER_THREAD; + + const uint32_t threads_per_cluster = + CORES_PER_CLUSTER * vx_num_threads() * vx_num_warps(); + // const uint32_t grid_size = arg->dim_m * arg->dim_n / ELEM_PER_THREAD; + const uint32_t grid_size = threads_per_cluster; + #ifdef RADIANCE vx_spawn_tasks_cluster(grid_size, (vx_spawn_tasks_cb)kernel_body, arg); #else