From 64e48de8af8d746ec1613e1a9066570d74fb55d2 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Mon, 19 Aug 2024 18:03:06 -0700 Subject: [PATCH] flash: Do accumulation of PV into O using the single_tile API --- tests/regression/flash_attention/kernel.cpp | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 5ee64f75..5737b00a 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -371,10 +371,13 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core); + // GEMM I: S = Q*K thread_block_gemm_single_tile( - smem_Q, smem_K, smem_S, tid_in_threadblock, threads_per_threadblock, - threadblocks_per_cluster, threadblock_id_in_cluster); + smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_threadblock, + threads_per_threadblock, threadblocks_per_cluster, + threadblock_id_in_cluster); #endif // protect GEMM result writes (smem_S) before softmax @@ -395,21 +398,24 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core); + // GEMM II: O = O + P*V + // clear out accumulators initialize_accum_regs<0>(); initialize_accum_regs<1>(); load_tile_to_smem( - B_COL, 0 /*FIXME*/, 0 /*FIXME*/, gmem_V, smem_V, tid_in_threadblock); + B_COL, 0, 0, gmem_V, smem_V, tid_in_threadblock); threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core); // FIXME: support MN_major for A for ideal performance thread_block_gemm_single_tile( - smem_P, smem_V, gmem_O /*smem_O*/, tid_in_threadblock, - threads_per_threadblock, threadblocks_per_cluster, + smem_P, smem_V, smem_O, gmem_O /*smem_O*/, + tid_in_threadblock, threads_per_threadblock, threadblocks_per_cluster, threadblock_id_in_cluster); threadblock_barrier(threadblock_id_in_cluster,