flash: Remove unnecessary dmem preload, fix rowmax/rowsum dependency
This commit is contained in:
@@ -216,10 +216,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
if (tid_in_warpgroup == 0) {
|
||||
gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0);
|
||||
|
||||
// configure DMA for the full Q matrix
|
||||
// configure DMA with GMEM address strides
|
||||
// Q matrix
|
||||
gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY,
|
||||
false, 0);
|
||||
// configure DMA for the full K matrix
|
||||
// K matrix
|
||||
gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t), MVIN_SCALE_IDENTITY,
|
||||
false, 1);
|
||||
// configure DMA for Q*K store
|
||||
@@ -344,16 +345,13 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
float *smem_S_produce = (tile_k & 1) ? smem_S0 : smem_S1;
|
||||
float *smem_P_consume = (tile_k & 1) ? smem_P1 : smem_P0;
|
||||
float *smem_P_produce = (tile_k & 1) ? smem_P0 : smem_P1;
|
||||
// O tile is sequentially updated at every iteration; no ping-pong
|
||||
// necessary
|
||||
// O, rowmax/rowsum etc. is sequentially updated at every iteration; no
|
||||
// ping-pong necessary
|
||||
float *smem_O = smem_O0;
|
||||
// FIXME: O_row_scale/rowmax/rowsum/spad shouldn't really need ping-pong
|
||||
float *smem_O_row_scale =
|
||||
(tile_k & 1) ? smem_O_row_scale_1 : smem_O_row_scale_0;
|
||||
float *smem_rowmax = (tile_k & 1) ? smem_rowmax_1 : smem_rowmax_0;
|
||||
float *smem_rowsum = (tile_k & 1) ? smem_rowsum_1 : smem_rowsum_0;
|
||||
float *smem_scratchpad =
|
||||
(tile_k & 1) ? smem_scratchpad_1 : smem_scratchpad_0;
|
||||
float *smem_O_row_scale = smem_O_row_scale_0;
|
||||
float *smem_rowmax = smem_rowmax_0;
|
||||
float *smem_rowsum = smem_rowsum_0;
|
||||
float *smem_scratchpad = smem_scratchpad_0;
|
||||
|
||||
const auto spad_addr_Q = spad_addr_Q0;
|
||||
const auto spad_addr_K_consume = (tile_k & 1) ? spad_addr_K1 : spad_addr_K0;
|
||||
@@ -394,6 +392,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// do matmul
|
||||
// among other things, this also configures CONFIG_BOUNDS so that the
|
||||
// DMA knows the full matrix dimensions
|
||||
// FIXME: perf: prevent GMEM->SMEM load for O tile
|
||||
gemmini_fence();
|
||||
sp_tiled_matmul_full_spad_ws(
|
||||
spad_addr_P_consume, spad_addr_V_consume,
|
||||
@@ -401,7 +400,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
/*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
|
||||
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul_preload);
|
||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul);
|
||||
#endif
|
||||
|
||||
gemmini_fence();
|
||||
|
||||
Reference in New Issue
Block a user