diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 2354a3e0..18f52917 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -8,7 +8,6 @@ #define NUM_LANES 8 -#define USE_TENSOR_CORE 1 // number of loop around the inner 0..TCK..BK loop to simulate perfect-DRAM // scenario #define BK_LOOP 1 @@ -42,14 +41,7 @@ #define TCK 8 #define WMITER (WM / TCM) #define WNITER (WN / TCN) -#if USE_TENSOR_CORE == 1 -#define TM 1 -#define TN ((TCM * TCN) / NUM_LANES / TM) -#else -#define TM 1 -#define TN 1 -#endif -#define ELEM_PER_THREAD (WMITER * WNITER * TM * TN / (DOUBLE_BUFFER ? 2 : 1)) +#define ELEM_PER_THREAD (WMITER * WNITER * ((TCM * TCN) / NUM_LANES) / (DOUBLE_BUFFER ? 2 : 1)) // FIXME: NUM_THREADS and NUM_WARPS hardcoded #if ((BM * BN / ELEM_PER_THREAD) > (CORES_PER_CLUSTER * 8 * 8)) @@ -564,16 +556,6 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, const uint32_t local_b_row = tid_in_threadblock / BN; const uint32_t local_b_col = tid_in_threadblock % BN; - const uint32_t local_c_row = tid_in_threadblock / (BN / TN); - const uint32_t local_c_col = tid_in_threadblock % (BN / TN); - -#if !USE_TENSOR_CORE - // each thread generates TM output element - float reg_c[TM * TN] = { 0.0f }; - float reg_a[TM] = { 0.0f }; - float reg_b[TN] = { 0.0f }; -#endif - const uint32_t threads_per_warpgroup = threads_per_threadblock / (DOUBLE_BUFFER ? 2 : 1); const uint32_t warpgroup_id = tid_in_threadblock / threads_per_warpgroup; const uint32_t tid_in_warpgroup = tid_in_threadblock % threads_per_warpgroup; // FIXME @@ -677,41 +659,22 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, } k_index++; -#if USE_TENSOR_CORE // @perf: this loop spills to stack a lot because of all the flws in - // vx_wmma_load #pragma GCC unroll 1 for (int i = 0; i < BK_LOOP; i++) { #pragma GCC unroll 2 for (uint32_t local_k = 0; local_k < BK; local_k += TCK) { - // perform wmma - // vx_wmma_load(local_a_consume, local_b_consume, warp_x, warp_y, - // tid_in_warp); - // FIXME: this is wrong!! need separate accumulation register for - // WM/WN_ITERS #pragma GCC unroll 2 for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { + // SMEM -> RF vx_wmma_load_b(local_b_consume, local_k, warp_col, wn_iter, tid_in_warp); - // vx_wmma_load_b(local_b_consume, 0, 0, 0, tid_in_warp); #pragma GCC unroll 2 for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { - // if ((threadblock_id_in_cluster % 2) == 0) { - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // } // SMEM -> RF vx_wmma_load_a(local_a_consume, local_k, warp_row, wm_iter, tid_in_warp); - // vx_wmma_load_a(local_a_consume, 0, 0, 0, tid_in_warp); - // compute + // perform mma vx_wmma(wm_iter); } } @@ -719,46 +682,8 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, } threadblock_barrier(0/*threadblock_id_in_cluster*/, threadblock_dim_y); - -#else - - // Compute single tile*tile matmul -#pragma GCC unroll 4 - for (uint32_t local_k = 0; local_k < BK; local_k++) { - // First, pump data from SMEM->RF -#pragma GCC unroll TM - for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { - reg_a[res_idx_m] = - local_a[BK * (TM * local_c_row + res_idx_m) + local_k]; - } -#pragma GCC unroll TN - for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { - reg_b[res_idx_n] = - local_b[BN * local_k + (TN * local_c_col + res_idx_n)]; - } - - // Next, compute multiple result elements (TM*TN) by reusing data in - // RF -#pragma GCC unroll TM - for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { -#pragma GCC unroll TN - for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { - // NOTE use of local_b_row - reg_c[TN * res_idx_m + res_idx_n] += - reg_a[res_idx_m] * reg_b[res_idx_n]; - // reg_c[TN * res_idx_m + res_idx_n] += - // local_a[BK * (TM * local_c_row + res_idx_m) + local_k] * - // local_b[BN * local_k + (TN * local_c_col + res_idx_n)]; - } - } - } - - threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, - threadblock_dim_y); -#endif } -#if USE_TENSOR_CORE #pragma GCC unroll 1 for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { #pragma GCC unroll 1 @@ -767,18 +692,6 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, write_results(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter, dim_n, C, block_n, block_m); } -#else - // Store result data from RF to GMEM -#pragma GCC unroll TM - for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { -#pragma GCC unroll TN - for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { - C[dim_n * (BM * threadblock_id_y + TM * local_c_row + res_idx_m) + - (BN * threadblock_id_x + TN * local_c_col + res_idx_n)] = - reg_c[TN * res_idx_m + res_idx_n]; - } - } -#endif } } }