flash: Move fence to start of loop; wrap all MMIO in one tid=0 branch

This commit is contained in:
Hansung Kim
2024-11-09 20:59:26 -08:00
parent fcd8b0b892
commit 68054689c9

View File

@@ -402,22 +402,28 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
asm volatile ("dbuf_sel_end_%=:" :: ); asm volatile ("dbuf_sel_end_%=:" :: );
{ {
// do all of GEMM kickoffs before the SIMT compute
//
if (tid_in_warpgroup == 0) {
// fence completion of the GEMMs in the previous loop iterations. Note
// this is done at the start of the loop to maximize window of
// overlapping.
gemmini_fence();
if (tile_k >= 2) // delay GEMM II by 2 iters for pipelining if (tile_k >= 2) // delay GEMM II by 2 iters for pipelining
{ {
const uint32_t tile_k_ = tile_k - 2; const uint32_t tile_k_ = tile_k - 2;
// GEMM II: O = O + P*V // GEMM II: O = O + P*V
// -------------------- // --------------------
// This is done *before* GEMM I in the software pipeline, working on the // This is done *before* GEMM I in the software pipeline, working on
// online softmax result tile from the previous iteration // the online softmax result tile from the previous iteration
asm volatile("gemm_pv_start_%=:" ::); asm volatile("gemm_pv_start_%=:" ::);
if (tid_in_warpgroup == 0) { // kick off GEMM II
// kickoff matmul //
// FIXME: perf: prevent GMEM->SMEM load for O tile // FIXME: perf: prevent GMEM->SMEM load for O tile
gemmini_fence();
#ifdef GEMMINI_NEW_CISC #ifdef GEMMINI_NEW_CISC
gemmini_tile_compute</*store_to_spad=*/true>( gemmini_tile_compute</*store_to_spad=*/true>(
spad_hex_P_consume, spad_hex_V_consume, spad_hex_O, spad_hex_P_consume, spad_hex_V_consume, spad_hex_O,
@@ -432,11 +438,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul); /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul);
#endif #endif
}
// // reconverge from mmio divergence
// threadblock_barrier(warpgroup_id_in_cluster,
// warps_per_warpgroup_per_core);
asm volatile("gemm_pv_finish_%=:" ::); asm volatile("gemm_pv_finish_%=:" ::);
} }
@@ -446,7 +447,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// kick off asynchronously; fence later // kick off asynchronously; fence later
asm volatile("gemm_qk_start_%=:" ::); asm volatile("gemm_qk_start_%=:" ::);
if (tid_in_warpgroup == 0) {
// FIXME: remove // FIXME: remove
// // fence to GEMM II completion // // fence to GEMM II completion
// gemmini_fence(); // gemmini_fence();
@@ -492,9 +492,12 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// do DMA // do DMA
if (tile_k == 0) { if (tile_k == 0) {
// commented out as we do two move-ins before the loop starts
//
// // configure address strides for the DMA // // configure address strides for the DMA
// // FIXME: unnecessary? // // FIXME: unnecessary?
// GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) | // GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ <<
// 8) |
// 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); // 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
// gemmini_fence(); // gemmini_fence();
// //
@@ -502,9 +505,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// sp_tiled_matmul_full_spad_ws( // sp_tiled_matmul_full_spad_ws(
// spad_addr_K_produce, spad_addr_V_produce, // spad_addr_K_produce, spad_addr_V_produce,
// /*spad_D=*/0, /*spad_C=*/0, // /*spad_D=*/0, /*spad_C=*/0,
// /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), // /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL /
// DIM),
// /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, // /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
// /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, // /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0,
// /*low_D=*/0,
// /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a); // /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a);
} else { } else {
#ifdef GEMMINI_NEW_CISC #ifdef GEMMINI_NEW_CISC
@@ -528,21 +533,13 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
#endif #endif
} }
asm volatile("move_k_v_finish_%=:" ::);
} }
// reconverge from mmio divergence // reconverge from mmio divergence
threadblock_barrier(warpgroup_id_in_cluster, threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core); warps_per_warpgroup_per_core);
asm volatile("move_k_v_finish_%=:" ::);
// FIXME: remove for nowarpspec
//
// NOTE: cannot put barrier here; thread 1-7 in warp 0 will skip the
// branch and call this barrier earlier than when thread 0 finishes.
// Since tmask is not considered, that will be a barrier resolve done too
// early
// threadblock_barrier(0, 1);
} }
{ {
@@ -630,8 +627,13 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
if constexpr (DEBUG) { if constexpr (DEBUG) {
if (warpgroup_id_in_cluster == 0) { if (warpgroup_id_in_cluster == 0) {
if (tid_in_warpgroup == 0) {
gemmini_fence(); gemmini_fence();
gemmini_fence(); gemmini_fence();
}
// reconverge from mmio divergence
threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core);
// O after PV // O after PV
if (tile_k_ == 1 /*wait until GEMM II finshes */) { if (tile_k_ == 1 /*wait until GEMM II finshes */) {
@@ -687,8 +689,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
threadblock_barrier(warpgroup_id_in_cluster, threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core); warps_per_warpgroup_per_core);
// fence everything before going to the next tile
gemmini_fence(); // instead of fencing here, we fence at the start of the loop to maximize
// overlapping
// gemmini_fence();
} }
} }