flash: Move fence to start of loop; wrap all MMIO in one tid=0 branch
This commit is contained in:
@@ -402,22 +402,28 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
asm volatile ("dbuf_sel_end_%=:" :: );
|
asm volatile ("dbuf_sel_end_%=:" :: );
|
||||||
|
|
||||||
{
|
{
|
||||||
|
// do all of GEMM kickoffs before the SIMT compute
|
||||||
|
//
|
||||||
|
if (tid_in_warpgroup == 0) {
|
||||||
|
// fence completion of the GEMMs in the previous loop iterations. Note
|
||||||
|
// this is done at the start of the loop to maximize window of
|
||||||
|
// overlapping.
|
||||||
|
gemmini_fence();
|
||||||
|
|
||||||
if (tile_k >= 2) // delay GEMM II by 2 iters for pipelining
|
if (tile_k >= 2) // delay GEMM II by 2 iters for pipelining
|
||||||
{
|
{
|
||||||
const uint32_t tile_k_ = tile_k - 2;
|
const uint32_t tile_k_ = tile_k - 2;
|
||||||
|
|
||||||
// GEMM II: O = O + P*V
|
// GEMM II: O = O + P*V
|
||||||
// --------------------
|
// --------------------
|
||||||
// This is done *before* GEMM I in the software pipeline, working on the
|
// This is done *before* GEMM I in the software pipeline, working on
|
||||||
// online softmax result tile from the previous iteration
|
// the online softmax result tile from the previous iteration
|
||||||
|
|
||||||
asm volatile("gemm_pv_start_%=:" ::);
|
asm volatile("gemm_pv_start_%=:" ::);
|
||||||
|
|
||||||
if (tid_in_warpgroup == 0) {
|
// kick off GEMM II
|
||||||
// kickoff matmul
|
//
|
||||||
|
|
||||||
// FIXME: perf: prevent GMEM->SMEM load for O tile
|
// FIXME: perf: prevent GMEM->SMEM load for O tile
|
||||||
gemmini_fence();
|
|
||||||
#ifdef GEMMINI_NEW_CISC
|
#ifdef GEMMINI_NEW_CISC
|
||||||
gemmini_tile_compute</*store_to_spad=*/true>(
|
gemmini_tile_compute</*store_to_spad=*/true>(
|
||||||
spad_hex_P_consume, spad_hex_V_consume, spad_hex_O,
|
spad_hex_P_consume, spad_hex_V_consume, spad_hex_O,
|
||||||
@@ -432,11 +438,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul);
|
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul);
|
||||||
#endif
|
#endif
|
||||||
}
|
|
||||||
|
|
||||||
// // reconverge from mmio divergence
|
|
||||||
// threadblock_barrier(warpgroup_id_in_cluster,
|
|
||||||
// warps_per_warpgroup_per_core);
|
|
||||||
|
|
||||||
asm volatile("gemm_pv_finish_%=:" ::);
|
asm volatile("gemm_pv_finish_%=:" ::);
|
||||||
}
|
}
|
||||||
@@ -446,7 +447,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
// kick off asynchronously; fence later
|
// kick off asynchronously; fence later
|
||||||
asm volatile("gemm_qk_start_%=:" ::);
|
asm volatile("gemm_qk_start_%=:" ::);
|
||||||
|
|
||||||
if (tid_in_warpgroup == 0) {
|
|
||||||
// FIXME: remove
|
// FIXME: remove
|
||||||
// // fence to GEMM II completion
|
// // fence to GEMM II completion
|
||||||
// gemmini_fence();
|
// gemmini_fence();
|
||||||
@@ -492,9 +492,12 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
|
|
||||||
// do DMA
|
// do DMA
|
||||||
if (tile_k == 0) {
|
if (tile_k == 0) {
|
||||||
|
// commented out as we do two move-ins before the loop starts
|
||||||
|
//
|
||||||
// // configure address strides for the DMA
|
// // configure address strides for the DMA
|
||||||
// // FIXME: unnecessary?
|
// // FIXME: unnecessary?
|
||||||
// GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) |
|
// GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ <<
|
||||||
|
// 8) |
|
||||||
// 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
|
// 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
|
||||||
// gemmini_fence();
|
// gemmini_fence();
|
||||||
//
|
//
|
||||||
@@ -502,9 +505,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
// sp_tiled_matmul_full_spad_ws(
|
// sp_tiled_matmul_full_spad_ws(
|
||||||
// spad_addr_K_produce, spad_addr_V_produce,
|
// spad_addr_K_produce, spad_addr_V_produce,
|
||||||
// /*spad_D=*/0, /*spad_C=*/0,
|
// /*spad_D=*/0, /*spad_C=*/0,
|
||||||
// /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
|
// /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL /
|
||||||
|
// DIM),
|
||||||
// /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
// /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||||
// /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
// /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0,
|
||||||
|
// /*low_D=*/0,
|
||||||
// /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a);
|
// /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a);
|
||||||
} else {
|
} else {
|
||||||
#ifdef GEMMINI_NEW_CISC
|
#ifdef GEMMINI_NEW_CISC
|
||||||
@@ -528,21 +533,13 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
|
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
asm volatile("move_k_v_finish_%=:" ::);
|
||||||
}
|
}
|
||||||
|
|
||||||
// reconverge from mmio divergence
|
// reconverge from mmio divergence
|
||||||
threadblock_barrier(warpgroup_id_in_cluster,
|
threadblock_barrier(warpgroup_id_in_cluster,
|
||||||
warps_per_warpgroup_per_core);
|
warps_per_warpgroup_per_core);
|
||||||
|
|
||||||
asm volatile("move_k_v_finish_%=:" ::);
|
|
||||||
|
|
||||||
// FIXME: remove for nowarpspec
|
|
||||||
//
|
|
||||||
// NOTE: cannot put barrier here; thread 1-7 in warp 0 will skip the
|
|
||||||
// branch and call this barrier earlier than when thread 0 finishes.
|
|
||||||
// Since tmask is not considered, that will be a barrier resolve done too
|
|
||||||
// early
|
|
||||||
// threadblock_barrier(0, 1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
@@ -630,8 +627,13 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
|
|
||||||
if constexpr (DEBUG) {
|
if constexpr (DEBUG) {
|
||||||
if (warpgroup_id_in_cluster == 0) {
|
if (warpgroup_id_in_cluster == 0) {
|
||||||
|
if (tid_in_warpgroup == 0) {
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
|
}
|
||||||
|
// reconverge from mmio divergence
|
||||||
|
threadblock_barrier(warpgroup_id_in_cluster,
|
||||||
|
warps_per_warpgroup_per_core);
|
||||||
|
|
||||||
// O after PV
|
// O after PV
|
||||||
if (tile_k_ == 1 /*wait until GEMM II finshes */) {
|
if (tile_k_ == 1 /*wait until GEMM II finshes */) {
|
||||||
@@ -687,8 +689,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
threadblock_barrier(warpgroup_id_in_cluster,
|
threadblock_barrier(warpgroup_id_in_cluster,
|
||||||
warps_per_warpgroup_per_core);
|
warps_per_warpgroup_per_core);
|
||||||
|
|
||||||
// fence everything before going to the next tile
|
|
||||||
gemmini_fence();
|
// instead of fencing here, we fence at the start of the loop to maximize
|
||||||
|
// overlapping
|
||||||
|
// gemmini_fence();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user