flash: Replace CISC with RISC
spadQuartile in hw does not match spad addresses in kernel; match them later for optimization.
This commit is contained in:
@@ -199,9 +199,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
constexpr uint32_t skips_only_a =
|
||||
loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/1, /*skip_ldd=*/1,
|
||||
/*skip_ex=*/1, /*skip_stc=*/1);
|
||||
constexpr uint32_t skips_only_b =
|
||||
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/0, /*skip_ldd=*/1,
|
||||
/*skip_ex=*/1, /*skip_stc=*/1);
|
||||
constexpr uint32_t skips_mvout_spad =
|
||||
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1,
|
||||
/*skip_ex=*/1, /*skip_stc=*/0);
|
||||
constexpr uint32_t skips_matmul =
|
||||
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1,
|
||||
/*skip_ex=*/0, /*skip_stc=*/1);
|
||||
constexpr uint32_t skips_matmul_preload =
|
||||
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/0,
|
||||
/*skip_ex=*/0, /*skip_stc=*/1);
|
||||
@@ -231,12 +237,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// other warps behave differently on the branch condition.
|
||||
// threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||
|
||||
// move Q and K into SMEM before the loop starts
|
||||
//
|
||||
static_assert(B_ROW == B_COL, "currently only supports square tiles");
|
||||
|
||||
// move Q and K into SMEM before the loop starts
|
||||
//
|
||||
asm volatile("dma_move_start_%=:" ::);
|
||||
|
||||
if (tid_in_warpgroup == 0) {
|
||||
// make sure to read from the correct row of Q
|
||||
const float *gmem_Q_tile = gmem_Q + HEADDIM * B_ROW * warpgroup_id;
|
||||
@@ -249,35 +254,48 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
|
||||
gemmini_fence();
|
||||
|
||||
#define GEMMINI_DMA_CISC
|
||||
// #define GEMMINI_DMA_CISC
|
||||
#ifdef GEMMINI_DMA_CISC
|
||||
// the target addresses of this should match with spad_addr_Q0 and
|
||||
// spad_addr_K0 set in this kernel
|
||||
GEMMINI_CISC_CMD_I(10);
|
||||
gemmini_fence();
|
||||
|
||||
// need to also move to Q1 for the next iteration
|
||||
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q_tile),
|
||||
(uint64_t)(gmem_K_tile), k_LOOP_WS_CONFIG_ADDRS_AB)
|
||||
GEMMINI_CISC_CMD_R((dim_seqlen << 20) | (HEADDIM << 8) |
|
||||
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
|
||||
gemmini_fence();
|
||||
GEMMINI_CISC_CMD_I(11);
|
||||
gemmini_fence();
|
||||
#else
|
||||
// do DMA
|
||||
//
|
||||
// among other things, this also configures CONFIG_BOUNDS so that the
|
||||
// DMA knows the full matrix dimensions
|
||||
sp_tiled_matmul_full_spad_ws(
|
||||
spad_addr_Q, spad_addr_K,
|
||||
/*spad_D=*/0, /*spad_C=*/spad_addr_S,
|
||||
spad_addr_Q0, spad_addr_K0,
|
||||
/*spad_D=*/0, /*spad_C=*/spad_addr_S0/*bogus*/,
|
||||
/*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM),
|
||||
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
|
||||
gemmini_fence();
|
||||
#endif
|
||||
gemmini_fence();
|
||||
|
||||
// need to also move Q to spad_addr_Q1 for the next iteration
|
||||
// FIXME: re-configure necessary?
|
||||
gmem_K_tile = gmem_K + (B_COL * 1);
|
||||
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q_tile),
|
||||
(uint64_t)(gmem_K_tile), k_LOOP_WS_CONFIG_ADDRS_AB)
|
||||
GEMMINI_CISC_CMD_R((dim_seqlen << 20) | (HEADDIM << 8) |
|
||||
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
|
||||
gemmini_fence();
|
||||
#ifdef GEMMINI_DMA_CISC
|
||||
// GEMMINI_CISC_CMD_I(11);
|
||||
#else
|
||||
sp_tiled_matmul_full_spad_ws(
|
||||
spad_addr_Q1, spad_addr_K1/*bogus*/,
|
||||
/*spad_D=*/0, /*spad_C=*/spad_addr_S0/*bogus*/,
|
||||
/*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM),
|
||||
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a);
|
||||
#endif
|
||||
|
||||
gemmini_fence();
|
||||
gemmini_fence();
|
||||
|
||||
// re-configure DMA for K and V load that will later happen in the loop
|
||||
// GMEM addr stride for K
|
||||
@@ -376,6 +394,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// do matmul
|
||||
// among other things, this also configures CONFIG_BOUNDS so that the
|
||||
// DMA knows the full matrix dimensions
|
||||
gemmini_fence();
|
||||
sp_tiled_matmul_full_spad_ws(
|
||||
spad_addr_P_consume, spad_addr_V_consume,
|
||||
/*spad_D=*/spad_addr_O, /*spad_C=*/spad_addr_O,
|
||||
@@ -437,11 +456,18 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
asm volatile("gemm_qk_start_%=:" ::);
|
||||
|
||||
if (tid_in_warpgroup == 0) {
|
||||
gemmini_fence();
|
||||
// 0,2,.: opcode 0 (quartile 0/2, no accum)
|
||||
// 1,3,.: opcode 3 (quartile 1/3, no accum)
|
||||
const uint32_t opcode = 3 * (tile_k & 1);
|
||||
gemmini_fence();
|
||||
GEMMINI_CISC_CMD_I(opcode);
|
||||
//GEMMINI_CISC_CMD_I(opcode);
|
||||
sp_tiled_matmul_full_spad_ws(
|
||||
spad_addr_Q, spad_addr_K_consume,
|
||||
/*spad_D=*/0, /*spad_C=*/spad_addr_S_produce,
|
||||
/*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM),
|
||||
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul);
|
||||
|
||||
gemmini_fence();
|
||||
gemmini_fence();
|
||||
@@ -574,7 +600,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
}
|
||||
}
|
||||
|
||||
// data movement for K and V
|
||||
// data move for K and V
|
||||
//
|
||||
// Q stays in SMEM for the entire loop
|
||||
asm volatile("move_k_v_start_%=:" ::);
|
||||
@@ -606,7 +632,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
sp_tiled_matmul_full_spad_ws(
|
||||
spad_addr_K_produce, spad_addr_V_produce,
|
||||
/*spad_D=*/0, /*spad_C=*/spad_addr_S_produce /*FIXME:bogus*/,
|
||||
/*I=*/(HEADDIM / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
|
||||
/*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
|
||||
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a);
|
||||
@@ -614,12 +640,14 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
sp_tiled_matmul_full_spad_ws(
|
||||
spad_addr_K_produce, spad_addr_V_produce,
|
||||
/*spad_D=*/0, /*spad_C=*/spad_addr_S_produce /*FIXME:bogus*/,
|
||||
/*I=*/(HEADDIM / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
|
||||
/*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
|
||||
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
|
||||
}
|
||||
gemmini_fence();
|
||||
gemmini_fence();
|
||||
gemmini_fence();
|
||||
}
|
||||
|
||||
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||
|
||||
Reference in New Issue
Block a user