diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 72d25a61..522dcb0b 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -82,10 +82,8 @@ inline void thread_block_flashattn(float *S, thread_offset += NUM_THREADS; } // stage per-thread max value in smem - // FIXME: we could warp_id instead of row here, but we need another barrier - // at the end of the loop iteration to prevent write-after-read hazard // FIXME: threadblock_id needs to be in here too - float *warp_smem = sharedmem_scratchpad + (row * NUM_THREADS); + float *warp_smem = sharedmem_scratchpad + (warp_id * NUM_THREADS); warp_smem[tid_in_warp] = per_thread_max; // sync writes to warp_smem @@ -115,22 +113,20 @@ inline void thread_block_flashattn(float *S, // const uint32_t row_stride = // (exp_elem_per_thread * threads_per_threadblock) / B_COL; - thread_offset = first_thread_offset + tid_in_warp; - // broadcast rowmax to all threads in the warp const float row_max = sharedmem_row_max_sum[row]; + thread_offset = first_thread_offset + tid_in_warp; #pragma GCC unroll for (int i = 0; i < per_row_iter; i++) { float val = S[thread_offset]; // FIXME: placeholder for proper exp - val = val; + val -= row_max; // update S in-place to P - // S[thread_offset] = val; + S[thread_offset] = val; gmem_tmp1[thread_offset] = val; - gmem_tmp2[thread_offset] = val - row_max; thread_offset += NUM_THREADS; } @@ -142,7 +138,7 @@ inline void thread_block_flashattn(float *S, // // two-level tree reduction, similar to rowmax -#if 0 + thread_offset = first_thread_offset + tid_in_warp; float per_thread_sum = 0.0f; #pragma GCC unroll for (int i = 0; i < per_row_iter; i++) { @@ -151,7 +147,7 @@ inline void thread_block_flashattn(float *S, } // stage per-thread sum value in smem // FIXME: threadblock_id needs to be in here too - warp_smem = sharedmem_scratchpad + (row * NUM_THREADS); + warp_smem = sharedmem_scratchpad + (warp_id * NUM_THREADS); warp_smem[tid_in_warp] = per_thread_sum; // sync writes to warp_smem @@ -167,7 +163,6 @@ inline void thread_block_flashattn(float *S, sharedmem_row_max_sum[row] = per_thread_sum; gmem_tmp2[row] = per_thread_sum; } -#endif threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core);