sgemm_impl: Add new block-row-major layout for DMA
This commit is contained in:
@@ -70,10 +70,10 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER ==
|
|||||||
// To model the case where the A matrix is already stored column-major in GMEM,
|
// To model the case where the A matrix is already stored column-major in GMEM,
|
||||||
// set both to 0.
|
// set both to 0.
|
||||||
#define TRANSPOSE_AT_PRODUCE 0
|
#define TRANSPOSE_AT_PRODUCE 0
|
||||||
#define TRANSPOSE_AT_CONSUME 1
|
#define TRANSPOSE_AT_CONSUME 0
|
||||||
|
|
||||||
#define GEMMINI_DMA 1
|
#define GEMMINI_DMA 1
|
||||||
#define GEMMINI_DMA_MN_MAJOR 0
|
#define GEMMINI_DMA_FLEXIBLE_LAYOUT 0
|
||||||
#if SMEM_SIZE == 0x4000
|
#if SMEM_SIZE == 0x4000
|
||||||
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
|
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
|
||||||
#define SMEM_ADDR_Q1 ((float * const) 0xff001000)
|
#define SMEM_ADDR_Q1 ((float * const) 0xff001000)
|
||||||
@@ -101,6 +101,7 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER ==
|
|||||||
enum class MemLayout {
|
enum class MemLayout {
|
||||||
MN_major,
|
MN_major,
|
||||||
K_major,
|
K_major,
|
||||||
|
block_row_major, // Gemmini DMA
|
||||||
};
|
};
|
||||||
|
|
||||||
inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) {
|
inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) {
|
||||||
@@ -253,13 +254,14 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
|||||||
constexpr int packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
|
constexpr int packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
|
||||||
const int local_k_adjusted = local_k / packed_factor;
|
const int local_k_adjusted = local_k / packed_factor;
|
||||||
|
|
||||||
static_assert(!GEMMINI_DMA || (layout == MemLayout::K_major) ||
|
static_assert(!GEMMINI_DMA || (layout == MemLayout::block_row_major) ||
|
||||||
GEMMINI_DMA_MN_MAJOR,
|
GEMMINI_DMA_FLEXIBLE_LAYOUT,
|
||||||
"GEMMINI_DMA only supported for K-major A tile");
|
"wrong memory layout selected for DMA");
|
||||||
static_assert((layout != MemLayout::K_major) || (FP_SIZE == 32),
|
static_assert((layout != MemLayout::K_major) || (FP_SIZE == 32),
|
||||||
"fp16 is not really tested for K-major A layout");
|
"fp16 is not really tested for K-major A layout");
|
||||||
|
|
||||||
if constexpr (layout == MemLayout::K_major) {
|
if constexpr (layout == MemLayout::K_major ||
|
||||||
|
layout == MemLayout::block_row_major) {
|
||||||
constexpr int smem_A_cols = leading_dim;
|
constexpr int smem_A_cols = leading_dim;
|
||||||
|
|
||||||
// f8-f15 stores a single row of A
|
// f8-f15 stores a single row of A
|
||||||
@@ -269,8 +271,9 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
|||||||
// if using Gemmini DMA, remap logical row/col to Gemmini's 2-level
|
// if using Gemmini DMA, remap logical row/col to Gemmini's 2-level
|
||||||
// block-row-major layout
|
// block-row-major layout
|
||||||
const auto [smem_row, smem_col] =
|
const auto [smem_row, smem_col] =
|
||||||
remap_to_gemmini_dma_layout<GEMMINI_DMA, smem_A_cols>(smem_logical_row,
|
remap_to_gemmini_dma_layout<layout == MemLayout::block_row_major,
|
||||||
smem_logical_col);
|
smem_A_cols>(smem_logical_row,
|
||||||
|
smem_logical_col);
|
||||||
|
|
||||||
const volatile uint8_t *smem_addr;
|
const volatile uint8_t *smem_addr;
|
||||||
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
||||||
@@ -356,8 +359,9 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k,
|
|||||||
const int thread_in_warp) {
|
const int thread_in_warp) {
|
||||||
asm volatile ("wmma_load_b_start_%=:" :: );
|
asm volatile ("wmma_load_b_start_%=:" :: );
|
||||||
|
|
||||||
static_assert(layout == MemLayout::MN_major,
|
static_assert(
|
||||||
"only N-major layout for the B tile is supported");
|
layout == MemLayout::MN_major || layout == MemLayout::block_row_major,
|
||||||
|
"only N-major or block-row-major layout are supported for the B tile");
|
||||||
|
|
||||||
const int tid = thread_in_warp;
|
const int tid = thread_in_warp;
|
||||||
const int tg = tid / 4;
|
const int tg = tid / 4;
|
||||||
@@ -379,8 +383,9 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k,
|
|||||||
// if using Gemmini DMA, remap logical row/col to Gemmini's 2-level
|
// if using Gemmini DMA, remap logical row/col to Gemmini's 2-level
|
||||||
// block-row-major layout
|
// block-row-major layout
|
||||||
const auto [smem_row, smem_col] =
|
const auto [smem_row, smem_col] =
|
||||||
remap_to_gemmini_dma_layout<GEMMINI_DMA, smem_B_cols>(smem_logical_row,
|
remap_to_gemmini_dma_layout<layout == MemLayout::block_row_major,
|
||||||
smem_logical_col);
|
smem_B_cols>(smem_logical_row,
|
||||||
|
smem_logical_col);
|
||||||
|
|
||||||
const volatile uint8_t *smem_addr;
|
const volatile uint8_t *smem_addr;
|
||||||
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
||||||
@@ -388,10 +393,10 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k,
|
|||||||
smem_B)[smem_B_cols * smem_row + smem_col]);
|
smem_B)[smem_B_cols * smem_row + smem_col]);
|
||||||
// f8-f15 stores a single column of B
|
// f8-f15 stores a single column of B
|
||||||
// threads read from different columns; no bank conflicts
|
// threads read from different columns; no bank conflicts
|
||||||
if constexpr (GEMMINI_DMA) {
|
if constexpr (layout == MemLayout::block_row_major) {
|
||||||
// for GEMMINI_DMA, moving rows for the next 7 elements in the same column
|
// for the block-row-major layout, moving rows for the next 7 elements in
|
||||||
// is the same as moving DIM elements forward in the memory because of the
|
// the same column is the same as moving DIM elements forward in the memory
|
||||||
// block-row-major layout
|
// because of the block-row-major layout
|
||||||
asm volatile("flw f8, %0(%1)" :: "i"(DIM * 0 * sizeof(float)), "r"(smem_addr));
|
asm volatile("flw f8, %0(%1)" :: "i"(DIM * 0 * sizeof(float)), "r"(smem_addr));
|
||||||
asm volatile("flw f9, %0(%1)" :: "i"(DIM * 1 * sizeof(float)), "r"(smem_addr));
|
asm volatile("flw f9, %0(%1)" :: "i"(DIM * 1 * sizeof(float)), "r"(smem_addr));
|
||||||
asm volatile("flw f10, %0(%1)" :: "i"(DIM * 2 * sizeof(float)), "r"(smem_addr));
|
asm volatile("flw f10, %0(%1)" :: "i"(DIM * 2 * sizeof(float)), "r"(smem_addr));
|
||||||
@@ -1064,8 +1069,12 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
}
|
}
|
||||||
|
|
||||||
constexpr MemLayout layout_a =
|
constexpr MemLayout layout_a =
|
||||||
TRANSPOSE_AT_CONSUME ? MemLayout::K_major : MemLayout::MN_major;
|
GEMMINI_DMA ? MemLayout::block_row_major
|
||||||
thread_block_gemm_single_tile<T, layout_a, MemLayout::MN_major,
|
: (TRANSPOSE_AT_CONSUME ? MemLayout::K_major
|
||||||
|
: MemLayout::MN_major);
|
||||||
|
constexpr MemLayout layout_b =
|
||||||
|
GEMMINI_DMA ? MemLayout::block_row_major : MemLayout::MN_major;
|
||||||
|
thread_block_gemm_single_tile<T, layout_a, layout_b,
|
||||||
BM, BN, BK, 0, 0,
|
BM, BN, BK, 0, 0,
|
||||||
/*load_accum=*/false,
|
/*load_accum=*/false,
|
||||||
/*write_to_mem=*/false>(
|
/*write_to_mem=*/false>(
|
||||||
|
|||||||
Reference in New Issue
Block a user