From 95b5719847da2a08b13aedc00a02f454aae73160 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 5 Jun 2024 17:14:39 -0700 Subject: [PATCH] sgemm_tcore: Split K-dim loop between consumer/producer ... so that you don't have to run (warpgroup_id == 0) condition at every loop iteration which is expensive due to vx_split/join. --- tests/regression/sgemm_tcore/kernel.cpp | 134 +++++++++++++----------- 1 file changed, 75 insertions(+), 59 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index e5a9cf33..cbd3b1df 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -557,44 +557,57 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); } - uint32_t k_index = 0; - + if (warpgroup_id == 0) { + // TODO: bring initiation pipeline here + uint32_t k_index = 0; #pragma GCC unroll 1 - for (uint32_t k = 0; k < dim_k; k += BK) { - // register volatile float *local_a_produce asm("t0"); - // register volatile float *local_b_produce asm("t1"); - // register volatile float *local_a_consume asm("t2"); - // register volatile float *local_b_consume asm("t3"); - volatile float *local_a_produce; - volatile float *local_b_produce; - volatile float *local_a_consume; - volatile float *local_b_consume; - if constexpr (DOUBLE_BUFFER) { - local_a_produce = (k_index % 2) ? local_a : local_a_buf; - local_b_produce = (k_index % 2) ? local_b : local_b_buf; - local_a_consume = (k_index % 2) ? local_a_buf : local_a; - local_b_consume = (k_index % 2) ? local_b_buf : local_b; - // local_a_consume = local_a_produce; - // local_b_consume = local_b_produce; - } else { - local_a_produce = local_a; - local_b_produce = local_b; - local_a_consume = local_a; - local_b_consume = local_b; - } - k_index++; - - if (warpgroup_id == 0) { - if (k != (dim_k - BK)) { - global_dmem_load(dim_n, dim_k, k + BK /*runahead*/, A, B, - local_a_produce, local_b_produce, tid_in_warpgroup, - threadblock_id_x, threadblock_id_y); + for (uint32_t k = 0; k < dim_k - BK; k += BK) { + volatile float *local_a_produce; + volatile float *local_b_produce; + volatile float *local_a_consume; + volatile float *local_b_consume; + if constexpr (DOUBLE_BUFFER) { + local_a_produce = (k_index % 2) ? local_a : local_a_buf; + local_b_produce = (k_index % 2) ? local_b : local_b_buf; + local_a_consume = (k_index % 2) ? local_a_buf : local_a; + local_b_consume = (k_index % 2) ? local_b_buf : local_b; + } else { + local_a_produce = local_a; + local_b_produce = local_b; + local_a_consume = local_a; + local_b_consume = local_b; } + k_index++; + + global_dmem_load(dim_n, dim_k, k + BK /*runahead*/, A, B, local_a_produce, + local_b_produce, tid_in_warpgroup, threadblock_id_x, + threadblock_id_y); threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); } - else { + threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); + } else { + uint32_t k_index = 0; +#pragma GCC unroll 1 + for (uint32_t k = 0; k < dim_k; k += BK) { + volatile float *local_a_produce; + volatile float *local_b_produce; + volatile float *local_a_consume; + volatile float *local_b_consume; + if constexpr (DOUBLE_BUFFER) { + local_a_produce = (k_index % 2) ? local_a : local_a_buf; + local_b_produce = (k_index % 2) ? local_b : local_b_buf; + local_a_consume = (k_index % 2) ? local_a_buf : local_a; + local_b_consume = (k_index % 2) ? local_b_buf : local_b; + } else { + local_a_produce = local_a; + local_b_produce = local_b; + local_a_consume = local_a; + local_b_consume = local_b; + } + k_index++; + #if USE_TENSOR_CORE // @perf: this loop spills to stack a lot because of all the flws in // vx_wmma_load @@ -603,12 +616,14 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, #pragma GCC unroll 1 for (uint32_t local_k = 0; local_k < BK; local_k += TCK) { // perform wmma - // vx_wmma_load(local_a_consume, local_b_consume, warp_x, warp_y, tid_in_warp); + // vx_wmma_load(local_a_consume, local_b_consume, warp_x, warp_y, + // tid_in_warp); // FIXME: this is wrong!! need separate accumulation register for // WM/WN_ITERS #pragma GCC unroll 2 for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { - vx_wmma_load_b(local_b_consume, local_k, warp_col, wn_iter, tid_in_warp); + vx_wmma_load_b(local_b_consume, local_k, warp_col, wn_iter, + tid_in_warp); // vx_wmma_load_b(local_b_consume, 0, 0, 0, tid_in_warp); #pragma GCC unroll 1 for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { @@ -641,43 +656,44 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, } threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); - } #else - // Compute single tile*tile matmul + // Compute single tile*tile matmul #pragma GCC unroll 4 - for (uint32_t local_k = 0; local_k < BK; local_k++) { - // First, pump data from SMEM->RF + for (uint32_t local_k = 0; local_k < BK; local_k++) { + // First, pump data from SMEM->RF #pragma GCC unroll TM - for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { - reg_a[res_idx_m] = - local_a[BK * (TM * local_c_row + res_idx_m) + local_k]; - } + for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { + reg_a[res_idx_m] = + local_a[BK * (TM * local_c_row + res_idx_m) + local_k]; + } #pragma GCC unroll TN - for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { - reg_b[res_idx_n] = - local_b[BN * local_k + (TN * local_c_col + res_idx_n)]; - } + for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { + reg_b[res_idx_n] = + local_b[BN * local_k + (TN * local_c_col + res_idx_n)]; + } - // Next, compute multiple result elements (TM*TN) by reusing data in RF + // Next, compute multiple result elements (TM*TN) by reusing data in + // RF #pragma GCC unroll TM - for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { + for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { #pragma GCC unroll TN - for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { - // NOTE use of local_b_row - reg_c[TN * res_idx_m + res_idx_n] += - reg_a[res_idx_m] * reg_b[res_idx_n]; - // reg_c[TN * res_idx_m + res_idx_n] += - // local_a[BK * (TM * local_c_row + res_idx_m) + local_k] * - // local_b[BN * local_k + (TN * local_c_col + res_idx_n)]; + for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { + // NOTE use of local_b_row + reg_c[TN * res_idx_m + res_idx_n] += + reg_a[res_idx_m] * reg_b[res_idx_n]; + // reg_c[TN * res_idx_m + res_idx_n] += + // local_a[BK * (TM * local_c_row + res_idx_m) + local_k] * + // local_b[BN * local_k + (TN * local_c_col + res_idx_n)]; + } + } } - } - } - threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, - threadblock_dim_y); + threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, + threadblock_dim_y); #endif + } } #if USE_TENSOR_CORE