diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index b1141875..a7d21591 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -204,8 +204,9 @@ inline void thread_block_online_softmax( // check Q*K result gmem_tmp0[thread_offset] = f0; - // FIXME: placeholder for proper exp f0 -= rowmax_new; + + // 2nd-order Taylor approximation float exp = 1.0f; exp += exponential_taylor_term<1>(f0); exp += exponential_taylor_term<2>(f0); @@ -228,6 +229,8 @@ inline void thread_block_online_softmax( // // two-level tree reduction, similar to rowmax + asm volatile("flashattn_rowsum_start_%=:" ::); + thread_offset = first_thread_offset + tid_in_warp; float per_thread_sum = 0.0f; #pragma GCC unroll @@ -254,38 +257,54 @@ inline void thread_block_online_softmax( const float mi_prev = smem_rowmax_prev[row]; const float mi_this = smem_rowmax_this[row]; - const float exp = mi_prev - mi_this; + + const float x = mi_prev - mi_this; + // 2nd-order Taylor approximation + float exp = 1.0f; + exp += exponential_taylor_term<1>(x); + exp += exponential_taylor_term<2>(x); // update rowsum const float rowsum_prev = smem_rowsum[row]; - // FIXME: placeholder for exponential float rowsum_new = exp * rowsum_prev + rowsum; + smem_rowsum[row] = rowsum_new; } + asm volatile("flashattn_rowsum_end_%=:" ::); + threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core); // Oi rescale // + asm volatile("flashattn_o_rescale_start_%=:" ::); + thread_offset = first_thread_offset + tid_in_warp; #pragma GCC unroll for (int i = 0; i < per_row_iter; i++) { - float fval = smem_O[thread_offset]; + float o = smem_O[thread_offset]; const float mi_prev = smem_rowmax_prev[row]; const float mi_new = smem_rowmax_new[row]; - const float exp = mi_prev - mi_new; - // FIXME: placeholder for proper exp - fval *= exp; + const float x = mi_prev - mi_new; + // 2nd-order Taylor approximation + float exp = 1.0f; + exp += exponential_taylor_term<1>(x); + exp += exponential_taylor_term<2>(x); + + // @perf: div vs. expansion on e(-x)? + o /= exp; // update Oi in-place - smem_O[thread_offset] = fval; + smem_O[thread_offset] = o; thread_offset += NUM_THREADS; } + asm volatile("flashattn_o_rescale_end_%=:" ::); + threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core); }