From a17edac8759a131c5ababf9dd492be97b1fe14bf Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Mon, 9 Sep 2024 17:02:05 -0700 Subject: [PATCH] flash: Fix barrier stall with DEBUG Verified for up to P_expected on 2nd iter; O_before_PV is partially correct --- .../regression/flash_attention/flash_impl.hpp | 10 +-- .../flash_attention/kernel.gemmini.cpp | 63 ++++++++++--------- 2 files changed, 40 insertions(+), 33 deletions(-) diff --git a/tests/regression/flash_attention/flash_impl.hpp b/tests/regression/flash_attention/flash_impl.hpp index 8aac50ab..bd4aee9d 100644 --- a/tests/regression/flash_attention/flash_impl.hpp +++ b/tests/regression/flash_attention/flash_impl.hpp @@ -89,8 +89,9 @@ inline void thread_block_copy_rowmax(const float *src, float *dest, dest[offset] = src[offset]; } - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); + // threadblock_barrier(threadblock_id_in_cluster, + // warps_per_threadblock_per_core); + threadblock_barrier(1, 7); asm volatile("threadblock_copy_rowmax_finish_%=:" ::); } @@ -127,8 +128,9 @@ inline void thread_block_copy_tile(const float *src, float *dest, dest[gmem_offset] = src[smem_offset]; } - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); + // threadblock_barrier(threadblock_id_in_cluster, + // warps_per_threadblock_per_core); + threadblock_barrier(1, 7); } asm volatile("threadblock_copy_tile_finish_%=:" ::); diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index 9d611ba2..f85755e1 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -8,7 +8,7 @@ #include "gemmini_mmio.h" #include "flash_impl.hpp" -constexpr bool DEBUG = false; +constexpr bool DEBUG = true; static_assert(GEMMINI_DMA && !WARP_SPECIALIZED, "GEMMINI_DMA should be set and WARP_SPECIALIZED unset"); @@ -528,8 +528,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { asm volatile("move_k_v_finish_%=:" ::); - // // intra-warpgroup barrier - // // FIXME hardcoded + // NOTE: cannot put barrier here; thread 1-7 in warp 0 will skip the + // branch and call this barrier earlier than when thread 0 finishes. + // Since tmask is not considered, that will be a barrier resolve done too + // early // threadblock_barrier(0, 1); } else /* warp_id != 0 */ { @@ -538,6 +540,24 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { { const uint32_t tile_k_ = tile_k - 1; + if constexpr (DEBUG) { + // verify S = Q*K + + if (warpgroup_id == 0) { + if (tile_k == 0) { + thread_block_copy_tile( + smem_S_produce, gmem_tmp_d0, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); + } else if (tile_k == 1) { + thread_block_copy_tile( + smem_S_produce, gmem_tmp_d1, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); + } + + threadblock_barrier(barrier_id_simt, barrier_count_simt); + } + } + // Online softmax // thread_block_online_softmax( @@ -550,25 +570,26 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if constexpr (DEBUG) { if (warpgroup_id == 0) { if (tile_k_ == 0) { - thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, - tid_in_warpgroup_simt, threads_per_warpgroup, - warpgroup_id_in_cluster); - thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, - tid_in_warpgroup_simt, threads_per_warpgroup, - warpgroup_id_in_cluster); + thread_block_copy_rowmax( + smem_rowmax, gmem_tmp_e0, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); + thread_block_copy_rowmax( + smem_rowsum, gmem_tmp_e2, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); } else if (tile_k_ == 1) { thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, - tid_in_warpgroup_simt, threads_per_warpgroup, - warpgroup_id_in_cluster); + tid_in_warpgroup_simt, threads_per_warpgroup_simt, + warpgroup_id_simt); thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, - tid_in_warpgroup_simt, threads_per_warpgroup, - warpgroup_id_in_cluster); + tid_in_warpgroup_simt, threads_per_warpgroup_simt, + warpgroup_id_simt); } threadblock_barrier(barrier_id_simt, barrier_count_simt); } } + // FIXME: put synchronization with GEMM II here #if 0 // fence GEMM II to make sure dependency on O tile is settled if (tid_in_warpgroup == 0) { @@ -662,22 +683,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { warps_per_warpgroup_per_core); #endif - if constexpr (DEBUG) { - if (warpgroup_id == 0) { - if (tile_k == 0) { - thread_block_copy_tile( - smem_S_produce, gmem_tmp_d0, tid_in_warpgroup_simt, - threads_per_warpgroup_simt, warpgroup_id_simt); - } else if (tile_k == 1) { - thread_block_copy_tile( - smem_S_produce, gmem_tmp_d1, tid_in_warpgroup_simt, - threads_per_warpgroup_simt, warpgroup_id_simt); - } - - threadblock_barrier(barrier_id_simt, barrier_count_simt); - } - } - // intra-warpgroup barrier threadblock_barrier(barrier_id_simt, barrier_count_simt); }