sgemm_impl: Refactor DMA layout remap logic into constexpr func
This commit is contained in:
@@ -72,7 +72,7 @@ using float_type = float16_t;
|
||||
#define TRANSPOSE_AT_PRODUCE 0
|
||||
#define TRANSPOSE_AT_CONSUME 0
|
||||
|
||||
#define GEMMINI_DMA 0
|
||||
#define GEMMINI_DMA 1
|
||||
#define GEMMINI_DMA_MN_MAJOR 1
|
||||
#if SMEM_SIZE == 0x4000
|
||||
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
|
||||
@@ -200,6 +200,28 @@ inline void vx_wmma(const int dest_reg) {
|
||||
}
|
||||
}
|
||||
|
||||
// Remap logical row/col coordinate of a matrix element to a memory index that
|
||||
// follows the 2-level block-row-major layout that Gemmini DMA uses
|
||||
template <bool use_dma, uint32_t dim_col>
|
||||
inline constexpr std::pair<uint32_t, uint32_t>
|
||||
remap_to_gemmini_dma_layout(const uint32_t logical_row,
|
||||
const uint32_t logical_col) {
|
||||
static_assert(DIM == 8,
|
||||
"GEMMINI_DMA layout remapping code only written for DIM == 8");
|
||||
|
||||
if constexpr (use_dma) {
|
||||
constexpr int dim_blocks_in_row = (dim_col / DIM);
|
||||
const uint32_t row =
|
||||
(logical_row / dim_blocks_in_row) * DIM + (logical_col / DIM);
|
||||
const uint32_t col =
|
||||
(logical_row % dim_blocks_in_row) * DIM + (logical_col % DIM);
|
||||
return {row, col};
|
||||
} else {
|
||||
// pass-through
|
||||
return {logical_row, logical_col};
|
||||
}
|
||||
}
|
||||
|
||||
// `local_k` is assumed to be multiple of TCK
|
||||
template <typename T, MemLayout layout,
|
||||
uint32_t leading_dim // stride in sizeof(T) between consecutive
|
||||
@@ -242,24 +264,13 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
||||
|
||||
// f8-f15 stores a single row of A
|
||||
const uint32_t smem_logical_row = WM * warp_row + TCM * wm_iter + row;
|
||||
const uint32_t smem_logical_col = local_k_adjusted + 0; /* FIXME: fp16 adjust necessary? */
|
||||
uint32_t smem_row;
|
||||
uint32_t smem_col;
|
||||
if constexpr (GEMMINI_DMA) {
|
||||
// if using Gemmini DMA, remap logical row/col to Gemmini's 2-level
|
||||
// block-row-major layout
|
||||
static_assert(
|
||||
DIM == 8,
|
||||
"GEMMINI_DMA layout remapping code only written for DIM == 8");
|
||||
constexpr int dim_blocks_in_row = (smem_A_cols / DIM);
|
||||
smem_row = (smem_logical_row / dim_blocks_in_row) * DIM +
|
||||
(smem_logical_col / DIM);
|
||||
smem_col = (smem_logical_row % dim_blocks_in_row) * DIM +
|
||||
(smem_logical_col % DIM);
|
||||
} else {
|
||||
smem_row = smem_logical_row;
|
||||
smem_col = smem_logical_col;
|
||||
}
|
||||
const uint32_t smem_logical_col =
|
||||
local_k_adjusted + 0; /* FIXME: fp16 adjust necessary? */
|
||||
// if using Gemmini DMA, remap logical row/col to Gemmini's 2-level
|
||||
// block-row-major layout
|
||||
const auto [smem_row, smem_col] =
|
||||
remap_to_gemmini_dma_layout<GEMMINI_DMA, smem_A_cols>(smem_logical_row,
|
||||
smem_logical_col);
|
||||
|
||||
const volatile uint8_t *smem_addr;
|
||||
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
||||
@@ -356,20 +367,11 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k,
|
||||
|
||||
const uint32_t smem_logical_row = local_k_adjusted + 0;
|
||||
const uint32_t smem_logical_col = (WN * warp_col + TCN * wn_iter) + col;
|
||||
uint32_t smem_row;
|
||||
uint32_t smem_col;
|
||||
if constexpr (GEMMINI_DMA) {
|
||||
// if using Gemmini DMA, remap logical row/col to Gemmini's 2-level
|
||||
// block-row-major layout
|
||||
constexpr int dim_blocks_in_row = (smem_B_cols / DIM);
|
||||
smem_row =
|
||||
(smem_logical_row / dim_blocks_in_row) * DIM + (smem_logical_col / DIM);
|
||||
smem_col =
|
||||
(smem_logical_row % dim_blocks_in_row) * DIM + (smem_logical_col % DIM);
|
||||
} else {
|
||||
smem_row = smem_logical_row;
|
||||
smem_col = smem_logical_col;
|
||||
}
|
||||
// if using Gemmini DMA, remap logical row/col to Gemmini's 2-level
|
||||
// block-row-major layout
|
||||
const auto [smem_row, smem_col] =
|
||||
remap_to_gemmini_dma_layout<GEMMINI_DMA, smem_B_cols>(smem_logical_row,
|
||||
smem_logical_col);
|
||||
|
||||
const volatile uint8_t *smem_addr;
|
||||
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
||||
@@ -475,6 +477,7 @@ wmma_load_accum(const int thread_in_warp, const int warp_col,
|
||||
asm volatile("wmma_load_accum_finish_%=:" ::);
|
||||
}
|
||||
|
||||
// Write out the matrix data stored in RF to memory
|
||||
__attribute__((always_inline)) inline void
|
||||
wmma_store(const int thread_in_warp, const int warp_col, const int warp_row,
|
||||
const int wn_iter, const int wm_iter, const int dim_n,
|
||||
|
||||
Reference in New Issue
Block a user