flash: Remove unnecessary dmem preload, fix rowmax/rowsum dependency

This commit is contained in:
Hansung Kim
2024-09-08 21:11:59 -07:00
parent a4dd45bc1b
commit 6911843a82

View File

@@ -216,10 +216,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
if (tid_in_warpgroup == 0) { if (tid_in_warpgroup == 0) {
gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0); gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0);
// configure DMA for the full Q matrix // configure DMA with GMEM address strides
// Q matrix
gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY, gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY,
false, 0); false, 0);
// configure DMA for the full K matrix // K matrix
gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t), MVIN_SCALE_IDENTITY, gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t), MVIN_SCALE_IDENTITY,
false, 1); false, 1);
// configure DMA for Q*K store // configure DMA for Q*K store
@@ -344,16 +345,13 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
float *smem_S_produce = (tile_k & 1) ? smem_S0 : smem_S1; float *smem_S_produce = (tile_k & 1) ? smem_S0 : smem_S1;
float *smem_P_consume = (tile_k & 1) ? smem_P1 : smem_P0; float *smem_P_consume = (tile_k & 1) ? smem_P1 : smem_P0;
float *smem_P_produce = (tile_k & 1) ? smem_P0 : smem_P1; float *smem_P_produce = (tile_k & 1) ? smem_P0 : smem_P1;
// O tile is sequentially updated at every iteration; no ping-pong // O, rowmax/rowsum etc. is sequentially updated at every iteration; no
// necessary // ping-pong necessary
float *smem_O = smem_O0; float *smem_O = smem_O0;
// FIXME: O_row_scale/rowmax/rowsum/spad shouldn't really need ping-pong float *smem_O_row_scale = smem_O_row_scale_0;
float *smem_O_row_scale = float *smem_rowmax = smem_rowmax_0;
(tile_k & 1) ? smem_O_row_scale_1 : smem_O_row_scale_0; float *smem_rowsum = smem_rowsum_0;
float *smem_rowmax = (tile_k & 1) ? smem_rowmax_1 : smem_rowmax_0; float *smem_scratchpad = smem_scratchpad_0;
float *smem_rowsum = (tile_k & 1) ? smem_rowsum_1 : smem_rowsum_0;
float *smem_scratchpad =
(tile_k & 1) ? smem_scratchpad_1 : smem_scratchpad_0;
const auto spad_addr_Q = spad_addr_Q0; const auto spad_addr_Q = spad_addr_Q0;
const auto spad_addr_K_consume = (tile_k & 1) ? spad_addr_K1 : spad_addr_K0; const auto spad_addr_K_consume = (tile_k & 1) ? spad_addr_K1 : spad_addr_K0;
@@ -394,6 +392,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// do matmul // do matmul
// among other things, this also configures CONFIG_BOUNDS so that the // among other things, this also configures CONFIG_BOUNDS so that the
// DMA knows the full matrix dimensions // DMA knows the full matrix dimensions
// FIXME: perf: prevent GMEM->SMEM load for O tile
gemmini_fence(); gemmini_fence();
sp_tiled_matmul_full_spad_ws( sp_tiled_matmul_full_spad_ws(
spad_addr_P_consume, spad_addr_V_consume, spad_addr_P_consume, spad_addr_V_consume,
@@ -401,7 +400,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
/*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul_preload); /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul);
#endif #endif
gemmini_fence(); gemmini_fence();