sgemm_impl: Drop volatile quanitifier

doesn't seem to do much & creates excessive type errors.
This commit is contained in:
Hansung Kim
2024-08-19 15:19:35 -07:00
parent 1e042af571
commit e93e54cdec

View File

@@ -246,7 +246,7 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
asm volatile("flw f5, %0(%1)" ::"i"(5 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f6, %0(%1)" ::"i"(6 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f7, %0(%1)" ::"i"(7 * sizeof(float)), "r"(smem_addr));
} else if (layout == MemLayout::MN_major) {
} else if constexpr (layout == MemLayout::MN_major) {
constexpr int smem_AS_rows = BK_adjusted;
constexpr int smem_AS_cols = BM;
@@ -395,7 +395,8 @@ inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count)
}
// Move a single matrix tile from global memory (GMEM) to shared memory (SMEM).
// `dim_col`: column dimension of the global matrix.
// `dim_major`: major dimension of the matrix in GMEM, e.g. if K-major, K; or
// MN-major, M/N.
template <typename T,
MemLayout gmem_layout, // memory layout of the GMEM tile
MemLayout smem_layout, // memory layout of the GMEM tile
@@ -403,7 +404,7 @@ template <typename T,
uint32_t tile_dim_k // column dimension of the SMEM tile
>
__attribute__((always_inline)) inline void
load_tile_to_smem(const uint32_t dim_col, const uint32_t mn_index,
load_tile_to_smem(const uint32_t dim_major, const uint32_t mn_index,
const uint32_t k, const T *global_addr,
volatile T *local_addr, const uint32_t tid_in_threadblock) {
asm volatile("global_dmem_load_start_new_%=:" ::);
@@ -425,8 +426,8 @@ load_tile_to_smem(const uint32_t dim_col, const uint32_t mn_index,
constexpr uint32_t smem_dim_col =
(smem_layout == MemLayout::K_major) ? tile_dim_k_packed : tile_dim_mn;
const uint32_t dim_col_ =
(gmem_layout == MemLayout::K_major) ? dim_col / packed_factor : dim_col;
const uint32_t dim_major_ =
(gmem_layout == MemLayout::K_major) ? dim_major / packed_factor : dim_major;
// FIXME: unsure about this
const uint32_t k_ = k / packed_factor;
@@ -456,7 +457,7 @@ load_tile_to_smem(const uint32_t dim_col, const uint32_t mn_index,
: global_col_mn_major;
const float *global = reinterpret_cast<const float *>(global_addr) +
dim_col_ * global_row + global_col;
dim_major_ * global_row + global_col;
volatile float *local = reinterpret_cast<volatile float *>(local_addr) +
smem_dim_col * local_row_smem + local_col_smem;
@@ -475,26 +476,26 @@ load_tile_to_smem(const uint32_t dim_col, const uint32_t mn_index,
// equivalent code:
//
// *local = *global;
// global += dim_col * row_stride;
// global += dim_major * row_stride;
// local += BN * row_stride;
// read same-column elements into fp registers
asm volatile("flw ft0, (%0)" ::"r"(global));
global += dim_col_ * row_stride;
global += dim_major_ * row_stride;
asm volatile("flw ft1, (%0)" ::"r"(global));
global += dim_col_ * row_stride;
global += dim_major_ * row_stride;
asm volatile("flw ft2, (%0)" ::"r"(global));
global += dim_col_ * row_stride;
global += dim_major_ * row_stride;
asm volatile("flw ft3, (%0)" ::"r"(global));
global += dim_col_ * row_stride;
global += dim_major_ * row_stride;
asm volatile("flw ft4, (%0)" ::"r"(global));
global += dim_col_ * row_stride;
global += dim_major_ * row_stride;
asm volatile("flw ft5, (%0)" ::"r"(global));
global += dim_col_ * row_stride;
global += dim_major_ * row_stride;
asm volatile("flw ft6, (%0)" ::"r"(global));
global += dim_col_ * row_stride;
global += dim_major_ * row_stride;
asm volatile("flw ft7, (%0)" ::"r"(global));
global += dim_col_ * row_stride;
global += dim_major_ * row_stride;
// need to branch because address offset constant in the inline assembly
// cannot be larger than a certain limit
@@ -656,13 +657,11 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
const uint32_t warps_per_threadblock_per_core =
NUM_WARPS / threads_per_threadblock;
volatile T *local_a =
reinterpret_cast<T *>(sharedmem_per_threadblock + smem_a_offset);
volatile T *local_a_buf =
T *local_a = reinterpret_cast<T *>(sharedmem_per_threadblock + smem_a_offset);
T *local_a_buf =
reinterpret_cast<T *>(sharedmem_per_threadblock + smem_a_dbuf_offset);
volatile T *local_b =
reinterpret_cast<T *>(sharedmem_per_threadblock + smem_b_offset);
volatile T *local_b_buf =
T *local_b = reinterpret_cast<T *>(sharedmem_per_threadblock + smem_b_offset);
T *local_b_buf =
reinterpret_cast<T *>(sharedmem_per_threadblock + smem_b_dbuf_offset);
constexpr uint32_t skips =
@@ -831,18 +830,18 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
// consumer code: SMEM->RF and compute
// ----------------------------------------------------------------------
// @perf: this loop spills to stack a lot because of all the flws in
const volatile T *local_a_consume;
const volatile T *local_b_consume;
const T *local_a_consume;
const T *local_b_consume;
if constexpr (GEMMINI_DMA) {
// local_a_consume = (k_index % 2) ? local_a_buf : local_a;
// local_b_consume = (k_index % 2) ? local_b_buf : local_b;
// FIXME: swap multiply with bitshifts
// const uint32_t mask_odd = (block_k & 1) << 31 >> 31;
// const uint32_t mask_even = ((block_k & 1) ^ 1) << 31 >> 31;
// local_a_consume = reinterpret_cast<volatile T *>(
// local_a_consume = reinterpret_cast<T *>(
// (mask_odd & reinterpret_cast<uintmax_t>(local_a_buf)) |
// (mask_even & reinterpret_cast<uintmax_t>(local_a)));
// local_b_consume = reinterpret_cast<volatile T *>(
// local_b_consume = reinterpret_cast<T *>(
// (mask_odd & reinterpret_cast<uintmax_t>(local_b_buf)) |
// (mask_even & reinterpret_cast<uintmax_t>(local_b)));
local_a_consume = local_a + (block_k & 1) * (BM * BK);
@@ -858,7 +857,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
thread_block_gemm_single_tile<T, layout_a, MemLayout::MN_major,
/*write_to_smem=*/false>(
local_a_consume, local_b_consume,
static_cast<volatile T *>(nullptr) /*ignore*/, tid_in_threadblock,
static_cast<T *>(nullptr) /*ignore*/, tid_in_threadblock,
threads_per_threadblock);
if constexpr (GEMMINI_DMA) {