flash: Compute exponents using prev/next/this rowmax values
maybe there is a better way than storing all three in sharedmem?
This commit is contained in:
@@ -31,14 +31,17 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock,
|
|||||||
constexpr uint32_t num_warps = B_ROW / NUM_THREADS;
|
constexpr uint32_t num_warps = B_ROW / NUM_THREADS;
|
||||||
if (warp_id < num_warps) {
|
if (warp_id < num_warps) {
|
||||||
uint32_t offset = NUM_THREADS * warp_id + tid_in_warp;
|
uint32_t offset = NUM_THREADS * warp_id + tid_in_warp;
|
||||||
|
// mi, mi~, minew
|
||||||
smem_rowmax[offset] = FLT_MIN;
|
smem_rowmax[offset] = FLT_MIN;
|
||||||
|
smem_rowmax[offset + B_ROW] = FLT_MIN;
|
||||||
|
smem_rowmax[offset + 2 * B_ROW] = FLT_MIN;
|
||||||
smem_rowsum[offset] = 0.0f;
|
smem_rowsum[offset] = 0.0f;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FIXME: dedup this pattern
|
||||||
for (int warp_offset = 0; warp_offset < B_COL;
|
for (int warp_offset = 0; warp_offset < B_COL;
|
||||||
warp_offset += warps_in_threadblock) {
|
warp_offset += warps_in_threadblock) {
|
||||||
// each warp clears out a row of smem_O
|
// each warp clears out a row of smem_O
|
||||||
// FIXME: dedup this pattern
|
|
||||||
const uint32_t row = warp_offset + warp_id;
|
const uint32_t row = warp_offset + warp_id;
|
||||||
uint32_t thread_offset = HEADDIM * row + tid_in_warp;
|
uint32_t thread_offset = HEADDIM * row + tid_in_warp;
|
||||||
constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS;
|
constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS;
|
||||||
@@ -79,6 +82,10 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O,
|
|||||||
volatile float *gmem_tmp1 = reinterpret_cast<volatile float *>(0xe0000000UL);
|
volatile float *gmem_tmp1 = reinterpret_cast<volatile float *>(0xe0000000UL);
|
||||||
volatile float *gmem_tmp2 = reinterpret_cast<volatile float *>(0xf0000000UL);
|
volatile float *gmem_tmp2 = reinterpret_cast<volatile float *>(0xf0000000UL);
|
||||||
|
|
||||||
|
float *smem_rowmax_prev = smem_rowmax;
|
||||||
|
float *smem_rowmax_new = smem_rowmax + B_ROW;
|
||||||
|
float *smem_rowmax_this = smem_rowmax + 2 * B_ROW;
|
||||||
|
|
||||||
for (int warp_offset = 0; warp_offset < B_ROW;
|
for (int warp_offset = 0; warp_offset < B_ROW;
|
||||||
warp_offset += warps_in_threadblock) {
|
warp_offset += warps_in_threadblock) {
|
||||||
const uint32_t row = warp_offset + warp_id;
|
const uint32_t row = warp_offset + warp_id;
|
||||||
@@ -136,15 +143,15 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O,
|
|||||||
: "=f"(rowmax)
|
: "=f"(rowmax)
|
||||||
: "f"(rowmax), "f"(other));
|
: "f"(rowmax), "f"(other));
|
||||||
}
|
}
|
||||||
|
smem_rowmax_this[row] = rowmax;
|
||||||
|
|
||||||
// update previous rowmax
|
// update previous rowmax
|
||||||
// i.e. mi_new = max(mi, mij)
|
// i.e. mi_new = max(mi, mij)
|
||||||
float prev_rowmax = smem_rowmax[row];
|
float prev_rowmax = smem_rowmax_prev[row];
|
||||||
asm volatile("fmax.s %0, %1, %2"
|
asm volatile("fmax.s %0, %1, %2"
|
||||||
: "=f"(rowmax)
|
: "=f"(rowmax)
|
||||||
: "f"(rowmax), "f"(prev_rowmax));
|
: "f"(rowmax), "f"(prev_rowmax));
|
||||||
|
smem_rowmax_new[row] = rowmax;
|
||||||
smem_rowmax[row] = rowmax;
|
|
||||||
gmem_tmp0[row] = rowmax;
|
gmem_tmp0[row] = rowmax;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@@ -160,7 +167,7 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O,
|
|||||||
// (exp_elem_per_thread * threads_per_threadblock) / B_COL;
|
// (exp_elem_per_thread * threads_per_threadblock) / B_COL;
|
||||||
|
|
||||||
// broadcast rowmax to all threads in the warp
|
// broadcast rowmax to all threads in the warp
|
||||||
const float row_max = smem_rowmax[row];
|
const float rowmax_new = smem_rowmax_new[row];
|
||||||
|
|
||||||
// each thread computes two fp32 elements, downconverts it to fp16, then
|
// each thread computes two fp32 elements, downconverts it to fp16, then
|
||||||
// packs them into one fp32
|
// packs them into one fp32
|
||||||
@@ -177,8 +184,8 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O,
|
|||||||
// float f1 = S[thread_offset + 1];
|
// float f1 = S[thread_offset + 1];
|
||||||
|
|
||||||
// FIXME: placeholder for proper exp
|
// FIXME: placeholder for proper exp
|
||||||
f0 -= row_max;
|
f0 -= rowmax_new;
|
||||||
// f1 -= row_max;
|
// f1 -= rowmax_new;
|
||||||
// float16_t h0 = NN_float_to_half(f0);
|
// float16_t h0 = NN_float_to_half(f0);
|
||||||
// float16_t h1 = NN_float_to_half(f1);
|
// float16_t h1 = NN_float_to_half(f1);
|
||||||
|
|
||||||
@@ -217,13 +224,22 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O,
|
|||||||
|
|
||||||
// 0-th thread collects all other thread's values in the warp
|
// 0-th thread collects all other thread's values in the warp
|
||||||
if (tid_in_warp == 0) {
|
if (tid_in_warp == 0) {
|
||||||
|
float rowsum = per_thread_sum;
|
||||||
for (int iter = 1; iter < NUM_THREADS; iter++) {
|
for (int iter = 1; iter < NUM_THREADS; iter++) {
|
||||||
float other = warp_smem[iter];
|
float other = warp_smem[iter];
|
||||||
per_thread_sum += other;
|
rowsum += other;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: update previous rowsum here
|
const float mi_prev = smem_rowmax_prev[row];
|
||||||
smem_rowsum[row] = per_thread_sum;
|
const float mi_this = smem_rowmax_this[row];
|
||||||
|
const float mi_new = smem_rowmax_new[row];
|
||||||
|
const float exp = mi_prev - mi_this;
|
||||||
|
|
||||||
|
// update rowsum
|
||||||
|
const float rowsum_prev = smem_rowsum[row];
|
||||||
|
// FIXME: placeholder for exponential
|
||||||
|
float rowsum_new = exp * rowsum_prev + rowsum;
|
||||||
|
smem_rowsum[row] = rowsum_new;
|
||||||
}
|
}
|
||||||
|
|
||||||
threadblock_barrier(threadblock_id_in_cluster,
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
@@ -236,8 +252,12 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O,
|
|||||||
for (int i = 0; i < per_row_iter; i++) {
|
for (int i = 0; i < per_row_iter; i++) {
|
||||||
float fval = smem_O[thread_offset];
|
float fval = smem_O[thread_offset];
|
||||||
|
|
||||||
|
const float mi_prev = smem_rowmax_prev[row];
|
||||||
|
const float mi_new = smem_rowmax_new[row];
|
||||||
|
const float exp = mi_prev - mi_new;
|
||||||
|
|
||||||
// FIXME: placeholder for proper exp
|
// FIXME: placeholder for proper exp
|
||||||
fval *= 2.0f;
|
fval *= exp;
|
||||||
|
|
||||||
// update Oi in-place
|
// update Oi in-place
|
||||||
smem_O[thread_offset] = fval;
|
smem_O[thread_offset] = fval;
|
||||||
@@ -300,7 +320,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
sizeof(float) * (smem_QK_size + smem_V_size);
|
sizeof(float) * (smem_QK_size + smem_V_size);
|
||||||
|
|
||||||
// allocate rowmax/rowsum storage at the end of the sharedmem address space
|
// allocate rowmax/rowsum storage at the end of the sharedmem address space
|
||||||
constexpr uint32_t smem_rowmax_size = sizeof(float) * B_ROW;
|
constexpr uint32_t smem_rowmax_size = sizeof(float) * B_ROW * 3 /* mi, mi~, minew */;
|
||||||
constexpr uint32_t smem_rowsum_size = sizeof(float) * B_ROW;
|
constexpr uint32_t smem_rowsum_size = sizeof(float) * B_ROW;
|
||||||
uint8_t *smem_rowmax =
|
uint8_t *smem_rowmax =
|
||||||
reinterpret_cast<uint8_t *>(SMEM_ADDR_END) - smem_rowmax_size;
|
reinterpret_cast<uint8_t *>(SMEM_ADDR_END) - smem_rowmax_size;
|
||||||
|
|||||||
Reference in New Issue
Block a user