From 829af5d429ed60f1fb803427fe928a35c816df97 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Mon, 9 Sep 2024 15:21:49 -0700 Subject: [PATCH] flash: Comment out mvout to smem Verified up to O_before_PV; still stalls without DEBUG --- tests/regression/flash_attention/kernel.gemmini.cpp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index 82a83aa4..a8188dc4 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -82,6 +82,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { float *gmem_tmp_e3 = reinterpret_cast(0xe3000000UL); // static shared memory allocation + // these are in float elements, not bytes constexpr uint32_t smem_Q_size = B_ROW * HEADDIM; constexpr uint32_t smem_K_size = B_COL * HEADDIM; constexpr uint32_t smem_QK_size = B_ROW * B_COL; @@ -384,6 +385,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // FIXME: perf: prevent GMEM->SMEM load for O tile gemmini_fence(); gemmini_fence(); + gemmini_fence(); + gemmini_fence(); sp_tiled_matmul_full_spad_ws( spad_addr_P_consume, spad_addr_V_consume, /*spad_D=*/spad_addr_O, /*spad_C=*/spad_addr_O, @@ -446,7 +449,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { gemmini_fence(); gemmini_fence(); -#if 1 +#if 0 // mvout to SMEM // GEMMINI_CISC_CMD_I(9); sp_tiled_matmul_full_spad_ws( @@ -493,10 +496,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { gemmini_fence(); gemmini_fence(); gemmini_fence(); + gemmini_fence(); // 0,2,.: opcode 0 (quartile 0/2, no accum) // 1,3,.: opcode 3 (quartile 1/3, no accum) - const uint32_t opcode = 3 * (tile_k & 1); + // const uint32_t opcode = 3 * (tile_k & 1); //GEMMINI_CISC_CMD_I(opcode); sp_tiled_matmul_full_spad_ws( spad_addr_Q, spad_addr_K_consume, @@ -571,7 +575,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { gemmini_fence(); gemmini_fence(); -#if 1 +#if 0 // mvout to SMEM // GEMMINI_CISC_CMD_I(9); sp_tiled_matmul_full_spad_ws( @@ -656,6 +660,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { gemmini_fence(); gemmini_fence(); gemmini_fence(); + gemmini_fence(); } threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);