diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 090df810..e10a3c0d 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -6,6 +6,9 @@ #include #include "common.h" +#define USE_TENSOR_CORE 1 +#define TC_SINGLE_WARP 0 + #define NUM_LANES 8 // Constraints on parameters: @@ -20,18 +23,19 @@ // (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER // * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields // BM <= BK*TM*TN -#define BM 8 -#define BN BM -#define BK 8 +#define BM 16 +#define BN 16 +#define BK 32 #define TCM 8 #define TCN 8 +#define TCK 8 #define WM 8 #define WN 8 #define WMITER (WM / TCM) #define WNITER (WN / TCN) #define TM 1 -// #define TN ((TCM * TCN) / NUM_LANES / TM) -#define TN 1 +#define TN ((TCM * TCN) / NUM_LANES / TM) +// #define TN 1 inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) { @@ -125,9 +129,10 @@ inline void vx_wmma() { asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)); } -void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_col, - int warp_row, int wn_iter, int wm_iter, - int thread_in_warp) { +// `local_k` is assumed to be multiple of TCK +void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, const int local_k, + const int warp_col, const int warp_row, const int wn_iter, + const int wm_iter, const int thread_in_warp) { int tid = thread_in_warp; int tg = tid / 4; @@ -142,23 +147,24 @@ void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_col, int A_offset = (row + WM * warp_row + TCM * wm_iter) * smem_A_cols; - asm volatile("flw f0, %0" ::"m"(smem_A[A_offset + 0])); - asm volatile("flw f1, %0" ::"m"(smem_A[A_offset + 1])); - asm volatile("flw f2, %0" ::"m"(smem_A[A_offset + 2])); - asm volatile("flw f3, %0" ::"m"(smem_A[A_offset + 3])); - asm volatile("flw f4, %0" ::"m"(smem_A[A_offset + 4])); - asm volatile("flw f5, %0" ::"m"(smem_A[A_offset + 5])); - asm volatile("flw f6, %0" ::"m"(smem_A[A_offset + 6])); - asm volatile("flw f7, %0" ::"m"(smem_A[A_offset + 7])); + // @perf: bank conflicts + asm volatile("flw f0, %0" ::"m"(smem_A[A_offset + (local_k + 0)])); + asm volatile("flw f1, %0" ::"m"(smem_A[A_offset + (local_k + 1)])); + asm volatile("flw f2, %0" ::"m"(smem_A[A_offset + (local_k + 2)])); + asm volatile("flw f3, %0" ::"m"(smem_A[A_offset + (local_k + 3)])); + asm volatile("flw f4, %0" ::"m"(smem_A[A_offset + (local_k + 4)])); + asm volatile("flw f5, %0" ::"m"(smem_A[A_offset + (local_k + 5)])); + asm volatile("flw f6, %0" ::"m"(smem_A[A_offset + (local_k + 6)])); + asm volatile("flw f7, %0" ::"m"(smem_A[A_offset + (local_k + 7)])); - asm volatile("flw f8 , %0" ::"m"(smem_B[(0 * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); - asm volatile("flw f9 , %0" ::"m"(smem_B[(1 * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); - asm volatile("flw f10, %0" ::"m"(smem_B[(2 * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); - asm volatile("flw f11, %0" ::"m"(smem_B[(3 * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); - asm volatile("flw f12, %0" ::"m"(smem_B[(4 * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); - asm volatile("flw f13, %0" ::"m"(smem_B[(5 * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); - asm volatile("flw f14, %0" ::"m"(smem_B[(6 * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); - asm volatile("flw f15, %0" ::"m"(smem_B[(7 * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + asm volatile("flw f8 , %0" ::"m"(smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + asm volatile("flw f9 , %0" ::"m"(smem_B[((local_k + 1) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + asm volatile("flw f10, %0" ::"m"(smem_B[((local_k + 2) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + asm volatile("flw f11, %0" ::"m"(smem_B[((local_k + 3) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + asm volatile("flw f12, %0" ::"m"(smem_B[((local_k + 4) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + asm volatile("flw f13, %0" ::"m"(smem_B[((local_k + 5) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + asm volatile("flw f14, %0" ::"m"(smem_B[((local_k + 6) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + asm volatile("flw f15, %0" ::"m"(smem_B[((local_k + 7) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); } inline void initialize_C() { @@ -232,6 +238,14 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, const uint32_t global_a_row = BM * threadblock_id_y + local_a_row; const uint32_t global_b_col = BN * threadblock_id_x + local_b_col; + const uint32_t local_c_row = tid_in_threadblock / (BN / TN); + const uint32_t local_c_col = tid_in_threadblock % (BN / TN); + + // each thread generates TM output element + float reg_c[TM * TN] = { 0.0f }; + float reg_a[TM] = { 0.0f }; + float reg_b[TN] = { 0.0f }; + const uint32_t warp_in_threadblock = tid_in_threadblock / NUM_LANES; const uint32_t warp_row = warp_in_threadblock / (BN / WN); const uint32_t warp_col = warp_in_threadblock % (BN / WN); @@ -239,11 +253,9 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, volatile float *local_a = sharedmem_per_threadblock; // const size_t local_a_elems = threadblock_dim_x * threadblock_dim_y; - // FIXME: this better be BM * BK, but the GMEM->SMEM load assumes all threads - // in TB participates in the load - const size_t local_a_elems = (BM * BN); + const size_t local_a_elems = (BM * BK); volatile float *local_b = sharedmem_per_threadblock + local_a_elems; - const size_t local_b_elems = (BM * BN); + const size_t local_b_elems = (BK * BN); volatile float *local_warp_results = local_b + local_b_elems + (warp_in_threadblock * TCM * TCN); @@ -281,36 +293,95 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, threadblock_dim_y); - // perform wmma - // vx_wmma_load(local_a, local_b, warp_x, warp_y, tid_in_warp); - // FIXME: If multiple warps try to issue to Tensor Core at the same time, - // does one stall the other? - // FIXME: this is wrong!! need separate accumulation register for - // WM/WN_ITERS - if (warp_in_threadblock == 0) { +#if USE_TENSOR_CORE + for (uint32_t local_k = 0; local_k < BK; local_k += TCK) { + // perform wmma + // vx_wmma_load(local_a, local_b, warp_x, warp_y, tid_in_warp); + // FIXME: If multiple warps try to issue to Tensor Core at the same time, + // does one stall the other? + // FIXME: this is wrong!! need separate accumulation register for + // WM/WN_ITERS for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { - vx_wmma_load(local_a, local_b, warp_col, warp_row, wn_iter, wm_iter, - tid_in_warp); - vx_wmma(); +#if TC_SINGLE_WARP + if (warp_in_threadblock == 0) { +#endif + vx_wmma_load(local_a, local_b, local_k, warp_col, warp_row, wn_iter, + wm_iter, tid_in_warp); + vx_wmma(); +#if TC_SINGLE_WARP + } +#endif } } } +#else + + // Compute single tile*tile matmul +#pragma GCC unroll 4 + for (uint32_t local_k = 0; local_k < BK; local_k++) { + // First, pump data from SMEM->RF +#pragma GCC unroll TM + for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { + reg_a[res_idx_m] = + local_a[BK * (TM * local_c_row + res_idx_m) + local_k]; + } +#pragma GCC unroll TN + for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { + reg_b[res_idx_n] = + local_b[BN * local_k + (TN * local_c_col + res_idx_n)]; + } + + // Next, compute multiple result elements (TM*TN) by reusing data in RF +#pragma GCC unroll TM + for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { +#pragma GCC unroll TN + for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { + // NOTE use of local_b_row + reg_c[TN * res_idx_m + res_idx_n] += + reg_a[res_idx_m] * reg_b[res_idx_n]; + // reg_c[TN * res_idx_m + res_idx_n] += + // local_a[BK * (TM * local_c_row + res_idx_m) + local_k] * + // local_b[BN * local_k + (TN * local_c_col + res_idx_n)]; + } + } + } +#endif + threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, threadblock_dim_y); } - if (warp_in_threadblock == 0) { - for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { - for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { - write_results(local_warp_results, tid_in_warp, - warp_col, warp_row, - wn_iter, wm_iter, - dim_m, dim_n, C, threadblock_id_x, threadblock_id_y); +#if USE_TENSOR_CORE + for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { + for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { +#if TC_SINGLE_WARP + if (warp_in_threadblock == 0) { +#endif + write_results(local_warp_results, tid_in_warp, warp_col, warp_row, + wn_iter, wm_iter, dim_m, dim_n, C, threadblock_id_x, + threadblock_id_y); +#if TC_SINGLE_WARP } +#endif } } + +#else + + // Store result data from RF to GMEM +#pragma GCC unroll TM + for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { +#pragma GCC unroll TN + for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { + C[dim_n * (BM * threadblock_id_y + TM * local_c_row + res_idx_m) + + (BN * threadblock_id_x + TN * local_c_col + res_idx_n)] = + reg_c[TN * res_idx_m + res_idx_n]; + } + } +#endif + } void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { @@ -340,8 +411,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // "static" shared memory allocation. This would determine threadblock // occupancy of a single cluster + // FIXME: 4* is unnecessary; being safe for overlaps float *sharedmem_per_threadblock = - (float *)DEV_SMEM_START_ADDR + (2 * BM * BK) * threadblock_id_in_cluster; + (float *)DEV_SMEM_START_ADDR + (4 * BM * BK) * threadblock_id_in_cluster; thread_block_gemm(arg, tid_in_threadblock, threadblock_dim_x, threadblock_dim_y, threadblock_id_x, threadblock_id_y, threadblock_id_in_cluster, sharedmem_per_threadblock); diff --git a/tests/regression/sgemm_tcore/main.cpp b/tests/regression/sgemm_tcore/main.cpp index e34b7066..5294f27b 100644 --- a/tests/regression/sgemm_tcore/main.cpp +++ b/tests/regression/sgemm_tcore/main.cpp @@ -147,9 +147,9 @@ int main(int argc, char *argv[]) { RT_CHECK(vx_dev_open(&device)); // FIXME: hardcoded - uint32_t dim_m = 64; - uint32_t dim_n = 64; - uint32_t dim_k = 64; + uint32_t dim_m = 32; + uint32_t dim_n = 32; + uint32_t dim_k = 32; generate_source_matrix(dim_m, dim_n, dim_k); generate_reference_matmul(dim_m, dim_n, dim_k);