sgemm_impl: Add DMA_FAST option; fix dbuf offset for dma
This commit is contained in:
@@ -90,13 +90,22 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
thread_block_gemm<float_type, threads_per_threadblock,
|
thread_block_gemm<float_type, threads_per_threadblock,
|
||||||
/*write_to_gmem=*/true,
|
/*write_to_gmem=*/true,
|
||||||
/*smem_a_offset=*/0,
|
/*smem_a_offset=*/0,
|
||||||
/*smem_a_dbuf_offset=*/0,
|
#ifdef GEMMINI_DMA
|
||||||
|
/*smem_a_dbuf_offset=*/1 * 128 * 128 * sizeof(float_type),
|
||||||
|
/*smem_b_offset=*/2 * 128 * 128 * sizeof(float_type),
|
||||||
|
/*smem_b_dbuf_offset=*/3 * 128 * 128 * sizeof(float_type)
|
||||||
|
// FIXME: above offsets are hardcoded to agree with CISC
|
||||||
|
// spadQuartile
|
||||||
|
#else
|
||||||
|
/*smem_a_dbuf_offset=*/1 * BM * BK * sizeof(float_type),
|
||||||
/*smem_b_offset=*/2 * BM * BK * sizeof(float_type),
|
/*smem_b_offset=*/2 * BM * BK * sizeof(float_type),
|
||||||
/*smem_b_dbuf_offset=*/2 * BM * BK * sizeof(float_type)>(
|
/*smem_b_dbuf_offset=*/(2 * BM * BK + BK * BN) * sizeof(float_type)
|
||||||
(const float_type *)arg->addr_a, (const float_type *)arg->addr_b,
|
#endif
|
||||||
(float *)arg->addr_c, arg->dim_m, arg->dim_n, arg->dim_k,
|
>((const float_type *)arg->addr_a,
|
||||||
tid_in_threadblock, threadblocks_per_cluster, threadblock_id_in_cluster,
|
(const float_type *)arg->addr_b, (float *)arg->addr_c,
|
||||||
sharedmem_per_threadblock);
|
arg->dim_m, arg->dim_n, arg->dim_k, tid_in_threadblock,
|
||||||
|
threadblocks_per_cluster, threadblock_id_in_cluster,
|
||||||
|
sharedmem_per_threadblock);
|
||||||
|
|
||||||
float *gmem_tmp_d0 = reinterpret_cast<float *>(0xd0000000UL);
|
float *gmem_tmp_d0 = reinterpret_cast<float *>(0xd0000000UL);
|
||||||
float *gmem_tmp_d1 = reinterpret_cast<float *>(0xd1000000UL);
|
float *gmem_tmp_d1 = reinterpret_cast<float *>(0xd1000000UL);
|
||||||
|
|||||||
@@ -72,8 +72,9 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER ==
|
|||||||
#define TRANSPOSE_AT_PRODUCE 0
|
#define TRANSPOSE_AT_PRODUCE 0
|
||||||
#define TRANSPOSE_AT_CONSUME 0
|
#define TRANSPOSE_AT_CONSUME 0
|
||||||
|
|
||||||
#define GEMMINI_DMA 0
|
#define GEMMINI_DMA 1
|
||||||
#define GEMMINI_DMA_FLEXIBLE_LAYOUT 0
|
#define GEMMINI_DMA_FAST 1
|
||||||
|
#define GEMMINI_DMA_FLEXIBLE_LAYOUT 1
|
||||||
#if SMEM_SIZE == 0x4000
|
#if SMEM_SIZE == 0x4000
|
||||||
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
|
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
|
||||||
#define SMEM_ADDR_Q1 ((float * const) 0xff001000)
|
#define SMEM_ADDR_Q1 ((float * const) 0xff001000)
|
||||||
@@ -207,7 +208,7 @@ template <bool use_dma, uint32_t dim_col>
|
|||||||
inline constexpr std::pair<uint32_t, uint32_t>
|
inline constexpr std::pair<uint32_t, uint32_t>
|
||||||
remap_to_gemmini_dma_layout(const uint32_t logical_row,
|
remap_to_gemmini_dma_layout(const uint32_t logical_row,
|
||||||
const uint32_t logical_col) {
|
const uint32_t logical_col) {
|
||||||
static_assert(GEMMINI_DMA_FLEXIBLE_LAYOUT || DIM == 8,
|
static_assert(!use_dma || DIM == 8,
|
||||||
"GEMMINI_DMA layout remapping code only written for DIM == 8");
|
"GEMMINI_DMA layout remapping code only written for DIM == 8");
|
||||||
|
|
||||||
if constexpr (use_dma) {
|
if constexpr (use_dma) {
|
||||||
@@ -915,7 +916,6 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
// pipeline initiation
|
// pipeline initiation
|
||||||
if (tid_in_threadblock == 0) {
|
if (tid_in_threadblock == 0) {
|
||||||
// configure dma gmem address to load from
|
// configure dma gmem address to load from
|
||||||
// FIXME: block_k is wrong
|
|
||||||
ROCC_INSTRUCTION_RS1_RS2(
|
ROCC_INSTRUCTION_RS1_RS2(
|
||||||
XCUSTOM_ACC,
|
XCUSTOM_ACC,
|
||||||
(uint64_t)(A + block_m * BM * dim_k + /*block_k:*/0 * BK),
|
(uint64_t)(A + block_m * BM * dim_k + /*block_k:*/0 * BK),
|
||||||
@@ -963,7 +963,6 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
#if (GEMMINI_DMA == 1)
|
#if (GEMMINI_DMA == 1)
|
||||||
if ((tid_in_threadblock == 0) && ((block_k * BK) != (dim_k - BK))) {
|
if ((tid_in_threadblock == 0) && ((block_k * BK) != (dim_k - BK))) {
|
||||||
// configure dma gmem address to load from
|
// configure dma gmem address to load from
|
||||||
// FIXME: block_k is wrong
|
|
||||||
ROCC_INSTRUCTION_RS1_RS2(
|
ROCC_INSTRUCTION_RS1_RS2(
|
||||||
XCUSTOM_ACC,
|
XCUSTOM_ACC,
|
||||||
(uint64_t)(A + block_m * BM * dim_k + (block_k + 1/*runahead*/) * BK),
|
(uint64_t)(A + block_m * BM * dim_k + (block_k + 1/*runahead*/) * BK),
|
||||||
@@ -976,7 +975,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
// block_k is even: opcode 11 (write to local_a_buf)
|
// block_k is even: opcode 11 (write to local_a_buf)
|
||||||
// block_k is odd: opcode 10 (write to local_a)
|
// block_k is odd: opcode 10 (write to local_a)
|
||||||
const uint32_t opcode = 11 - (block_k & 1);
|
const uint32_t opcode = 11 - (block_k & 1);
|
||||||
GEMMINI_CISC_CMD_R(opcode);
|
GEMMINI_CISC_CMD_I(opcode);
|
||||||
// // TODO: branch is probably slow
|
// // TODO: branch is probably slow
|
||||||
// if (block_k & 1) {
|
// if (block_k & 1) {
|
||||||
// GEMMINI_CISC_CMD_I(12);
|
// GEMMINI_CISC_CMD_I(12);
|
||||||
@@ -1061,8 +1060,12 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
// local_b_consume = reinterpret_cast<T *>(
|
// local_b_consume = reinterpret_cast<T *>(
|
||||||
// (mask_odd & reinterpret_cast<uintmax_t>(local_b_buf)) |
|
// (mask_odd & reinterpret_cast<uintmax_t>(local_b_buf)) |
|
||||||
// (mask_even & reinterpret_cast<uintmax_t>(local_b)));
|
// (mask_even & reinterpret_cast<uintmax_t>(local_b)));
|
||||||
local_a_consume = local_a + (block_k & 1) * (BM * BK);
|
local_a_consume = local_a + (block_k & 1) *
|
||||||
local_b_consume = local_b + (block_k & 1) * (BK * BN);
|
(smem_a_dbuf_offset - smem_a_offset) /
|
||||||
|
sizeof(T);
|
||||||
|
local_b_consume = local_b + (block_k & 1) *
|
||||||
|
(smem_b_dbuf_offset - smem_b_offset) /
|
||||||
|
sizeof(T);
|
||||||
} else {
|
} else {
|
||||||
// no double-buffering without DMA
|
// no double-buffering without DMA
|
||||||
local_a_consume = local_a;
|
local_a_consume = local_a;
|
||||||
@@ -1071,11 +1074,14 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
asm volatile("dbuf_sel_end_%=:" ::);
|
asm volatile("dbuf_sel_end_%=:" ::);
|
||||||
|
|
||||||
constexpr MemLayout layout_a =
|
constexpr MemLayout layout_a =
|
||||||
GEMMINI_DMA ? MemLayout::block_row_major
|
GEMMINI_DMA ? (GEMMINI_DMA_FAST ? MemLayout::MN_major
|
||||||
|
: MemLayout::block_row_major)
|
||||||
: (TRANSPOSE_AT_CONSUME ? MemLayout::K_major
|
: (TRANSPOSE_AT_CONSUME ? MemLayout::K_major
|
||||||
: MemLayout::MN_major);
|
: MemLayout::MN_major);
|
||||||
constexpr MemLayout layout_b =
|
constexpr MemLayout layout_b =
|
||||||
GEMMINI_DMA ? MemLayout::block_row_major : MemLayout::MN_major;
|
GEMMINI_DMA ? (GEMMINI_DMA_FAST ? MemLayout::MN_major
|
||||||
|
: MemLayout::block_row_major)
|
||||||
|
: MemLayout::MN_major;
|
||||||
thread_block_gemm_single_tile<T, layout_a, layout_b,
|
thread_block_gemm_single_tile<T, layout_a, layout_b,
|
||||||
BM, BN, BK, 0, 0,
|
BM, BN, BK, 0, 0,
|
||||||
/*load_accum=*/false,
|
/*load_accum=*/false,
|
||||||
|
|||||||
Reference in New Issue
Block a user