flash: Do exponential approx to rowsum and Oi as well
This commit is contained in:
@@ -204,8 +204,9 @@ inline void thread_block_online_softmax(
|
||||
// check Q*K result
|
||||
gmem_tmp0[thread_offset] = f0;
|
||||
|
||||
// FIXME: placeholder for proper exp
|
||||
f0 -= rowmax_new;
|
||||
|
||||
// 2nd-order Taylor approximation
|
||||
float exp = 1.0f;
|
||||
exp += exponential_taylor_term<1>(f0);
|
||||
exp += exponential_taylor_term<2>(f0);
|
||||
@@ -228,6 +229,8 @@ inline void thread_block_online_softmax(
|
||||
//
|
||||
// two-level tree reduction, similar to rowmax
|
||||
|
||||
asm volatile("flashattn_rowsum_start_%=:" ::);
|
||||
|
||||
thread_offset = first_thread_offset + tid_in_warp;
|
||||
float per_thread_sum = 0.0f;
|
||||
#pragma GCC unroll
|
||||
@@ -254,38 +257,54 @@ inline void thread_block_online_softmax(
|
||||
|
||||
const float mi_prev = smem_rowmax_prev[row];
|
||||
const float mi_this = smem_rowmax_this[row];
|
||||
const float exp = mi_prev - mi_this;
|
||||
|
||||
const float x = mi_prev - mi_this;
|
||||
// 2nd-order Taylor approximation
|
||||
float exp = 1.0f;
|
||||
exp += exponential_taylor_term<1>(x);
|
||||
exp += exponential_taylor_term<2>(x);
|
||||
|
||||
// update rowsum
|
||||
const float rowsum_prev = smem_rowsum[row];
|
||||
// FIXME: placeholder for exponential
|
||||
float rowsum_new = exp * rowsum_prev + rowsum;
|
||||
|
||||
smem_rowsum[row] = rowsum_new;
|
||||
}
|
||||
|
||||
asm volatile("flashattn_rowsum_end_%=:" ::);
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
|
||||
// Oi rescale
|
||||
//
|
||||
asm volatile("flashattn_o_rescale_start_%=:" ::);
|
||||
|
||||
thread_offset = first_thread_offset + tid_in_warp;
|
||||
#pragma GCC unroll
|
||||
for (int i = 0; i < per_row_iter; i++) {
|
||||
float fval = smem_O[thread_offset];
|
||||
float o = smem_O[thread_offset];
|
||||
|
||||
const float mi_prev = smem_rowmax_prev[row];
|
||||
const float mi_new = smem_rowmax_new[row];
|
||||
const float exp = mi_prev - mi_new;
|
||||
|
||||
// FIXME: placeholder for proper exp
|
||||
fval *= exp;
|
||||
const float x = mi_prev - mi_new;
|
||||
// 2nd-order Taylor approximation
|
||||
float exp = 1.0f;
|
||||
exp += exponential_taylor_term<1>(x);
|
||||
exp += exponential_taylor_term<2>(x);
|
||||
|
||||
// @perf: div vs. expansion on e(-x)?
|
||||
o /= exp;
|
||||
|
||||
// update Oi in-place
|
||||
smem_O[thread_offset] = fval;
|
||||
smem_O[thread_offset] = o;
|
||||
|
||||
thread_offset += NUM_THREADS;
|
||||
}
|
||||
|
||||
asm volatile("flashattn_o_rescale_end_%=:" ::);
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user