flash: Replace CISC with RISC

spadQuartile in hw does not match spad addresses in kernel; match them
later for optimization.
This commit is contained in:
Hansung Kim
2024-09-08 20:52:28 -07:00
parent 6547e92757
commit a4dd45bc1b

View File

@@ -199,9 +199,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
constexpr uint32_t skips_only_a =
loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/1, /*skip_ldd=*/1,
/*skip_ex=*/1, /*skip_stc=*/1);
constexpr uint32_t skips_only_b =
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/0, /*skip_ldd=*/1,
/*skip_ex=*/1, /*skip_stc=*/1);
constexpr uint32_t skips_mvout_spad =
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1,
/*skip_ex=*/1, /*skip_stc=*/0);
constexpr uint32_t skips_matmul =
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1,
/*skip_ex=*/0, /*skip_stc=*/1);
constexpr uint32_t skips_matmul_preload =
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/0,
/*skip_ex=*/0, /*skip_stc=*/1);
@@ -231,12 +237,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// other warps behave differently on the branch condition.
// threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
// move Q and K into SMEM before the loop starts
//
static_assert(B_ROW == B_COL, "currently only supports square tiles");
// move Q and K into SMEM before the loop starts
//
asm volatile("dma_move_start_%=:" ::);
if (tid_in_warpgroup == 0) {
// make sure to read from the correct row of Q
const float *gmem_Q_tile = gmem_Q + HEADDIM * B_ROW * warpgroup_id;
@@ -249,35 +254,48 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
gemmini_fence();
#define GEMMINI_DMA_CISC
// #define GEMMINI_DMA_CISC
#ifdef GEMMINI_DMA_CISC
// the target addresses of this should match with spad_addr_Q0 and
// spad_addr_K0 set in this kernel
GEMMINI_CISC_CMD_I(10);
gemmini_fence();
// need to also move to Q1 for the next iteration
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q_tile),
(uint64_t)(gmem_K_tile), k_LOOP_WS_CONFIG_ADDRS_AB)
GEMMINI_CISC_CMD_R((dim_seqlen << 20) | (HEADDIM << 8) |
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
gemmini_fence();
GEMMINI_CISC_CMD_I(11);
gemmini_fence();
#else
// do DMA
//
// among other things, this also configures CONFIG_BOUNDS so that the
// DMA knows the full matrix dimensions
sp_tiled_matmul_full_spad_ws(
spad_addr_Q, spad_addr_K,
/*spad_D=*/0, /*spad_C=*/spad_addr_S,
spad_addr_Q0, spad_addr_K0,
/*spad_D=*/0, /*spad_C=*/spad_addr_S0/*bogus*/,
/*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM),
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
gemmini_fence();
#endif
gemmini_fence();
// need to also move Q to spad_addr_Q1 for the next iteration
// FIXME: re-configure necessary?
gmem_K_tile = gmem_K + (B_COL * 1);
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q_tile),
(uint64_t)(gmem_K_tile), k_LOOP_WS_CONFIG_ADDRS_AB)
GEMMINI_CISC_CMD_R((dim_seqlen << 20) | (HEADDIM << 8) |
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
gemmini_fence();
#ifdef GEMMINI_DMA_CISC
// GEMMINI_CISC_CMD_I(11);
#else
sp_tiled_matmul_full_spad_ws(
spad_addr_Q1, spad_addr_K1/*bogus*/,
/*spad_D=*/0, /*spad_C=*/spad_addr_S0/*bogus*/,
/*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM),
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a);
#endif
gemmini_fence();
gemmini_fence();
// re-configure DMA for K and V load that will later happen in the loop
// GMEM addr stride for K
@@ -376,6 +394,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// do matmul
// among other things, this also configures CONFIG_BOUNDS so that the
// DMA knows the full matrix dimensions
gemmini_fence();
sp_tiled_matmul_full_spad_ws(
spad_addr_P_consume, spad_addr_V_consume,
/*spad_D=*/spad_addr_O, /*spad_C=*/spad_addr_O,
@@ -437,11 +456,18 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
asm volatile("gemm_qk_start_%=:" ::);
if (tid_in_warpgroup == 0) {
gemmini_fence();
// 0,2,.: opcode 0 (quartile 0/2, no accum)
// 1,3,.: opcode 3 (quartile 1/3, no accum)
const uint32_t opcode = 3 * (tile_k & 1);
gemmini_fence();
GEMMINI_CISC_CMD_I(opcode);
//GEMMINI_CISC_CMD_I(opcode);
sp_tiled_matmul_full_spad_ws(
spad_addr_Q, spad_addr_K_consume,
/*spad_D=*/0, /*spad_C=*/spad_addr_S_produce,
/*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM),
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul);
gemmini_fence();
gemmini_fence();
@@ -574,7 +600,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
}
}
// data movement for K and V
// data move for K and V
//
// Q stays in SMEM for the entire loop
asm volatile("move_k_v_start_%=:" ::);
@@ -606,7 +632,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
sp_tiled_matmul_full_spad_ws(
spad_addr_K_produce, spad_addr_V_produce,
/*spad_D=*/0, /*spad_C=*/spad_addr_S_produce /*FIXME:bogus*/,
/*I=*/(HEADDIM / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
/*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a);
@@ -614,12 +640,14 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
sp_tiled_matmul_full_spad_ws(
spad_addr_K_produce, spad_addr_V_produce,
/*spad_D=*/0, /*spad_C=*/spad_addr_S_produce /*FIXME:bogus*/,
/*I=*/(HEADDIM / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
/*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
}
gemmini_fence();
gemmini_fence();
gemmini_fence();
}
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);