flash: Remove unnecessary dmem preload, fix rowmax/rowsum dependency
This commit is contained in:
@@ -216,10 +216,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
if (tid_in_warpgroup == 0) {
|
if (tid_in_warpgroup == 0) {
|
||||||
gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0);
|
gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0);
|
||||||
|
|
||||||
// configure DMA for the full Q matrix
|
// configure DMA with GMEM address strides
|
||||||
|
// Q matrix
|
||||||
gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY,
|
gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY,
|
||||||
false, 0);
|
false, 0);
|
||||||
// configure DMA for the full K matrix
|
// K matrix
|
||||||
gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t), MVIN_SCALE_IDENTITY,
|
gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t), MVIN_SCALE_IDENTITY,
|
||||||
false, 1);
|
false, 1);
|
||||||
// configure DMA for Q*K store
|
// configure DMA for Q*K store
|
||||||
@@ -344,16 +345,13 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
float *smem_S_produce = (tile_k & 1) ? smem_S0 : smem_S1;
|
float *smem_S_produce = (tile_k & 1) ? smem_S0 : smem_S1;
|
||||||
float *smem_P_consume = (tile_k & 1) ? smem_P1 : smem_P0;
|
float *smem_P_consume = (tile_k & 1) ? smem_P1 : smem_P0;
|
||||||
float *smem_P_produce = (tile_k & 1) ? smem_P0 : smem_P1;
|
float *smem_P_produce = (tile_k & 1) ? smem_P0 : smem_P1;
|
||||||
// O tile is sequentially updated at every iteration; no ping-pong
|
// O, rowmax/rowsum etc. is sequentially updated at every iteration; no
|
||||||
// necessary
|
// ping-pong necessary
|
||||||
float *smem_O = smem_O0;
|
float *smem_O = smem_O0;
|
||||||
// FIXME: O_row_scale/rowmax/rowsum/spad shouldn't really need ping-pong
|
float *smem_O_row_scale = smem_O_row_scale_0;
|
||||||
float *smem_O_row_scale =
|
float *smem_rowmax = smem_rowmax_0;
|
||||||
(tile_k & 1) ? smem_O_row_scale_1 : smem_O_row_scale_0;
|
float *smem_rowsum = smem_rowsum_0;
|
||||||
float *smem_rowmax = (tile_k & 1) ? smem_rowmax_1 : smem_rowmax_0;
|
float *smem_scratchpad = smem_scratchpad_0;
|
||||||
float *smem_rowsum = (tile_k & 1) ? smem_rowsum_1 : smem_rowsum_0;
|
|
||||||
float *smem_scratchpad =
|
|
||||||
(tile_k & 1) ? smem_scratchpad_1 : smem_scratchpad_0;
|
|
||||||
|
|
||||||
const auto spad_addr_Q = spad_addr_Q0;
|
const auto spad_addr_Q = spad_addr_Q0;
|
||||||
const auto spad_addr_K_consume = (tile_k & 1) ? spad_addr_K1 : spad_addr_K0;
|
const auto spad_addr_K_consume = (tile_k & 1) ? spad_addr_K1 : spad_addr_K0;
|
||||||
@@ -394,6 +392,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
// do matmul
|
// do matmul
|
||||||
// among other things, this also configures CONFIG_BOUNDS so that the
|
// among other things, this also configures CONFIG_BOUNDS so that the
|
||||||
// DMA knows the full matrix dimensions
|
// DMA knows the full matrix dimensions
|
||||||
|
// FIXME: perf: prevent GMEM->SMEM load for O tile
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
sp_tiled_matmul_full_spad_ws(
|
sp_tiled_matmul_full_spad_ws(
|
||||||
spad_addr_P_consume, spad_addr_V_consume,
|
spad_addr_P_consume, spad_addr_V_consume,
|
||||||
@@ -401,7 +400,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
/*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
|
/*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
|
||||||
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||||
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul_preload);
|
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
|
|||||||
Reference in New Issue
Block a user