diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 6b080369..cec214f1 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -31,7 +31,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const uint32_t threadblocks_per_cluster = hw_threads_per_cluster / threads_per_threadblock; - const uint32_t threadblock_dim_y = vx_num_warps() / threadblocks_per_cluster; const int threadblock_id = task_id / threads_per_threadblock; const int threadblock_id_in_cluster = threadblock_id % threadblocks_per_cluster; @@ -51,11 +50,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { DEV_SMEM_START_ADDR + sizeof(float_type) * 2 /*overkill for non-dma*/ * (2 * BM * BK) * threadblock_id_in_cluster); - thread_block_gemm( - arg, tid_in_threadblock, threads_per_threadblock, threadblock_dim_y, - /*threadblock_id_x, threadblock_id_y,*/ - threadblocks_per_cluster, - // threadblock_id, + thread_block_gemm( + (const float_type *)arg->addr_a, (const float_type *)arg->addr_b, + (float *)arg->addr_c, arg->dim_m, arg->dim_n, arg->dim_k, + tid_in_threadblock, threads_per_threadblock, threadblocks_per_cluster, threadblock_id_in_cluster, sharedmem_per_threadblock); } diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index a05e886d..8e22b16a 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -28,7 +28,7 @@ #define TCK 16 #define WMITER (WM / TCM) #define WNITER (WN / TCN) -#define ELEM_PER_THREAD (WMITER * WNITER * (TCM * TCN) / NUM_THREADS) +#define ELEM_PER_THREAD (WM * WN / NUM_THREADS) // FIXME: NUM_THREADS and NUM_WARPS hardcoded #if ((BM * BN / ELEM_PER_THREAD) > (CORES_PER_CLUSTER * 8 * 8)) #error "threadblock size too big for cluster" @@ -433,8 +433,6 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u const uint32_t global_a_row = k_adjusted + local_as_row; const uint32_t global_a_col = BM * block_m + local_as_col; // number of rows a full TB can read at a time - // this is equivalent to threadblock_dim_y (assuming threadblock_dim_x == - // BK) constexpr uint32_t row_stride_as = threads_in_threadblock / BM; const float *global_a = reinterpret_cast(A) + dim_m * global_a_row + global_a_col; @@ -628,23 +626,15 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u } template -inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, +inline void thread_block_gemm(const T *A, const T *B, float *C, + const uint32_t dim_m, + const uint32_t dim_n, + const uint32_t dim_k, const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock, - const uint32_t threadblock_dim_y, - /*const uint32_t threadblock_id_x, - const uint32_t threadblock_id_y,*/ const uint32_t threadblocks_per_cluster, const uint32_t threadblock_id_in_cluster, uint8_t *sharedmem_per_threadblock) { - const T *A = (const T *)arg->addr_a; - const T *B = (const T *)arg->addr_b; - float *C = (float *)arg->addr_c; - - const uint32_t dim_m = arg->dim_m; - const uint32_t dim_n = arg->dim_n; - const uint32_t dim_k = arg->dim_k; - const uint32_t local_a_row = tid_in_threadblock / BK; const uint32_t local_a_col = tid_in_threadblock % BK; const uint32_t local_as_row = tid_in_threadblock / BM; @@ -658,6 +648,7 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, const uint32_t warp_row = warp_id_in_warpgroup / (BN / WN); const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN); const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; + const uint32_t warps_per_threadblock_per_core = NUM_WARPS / threads_per_threadblock; volatile T *local_a = reinterpret_cast(sharedmem_per_threadblock); constexpr size_t local_a_elems = (BM * BK); @@ -737,7 +728,8 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, #endif } - threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); } #pragma GCC unroll 1 @@ -814,7 +806,8 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, global_dmem_load(dim_m, dim_n, dim_k, block_k * BK, A, B, local_a, local_b, tid_in_threadblock, block_n, block_m); - threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); #endif // consumer code: SMEM->RF and compute @@ -865,14 +858,15 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, if constexpr (GEMMINI_DMA) { // Call gemmini fence at the end of the loop to overlap dma & wmma. - // Hopefully by this time, dma would have finished so that this is a - // no-op + // Usually, by this time, dma has finished the copy so that this + // becomes a no-op. if (tid_in_threadblock == 0) { gemmini_fence(); } } - threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); } if constexpr (write_to_gmem) {