From 68054689c96d0405e0c62bd296cadadd97f0a886 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sat, 9 Nov 2024 20:59:26 -0800 Subject: [PATCH] flash: Move fence to start of loop; wrap all MMIO in one tid=0 branch --- .../flash_attention/kernel.gemmini.cpp | 100 +++++++++--------- 1 file changed, 52 insertions(+), 48 deletions(-) diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index 882d8f96..e3335861 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -402,22 +402,28 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { asm volatile ("dbuf_sel_end_%=:" :: ); { - if (tile_k >= 2) // delay GEMM II by 2 iters for pipelining - { - const uint32_t tile_k_ = tile_k - 2; + // do all of GEMM kickoffs before the SIMT compute + // + if (tid_in_warpgroup == 0) { + // fence completion of the GEMMs in the previous loop iterations. Note + // this is done at the start of the loop to maximize window of + // overlapping. + gemmini_fence(); - // GEMM II: O = O + P*V - // -------------------- - // This is done *before* GEMM I in the software pipeline, working on the - // online softmax result tile from the previous iteration + if (tile_k >= 2) // delay GEMM II by 2 iters for pipelining + { + const uint32_t tile_k_ = tile_k - 2; - asm volatile("gemm_pv_start_%=:" ::); + // GEMM II: O = O + P*V + // -------------------- + // This is done *before* GEMM I in the software pipeline, working on + // the online softmax result tile from the previous iteration - if (tid_in_warpgroup == 0) { - // kickoff matmul + asm volatile("gemm_pv_start_%=:" ::); + // kick off GEMM II + // // FIXME: perf: prevent GMEM->SMEM load for O tile - gemmini_fence(); #ifdef GEMMINI_NEW_CISC gemmini_tile_compute( spad_hex_P_consume, spad_hex_V_consume, spad_hex_O, @@ -432,32 +438,26 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul); #endif + + asm volatile("gemm_pv_finish_%=:" ::); } - // // reconverge from mmio divergence - // threadblock_barrier(warpgroup_id_in_cluster, - // warps_per_warpgroup_per_core); + // GEMM I: S = Q*K + // + // kick off asynchronously; fence later + asm volatile("gemm_qk_start_%=:" ::); - asm volatile("gemm_pv_finish_%=:" ::); - } - - // GEMM I: S = Q*K - // - // kick off asynchronously; fence later - asm volatile("gemm_qk_start_%=:" ::); - - if (tid_in_warpgroup == 0) { // FIXME: remove // // fence to GEMM II completion // gemmini_fence(); -// #ifdef FENCE_GEMM_II -// asm volatile("rescale_fence_write_start_%=:" ::); -// // signal that GEMM II is finished to O rescale step -// *smem_O_flag = 1; -// vx_fence(); -// asm volatile("rescale_fence_write_end_%=:" ::); -// #endif + // #ifdef FENCE_GEMM_II + // asm volatile("rescale_fence_write_start_%=:" ::); + // // signal that GEMM II is finished to O rescale step + // *smem_O_flag = 1; + // vx_fence(); + // asm volatile("rescale_fence_write_end_%=:" ::); + // #endif // Kick off GEMM I // @@ -492,9 +492,12 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // do DMA if (tile_k == 0) { + // commented out as we do two move-ins before the loop starts + // // // configure address strides for the DMA // // FIXME: unnecessary? - // GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) | + // GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << + // 8) | // 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); // gemmini_fence(); // @@ -502,9 +505,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // sp_tiled_matmul_full_spad_ws( // spad_addr_K_produce, spad_addr_V_produce, // /*spad_D=*/0, /*spad_C=*/0, - // /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + // /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / + // DIM), // /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, - // /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + // /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, + // /*low_D=*/0, // /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a); } else { #ifdef GEMMINI_NEW_CISC @@ -528,22 +533,14 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); #endif } + + asm volatile("move_k_v_finish_%=:" ::); } // reconverge from mmio divergence threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); - - asm volatile("move_k_v_finish_%=:" ::); - - // FIXME: remove for nowarpspec - // - // NOTE: cannot put barrier here; thread 1-7 in warp 0 will skip the - // branch and call this barrier earlier than when thread 0 finishes. - // Since tmask is not considered, that will be a barrier resolve done too - // early - // threadblock_barrier(0, 1); - } + } { if (tile_k >= 1) // delay online softmax by 1 iters @@ -630,8 +627,13 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if constexpr (DEBUG) { if (warpgroup_id_in_cluster == 0) { - gemmini_fence(); - gemmini_fence(); + if (tid_in_warpgroup == 0) { + gemmini_fence(); + gemmini_fence(); + } + // reconverge from mmio divergence + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); // O after PV if (tile_k_ == 1 /*wait until GEMM II finshes */) { @@ -687,8 +689,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); - // fence everything before going to the next tile - gemmini_fence(); + + // instead of fencing here, we fence at the start of the loop to maximize + // overlapping + // gemmini_fence(); } }