sgemm_impl: Add new block-row-major layout for DMA

This commit is contained in:
Hansung Kim
2024-09-07 16:38:22 -07:00
parent ed9bf6f73e
commit a967c262b1

View File

@@ -70,10 +70,10 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER ==
// To model the case where the A matrix is already stored column-major in GMEM, // To model the case where the A matrix is already stored column-major in GMEM,
// set both to 0. // set both to 0.
#define TRANSPOSE_AT_PRODUCE 0 #define TRANSPOSE_AT_PRODUCE 0
#define TRANSPOSE_AT_CONSUME 1 #define TRANSPOSE_AT_CONSUME 0
#define GEMMINI_DMA 1 #define GEMMINI_DMA 1
#define GEMMINI_DMA_MN_MAJOR 0 #define GEMMINI_DMA_FLEXIBLE_LAYOUT 0
#if SMEM_SIZE == 0x4000 #if SMEM_SIZE == 0x4000
#define SMEM_ADDR_Q0 ((float * const) 0xff000000) #define SMEM_ADDR_Q0 ((float * const) 0xff000000)
#define SMEM_ADDR_Q1 ((float * const) 0xff001000) #define SMEM_ADDR_Q1 ((float * const) 0xff001000)
@@ -101,6 +101,7 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER ==
enum class MemLayout { enum class MemLayout {
MN_major, MN_major,
K_major, K_major,
block_row_major, // Gemmini DMA
}; };
inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) { inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) {
@@ -253,13 +254,14 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
constexpr int packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1); constexpr int packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
const int local_k_adjusted = local_k / packed_factor; const int local_k_adjusted = local_k / packed_factor;
static_assert(!GEMMINI_DMA || (layout == MemLayout::K_major) || static_assert(!GEMMINI_DMA || (layout == MemLayout::block_row_major) ||
GEMMINI_DMA_MN_MAJOR, GEMMINI_DMA_FLEXIBLE_LAYOUT,
"GEMMINI_DMA only supported for K-major A tile"); "wrong memory layout selected for DMA");
static_assert((layout != MemLayout::K_major) || (FP_SIZE == 32), static_assert((layout != MemLayout::K_major) || (FP_SIZE == 32),
"fp16 is not really tested for K-major A layout"); "fp16 is not really tested for K-major A layout");
if constexpr (layout == MemLayout::K_major) { if constexpr (layout == MemLayout::K_major ||
layout == MemLayout::block_row_major) {
constexpr int smem_A_cols = leading_dim; constexpr int smem_A_cols = leading_dim;
// f8-f15 stores a single row of A // f8-f15 stores a single row of A
@@ -269,8 +271,9 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
// if using Gemmini DMA, remap logical row/col to Gemmini's 2-level // if using Gemmini DMA, remap logical row/col to Gemmini's 2-level
// block-row-major layout // block-row-major layout
const auto [smem_row, smem_col] = const auto [smem_row, smem_col] =
remap_to_gemmini_dma_layout<GEMMINI_DMA, smem_A_cols>(smem_logical_row, remap_to_gemmini_dma_layout<layout == MemLayout::block_row_major,
smem_logical_col); smem_A_cols>(smem_logical_row,
smem_logical_col);
const volatile uint8_t *smem_addr; const volatile uint8_t *smem_addr;
smem_addr = reinterpret_cast<const volatile uint8_t *>( smem_addr = reinterpret_cast<const volatile uint8_t *>(
@@ -356,8 +359,9 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k,
const int thread_in_warp) { const int thread_in_warp) {
asm volatile ("wmma_load_b_start_%=:" :: ); asm volatile ("wmma_load_b_start_%=:" :: );
static_assert(layout == MemLayout::MN_major, static_assert(
"only N-major layout for the B tile is supported"); layout == MemLayout::MN_major || layout == MemLayout::block_row_major,
"only N-major or block-row-major layout are supported for the B tile");
const int tid = thread_in_warp; const int tid = thread_in_warp;
const int tg = tid / 4; const int tg = tid / 4;
@@ -379,8 +383,9 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k,
// if using Gemmini DMA, remap logical row/col to Gemmini's 2-level // if using Gemmini DMA, remap logical row/col to Gemmini's 2-level
// block-row-major layout // block-row-major layout
const auto [smem_row, smem_col] = const auto [smem_row, smem_col] =
remap_to_gemmini_dma_layout<GEMMINI_DMA, smem_B_cols>(smem_logical_row, remap_to_gemmini_dma_layout<layout == MemLayout::block_row_major,
smem_logical_col); smem_B_cols>(smem_logical_row,
smem_logical_col);
const volatile uint8_t *smem_addr; const volatile uint8_t *smem_addr;
smem_addr = reinterpret_cast<const volatile uint8_t *>( smem_addr = reinterpret_cast<const volatile uint8_t *>(
@@ -388,10 +393,10 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k,
smem_B)[smem_B_cols * smem_row + smem_col]); smem_B)[smem_B_cols * smem_row + smem_col]);
// f8-f15 stores a single column of B // f8-f15 stores a single column of B
// threads read from different columns; no bank conflicts // threads read from different columns; no bank conflicts
if constexpr (GEMMINI_DMA) { if constexpr (layout == MemLayout::block_row_major) {
// for GEMMINI_DMA, moving rows for the next 7 elements in the same column // for the block-row-major layout, moving rows for the next 7 elements in
// is the same as moving DIM elements forward in the memory because of the // the same column is the same as moving DIM elements forward in the memory
// block-row-major layout // because of the block-row-major layout
asm volatile("flw f8, %0(%1)" :: "i"(DIM * 0 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f8, %0(%1)" :: "i"(DIM * 0 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f9, %0(%1)" :: "i"(DIM * 1 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f9, %0(%1)" :: "i"(DIM * 1 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f10, %0(%1)" :: "i"(DIM * 2 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f10, %0(%1)" :: "i"(DIM * 2 * sizeof(float)), "r"(smem_addr));
@@ -1064,8 +1069,12 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
} }
constexpr MemLayout layout_a = constexpr MemLayout layout_a =
TRANSPOSE_AT_CONSUME ? MemLayout::K_major : MemLayout::MN_major; GEMMINI_DMA ? MemLayout::block_row_major
thread_block_gemm_single_tile<T, layout_a, MemLayout::MN_major, : (TRANSPOSE_AT_CONSUME ? MemLayout::K_major
: MemLayout::MN_major);
constexpr MemLayout layout_b =
GEMMINI_DMA ? MemLayout::block_row_major : MemLayout::MN_major;
thread_block_gemm_single_tile<T, layout_a, layout_b,
BM, BN, BK, 0, 0, BM, BN, BK, 0, 0,
/*load_accum=*/false, /*load_accum=*/false,
/*write_to_mem=*/false>( /*write_to_mem=*/false>(