sgemm_impl: Refactor DMA layout remap logic into constexpr func

This commit is contained in:
Hansung Kim
2024-09-03 16:20:31 -07:00
parent 58fa2a3e91
commit ced98a6ff4

View File

@@ -72,7 +72,7 @@ using float_type = float16_t;
#define TRANSPOSE_AT_PRODUCE 0
#define TRANSPOSE_AT_CONSUME 0
#define GEMMINI_DMA 0
#define GEMMINI_DMA 1
#define GEMMINI_DMA_MN_MAJOR 1
#if SMEM_SIZE == 0x4000
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
@@ -200,6 +200,28 @@ inline void vx_wmma(const int dest_reg) {
}
}
// Remap logical row/col coordinate of a matrix element to a memory index that
// follows the 2-level block-row-major layout that Gemmini DMA uses
template <bool use_dma, uint32_t dim_col>
inline constexpr std::pair<uint32_t, uint32_t>
remap_to_gemmini_dma_layout(const uint32_t logical_row,
const uint32_t logical_col) {
static_assert(DIM == 8,
"GEMMINI_DMA layout remapping code only written for DIM == 8");
if constexpr (use_dma) {
constexpr int dim_blocks_in_row = (dim_col / DIM);
const uint32_t row =
(logical_row / dim_blocks_in_row) * DIM + (logical_col / DIM);
const uint32_t col =
(logical_row % dim_blocks_in_row) * DIM + (logical_col % DIM);
return {row, col};
} else {
// pass-through
return {logical_row, logical_col};
}
}
// `local_k` is assumed to be multiple of TCK
template <typename T, MemLayout layout,
uint32_t leading_dim // stride in sizeof(T) between consecutive
@@ -242,24 +264,13 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
// f8-f15 stores a single row of A
const uint32_t smem_logical_row = WM * warp_row + TCM * wm_iter + row;
const uint32_t smem_logical_col = local_k_adjusted + 0; /* FIXME: fp16 adjust necessary? */
uint32_t smem_row;
uint32_t smem_col;
if constexpr (GEMMINI_DMA) {
const uint32_t smem_logical_col =
local_k_adjusted + 0; /* FIXME: fp16 adjust necessary? */
// if using Gemmini DMA, remap logical row/col to Gemmini's 2-level
// block-row-major layout
static_assert(
DIM == 8,
"GEMMINI_DMA layout remapping code only written for DIM == 8");
constexpr int dim_blocks_in_row = (smem_A_cols / DIM);
smem_row = (smem_logical_row / dim_blocks_in_row) * DIM +
(smem_logical_col / DIM);
smem_col = (smem_logical_row % dim_blocks_in_row) * DIM +
(smem_logical_col % DIM);
} else {
smem_row = smem_logical_row;
smem_col = smem_logical_col;
}
const auto [smem_row, smem_col] =
remap_to_gemmini_dma_layout<GEMMINI_DMA, smem_A_cols>(smem_logical_row,
smem_logical_col);
const volatile uint8_t *smem_addr;
smem_addr = reinterpret_cast<const volatile uint8_t *>(
@@ -356,20 +367,11 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k,
const uint32_t smem_logical_row = local_k_adjusted + 0;
const uint32_t smem_logical_col = (WN * warp_col + TCN * wn_iter) + col;
uint32_t smem_row;
uint32_t smem_col;
if constexpr (GEMMINI_DMA) {
// if using Gemmini DMA, remap logical row/col to Gemmini's 2-level
// block-row-major layout
constexpr int dim_blocks_in_row = (smem_B_cols / DIM);
smem_row =
(smem_logical_row / dim_blocks_in_row) * DIM + (smem_logical_col / DIM);
smem_col =
(smem_logical_row % dim_blocks_in_row) * DIM + (smem_logical_col % DIM);
} else {
smem_row = smem_logical_row;
smem_col = smem_logical_col;
}
const auto [smem_row, smem_col] =
remap_to_gemmini_dma_layout<GEMMINI_DMA, smem_B_cols>(smem_logical_row,
smem_logical_col);
const volatile uint8_t *smem_addr;
smem_addr = reinterpret_cast<const volatile uint8_t *>(
@@ -475,6 +477,7 @@ wmma_load_accum(const int thread_in_warp, const int warp_col,
asm volatile("wmma_load_accum_finish_%=:" ::);
}
// Write out the matrix data stored in RF to memory
__attribute__((always_inline)) inline void
wmma_store(const int thread_in_warp, const int warp_col, const int warp_row,
const int wn_iter, const int wm_iter, const int dim_n,