flash: Replace CISC with RISC

spadQuartile in hw does not match spad addresses in kernel; match them
later for optimization.
This commit is contained in:
Hansung Kim
2024-09-08 20:52:28 -07:00
parent 6547e92757
commit a4dd45bc1b

View File

@@ -199,9 +199,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
constexpr uint32_t skips_only_a = constexpr uint32_t skips_only_a =
loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/1, /*skip_ldd=*/1, loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/1, /*skip_ldd=*/1,
/*skip_ex=*/1, /*skip_stc=*/1); /*skip_ex=*/1, /*skip_stc=*/1);
constexpr uint32_t skips_only_b =
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/0, /*skip_ldd=*/1,
/*skip_ex=*/1, /*skip_stc=*/1);
constexpr uint32_t skips_mvout_spad = constexpr uint32_t skips_mvout_spad =
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1, loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1,
/*skip_ex=*/1, /*skip_stc=*/0); /*skip_ex=*/1, /*skip_stc=*/0);
constexpr uint32_t skips_matmul =
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1,
/*skip_ex=*/0, /*skip_stc=*/1);
constexpr uint32_t skips_matmul_preload = constexpr uint32_t skips_matmul_preload =
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/0, loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/0,
/*skip_ex=*/0, /*skip_stc=*/1); /*skip_ex=*/0, /*skip_stc=*/1);
@@ -231,12 +237,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// other warps behave differently on the branch condition. // other warps behave differently on the branch condition.
// threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
// move Q and K into SMEM before the loop starts
//
static_assert(B_ROW == B_COL, "currently only supports square tiles"); static_assert(B_ROW == B_COL, "currently only supports square tiles");
// move Q and K into SMEM before the loop starts
//
asm volatile("dma_move_start_%=:" ::); asm volatile("dma_move_start_%=:" ::);
if (tid_in_warpgroup == 0) { if (tid_in_warpgroup == 0) {
// make sure to read from the correct row of Q // make sure to read from the correct row of Q
const float *gmem_Q_tile = gmem_Q + HEADDIM * B_ROW * warpgroup_id; const float *gmem_Q_tile = gmem_Q + HEADDIM * B_ROW * warpgroup_id;
@@ -249,35 +254,48 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
gemmini_fence(); gemmini_fence();
#define GEMMINI_DMA_CISC // #define GEMMINI_DMA_CISC
#ifdef GEMMINI_DMA_CISC #ifdef GEMMINI_DMA_CISC
// the target addresses of this should match with spad_addr_Q0 and // the target addresses of this should match with spad_addr_Q0 and
// spad_addr_K0 set in this kernel // spad_addr_K0 set in this kernel
GEMMINI_CISC_CMD_I(10); GEMMINI_CISC_CMD_I(10);
gemmini_fence();
// need to also move to Q1 for the next iteration
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q_tile),
(uint64_t)(gmem_K_tile), k_LOOP_WS_CONFIG_ADDRS_AB)
GEMMINI_CISC_CMD_R((dim_seqlen << 20) | (HEADDIM << 8) |
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
gemmini_fence();
GEMMINI_CISC_CMD_I(11);
gemmini_fence();
#else #else
// do DMA // do DMA
// //
// among other things, this also configures CONFIG_BOUNDS so that the // among other things, this also configures CONFIG_BOUNDS so that the
// DMA knows the full matrix dimensions // DMA knows the full matrix dimensions
sp_tiled_matmul_full_spad_ws( sp_tiled_matmul_full_spad_ws(
spad_addr_Q, spad_addr_K, spad_addr_Q0, spad_addr_K0,
/*spad_D=*/0, /*spad_C=*/spad_addr_S, /*spad_D=*/0, /*spad_C=*/spad_addr_S0/*bogus*/,
/*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM),
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
gemmini_fence();
#endif #endif
gemmini_fence();
// need to also move Q to spad_addr_Q1 for the next iteration
// FIXME: re-configure necessary?
gmem_K_tile = gmem_K + (B_COL * 1);
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q_tile),
(uint64_t)(gmem_K_tile), k_LOOP_WS_CONFIG_ADDRS_AB)
GEMMINI_CISC_CMD_R((dim_seqlen << 20) | (HEADDIM << 8) |
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
gemmini_fence();
#ifdef GEMMINI_DMA_CISC
// GEMMINI_CISC_CMD_I(11);
#else
sp_tiled_matmul_full_spad_ws(
spad_addr_Q1, spad_addr_K1/*bogus*/,
/*spad_D=*/0, /*spad_C=*/spad_addr_S0/*bogus*/,
/*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM),
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a);
#endif
gemmini_fence();
gemmini_fence();
// re-configure DMA for K and V load that will later happen in the loop // re-configure DMA for K and V load that will later happen in the loop
// GMEM addr stride for K // GMEM addr stride for K
@@ -376,6 +394,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// do matmul // do matmul
// among other things, this also configures CONFIG_BOUNDS so that the // among other things, this also configures CONFIG_BOUNDS so that the
// DMA knows the full matrix dimensions // DMA knows the full matrix dimensions
gemmini_fence();
sp_tiled_matmul_full_spad_ws( sp_tiled_matmul_full_spad_ws(
spad_addr_P_consume, spad_addr_V_consume, spad_addr_P_consume, spad_addr_V_consume,
/*spad_D=*/spad_addr_O, /*spad_C=*/spad_addr_O, /*spad_D=*/spad_addr_O, /*spad_C=*/spad_addr_O,
@@ -437,11 +456,18 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
asm volatile("gemm_qk_start_%=:" ::); asm volatile("gemm_qk_start_%=:" ::);
if (tid_in_warpgroup == 0) { if (tid_in_warpgroup == 0) {
gemmini_fence();
// 0,2,.: opcode 0 (quartile 0/2, no accum) // 0,2,.: opcode 0 (quartile 0/2, no accum)
// 1,3,.: opcode 3 (quartile 1/3, no accum) // 1,3,.: opcode 3 (quartile 1/3, no accum)
const uint32_t opcode = 3 * (tile_k & 1); const uint32_t opcode = 3 * (tile_k & 1);
gemmini_fence(); //GEMMINI_CISC_CMD_I(opcode);
GEMMINI_CISC_CMD_I(opcode); sp_tiled_matmul_full_spad_ws(
spad_addr_Q, spad_addr_K_consume,
/*spad_D=*/0, /*spad_C=*/spad_addr_S_produce,
/*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM),
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul);
gemmini_fence(); gemmini_fence();
gemmini_fence(); gemmini_fence();
@@ -574,7 +600,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
} }
} }
// data movement for K and V // data move for K and V
// //
// Q stays in SMEM for the entire loop // Q stays in SMEM for the entire loop
asm volatile("move_k_v_start_%=:" ::); asm volatile("move_k_v_start_%=:" ::);
@@ -606,7 +632,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
sp_tiled_matmul_full_spad_ws( sp_tiled_matmul_full_spad_ws(
spad_addr_K_produce, spad_addr_V_produce, spad_addr_K_produce, spad_addr_V_produce,
/*spad_D=*/0, /*spad_C=*/spad_addr_S_produce /*FIXME:bogus*/, /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce /*FIXME:bogus*/,
/*I=*/(HEADDIM / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a); /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a);
@@ -614,12 +640,14 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
sp_tiled_matmul_full_spad_ws( sp_tiled_matmul_full_spad_ws(
spad_addr_K_produce, spad_addr_V_produce, spad_addr_K_produce, spad_addr_V_produce,
/*spad_D=*/0, /*spad_C=*/spad_addr_S_produce /*FIXME:bogus*/, /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce /*FIXME:bogus*/,
/*I=*/(HEADDIM / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
} }
gemmini_fence(); gemmini_fence();
gemmini_fence();
gemmini_fence();
} }
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);