diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index ab2a9233..fe4f7586 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -69,7 +69,7 @@ using float_type = float16_t; // generates the NN kernel where both A and B are stored row-major in GMEM. // To model the case where the A matrix is already stored column-major in GMEM, // set both to 0. -#define TRANSPOSE_AT_PRODUCE 1 +#define TRANSPOSE_AT_PRODUCE 0 #define TRANSPOSE_AT_CONSUME 0 #define GEMMINI_DMA 0 @@ -97,6 +97,11 @@ using float_type = float16_t; #error Unsupported smem size #endif +enum class MemLayout { + MN_major, + K_major, +}; + inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) { const int tg = tid / 4; @@ -195,10 +200,10 @@ inline void vx_wmma(const int dest_reg) { } // `local_k` is assumed to be multiple of TCK -template +template inline void wmma_load_a(volatile const T *smem_A, const int local_k, - const int warp_row, const int wm_iter, - const int thread_in_warp) { + const int warp_row, const int wm_iter, + const int thread_in_warp) { asm volatile ("wmma_load_a_start_%=:" :: ); const int tid = thread_in_warp; @@ -219,8 +224,7 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k, constexpr int BK_adjusted = BK / packed_factor; const int local_k_adjusted = local_k / packed_factor; - if constexpr (TRANSPOSE_AT_CONSUME) { - // A is stored K-major in smem + if constexpr (layout == MemLayout::K_major) { constexpr int smem_A_rows = BM; constexpr int smem_A_cols = BK_adjusted; @@ -242,8 +246,7 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k, asm volatile("flw f5, %0(%1)" ::"i"(5 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f6, %0(%1)" ::"i"(6 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f7, %0(%1)" ::"i"(7 * sizeof(float)), "r"(smem_addr)); - } else { - // A is stored M-major in smem + } else if (layout == MemLayout::MN_major) { constexpr int smem_AS_rows = BK_adjusted; constexpr int smem_AS_cols = BM; @@ -262,18 +265,25 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k, asm volatile("flw f5, %0(%1)" :: "i"(smem_AS_cols * 5 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f6, %0(%1)" :: "i"(smem_AS_cols * 6 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f7, %0(%1)" :: "i"(smem_AS_cols * 7 * sizeof(float)), "r"(smem_addr)); + } else { + static_assert(layout == + MemLayout::K_major /* fake cond that is always false */, + "unsupported memory layout"); } asm volatile ("wmma_load_a_finish_%=:" :: ); } // `local_k` is assumed to be multiple of TCK -template +template inline void wmma_load_b(const volatile T *smem_B, const int local_k, const int warp_col, const int wn_iter, const int thread_in_warp) { asm volatile ("wmma_load_b_start_%=:" :: ); + static_assert(layout == MemLayout::MN_major, + "only N-major layout for the B tile is supported"); + const int tid = thread_in_warp; const int tg = tid / 4; @@ -384,11 +394,6 @@ inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count) vx_barrier(barrier_id, count); } -enum class MemLayout { - MN_major, - K_major, -}; - // Move a single matrix tile from global memory (GMEM) to shared memory (SMEM). // `dim_col`: column dimension of the global matrix. template ? 2 : 1); constexpr uint32_t tile_dim_k_packed = tile_dim_k / packed_factor; @@ -555,6 +559,8 @@ load_tile_to_smem(const uint32_t dim_col, const uint32_t mn_index, // Do a single tile*tile matrix multiplication using the matrix data stored in // SMEM. Useful in fused kernels where GEMMs are done at a per-tile scope. template @@ -577,11 +583,13 @@ thread_block_gemm_single_tile(const T *local_a, const T *local_b, T *local_c, #pragma GCC unroll 2 for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { // SMEM -> RF - wmma_load_b(local_b, local_k, warp_col, wn_iter, tid_in_warp); + wmma_load_b(local_b, local_k, warp_col, wn_iter, + tid_in_warp); #pragma GCC unroll 2 for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { // SMEM -> RF - wmma_load_a(local_a, local_k, warp_row, wm_iter, tid_in_warp); + wmma_load_a(local_a, local_k, warp_row, wm_iter, + tid_in_warp); // perform mma vx_wmma(wm_iter); } @@ -845,7 +853,10 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, local_b_consume = local_b; } - thread_block_gemm_single_tile( + constexpr MemLayout layout_a = + TRANSPOSE_AT_CONSUME ? MemLayout::K_major : MemLayout::MN_major; + thread_block_gemm_single_tile( local_a_consume, local_b_consume, static_cast(nullptr) /*ignore*/, tid_in_threadblock, threads_per_threadblock);