sgemm_impl: Add new block-row-major layout for DMA

This commit is contained in:
Hansung Kim
2024-09-07 16:38:22 -07:00
parent ed9bf6f73e
commit a967c262b1

View File

@@ -70,10 +70,10 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER ==
// To model the case where the A matrix is already stored column-major in GMEM,
// set both to 0.
#define TRANSPOSE_AT_PRODUCE 0
#define TRANSPOSE_AT_CONSUME 1
#define TRANSPOSE_AT_CONSUME 0
#define GEMMINI_DMA 1
#define GEMMINI_DMA_MN_MAJOR 0
#define GEMMINI_DMA_FLEXIBLE_LAYOUT 0
#if SMEM_SIZE == 0x4000
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
#define SMEM_ADDR_Q1 ((float * const) 0xff001000)
@@ -101,6 +101,7 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER ==
enum class MemLayout {
MN_major,
K_major,
block_row_major, // Gemmini DMA
};
inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) {
@@ -253,13 +254,14 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
constexpr int packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
const int local_k_adjusted = local_k / packed_factor;
static_assert(!GEMMINI_DMA || (layout == MemLayout::K_major) ||
GEMMINI_DMA_MN_MAJOR,
"GEMMINI_DMA only supported for K-major A tile");
static_assert(!GEMMINI_DMA || (layout == MemLayout::block_row_major) ||
GEMMINI_DMA_FLEXIBLE_LAYOUT,
"wrong memory layout selected for DMA");
static_assert((layout != MemLayout::K_major) || (FP_SIZE == 32),
"fp16 is not really tested for K-major A layout");
if constexpr (layout == MemLayout::K_major) {
if constexpr (layout == MemLayout::K_major ||
layout == MemLayout::block_row_major) {
constexpr int smem_A_cols = leading_dim;
// f8-f15 stores a single row of A
@@ -269,8 +271,9 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
// if using Gemmini DMA, remap logical row/col to Gemmini's 2-level
// block-row-major layout
const auto [smem_row, smem_col] =
remap_to_gemmini_dma_layout<GEMMINI_DMA, smem_A_cols>(smem_logical_row,
smem_logical_col);
remap_to_gemmini_dma_layout<layout == MemLayout::block_row_major,
smem_A_cols>(smem_logical_row,
smem_logical_col);
const volatile uint8_t *smem_addr;
smem_addr = reinterpret_cast<const volatile uint8_t *>(
@@ -356,8 +359,9 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k,
const int thread_in_warp) {
asm volatile ("wmma_load_b_start_%=:" :: );
static_assert(layout == MemLayout::MN_major,
"only N-major layout for the B tile is supported");
static_assert(
layout == MemLayout::MN_major || layout == MemLayout::block_row_major,
"only N-major or block-row-major layout are supported for the B tile");
const int tid = thread_in_warp;
const int tg = tid / 4;
@@ -379,8 +383,9 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k,
// if using Gemmini DMA, remap logical row/col to Gemmini's 2-level
// block-row-major layout
const auto [smem_row, smem_col] =
remap_to_gemmini_dma_layout<GEMMINI_DMA, smem_B_cols>(smem_logical_row,
smem_logical_col);
remap_to_gemmini_dma_layout<layout == MemLayout::block_row_major,
smem_B_cols>(smem_logical_row,
smem_logical_col);
const volatile uint8_t *smem_addr;
smem_addr = reinterpret_cast<const volatile uint8_t *>(
@@ -388,10 +393,10 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k,
smem_B)[smem_B_cols * smem_row + smem_col]);
// f8-f15 stores a single column of B
// threads read from different columns; no bank conflicts
if constexpr (GEMMINI_DMA) {
// for GEMMINI_DMA, moving rows for the next 7 elements in the same column
// is the same as moving DIM elements forward in the memory because of the
// block-row-major layout
if constexpr (layout == MemLayout::block_row_major) {
// for the block-row-major layout, moving rows for the next 7 elements in
// the same column is the same as moving DIM elements forward in the memory
// because of the block-row-major layout
asm volatile("flw f8, %0(%1)" :: "i"(DIM * 0 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f9, %0(%1)" :: "i"(DIM * 1 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f10, %0(%1)" :: "i"(DIM * 2 * sizeof(float)), "r"(smem_addr));
@@ -1064,8 +1069,12 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
}
constexpr MemLayout layout_a =
TRANSPOSE_AT_CONSUME ? MemLayout::K_major : MemLayout::MN_major;
thread_block_gemm_single_tile<T, layout_a, MemLayout::MN_major,
GEMMINI_DMA ? MemLayout::block_row_major
: (TRANSPOSE_AT_CONSUME ? MemLayout::K_major
: MemLayout::MN_major);
constexpr MemLayout layout_b =
GEMMINI_DMA ? MemLayout::block_row_major : MemLayout::MN_major;
thread_block_gemm_single_tile<T, layout_a, layout_b,
BM, BN, BK, 0, 0,
/*load_accum=*/false,
/*write_to_mem=*/false>(