From 4e087a8aab3c3ba8fd51f359c9a24f462352f097 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Fri, 8 Nov 2024 16:43:08 -0800 Subject: [PATCH] flash: Fix loop iteration for gemmini Kernel is software-pipelined around 2 GEMMs and softmax; it requires two iterations to fully complete a tile. --- tests/regression/flash_attention/kernel.gemmini.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index ac3788d4..089f0a3f 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -336,7 +336,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // "inner loop" along the columns of K^T const uint32_t k_tiles = (dim_seqlen / B_COL); for (uint32_t tile_k = 0; - tile_k < (4 /*for perf measurement*/ * k_tiles) + 2 /*pipeline latency*/; + tile_k < (4 /*for perf measurement*/ * + // virgo kernel is fully pipelined around (2 GEMMs | softmax); + // requires two loop iterations to finish one tile compute + (2 * k_tiles)) + + 2 /*pipeline latency*/; tile_k++) { if constexpr (DEBUG || true) { threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);