From 90f6effa971fbd3b17f371d9839024ad45bde590 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sun, 18 Aug 2024 15:21:05 -0700 Subject: [PATCH] flash: Pass smem_P arg to softmax func --- tests/regression/flash_attention/kernel.cpp | 36 ++++++++++++--------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 6b3bb4fb..bd238bb6 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -53,13 +53,11 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock, } } -inline void thread_block_flashattn(float *smem_S, float *smem_O, - const uint32_t tid_in_threadblock, - const uint32_t threads_per_threadblock, - const uint32_t threadblock_id_in_cluster, - float *smem_scratchpad, - float *smem_rowmax, - float *smem_rowsum) { +inline void thread_block_online_softmax( + float *smem_S, float *smem_O, float *smem_P, + const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock, + const uint32_t threadblock_id_in_cluster, float *smem_scratchpad, + float *smem_rowmax, float *smem_rowsum) { asm volatile("thread_block_flashattn_start_%=:" ::); const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; @@ -191,7 +189,6 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O, // Store S transposed to the shared memory - // update S in-place into P smem_S[thread_offset] = f0; // S[thread_offset + 1] = f1; gmem_tmp1[thread_offset] = f0; @@ -232,7 +229,6 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O, const float mi_prev = smem_rowmax_prev[row]; const float mi_this = smem_rowmax_this[row]; - const float mi_new = smem_rowmax_new[row]; const float exp = mi_prev - mi_this; // update rowsum @@ -261,7 +257,6 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O, // update Oi in-place smem_O[thread_offset] = fval; - gmem_tmp2[thread_offset] = fval; thread_offset += NUM_THREADS; } @@ -316,8 +311,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { threadblock_id_in_cluster); uint8_t *smem_S = smem_per_threadblock; - uint8_t *smem_O = smem_per_threadblock + - sizeof(float) * (smem_QK_size + smem_V_size); + uint8_t *smem_P = smem_S; // in-place update from S to P + uint8_t *smem_V = smem_per_threadblock + sizeof(float) * smem_QK_size; + uint8_t *smem_O = + smem_per_threadblock + sizeof(float) * (smem_QK_size + smem_V_size); // allocate rowmax/rowsum storage at the end of the sharedmem address space constexpr uint32_t smem_rowmax_size = sizeof(float) * B_ROW * 3 /* mi, mi~, minew */; @@ -361,10 +358,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { float *tile_S = (float *)arg->addr_q; #endif - thread_block_flashattn(tile_S, (float *)smem_O, tid_in_threadblock, - threads_per_threadblock, threadblock_id_in_cluster, - (float *)smem_scratchpad, (float *)smem_rowmax, - (float *)smem_rowsum); + thread_block_online_softmax( + tile_S, (float *)smem_O, (float *)smem_P, tid_in_threadblock, + threads_per_threadblock, threadblock_id_in_cluster, + (float *)smem_scratchpad, (float *)smem_rowmax, (float *)smem_rowsum); + + // FIXME unnecessary? + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + + thread_block_gemm_single_tile(smem_P, smem_V, tid_in_threadblock, + threads_per_threadblock); } int main() {