From a9b0814211b760b2b0299614ead326af1e989c46 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Thu, 28 Mar 2024 18:17:00 -0700 Subject: [PATCH] sgemm_wg: Document tiling parameter constraints --- tests/regression/sgemm_wg/kernel.cpp | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/tests/regression/sgemm_wg/kernel.cpp b/tests/regression/sgemm_wg/kernel.cpp index 5fc1b8b8..11612db1 100644 --- a/tests/regression/sgemm_wg/kernel.cpp +++ b/tests/regression/sgemm_wg/kernel.cpp @@ -4,11 +4,20 @@ #include #include "common.h" +// Constraints on parameters: +// * Memory: +// (BM + BN) * BK * sizeof(float) <= sharedmem size. +// BM * BK == BN * BK >= threadblock size >= NT * CORES_PER_CLUSTER +// When larger, the kernel runs a sequential loop to read into sharedmem; +// but smaller case is not handled. +// * Compute: +// ( M* N) / (TM*TN) == grid size >= NC*NW*NT +// (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER +// * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields +// BM <= BK*TM*TN. #define BM 8 #define BN BM #define BK 2 -// #define TM (BM/BK) -// #define TN (BN/BK) #define TM 2 #define TN 2 @@ -82,7 +91,9 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, threadblock_barrier(tid_in_threadblock, threadblock_id_in_core, threadblock_dim_y); + // Compute single tile*tile matmul for (uint32_t local_k = 0; local_k < BK; local_k++) { + // First, pump data from SMEM->RF #pragma GCC unroll TM for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { reg_a[res_idx_m] = @@ -94,7 +105,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, local_b[BN * local_k + (TN * local_c_col + res_idx_n)]; } - // Compute multiple result elements (TM) per thread + // Next, compute multiple result elements (TM*TN) by reusing data in RF #pragma GCC unroll TM for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { #pragma GCC unroll TN @@ -113,6 +124,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, threadblock_dim_y); } + // Store result data from RF to GMEM #pragma GCC unroll TM for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { #pragma GCC unroll TN