sgemm_impl: Split out smem addr gen to functions

so that the addr gen code can also be used for wgmma.
This commit is contained in:
Hansung Kim
2024-10-29 01:30:48 -07:00
parent ae98ae6e93
commit bd7a8e39b9

View File

@@ -276,10 +276,11 @@ template <typename T, MemLayout layout,
// becomes the stride between the 1st M-dim
// vector and the 2nd M-dim vector.
>
inline void wmma_load_a(volatile const T *smem_A, const int local_k,
const int warp_row, const int wm_iter,
const int thread_in_warp) {
asm volatile ("wmma_load_a_start_%=:" :: );
inline volatile const uint8_t *
generate_smem_addr_a(volatile const T *smem_A, const int local_k,
const int warp_row, const int wm_iter,
const int thread_in_warp) {
asm volatile ("generate_smem_addr_a_start_%=:" :: );
const int tid = thread_in_warp;
const int tg = tid / 4;
@@ -316,10 +317,36 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
smem_A_cols>(smem_logical_row,
smem_logical_col);
const volatile uint8_t *smem_addr;
smem_addr = reinterpret_cast<const volatile uint8_t *>(
return reinterpret_cast<const volatile uint8_t *>(
&reinterpret_cast<const volatile float *>(
smem_A)[smem_A_cols * smem_row + smem_col]);
} else if constexpr (layout == MemLayout::MN_major) {
constexpr int smem_AS_cols = leading_dim;
return reinterpret_cast<const volatile uint8_t *>(
&reinterpret_cast<const volatile float *>(
smem_A)[((local_k_adjusted + 0) * smem_AS_cols) +
(WM * warp_row + TCM * wm_iter) + row]);
} else {
static_assert(layout ==
MemLayout::K_major /* fake cond that is always false */,
"unsupported memory layout");
}
asm volatile ("generate_smem_addr_a_finish_%=:" :: );
}
template <typename T, MemLayout layout, uint32_t leading_dim>
inline void wmma_load_a(volatile const T *smem_A, const int local_k,
const int warp_row, const int wm_iter,
const int thread_in_warp) {
asm volatile ("wmma_load_a_start_%=:" :: );
if constexpr (layout == MemLayout::K_major ||
layout == MemLayout::block_row_major) {
const volatile uint8_t *smem_addr =
generate_smem_addr_a<T, layout, leading_dim>(smem_A, local_k, warp_row,
wm_iter, thread_in_warp);
// step to the next column
// @perf: bank conflicts; threads read from different rows
// below is correct for GEMMINI_DMA; smem_col is always a multiple of 8,
@@ -336,11 +363,9 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
} else if constexpr (layout == MemLayout::MN_major) {
constexpr int smem_AS_cols = leading_dim;
const volatile uint8_t *smem_addr;
smem_addr = reinterpret_cast<const volatile uint8_t *>(
&reinterpret_cast<const volatile float *>(
smem_A)[((local_k_adjusted + 0) * smem_AS_cols) +
(WM * warp_row + TCM * wm_iter) + row]);
const volatile uint8_t *smem_addr =
generate_smem_addr_a<T, layout, leading_dim>(smem_A, local_k, warp_row,
wm_iter, thread_in_warp);
// f8-f15 stores a single row of A
// threads read from different columns; no bank conflicts
// asm volatile("flw f0, %0(%1)" :: "i"(smem_AS_cols * 0 * sizeof(float)), "r"(smem_addr));
@@ -393,12 +418,13 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
}
// `local_k` is assumed to be multiple of TCK
template <typename T, MemLayout layout, uint32_t tile_dim_m,
uint32_t tile_dim_n, uint32_t tile_dim_k>
inline void wmma_load_b(const volatile T *smem_B, const int local_k,
const int warp_col, const int wn_iter,
const int thread_in_warp) {
asm volatile ("wmma_load_b_start_%=:" :: );
template <typename T, MemLayout layout, uint32_t leading_dim,
uint32_t tile_dim_k>
inline volatile const uint8_t *
generate_smem_addr_b(const volatile T *smem_B, const int local_k,
const int warp_col, const int wn_iter,
const int thread_in_warp) {
asm volatile ("generate_smem_addr_b_start_%=:" :: );
static_assert(
layout == MemLayout::MN_major || layout == MemLayout::block_row_major,
@@ -417,7 +443,7 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k,
const int local_k_adjusted = local_k / packed_factor;
// B is stored N-major in smem
constexpr int smem_B_cols = tile_dim_n;
constexpr int smem_B_cols = leading_dim;
const uint32_t smem_logical_row = local_k_adjusted + 0;
const uint32_t smem_logical_col = (WN * warp_col + TCN * wn_iter) + col;
@@ -428,10 +454,27 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k,
smem_B_cols>(smem_logical_row,
smem_logical_col);
const volatile uint8_t *smem_addr;
smem_addr = reinterpret_cast<const volatile uint8_t *>(
return reinterpret_cast<const volatile uint8_t *>(
&reinterpret_cast<const volatile float *>(
smem_B)[smem_B_cols * smem_row + smem_col]);
asm volatile ("generate_smem_addr_b_finish_%=:" :: );
}
template <typename T, MemLayout layout, uint32_t leading_dim,
uint32_t tile_dim_k>
inline void wmma_load_b(const volatile T *smem_B, const int local_k,
const int warp_col, const int wn_iter,
const int thread_in_warp) {
asm volatile ("wmma_load_b_start_%=:" :: );
// B is stored N-major in smem
constexpr int smem_B_cols = leading_dim;
const volatile uint8_t *smem_addr =
generate_smem_addr_b<T, layout, leading_dim, tile_dim_k>(
smem_B, local_k, warp_col, wn_iter, thread_in_warp);
// f8-f15 stores a single column of B
// threads read from different columns; no bank conflicts
if constexpr (layout == MemLayout::block_row_major) {
@@ -849,7 +892,7 @@ __attribute__((always_inline)) inline void thread_block_gemm_single_tile(
// SMEM -> RF
static_assert(leading_dim_b == 0,
"leading_dim for wmma_load_b is not implemented yet");
wmma_load_b<T, layout_b, tile_dim_m, tile_dim_n,
wmma_load_b<T, layout_b, tile_dim_n,
tile_dim_k /*leading_dim_b is TODO */>(
local_b, local_k, warp_col, wn_iter, tid_in_warp);
#pragma GCC unroll 2