From 46b5047775b6338efb7a5c5e98df302c2b9c453b Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sun, 18 Aug 2024 22:21:17 -0700 Subject: [PATCH] sgemm_impl: Remove GMEM_COALESCED_A option Uncoalesced GMEM accesses is verified to yield slow performance and the relevant code is not used anymore; remove the cruft --- tests/regression/sgemm_tcore/sgemm_impl.hpp | 157 +++----------------- 1 file changed, 20 insertions(+), 137 deletions(-) diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index 4b57e28c..9daaae2a 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -71,14 +71,6 @@ using float_type = float16_t; // set both to 0. #define TRANSPOSE_AT_PRODUCE 1 #define TRANSPOSE_AT_CONSUME 0 -// GMEM_COALESCED: When TRANSPOSE_AT_PRODUCE == 1 (i.e. transpose at -// GMEM->SMEM), determines whether we do bank-conflict-free accesses for -// 1: GMEM loads of A matrix, or -// 0: SMEM stores of A matrix. -// -// Usually, GMEM_COALESCED==1 yields better performance since the memory -// behavior of GMEM is more sensitive to bank conflicts. -#define GMEM_COALESCED_A 1 #define GEMMINI_DMA 0 #if SMEM_SIZE == 0x4000 @@ -403,8 +395,7 @@ template __attribute__((always_inline)) inline void global_dmem_load_new(const uint32_t dim_col, const uint32_t mn_index, @@ -450,9 +441,6 @@ global_dmem_load_new(const uint32_t dim_col, const uint32_t mn_index, // FIXME: don't hardcode this here constexpr uint32_t threads_per_threadblock = (BM * BN) / ELEM_PER_THREAD; - static_assert(gmem_contiguous == true, - "currently only supports contiguous accesses in GMEM"); - const uint32_t global_row_mn_major = k_ + local_row_gmem; const uint32_t global_col_mn_major = smem_dim_col * mn_index + local_col_gmem; const uint32_t global_row_k_major = gmem_dim_row * mn_index + local_row_gmem; @@ -505,12 +493,9 @@ global_dmem_load_new(const uint32_t dim_col, const uint32_t mn_index, asm volatile("flw ft7, (%0)" ::"r"(global)); global += dim_col_ * row_stride; - // do we need to do transposed write? + // need to branch because address offset constant in the inline assembly + // cannot be larger than a certain limit if constexpr (!transposed_write) { - static_assert(gmem_layout == MemLayout::MN_major); - - // if not, do the same along-the-column accesses for registers as we did - // for gmem asm volatile("fsw ft0, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 * sizeof(float)), "r"(local)); @@ -540,11 +525,11 @@ global_dmem_load_new(const uint32_t dim_col, const uint32_t mn_index, "r"(local)); local += smem_dim_col * row_stride * 2; } else { + // currently, tensor core hardware only supports MN-major SMEM tile + // layout for correct results static_assert(gmem_layout == MemLayout::K_major); static_assert(smem_layout == MemLayout::MN_major); - // if yes, write the registers along the row, doing a transpose - // @perf: this will incur bank conflicts in smem asm volatile("fsw ft0, %0(%1)" ::"i"(row_stride * 0 * sizeof(float)), "r"(local)); asm volatile("fsw ft1, %0(%1)" ::"i"(row_stride * 1 * sizeof(float)), @@ -568,121 +553,6 @@ global_dmem_load_new(const uint32_t dim_col, const uint32_t mn_index, asm volatile("global_dmem_load_finish_new_%=:" ::); } -// TODO: reduce args by passing leading A/B dimensions -template -__attribute__((always_inline)) -inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const uint32_t dim_k, - const uint32_t k, const T *A, const T *B, - volatile T *local_a, volatile T *local_b, - const uint32_t tid_in_threadblock, - const uint32_t threadblock_id_x, - const uint32_t threadblock_id_y) { - asm volatile ("global_dmem_load_start_%=:" :: ); - - // In fp16 mode, bit-pack two fp16 elements into each fp32 element, and do - // data movement at the fp32 granularity. Assuming that the matrix is stored - // row-major in GMEM, the packed fp16 pairs belong to the same row, - // neighboring columns; therefore, it essentially becomes equivalent to - // moving a fp32 matrix whose column dimensions (dim_k/BK/k) are compressed - // by a factor of two. - constexpr uint32_t packed_factor = (std::is_same_v ? 2 : 1); - constexpr uint32_t BK_adjusted = BK / packed_factor; - const uint32_t dim_k_adjusted = dim_k / packed_factor; - const uint32_t k_adjusted = k / packed_factor; - - const uint32_t local_a_row = tid_in_threadblock / BK_adjusted; - const uint32_t local_a_col = tid_in_threadblock % BK_adjusted; - const uint32_t local_as_row = tid_in_threadblock / BM; - const uint32_t local_as_col = tid_in_threadblock % BM; - const uint32_t local_b_row = tid_in_threadblock / BN; - const uint32_t local_b_col = tid_in_threadblock % BN; - - // FIXME: need fix for fp16? - constexpr uint32_t threads_per_threadblock = (BM * BN) / ELEM_PER_THREAD; - - // Data move from GMEM to SMEM - // - // Make sure global offset values for A and B are contiguous between - // neighboring threads to ensure GMEM coalescing. - // - // TODO: Sharedmem swizzling is important here - - // move A - if constexpr (!TRANSPOSE_AT_PRODUCE) { - global_dmem_load_new( - dim_m, threadblock_id_y, k, A, local_a, tid_in_threadblock); - } else { - if constexpr (!GMEM_COALESCED_A) { - // !GMEM_COALESCED_A: threads do uncoalesced read from neighboring row in - // GMEM, writes to neighboring cols in SMEM - constexpr uint32_t row_stride_as = threads_per_threadblock / BM; - const uint32_t global_a_row = BM * threadblock_id_y + local_as_col; - const float *global_a = - reinterpret_cast(A) + dim_k_adjusted * global_a_row + (k_adjusted + local_as_row); - volatile float *local_a_tmp = - reinterpret_cast(local_a) + BM * local_as_row + local_as_col; - - static_assert( - row_stride_as * 8 <= BK_adjusted, - "manual loop unrolling condition not met; consider increasing BK"); - static_assert( - (BK_adjusted % (row_stride_as * 8)) == 0, - "manual loop unrolling condition not met; BK should be power-of-two"); - -#pragma GCC unroll 1 - for (uint32_t local_row_offset = 0; local_row_offset < BK_adjusted; - local_row_offset += row_stride_as * 8) { - // const uint32_t global_a_offset = - // dim_k_adjusted * (global_a_row) + (k + local_as_row + local_row_offset); - // local_a[BM * (local_as_row + local_row_offset) + local_as_col] = - // A[global_a_offset]; - - // @perf: bank conflicts - asm volatile ("flw ft0, (%0)" :: "r"(global_a)); - global_a += row_stride_as; - asm volatile ("flw ft1, (%0)" :: "r"(global_a)); - global_a += row_stride_as; - asm volatile ("flw ft2, (%0)" :: "r"(global_a)); - global_a += row_stride_as; - asm volatile ("flw ft3, (%0)" :: "r"(global_a)); - global_a += row_stride_as; - asm volatile ("flw ft4, (%0)" :: "r"(global_a)); - global_a += row_stride_as; - asm volatile ("flw ft5, (%0)" :: "r"(global_a)); - global_a += row_stride_as; - asm volatile ("flw ft6, (%0)" :: "r"(global_a)); - global_a += row_stride_as; - asm volatile ("flw ft7, (%0)" :: "r"(global_a)); - global_a += row_stride_as; - - // NOTE: stride is fixed to word size , i.e. sizeof(float) = 4, - // regardless of fp16 or fp32. Since Vortex core does not support fp16, - // load things at word granularity and reinterpret bits inside the - // tensor core. - asm volatile ("fsw ft0, %0(%1)" :: "i"(BM * row_stride_as * 0 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft1, %0(%1)" :: "i"(BM * row_stride_as * 1 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft2, %0(%1)" :: "i"(BM * row_stride_as * 2 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft3, %0(%1)" :: "i"(BM * row_stride_as * 3 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft4, %0(%1)" :: "i"(BM * row_stride_as * 4 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft5, %0(%1)" :: "i"(BM * row_stride_as * 5 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft6, %0(%1)" :: "i"(BM * row_stride_as * 6 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft7, %0(%1)" :: "i"(BM * row_stride_as * 7 * sizeof(float)), "r"(local_a_tmp)); - local_a_tmp += BM * row_stride_as * 8; - } - } else { - global_dmem_load_new(dim_k, threadblock_id_y, k, A, local_a, - tid_in_threadblock); - } - } // end move A - - // move B - global_dmem_load_new( - dim_n, threadblock_id_x, k, B, local_b, tid_in_threadblock); - - asm volatile ("global_dmem_load_finish_%=:" :: ); -} - // Do a single tile*tile matrix multiplication using the matrix data stored in // SMEM. Useful in fused kernels where GEMMs are done at a per-tile scope. template (dim_m, dim_n, dim_k, block_k * BK, A, B, local_a, - local_b, tid_in_threadblock, block_n, block_m); + // move A + if constexpr (!TRANSPOSE_AT_PRODUCE) { + global_dmem_load_new(dim_m, block_m, block_k * BK, A, local_a, + tid_in_threadblock); + } else { + global_dmem_load_new(dim_k, block_m, block_k * BK, A, local_a, + tid_in_threadblock); + } + + // move B + global_dmem_load_new(dim_n, block_n, block_k * BK, B, local_b, + tid_in_threadblock); threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core);