flash: Compute exponents using prev/next/this rowmax values

maybe there is a better way than storing all three in sharedmem?
This commit is contained in:
Hansung Kim
2024-08-15 22:09:13 -07:00
parent be08204e65
commit d3de1b674a

View File

@@ -31,14 +31,17 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock,
constexpr uint32_t num_warps = B_ROW / NUM_THREADS;
if (warp_id < num_warps) {
uint32_t offset = NUM_THREADS * warp_id + tid_in_warp;
// mi, mi~, minew
smem_rowmax[offset] = FLT_MIN;
smem_rowmax[offset + B_ROW] = FLT_MIN;
smem_rowmax[offset + 2 * B_ROW] = FLT_MIN;
smem_rowsum[offset] = 0.0f;
}
// FIXME: dedup this pattern
for (int warp_offset = 0; warp_offset < B_COL;
warp_offset += warps_in_threadblock) {
// each warp clears out a row of smem_O
// FIXME: dedup this pattern
const uint32_t row = warp_offset + warp_id;
uint32_t thread_offset = HEADDIM * row + tid_in_warp;
constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS;
@@ -79,6 +82,10 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O,
volatile float *gmem_tmp1 = reinterpret_cast<volatile float *>(0xe0000000UL);
volatile float *gmem_tmp2 = reinterpret_cast<volatile float *>(0xf0000000UL);
float *smem_rowmax_prev = smem_rowmax;
float *smem_rowmax_new = smem_rowmax + B_ROW;
float *smem_rowmax_this = smem_rowmax + 2 * B_ROW;
for (int warp_offset = 0; warp_offset < B_ROW;
warp_offset += warps_in_threadblock) {
const uint32_t row = warp_offset + warp_id;
@@ -136,15 +143,15 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O,
: "=f"(rowmax)
: "f"(rowmax), "f"(other));
}
smem_rowmax_this[row] = rowmax;
// update previous rowmax
// i.e. mi_new = max(mi, mij)
float prev_rowmax = smem_rowmax[row];
float prev_rowmax = smem_rowmax_prev[row];
asm volatile("fmax.s %0, %1, %2"
: "=f"(rowmax)
: "f"(rowmax), "f"(prev_rowmax));
smem_rowmax[row] = rowmax;
smem_rowmax_new[row] = rowmax;
gmem_tmp0[row] = rowmax;
}
#endif
@@ -160,7 +167,7 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O,
// (exp_elem_per_thread * threads_per_threadblock) / B_COL;
// broadcast rowmax to all threads in the warp
const float row_max = smem_rowmax[row];
const float rowmax_new = smem_rowmax_new[row];
// each thread computes two fp32 elements, downconverts it to fp16, then
// packs them into one fp32
@@ -177,8 +184,8 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O,
// float f1 = S[thread_offset + 1];
// FIXME: placeholder for proper exp
f0 -= row_max;
// f1 -= row_max;
f0 -= rowmax_new;
// f1 -= rowmax_new;
// float16_t h0 = NN_float_to_half(f0);
// float16_t h1 = NN_float_to_half(f1);
@@ -217,13 +224,22 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O,
// 0-th thread collects all other thread's values in the warp
if (tid_in_warp == 0) {
float rowsum = per_thread_sum;
for (int iter = 1; iter < NUM_THREADS; iter++) {
float other = warp_smem[iter];
per_thread_sum += other;
rowsum += other;
}
// TODO: update previous rowsum here
smem_rowsum[row] = per_thread_sum;
const float mi_prev = smem_rowmax_prev[row];
const float mi_this = smem_rowmax_this[row];
const float mi_new = smem_rowmax_new[row];
const float exp = mi_prev - mi_this;
// update rowsum
const float rowsum_prev = smem_rowsum[row];
// FIXME: placeholder for exponential
float rowsum_new = exp * rowsum_prev + rowsum;
smem_rowsum[row] = rowsum_new;
}
threadblock_barrier(threadblock_id_in_cluster,
@@ -236,8 +252,12 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O,
for (int i = 0; i < per_row_iter; i++) {
float fval = smem_O[thread_offset];
const float mi_prev = smem_rowmax_prev[row];
const float mi_new = smem_rowmax_new[row];
const float exp = mi_prev - mi_new;
// FIXME: placeholder for proper exp
fval *= 2.0f;
fval *= exp;
// update Oi in-place
smem_O[thread_offset] = fval;
@@ -300,7 +320,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
sizeof(float) * (smem_QK_size + smem_V_size);
// allocate rowmax/rowsum storage at the end of the sharedmem address space
constexpr uint32_t smem_rowmax_size = sizeof(float) * B_ROW;
constexpr uint32_t smem_rowmax_size = sizeof(float) * B_ROW * 3 /* mi, mi~, minew */;
constexpr uint32_t smem_rowsum_size = sizeof(float) * B_ROW;
uint8_t *smem_rowmax =
reinterpret_cast<uint8_t *>(SMEM_ADDR_END) - smem_rowmax_size;