From bdd955836d4ccd151b6cbe71edc25e280e8129d9 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Mon, 2 Sep 2024 00:14:35 -0700 Subject: [PATCH] sgemm_impl: Specify leading dimension to wmma load This is necessary for when loading a subtile from a full tile in SMEM into RF, but that subtile is split by non-major dimension. --- tests/regression/sgemm_tcore/sgemm_impl.hpp | 88 +++++++++++++++------ 1 file changed, 65 insertions(+), 23 deletions(-) diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index e77fea35..13818d43 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -200,7 +200,14 @@ inline void vx_wmma(const int dest_reg) { } // `local_k` is assumed to be multiple of TCK -template +template inline void wmma_load_a(volatile const T *smem_A, const int local_k, const int warp_row, const int wm_iter, const int thread_in_warp) { @@ -221,12 +228,10 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k, // moving a fp32 matrix whose column dimensions (dim_k/BK/k) are compressed // by a factor of two. constexpr int packed_factor = (std::is_same_v ? 2 : 1); - constexpr int BK_adjusted = BK / packed_factor; const int local_k_adjusted = local_k / packed_factor; if constexpr (layout == MemLayout::K_major) { - constexpr int smem_A_rows = BM; - constexpr int smem_A_cols = BK_adjusted; + constexpr int smem_A_cols = leading_dim; // int A_offset = (WM * warp_row + TCM * wm_iter + row) * smem_A_cols; @@ -247,8 +252,7 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k, asm volatile("flw f6, %0(%1)" ::"i"(6 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f7, %0(%1)" ::"i"(7 * sizeof(float)), "r"(smem_addr)); } else if constexpr (layout == MemLayout::MN_major) { - constexpr int smem_AS_rows = BK_adjusted; - constexpr int smem_AS_cols = BM; + constexpr int smem_AS_cols = leading_dim; const volatile uint8_t *smem_addr; smem_addr = reinterpret_cast( @@ -274,11 +278,35 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k, asm volatile ("wmma_load_a_finish_%=:" :: ); } +// Convenience wrapper for wmma_load_a if tile layout is packed, i.e. +// leading_dim == col. +template +inline void wmma_load_a(volatile const T *smem_A, const int local_k, + const int warp_row, const int wm_iter, + const int thread_in_warp) { + // In fp16 mode, bit-pack two fp16 elements into each fp32 element, and do + // data movement at the fp32 granularity. Assuming that the matrix is stored + // row-major in GMEM, the packed fp16 pairs belong to the same row, + // neighboring columns; therefore, it essentially becomes equivalent to + // moving a fp32 matrix whose column dimensions (dim_k/BK/k) are compressed + // by a factor of two. + constexpr int packed_factor = (std::is_same_v ? 2 : 1); + constexpr int tile_dim_k_adjusted = tile_dim_k / packed_factor; + constexpr int leading_dim = (layout == MemLayout::K_major) + ? tile_dim_k_adjusted + : tile_dim_m; + + wmma_load_a(smem_A, local_k, warp_row, wm_iter, + thread_in_warp); +} + // `local_k` is assumed to be multiple of TCK -template +template inline void wmma_load_b(const volatile T *smem_B, const int local_k, - const int warp_col, const int wn_iter, - const int thread_in_warp) { + const int warp_col, const int wn_iter, + const int thread_in_warp) { asm volatile ("wmma_load_b_start_%=:" :: ); static_assert(layout == MemLayout::MN_major, @@ -293,12 +321,12 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k, // see comment in wmma_load_a constexpr int packed_factor = (std::is_same_v ? 2 : 1); - constexpr int BK_adjusted = BK / packed_factor; + constexpr int tile_dim_k_adjusted = tile_dim_k / packed_factor; const int local_k_adjusted = local_k / packed_factor; // B is stored N-major in smem - constexpr int smem_B_rows = BK_adjusted; - constexpr int smem_B_cols = BN; + constexpr int smem_B_rows = tile_dim_k_adjusted; + constexpr int smem_B_cols = tile_dim_n; const volatile uint8_t *smem_addr; smem_addr = reinterpret_cast( @@ -445,6 +473,11 @@ inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count) // Move a single matrix tile from global memory (GMEM) to shared memory (SMEM). // `dim_major`: major dimension of the matrix in GMEM, e.g. if K-major, K; or // MN-major, M/N. +// +// Note that there's not a single way to specify a layout of the matrix. +// Identifying a matrix to be K-major and specifying the mn_index of a tile, +// is equivalent to identifying it as MN-major and specifying the k_index +// (provided `dim_major` is set accordingly). template RF - wmma_load_b(local_b, local_k, warp_col, wn_iter, - tid_in_warp); + static_assert(leading_dim_b == 0, + "leading_dim for wmma_load_b is not implemented yet"); + wmma_load_b( + local_b, local_k, warp_col, wn_iter, tid_in_warp); #pragma GCC unroll 2 for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { // SMEM -> RF - wmma_load_a(local_a, local_k, warp_row, wm_iter, - tid_in_warp); + if constexpr (leading_dim_a == 0) { + wmma_load_a( + local_a, local_k, warp_row, wm_iter, tid_in_warp); + } else { + wmma_load_a(local_a, local_k, warp_row, + wm_iter, tid_in_warp); + } // perform mma vx_wmma(wm_iter); } @@ -925,7 +967,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, constexpr MemLayout layout_a = TRANSPOSE_AT_CONSUME ? MemLayout::K_major : MemLayout::MN_major; thread_block_gemm_single_tile( local_a_consume, local_b_consume,