flash: Compute exponents using prev/next/this rowmax values
maybe there is a better way than storing all three in sharedmem?
This commit is contained in:
@@ -31,14 +31,17 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock,
|
||||
constexpr uint32_t num_warps = B_ROW / NUM_THREADS;
|
||||
if (warp_id < num_warps) {
|
||||
uint32_t offset = NUM_THREADS * warp_id + tid_in_warp;
|
||||
// mi, mi~, minew
|
||||
smem_rowmax[offset] = FLT_MIN;
|
||||
smem_rowmax[offset + B_ROW] = FLT_MIN;
|
||||
smem_rowmax[offset + 2 * B_ROW] = FLT_MIN;
|
||||
smem_rowsum[offset] = 0.0f;
|
||||
}
|
||||
|
||||
// FIXME: dedup this pattern
|
||||
for (int warp_offset = 0; warp_offset < B_COL;
|
||||
warp_offset += warps_in_threadblock) {
|
||||
// each warp clears out a row of smem_O
|
||||
// FIXME: dedup this pattern
|
||||
const uint32_t row = warp_offset + warp_id;
|
||||
uint32_t thread_offset = HEADDIM * row + tid_in_warp;
|
||||
constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS;
|
||||
@@ -79,6 +82,10 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O,
|
||||
volatile float *gmem_tmp1 = reinterpret_cast<volatile float *>(0xe0000000UL);
|
||||
volatile float *gmem_tmp2 = reinterpret_cast<volatile float *>(0xf0000000UL);
|
||||
|
||||
float *smem_rowmax_prev = smem_rowmax;
|
||||
float *smem_rowmax_new = smem_rowmax + B_ROW;
|
||||
float *smem_rowmax_this = smem_rowmax + 2 * B_ROW;
|
||||
|
||||
for (int warp_offset = 0; warp_offset < B_ROW;
|
||||
warp_offset += warps_in_threadblock) {
|
||||
const uint32_t row = warp_offset + warp_id;
|
||||
@@ -136,15 +143,15 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O,
|
||||
: "=f"(rowmax)
|
||||
: "f"(rowmax), "f"(other));
|
||||
}
|
||||
smem_rowmax_this[row] = rowmax;
|
||||
|
||||
// update previous rowmax
|
||||
// i.e. mi_new = max(mi, mij)
|
||||
float prev_rowmax = smem_rowmax[row];
|
||||
float prev_rowmax = smem_rowmax_prev[row];
|
||||
asm volatile("fmax.s %0, %1, %2"
|
||||
: "=f"(rowmax)
|
||||
: "f"(rowmax), "f"(prev_rowmax));
|
||||
|
||||
smem_rowmax[row] = rowmax;
|
||||
smem_rowmax_new[row] = rowmax;
|
||||
gmem_tmp0[row] = rowmax;
|
||||
}
|
||||
#endif
|
||||
@@ -160,7 +167,7 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O,
|
||||
// (exp_elem_per_thread * threads_per_threadblock) / B_COL;
|
||||
|
||||
// broadcast rowmax to all threads in the warp
|
||||
const float row_max = smem_rowmax[row];
|
||||
const float rowmax_new = smem_rowmax_new[row];
|
||||
|
||||
// each thread computes two fp32 elements, downconverts it to fp16, then
|
||||
// packs them into one fp32
|
||||
@@ -177,8 +184,8 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O,
|
||||
// float f1 = S[thread_offset + 1];
|
||||
|
||||
// FIXME: placeholder for proper exp
|
||||
f0 -= row_max;
|
||||
// f1 -= row_max;
|
||||
f0 -= rowmax_new;
|
||||
// f1 -= rowmax_new;
|
||||
// float16_t h0 = NN_float_to_half(f0);
|
||||
// float16_t h1 = NN_float_to_half(f1);
|
||||
|
||||
@@ -217,13 +224,22 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O,
|
||||
|
||||
// 0-th thread collects all other thread's values in the warp
|
||||
if (tid_in_warp == 0) {
|
||||
float rowsum = per_thread_sum;
|
||||
for (int iter = 1; iter < NUM_THREADS; iter++) {
|
||||
float other = warp_smem[iter];
|
||||
per_thread_sum += other;
|
||||
rowsum += other;
|
||||
}
|
||||
|
||||
// TODO: update previous rowsum here
|
||||
smem_rowsum[row] = per_thread_sum;
|
||||
const float mi_prev = smem_rowmax_prev[row];
|
||||
const float mi_this = smem_rowmax_this[row];
|
||||
const float mi_new = smem_rowmax_new[row];
|
||||
const float exp = mi_prev - mi_this;
|
||||
|
||||
// update rowsum
|
||||
const float rowsum_prev = smem_rowsum[row];
|
||||
// FIXME: placeholder for exponential
|
||||
float rowsum_new = exp * rowsum_prev + rowsum;
|
||||
smem_rowsum[row] = rowsum_new;
|
||||
}
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
@@ -236,8 +252,12 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O,
|
||||
for (int i = 0; i < per_row_iter; i++) {
|
||||
float fval = smem_O[thread_offset];
|
||||
|
||||
const float mi_prev = smem_rowmax_prev[row];
|
||||
const float mi_new = smem_rowmax_new[row];
|
||||
const float exp = mi_prev - mi_new;
|
||||
|
||||
// FIXME: placeholder for proper exp
|
||||
fval *= 2.0f;
|
||||
fval *= exp;
|
||||
|
||||
// update Oi in-place
|
||||
smem_O[thread_offset] = fval;
|
||||
@@ -300,7 +320,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
sizeof(float) * (smem_QK_size + smem_V_size);
|
||||
|
||||
// allocate rowmax/rowsum storage at the end of the sharedmem address space
|
||||
constexpr uint32_t smem_rowmax_size = sizeof(float) * B_ROW;
|
||||
constexpr uint32_t smem_rowmax_size = sizeof(float) * B_ROW * 3 /* mi, mi~, minew */;
|
||||
constexpr uint32_t smem_rowsum_size = sizeof(float) * B_ROW;
|
||||
uint8_t *smem_rowmax =
|
||||
reinterpret_cast<uint8_t *>(SMEM_ADDR_END) - smem_rowmax_size;
|
||||
|
||||
Reference in New Issue
Block a user