From 0534e5d1f6a84b3fdf3285b972cfd9c6f659bf1d Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 14 Aug 2024 15:28:52 -0700 Subject: [PATCH] sgemm_tcore: Fix addr gen for GMEM->SMEM for M-major A This fixes correctness for TRANSPOSE_AT_PRODUCE/COLUMN=0/0, provided the matrices are already stored in the correct layout in GMEM. --- tests/regression/sgemm_tcore/kernel.cpp | 42 +++++++++++++------------ tests/regression/sgemm_tcore/util.hpp | 39 +++++++++++++---------- 2 files changed, 45 insertions(+), 36 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index b0a7e9e2..a92d69ff 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -40,13 +40,16 @@ // using float_type = float; using float_type = float16_t; +// TODO: reduce args by passing leading A/B dimensions template -inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, +inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, const T *A, const T *B, volatile T *local_a, volatile T *local_b, const uint32_t tid_in_threadblock, const uint32_t threadblock_id_x, const uint32_t threadblock_id_y) { + asm volatile ("global_dmem_load_start_%=:" :: ); + // In fp16 mode, bit-pack two fp16 elements into each fp32 element, and do // data movement at the fp32 granularity. Assuming that the matrix is stored // row-major in GMEM, the packed fp16 pairs belong to the same row, @@ -79,28 +82,28 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, // move A if constexpr (!TRANSPOSE_AT_PRODUCE) { - // No transpose at GMEM->SMEM movement - // FIXME: !TRANSPOSE_AS code is old - - const uint32_t global_a_row = BM * threadblock_id_y + local_a_row; + // A is stored M-major in GMEM; + // no transpose at GMEM->SMEM movement + const uint32_t block_m = threadblock_id_y; + const uint32_t global_a_row = k_adjusted + local_as_row; + const uint32_t global_a_col = BM * block_m + local_as_col; // number of rows a full TB can read at a time // this is equivalent to threadblock_dim_y (assuming threadblock_dim_x == // BK) - constexpr uint32_t row_stride_a = threads_in_threadblock / BK_adjusted; + constexpr uint32_t row_stride_as = threads_in_threadblock / BM; const float *global_a = reinterpret_cast(A) + - dim_k_adjusted * global_a_row + - (k_adjusted + local_a_col); + dim_m * global_a_row + global_a_col; volatile float *local_a_tmp = reinterpret_cast(local_a) + - BK_adjusted * local_a_row + local_a_col; + BM * local_as_row + local_as_col; #pragma GCC unroll 1 - for (uint32_t local_row_offset = 0; local_row_offset < BM; - local_row_offset += row_stride_a) { + for (uint32_t local_row_offset = 0; local_row_offset < BK_adjusted; + local_row_offset += row_stride_as) { + // TODO: the code GCC generates for below seems fine atm, but unroll to + // assembly to be absolutely sure *local_a_tmp = *global_a; - - // move to the next "row-chunk", when threadblock is smaller than BM*BK - global_a += dim_k_adjusted * row_stride_a; - local_a_tmp += BK_adjusted * row_stride_a; + global_a += dim_m * row_stride_as; + local_a_tmp += BM * row_stride_as; } } else { if constexpr (!GMEM_COALESCED_A) { @@ -126,9 +129,6 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, // @perf: bank conflicts here // const uint32_t global_a_offset = // dim_k_adjusted * (global_a_row) + (k + local_as_row + local_row_offset); - // FIXME experimenting with global coalescing - // const uint32_t global_a_offset = - // dim_k_adjusted * (global_a_row + local_row_offset) + (k + local_as_col); // local_a[BM * (local_as_row + local_row_offset) + local_as_col] = // A[global_a_offset]; @@ -278,6 +278,8 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, asm volatile ("fsw ft7, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp)); local_b_tmp += BN_adjusted * row_stride_b * 2; } + + asm volatile ("global_dmem_load_finish_%=:" :: ); } template @@ -464,8 +466,8 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, #endif } #else - global_dmem_load(dim_n, dim_k, block_k * BK, A, B, local_a, local_b, - tid_in_threadblock, block_n, block_m); + global_dmem_load(dim_m, dim_n, dim_k, block_k * BK, A, B, local_a, + local_b, tid_in_threadblock, block_n, block_m); threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); #endif diff --git a/tests/regression/sgemm_tcore/util.hpp b/tests/regression/sgemm_tcore/util.hpp index e1db7bb8..6f8772f1 100644 --- a/tests/regression/sgemm_tcore/util.hpp +++ b/tests/regression/sgemm_tcore/util.hpp @@ -158,6 +158,8 @@ template inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k, const int warp_row, const int wm_iter, const int thread_in_warp) { + asm volatile ("vx_wmma_load_a_start_%=:" :: ); + const int tid = thread_in_warp; const int tg = tid / 4; @@ -174,17 +176,13 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k, // by a factor of two. constexpr int packed_factor = (std::is_same_v ? 2 : 1); constexpr int BK_adjusted = BK / packed_factor; - constexpr int BM_adjusted = BM / packed_factor; const int local_k_adjusted = local_k / packed_factor; - constexpr int smem_A_rows = BM; - constexpr int smem_A_cols = BK_adjusted; - constexpr int smem_AS_rows = BK_adjusted; - constexpr int smem_AS_cols = BM; - // constexpr int smem_AS_rows = BK; - // constexpr int smem_AS_cols = BM_adjusted; - if constexpr (TRANSPOSE_AT_CONSUME) { + // A is stored K-major in smem + constexpr int smem_A_rows = BM; + constexpr int smem_A_cols = BK_adjusted; + // int A_offset = (WM * warp_row + TCM * wm_iter + row) * smem_A_cols; // @perf: bank conflicts @@ -205,15 +203,16 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k, asm volatile("flw f6, %0(%1)" ::"i"(6 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f7, %0(%1)" ::"i"(7 * sizeof(float)), "r"(smem_addr)); } else { - // read smem A tile as-is; bank-conflict-free AS load - // smem A tile is stored column-major - // f8-f15 stores a single row of A + // A is stored M-major in smem + constexpr int smem_AS_rows = BK_adjusted; + constexpr int smem_AS_cols = BM; + const volatile uint8_t *smem_addr; smem_addr = reinterpret_cast( &reinterpret_cast( smem_A)[((local_k_adjusted + 0) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]); - // step to the next row + // f8-f15 stores a single row of A // threads read from different columns; no bank conflicts asm volatile("flw f0, %0(%1)" :: "i"(smem_AS_cols * 0 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f1, %0(%1)" :: "i"(smem_AS_cols * 1 * sizeof(float)), "r"(smem_addr)); @@ -224,6 +223,8 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k, asm volatile("flw f6, %0(%1)" :: "i"(smem_AS_cols * 6 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f7, %0(%1)" :: "i"(smem_AS_cols * 7 * sizeof(float)), "r"(smem_addr)); } + + asm volatile ("vx_wmma_load_a_finish_%=:" :: ); } // `local_k` is assumed to be multiple of TCK @@ -231,6 +232,8 @@ template inline void vx_wmma_load_b(const volatile T *smem_B, const int local_k, const int warp_col, const int wn_iter, const int thread_in_warp) { + asm volatile ("vx_wmma_load_b_start_%=:" :: ); + const int tid = thread_in_warp; const int tg = tid / 4; @@ -244,18 +247,16 @@ inline void vx_wmma_load_b(const volatile T *smem_B, const int local_k, constexpr int BN_adjusted = BN / packed_factor; const int local_k_adjusted = local_k / packed_factor; - // constexpr int smem_B_rows = BK; - // constexpr int smem_B_cols = BN_adjusted; + // B is stored N-major in smem constexpr int smem_B_rows = BK_adjusted; constexpr int smem_B_cols = BN; - // f8-f15 stores a single column of B const volatile uint8_t *smem_addr; smem_addr = reinterpret_cast( &reinterpret_cast( smem_B)[((local_k_adjusted + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]); - // step to the next row + // f8-f15 stores a single column of B // threads read from different columns; no bank conflicts asm volatile("flw f8, %0(%1)" :: "i"(smem_B_cols * 0 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f9, %0(%1)" :: "i"(smem_B_cols * 1 * sizeof(float)), "r"(smem_addr)); @@ -265,6 +266,8 @@ inline void vx_wmma_load_b(const volatile T *smem_B, const int local_k, asm volatile("flw f13, %0(%1)" :: "i"(smem_B_cols * 5 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f14, %0(%1)" :: "i"(smem_B_cols * 6 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f15, %0(%1)" :: "i"(smem_B_cols * 7 * sizeof(float)), "r"(smem_addr)); + + asm volatile ("vx_wmma_load_b_finish_%=:" :: ); } inline void initialize_C(const int dest_reg) { @@ -295,6 +298,8 @@ inline void write_results(const int thread_in_warp, const int warp_col, const int wm_iter, const int dim_n, float *C, const int threadblock_id_x, const int threadblock_id_y) { + asm volatile ("write_results_start_%=:" :: ); + int tid = thread_in_warp; // these are [0, TCM/TCN) @@ -342,6 +347,8 @@ inline void write_results(const int thread_in_warp, const int warp_col, asm volatile ("fsw f30, %0(%1)" :: "i"(4 * sizeof(float)), "r"(gmem_addr_tmp)); asm volatile ("fsw f31, %0(%1)" :: "i"(5 * sizeof(float)), "r"(gmem_addr_tmp)); } + + asm volatile ("write_results_finish_%=:" :: ); } inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count) {