sgemm_impl: Add DMA_FAST option; fix dbuf offset for dma

This commit is contained in:
Hansung Kim
2024-09-08 14:56:48 -07:00
parent 42913c00c4
commit 443a37be6c
2 changed files with 31 additions and 16 deletions

View File

@@ -90,13 +90,22 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
thread_block_gemm<float_type, threads_per_threadblock,
/*write_to_gmem=*/true,
/*smem_a_offset=*/0,
/*smem_a_dbuf_offset=*/0,
#ifdef GEMMINI_DMA
/*smem_a_dbuf_offset=*/1 * 128 * 128 * sizeof(float_type),
/*smem_b_offset=*/2 * 128 * 128 * sizeof(float_type),
/*smem_b_dbuf_offset=*/3 * 128 * 128 * sizeof(float_type)
// FIXME: above offsets are hardcoded to agree with CISC
// spadQuartile
#else
/*smem_a_dbuf_offset=*/1 * BM * BK * sizeof(float_type),
/*smem_b_offset=*/2 * BM * BK * sizeof(float_type),
/*smem_b_dbuf_offset=*/2 * BM * BK * sizeof(float_type)>(
(const float_type *)arg->addr_a, (const float_type *)arg->addr_b,
(float *)arg->addr_c, arg->dim_m, arg->dim_n, arg->dim_k,
tid_in_threadblock, threadblocks_per_cluster, threadblock_id_in_cluster,
sharedmem_per_threadblock);
/*smem_b_dbuf_offset=*/(2 * BM * BK + BK * BN) * sizeof(float_type)
#endif
>((const float_type *)arg->addr_a,
(const float_type *)arg->addr_b, (float *)arg->addr_c,
arg->dim_m, arg->dim_n, arg->dim_k, tid_in_threadblock,
threadblocks_per_cluster, threadblock_id_in_cluster,
sharedmem_per_threadblock);
float *gmem_tmp_d0 = reinterpret_cast<float *>(0xd0000000UL);
float *gmem_tmp_d1 = reinterpret_cast<float *>(0xd1000000UL);

View File

@@ -72,8 +72,9 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER ==
#define TRANSPOSE_AT_PRODUCE 0
#define TRANSPOSE_AT_CONSUME 0
#define GEMMINI_DMA 0
#define GEMMINI_DMA_FLEXIBLE_LAYOUT 0
#define GEMMINI_DMA 1
#define GEMMINI_DMA_FAST 1
#define GEMMINI_DMA_FLEXIBLE_LAYOUT 1
#if SMEM_SIZE == 0x4000
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
#define SMEM_ADDR_Q1 ((float * const) 0xff001000)
@@ -207,7 +208,7 @@ template <bool use_dma, uint32_t dim_col>
inline constexpr std::pair<uint32_t, uint32_t>
remap_to_gemmini_dma_layout(const uint32_t logical_row,
const uint32_t logical_col) {
static_assert(GEMMINI_DMA_FLEXIBLE_LAYOUT || DIM == 8,
static_assert(!use_dma || DIM == 8,
"GEMMINI_DMA layout remapping code only written for DIM == 8");
if constexpr (use_dma) {
@@ -915,7 +916,6 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
// pipeline initiation
if (tid_in_threadblock == 0) {
// configure dma gmem address to load from
// FIXME: block_k is wrong
ROCC_INSTRUCTION_RS1_RS2(
XCUSTOM_ACC,
(uint64_t)(A + block_m * BM * dim_k + /*block_k:*/0 * BK),
@@ -963,7 +963,6 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
#if (GEMMINI_DMA == 1)
if ((tid_in_threadblock == 0) && ((block_k * BK) != (dim_k - BK))) {
// configure dma gmem address to load from
// FIXME: block_k is wrong
ROCC_INSTRUCTION_RS1_RS2(
XCUSTOM_ACC,
(uint64_t)(A + block_m * BM * dim_k + (block_k + 1/*runahead*/) * BK),
@@ -976,7 +975,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
// block_k is even: opcode 11 (write to local_a_buf)
// block_k is odd: opcode 10 (write to local_a)
const uint32_t opcode = 11 - (block_k & 1);
GEMMINI_CISC_CMD_R(opcode);
GEMMINI_CISC_CMD_I(opcode);
// // TODO: branch is probably slow
// if (block_k & 1) {
// GEMMINI_CISC_CMD_I(12);
@@ -1061,8 +1060,12 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
// local_b_consume = reinterpret_cast<T *>(
// (mask_odd & reinterpret_cast<uintmax_t>(local_b_buf)) |
// (mask_even & reinterpret_cast<uintmax_t>(local_b)));
local_a_consume = local_a + (block_k & 1) * (BM * BK);
local_b_consume = local_b + (block_k & 1) * (BK * BN);
local_a_consume = local_a + (block_k & 1) *
(smem_a_dbuf_offset - smem_a_offset) /
sizeof(T);
local_b_consume = local_b + (block_k & 1) *
(smem_b_dbuf_offset - smem_b_offset) /
sizeof(T);
} else {
// no double-buffering without DMA
local_a_consume = local_a;
@@ -1071,11 +1074,14 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
asm volatile("dbuf_sel_end_%=:" ::);
constexpr MemLayout layout_a =
GEMMINI_DMA ? MemLayout::block_row_major
GEMMINI_DMA ? (GEMMINI_DMA_FAST ? MemLayout::MN_major
: MemLayout::block_row_major)
: (TRANSPOSE_AT_CONSUME ? MemLayout::K_major
: MemLayout::MN_major);
constexpr MemLayout layout_b =
GEMMINI_DMA ? MemLayout::block_row_major : MemLayout::MN_major;
GEMMINI_DMA ? (GEMMINI_DMA_FAST ? MemLayout::MN_major
: MemLayout::block_row_major)
: MemLayout::MN_major;
thread_block_gemm_single_tile<T, layout_a, layout_b,
BM, BN, BK, 0, 0,
/*load_accum=*/false,