From 4e723c46558d71a1a033bfdf3409baebfaa31e1d Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sat, 1 Jun 2024 01:12:08 -0700 Subject: [PATCH] sgemm_tcore: Support two accumulation reg tiles --- tests/regression/sgemm_tcore/kernel.cpp | 86 ++++++++++++++++--------- 1 file changed, 57 insertions(+), 29 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 4ded4758..4ac80775 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -27,10 +27,10 @@ // (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER // * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields // BM <= BK*TM*TN -#define BM 8 -#define BN 8 -#define BK 8 -#define WM 8 +#define BM 32 +#define BN 32 +#define BK 32 +#define WM 16 #define WN 8 #define TCM 8 #define TCN 8 @@ -133,8 +133,12 @@ inline constexpr void map_c(const int tid, int &row, int &col) { } } -inline void vx_wmma() { - asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)); +inline void vx_wmma(const int dest_reg) { + if (dest_reg == 0) { + asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)); + } else { + asm volatile (".insn r %0, 0, 0, x1, x0, x0" :: "i"(RISCV_CUSTOM3)); + } } // `local_k` is assumed to be multiple of TCK @@ -196,23 +200,35 @@ inline void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, const i asm volatile("flw f15, %0" ::"m"(smem_B[((local_k + 7) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); } -inline void initialize_C() { +inline void initialize_C(const int dest_reg) { // initialize C to zeros - asm volatile("fmv.w.x f16, x0"); - asm volatile("fmv.w.x f17, x0"); - asm volatile("fmv.w.x f18, x0"); - asm volatile("fmv.w.x f19, x0"); - asm volatile("fmv.w.x f20, x0"); - asm volatile("fmv.w.x f21, x0"); - asm volatile("fmv.w.x f22, x0"); - asm volatile("fmv.w.x f23, x0"); + if (dest_reg == 0) { + asm volatile("fmv.w.x f16, x0"); + asm volatile("fmv.w.x f17, x0"); + asm volatile("fmv.w.x f18, x0"); + asm volatile("fmv.w.x f19, x0"); + asm volatile("fmv.w.x f20, x0"); + asm volatile("fmv.w.x f21, x0"); + asm volatile("fmv.w.x f22, x0"); + asm volatile("fmv.w.x f23, x0"); + } else { + asm volatile("fmv.w.x f24, x0"); + asm volatile("fmv.w.x f25, x0"); + asm volatile("fmv.w.x f26, x0"); + asm volatile("fmv.w.x f27, x0"); + asm volatile("fmv.w.x f28, x0"); + asm volatile("fmv.w.x f29, x0"); + asm volatile("fmv.w.x f30, x0"); + asm volatile("fmv.w.x f31, x0"); + } } inline void write_results(volatile float *local_warp_results, - int thread_in_warp, int warp_col, int warp_row, - int wn_iter, int wm_iter, int dim_m, int dim_n, - float *C, int threadblock_id_x, - int threadblock_id_y) { + const int thread_in_warp, const int warp_col, + const int warp_row, const int wn_iter, + const int wm_iter, const int dim_m, const int dim_n, + float *C, const int threadblock_id_x, + const int threadblock_id_y) { int tid = thread_in_warp; int tg = tid / 4; @@ -229,14 +245,25 @@ inline void write_results(volatile float *local_warp_results, BN * threadblock_id_x; // @perf: this likely causes a lot of gmem bank conflicts - asm volatile ("fsw f16, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 0)])); - asm volatile ("fsw f17, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 1)])); - asm volatile ("fsw f18, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 0)])); - asm volatile ("fsw f19, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 1)])); - asm volatile ("fsw f20, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 4)])); - asm volatile ("fsw f21, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 5)])); - asm volatile ("fsw f22, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 4)])); - asm volatile ("fsw f23, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)])); + if (wm_iter == 0) { + asm volatile ("fsw f16, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 0)])); + asm volatile ("fsw f17, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 1)])); + asm volatile ("fsw f18, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 0)])); + asm volatile ("fsw f19, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 1)])); + asm volatile ("fsw f20, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 4)])); + asm volatile ("fsw f21, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 5)])); + asm volatile ("fsw f22, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 4)])); + asm volatile ("fsw f23, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)])); + } else { + asm volatile ("fsw f24, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 0)])); + asm volatile ("fsw f25, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 1)])); + asm volatile ("fsw f26, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 0)])); + asm volatile ("fsw f27, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 1)])); + asm volatile ("fsw f28, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 4)])); + asm volatile ("fsw f29, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 5)])); + asm volatile ("fsw f30, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 4)])); + asm volatile ("fsw f31, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)])); + } } inline void threadblock_barrier(unsigned int tid_in_threadblock, @@ -349,7 +376,8 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, local_b + local_b_elems + (warp_in_threadblock * TCM * TCN); // clear out C - initialize_C(); + initialize_C(0); + initialize_C(1); #pragma GCC unroll 1 for (uint32_t k = 0; k < dim_k; k += BK) { @@ -394,7 +422,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, vx_wmma_load(local_a, local_b, local_k, warp_col, warp_row, wn_iter, wm_iter, tid_in_warp); // compute - vx_wmma(); + vx_wmma(wm_iter); #if TC_SINGLE_WARP } #endif