sgemm_impl: Drop volatile quanitifier
doesn't seem to do much & creates excessive type errors.
This commit is contained in:
@@ -246,7 +246,7 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
|||||||
asm volatile("flw f5, %0(%1)" ::"i"(5 * sizeof(float)), "r"(smem_addr));
|
asm volatile("flw f5, %0(%1)" ::"i"(5 * sizeof(float)), "r"(smem_addr));
|
||||||
asm volatile("flw f6, %0(%1)" ::"i"(6 * sizeof(float)), "r"(smem_addr));
|
asm volatile("flw f6, %0(%1)" ::"i"(6 * sizeof(float)), "r"(smem_addr));
|
||||||
asm volatile("flw f7, %0(%1)" ::"i"(7 * sizeof(float)), "r"(smem_addr));
|
asm volatile("flw f7, %0(%1)" ::"i"(7 * sizeof(float)), "r"(smem_addr));
|
||||||
} else if (layout == MemLayout::MN_major) {
|
} else if constexpr (layout == MemLayout::MN_major) {
|
||||||
constexpr int smem_AS_rows = BK_adjusted;
|
constexpr int smem_AS_rows = BK_adjusted;
|
||||||
constexpr int smem_AS_cols = BM;
|
constexpr int smem_AS_cols = BM;
|
||||||
|
|
||||||
@@ -395,7 +395,8 @@ inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Move a single matrix tile from global memory (GMEM) to shared memory (SMEM).
|
// Move a single matrix tile from global memory (GMEM) to shared memory (SMEM).
|
||||||
// `dim_col`: column dimension of the global matrix.
|
// `dim_major`: major dimension of the matrix in GMEM, e.g. if K-major, K; or
|
||||||
|
// MN-major, M/N.
|
||||||
template <typename T,
|
template <typename T,
|
||||||
MemLayout gmem_layout, // memory layout of the GMEM tile
|
MemLayout gmem_layout, // memory layout of the GMEM tile
|
||||||
MemLayout smem_layout, // memory layout of the GMEM tile
|
MemLayout smem_layout, // memory layout of the GMEM tile
|
||||||
@@ -403,7 +404,7 @@ template <typename T,
|
|||||||
uint32_t tile_dim_k // column dimension of the SMEM tile
|
uint32_t tile_dim_k // column dimension of the SMEM tile
|
||||||
>
|
>
|
||||||
__attribute__((always_inline)) inline void
|
__attribute__((always_inline)) inline void
|
||||||
load_tile_to_smem(const uint32_t dim_col, const uint32_t mn_index,
|
load_tile_to_smem(const uint32_t dim_major, const uint32_t mn_index,
|
||||||
const uint32_t k, const T *global_addr,
|
const uint32_t k, const T *global_addr,
|
||||||
volatile T *local_addr, const uint32_t tid_in_threadblock) {
|
volatile T *local_addr, const uint32_t tid_in_threadblock) {
|
||||||
asm volatile("global_dmem_load_start_new_%=:" ::);
|
asm volatile("global_dmem_load_start_new_%=:" ::);
|
||||||
@@ -425,8 +426,8 @@ load_tile_to_smem(const uint32_t dim_col, const uint32_t mn_index,
|
|||||||
constexpr uint32_t smem_dim_col =
|
constexpr uint32_t smem_dim_col =
|
||||||
(smem_layout == MemLayout::K_major) ? tile_dim_k_packed : tile_dim_mn;
|
(smem_layout == MemLayout::K_major) ? tile_dim_k_packed : tile_dim_mn;
|
||||||
|
|
||||||
const uint32_t dim_col_ =
|
const uint32_t dim_major_ =
|
||||||
(gmem_layout == MemLayout::K_major) ? dim_col / packed_factor : dim_col;
|
(gmem_layout == MemLayout::K_major) ? dim_major / packed_factor : dim_major;
|
||||||
// FIXME: unsure about this
|
// FIXME: unsure about this
|
||||||
const uint32_t k_ = k / packed_factor;
|
const uint32_t k_ = k / packed_factor;
|
||||||
|
|
||||||
@@ -456,7 +457,7 @@ load_tile_to_smem(const uint32_t dim_col, const uint32_t mn_index,
|
|||||||
: global_col_mn_major;
|
: global_col_mn_major;
|
||||||
|
|
||||||
const float *global = reinterpret_cast<const float *>(global_addr) +
|
const float *global = reinterpret_cast<const float *>(global_addr) +
|
||||||
dim_col_ * global_row + global_col;
|
dim_major_ * global_row + global_col;
|
||||||
volatile float *local = reinterpret_cast<volatile float *>(local_addr) +
|
volatile float *local = reinterpret_cast<volatile float *>(local_addr) +
|
||||||
smem_dim_col * local_row_smem + local_col_smem;
|
smem_dim_col * local_row_smem + local_col_smem;
|
||||||
|
|
||||||
@@ -475,26 +476,26 @@ load_tile_to_smem(const uint32_t dim_col, const uint32_t mn_index,
|
|||||||
// equivalent code:
|
// equivalent code:
|
||||||
//
|
//
|
||||||
// *local = *global;
|
// *local = *global;
|
||||||
// global += dim_col * row_stride;
|
// global += dim_major * row_stride;
|
||||||
// local += BN * row_stride;
|
// local += BN * row_stride;
|
||||||
|
|
||||||
// read same-column elements into fp registers
|
// read same-column elements into fp registers
|
||||||
asm volatile("flw ft0, (%0)" ::"r"(global));
|
asm volatile("flw ft0, (%0)" ::"r"(global));
|
||||||
global += dim_col_ * row_stride;
|
global += dim_major_ * row_stride;
|
||||||
asm volatile("flw ft1, (%0)" ::"r"(global));
|
asm volatile("flw ft1, (%0)" ::"r"(global));
|
||||||
global += dim_col_ * row_stride;
|
global += dim_major_ * row_stride;
|
||||||
asm volatile("flw ft2, (%0)" ::"r"(global));
|
asm volatile("flw ft2, (%0)" ::"r"(global));
|
||||||
global += dim_col_ * row_stride;
|
global += dim_major_ * row_stride;
|
||||||
asm volatile("flw ft3, (%0)" ::"r"(global));
|
asm volatile("flw ft3, (%0)" ::"r"(global));
|
||||||
global += dim_col_ * row_stride;
|
global += dim_major_ * row_stride;
|
||||||
asm volatile("flw ft4, (%0)" ::"r"(global));
|
asm volatile("flw ft4, (%0)" ::"r"(global));
|
||||||
global += dim_col_ * row_stride;
|
global += dim_major_ * row_stride;
|
||||||
asm volatile("flw ft5, (%0)" ::"r"(global));
|
asm volatile("flw ft5, (%0)" ::"r"(global));
|
||||||
global += dim_col_ * row_stride;
|
global += dim_major_ * row_stride;
|
||||||
asm volatile("flw ft6, (%0)" ::"r"(global));
|
asm volatile("flw ft6, (%0)" ::"r"(global));
|
||||||
global += dim_col_ * row_stride;
|
global += dim_major_ * row_stride;
|
||||||
asm volatile("flw ft7, (%0)" ::"r"(global));
|
asm volatile("flw ft7, (%0)" ::"r"(global));
|
||||||
global += dim_col_ * row_stride;
|
global += dim_major_ * row_stride;
|
||||||
|
|
||||||
// need to branch because address offset constant in the inline assembly
|
// need to branch because address offset constant in the inline assembly
|
||||||
// cannot be larger than a certain limit
|
// cannot be larger than a certain limit
|
||||||
@@ -656,13 +657,11 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
const uint32_t warps_per_threadblock_per_core =
|
const uint32_t warps_per_threadblock_per_core =
|
||||||
NUM_WARPS / threads_per_threadblock;
|
NUM_WARPS / threads_per_threadblock;
|
||||||
|
|
||||||
volatile T *local_a =
|
T *local_a = reinterpret_cast<T *>(sharedmem_per_threadblock + smem_a_offset);
|
||||||
reinterpret_cast<T *>(sharedmem_per_threadblock + smem_a_offset);
|
T *local_a_buf =
|
||||||
volatile T *local_a_buf =
|
|
||||||
reinterpret_cast<T *>(sharedmem_per_threadblock + smem_a_dbuf_offset);
|
reinterpret_cast<T *>(sharedmem_per_threadblock + smem_a_dbuf_offset);
|
||||||
volatile T *local_b =
|
T *local_b = reinterpret_cast<T *>(sharedmem_per_threadblock + smem_b_offset);
|
||||||
reinterpret_cast<T *>(sharedmem_per_threadblock + smem_b_offset);
|
T *local_b_buf =
|
||||||
volatile T *local_b_buf =
|
|
||||||
reinterpret_cast<T *>(sharedmem_per_threadblock + smem_b_dbuf_offset);
|
reinterpret_cast<T *>(sharedmem_per_threadblock + smem_b_dbuf_offset);
|
||||||
|
|
||||||
constexpr uint32_t skips =
|
constexpr uint32_t skips =
|
||||||
@@ -831,18 +830,18 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
// consumer code: SMEM->RF and compute
|
// consumer code: SMEM->RF and compute
|
||||||
// ----------------------------------------------------------------------
|
// ----------------------------------------------------------------------
|
||||||
// @perf: this loop spills to stack a lot because of all the flws in
|
// @perf: this loop spills to stack a lot because of all the flws in
|
||||||
const volatile T *local_a_consume;
|
const T *local_a_consume;
|
||||||
const volatile T *local_b_consume;
|
const T *local_b_consume;
|
||||||
if constexpr (GEMMINI_DMA) {
|
if constexpr (GEMMINI_DMA) {
|
||||||
// local_a_consume = (k_index % 2) ? local_a_buf : local_a;
|
// local_a_consume = (k_index % 2) ? local_a_buf : local_a;
|
||||||
// local_b_consume = (k_index % 2) ? local_b_buf : local_b;
|
// local_b_consume = (k_index % 2) ? local_b_buf : local_b;
|
||||||
// FIXME: swap multiply with bitshifts
|
// FIXME: swap multiply with bitshifts
|
||||||
// const uint32_t mask_odd = (block_k & 1) << 31 >> 31;
|
// const uint32_t mask_odd = (block_k & 1) << 31 >> 31;
|
||||||
// const uint32_t mask_even = ((block_k & 1) ^ 1) << 31 >> 31;
|
// const uint32_t mask_even = ((block_k & 1) ^ 1) << 31 >> 31;
|
||||||
// local_a_consume = reinterpret_cast<volatile T *>(
|
// local_a_consume = reinterpret_cast<T *>(
|
||||||
// (mask_odd & reinterpret_cast<uintmax_t>(local_a_buf)) |
|
// (mask_odd & reinterpret_cast<uintmax_t>(local_a_buf)) |
|
||||||
// (mask_even & reinterpret_cast<uintmax_t>(local_a)));
|
// (mask_even & reinterpret_cast<uintmax_t>(local_a)));
|
||||||
// local_b_consume = reinterpret_cast<volatile T *>(
|
// local_b_consume = reinterpret_cast<T *>(
|
||||||
// (mask_odd & reinterpret_cast<uintmax_t>(local_b_buf)) |
|
// (mask_odd & reinterpret_cast<uintmax_t>(local_b_buf)) |
|
||||||
// (mask_even & reinterpret_cast<uintmax_t>(local_b)));
|
// (mask_even & reinterpret_cast<uintmax_t>(local_b)));
|
||||||
local_a_consume = local_a + (block_k & 1) * (BM * BK);
|
local_a_consume = local_a + (block_k & 1) * (BM * BK);
|
||||||
@@ -858,7 +857,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
thread_block_gemm_single_tile<T, layout_a, MemLayout::MN_major,
|
thread_block_gemm_single_tile<T, layout_a, MemLayout::MN_major,
|
||||||
/*write_to_smem=*/false>(
|
/*write_to_smem=*/false>(
|
||||||
local_a_consume, local_b_consume,
|
local_a_consume, local_b_consume,
|
||||||
static_cast<volatile T *>(nullptr) /*ignore*/, tid_in_threadblock,
|
static_cast<T *>(nullptr) /*ignore*/, tid_in_threadblock,
|
||||||
threads_per_threadblock);
|
threads_per_threadblock);
|
||||||
|
|
||||||
if constexpr (GEMMINI_DMA) {
|
if constexpr (GEMMINI_DMA) {
|
||||||
|
|||||||
Reference in New Issue
Block a user