From e93e54cdec9975086c0777c67db7cf2ccfc626e6 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Mon, 19 Aug 2024 15:19:35 -0700 Subject: [PATCH] sgemm_impl: Drop volatile quanitifier doesn't seem to do much & creates excessive type errors. --- tests/regression/sgemm_tcore/sgemm_impl.hpp | 51 ++++++++++----------- 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index fe4f7586..f500280e 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -246,7 +246,7 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k, asm volatile("flw f5, %0(%1)" ::"i"(5 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f6, %0(%1)" ::"i"(6 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f7, %0(%1)" ::"i"(7 * sizeof(float)), "r"(smem_addr)); - } else if (layout == MemLayout::MN_major) { + } else if constexpr (layout == MemLayout::MN_major) { constexpr int smem_AS_rows = BK_adjusted; constexpr int smem_AS_cols = BM; @@ -395,7 +395,8 @@ inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count) } // Move a single matrix tile from global memory (GMEM) to shared memory (SMEM). -// `dim_col`: column dimension of the global matrix. +// `dim_major`: major dimension of the matrix in GMEM, e.g. if K-major, K; or +// MN-major, M/N. template __attribute__((always_inline)) inline void -load_tile_to_smem(const uint32_t dim_col, const uint32_t mn_index, +load_tile_to_smem(const uint32_t dim_major, const uint32_t mn_index, const uint32_t k, const T *global_addr, volatile T *local_addr, const uint32_t tid_in_threadblock) { asm volatile("global_dmem_load_start_new_%=:" ::); @@ -425,8 +426,8 @@ load_tile_to_smem(const uint32_t dim_col, const uint32_t mn_index, constexpr uint32_t smem_dim_col = (smem_layout == MemLayout::K_major) ? tile_dim_k_packed : tile_dim_mn; - const uint32_t dim_col_ = - (gmem_layout == MemLayout::K_major) ? dim_col / packed_factor : dim_col; + const uint32_t dim_major_ = + (gmem_layout == MemLayout::K_major) ? dim_major / packed_factor : dim_major; // FIXME: unsure about this const uint32_t k_ = k / packed_factor; @@ -456,7 +457,7 @@ load_tile_to_smem(const uint32_t dim_col, const uint32_t mn_index, : global_col_mn_major; const float *global = reinterpret_cast(global_addr) + - dim_col_ * global_row + global_col; + dim_major_ * global_row + global_col; volatile float *local = reinterpret_cast(local_addr) + smem_dim_col * local_row_smem + local_col_smem; @@ -475,26 +476,26 @@ load_tile_to_smem(const uint32_t dim_col, const uint32_t mn_index, // equivalent code: // // *local = *global; - // global += dim_col * row_stride; + // global += dim_major * row_stride; // local += BN * row_stride; // read same-column elements into fp registers asm volatile("flw ft0, (%0)" ::"r"(global)); - global += dim_col_ * row_stride; + global += dim_major_ * row_stride; asm volatile("flw ft1, (%0)" ::"r"(global)); - global += dim_col_ * row_stride; + global += dim_major_ * row_stride; asm volatile("flw ft2, (%0)" ::"r"(global)); - global += dim_col_ * row_stride; + global += dim_major_ * row_stride; asm volatile("flw ft3, (%0)" ::"r"(global)); - global += dim_col_ * row_stride; + global += dim_major_ * row_stride; asm volatile("flw ft4, (%0)" ::"r"(global)); - global += dim_col_ * row_stride; + global += dim_major_ * row_stride; asm volatile("flw ft5, (%0)" ::"r"(global)); - global += dim_col_ * row_stride; + global += dim_major_ * row_stride; asm volatile("flw ft6, (%0)" ::"r"(global)); - global += dim_col_ * row_stride; + global += dim_major_ * row_stride; asm volatile("flw ft7, (%0)" ::"r"(global)); - global += dim_col_ * row_stride; + global += dim_major_ * row_stride; // need to branch because address offset constant in the inline assembly // cannot be larger than a certain limit @@ -656,13 +657,11 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, const uint32_t warps_per_threadblock_per_core = NUM_WARPS / threads_per_threadblock; - volatile T *local_a = - reinterpret_cast(sharedmem_per_threadblock + smem_a_offset); - volatile T *local_a_buf = + T *local_a = reinterpret_cast(sharedmem_per_threadblock + smem_a_offset); + T *local_a_buf = reinterpret_cast(sharedmem_per_threadblock + smem_a_dbuf_offset); - volatile T *local_b = - reinterpret_cast(sharedmem_per_threadblock + smem_b_offset); - volatile T *local_b_buf = + T *local_b = reinterpret_cast(sharedmem_per_threadblock + smem_b_offset); + T *local_b_buf = reinterpret_cast(sharedmem_per_threadblock + smem_b_dbuf_offset); constexpr uint32_t skips = @@ -831,18 +830,18 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, // consumer code: SMEM->RF and compute // ---------------------------------------------------------------------- // @perf: this loop spills to stack a lot because of all the flws in - const volatile T *local_a_consume; - const volatile T *local_b_consume; + const T *local_a_consume; + const T *local_b_consume; if constexpr (GEMMINI_DMA) { // local_a_consume = (k_index % 2) ? local_a_buf : local_a; // local_b_consume = (k_index % 2) ? local_b_buf : local_b; // FIXME: swap multiply with bitshifts // const uint32_t mask_odd = (block_k & 1) << 31 >> 31; // const uint32_t mask_even = ((block_k & 1) ^ 1) << 31 >> 31; - // local_a_consume = reinterpret_cast( + // local_a_consume = reinterpret_cast( // (mask_odd & reinterpret_cast(local_a_buf)) | // (mask_even & reinterpret_cast(local_a))); - // local_b_consume = reinterpret_cast( + // local_b_consume = reinterpret_cast( // (mask_odd & reinterpret_cast(local_b_buf)) | // (mask_even & reinterpret_cast(local_b))); local_a_consume = local_a + (block_k & 1) * (BM * BK); @@ -858,7 +857,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, thread_block_gemm_single_tile( local_a_consume, local_b_consume, - static_cast(nullptr) /*ignore*/, tid_in_threadblock, + static_cast(nullptr) /*ignore*/, tid_in_threadblock, threads_per_threadblock); if constexpr (GEMMINI_DMA) {