flash: Do exponential approx to rowsum and Oi as well

This commit is contained in:
Hansung Kim
2024-08-19 20:52:57 -07:00
parent f6cc61241b
commit 4080dec9d6

View File

@@ -204,8 +204,9 @@ inline void thread_block_online_softmax(
// check Q*K result
gmem_tmp0[thread_offset] = f0;
// FIXME: placeholder for proper exp
f0 -= rowmax_new;
// 2nd-order Taylor approximation
float exp = 1.0f;
exp += exponential_taylor_term<1>(f0);
exp += exponential_taylor_term<2>(f0);
@@ -228,6 +229,8 @@ inline void thread_block_online_softmax(
//
// two-level tree reduction, similar to rowmax
asm volatile("flashattn_rowsum_start_%=:" ::);
thread_offset = first_thread_offset + tid_in_warp;
float per_thread_sum = 0.0f;
#pragma GCC unroll
@@ -254,38 +257,54 @@ inline void thread_block_online_softmax(
const float mi_prev = smem_rowmax_prev[row];
const float mi_this = smem_rowmax_this[row];
const float exp = mi_prev - mi_this;
const float x = mi_prev - mi_this;
// 2nd-order Taylor approximation
float exp = 1.0f;
exp += exponential_taylor_term<1>(x);
exp += exponential_taylor_term<2>(x);
// update rowsum
const float rowsum_prev = smem_rowsum[row];
// FIXME: placeholder for exponential
float rowsum_new = exp * rowsum_prev + rowsum;
smem_rowsum[row] = rowsum_new;
}
asm volatile("flashattn_rowsum_end_%=:" ::);
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
// Oi rescale
//
asm volatile("flashattn_o_rescale_start_%=:" ::);
thread_offset = first_thread_offset + tid_in_warp;
#pragma GCC unroll
for (int i = 0; i < per_row_iter; i++) {
float fval = smem_O[thread_offset];
float o = smem_O[thread_offset];
const float mi_prev = smem_rowmax_prev[row];
const float mi_new = smem_rowmax_new[row];
const float exp = mi_prev - mi_new;
// FIXME: placeholder for proper exp
fval *= exp;
const float x = mi_prev - mi_new;
// 2nd-order Taylor approximation
float exp = 1.0f;
exp += exponential_taylor_term<1>(x);
exp += exponential_taylor_term<2>(x);
// @perf: div vs. expansion on e(-x)?
o /= exp;
// update Oi in-place
smem_O[thread_offset] = fval;
smem_O[thread_offset] = o;
thread_offset += NUM_THREADS;
}
asm volatile("flashattn_o_rescale_end_%=:" ::);
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}