flash: Pass smem_P arg to softmax func

This commit is contained in:
Hansung Kim
2024-08-18 15:21:05 -07:00
parent d0809d292a
commit 90f6effa97

View File

@@ -53,13 +53,11 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock,
} }
} }
inline void thread_block_flashattn(float *smem_S, float *smem_O, inline void thread_block_online_softmax(
const uint32_t tid_in_threadblock, float *smem_S, float *smem_O, float *smem_P,
const uint32_t threads_per_threadblock, const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock,
const uint32_t threadblock_id_in_cluster, const uint32_t threadblock_id_in_cluster, float *smem_scratchpad,
float *smem_scratchpad, float *smem_rowmax, float *smem_rowsum) {
float *smem_rowmax,
float *smem_rowsum) {
asm volatile("thread_block_flashattn_start_%=:" ::); asm volatile("thread_block_flashattn_start_%=:" ::);
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
@@ -191,7 +189,6 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O,
// Store S transposed to the shared memory // Store S transposed to the shared memory
// update S in-place into P
smem_S[thread_offset] = f0; smem_S[thread_offset] = f0;
// S[thread_offset + 1] = f1; // S[thread_offset + 1] = f1;
gmem_tmp1[thread_offset] = f0; gmem_tmp1[thread_offset] = f0;
@@ -232,7 +229,6 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O,
const float mi_prev = smem_rowmax_prev[row]; const float mi_prev = smem_rowmax_prev[row];
const float mi_this = smem_rowmax_this[row]; const float mi_this = smem_rowmax_this[row];
const float mi_new = smem_rowmax_new[row];
const float exp = mi_prev - mi_this; const float exp = mi_prev - mi_this;
// update rowsum // update rowsum
@@ -261,7 +257,6 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O,
// update Oi in-place // update Oi in-place
smem_O[thread_offset] = fval; smem_O[thread_offset] = fval;
gmem_tmp2[thread_offset] = fval;
thread_offset += NUM_THREADS; thread_offset += NUM_THREADS;
} }
@@ -316,8 +311,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
threadblock_id_in_cluster); threadblock_id_in_cluster);
uint8_t *smem_S = smem_per_threadblock; uint8_t *smem_S = smem_per_threadblock;
uint8_t *smem_O = smem_per_threadblock + uint8_t *smem_P = smem_S; // in-place update from S to P
sizeof(float) * (smem_QK_size + smem_V_size); uint8_t *smem_V = smem_per_threadblock + sizeof(float) * smem_QK_size;
uint8_t *smem_O =
smem_per_threadblock + sizeof(float) * (smem_QK_size + smem_V_size);
// allocate rowmax/rowsum storage at the end of the sharedmem address space // allocate rowmax/rowsum storage at the end of the sharedmem address space
constexpr uint32_t smem_rowmax_size = sizeof(float) * B_ROW * 3 /* mi, mi~, minew */; constexpr uint32_t smem_rowmax_size = sizeof(float) * B_ROW * 3 /* mi, mi~, minew */;
@@ -361,10 +358,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
float *tile_S = (float *)arg->addr_q; float *tile_S = (float *)arg->addr_q;
#endif #endif
thread_block_flashattn(tile_S, (float *)smem_O, tid_in_threadblock, thread_block_online_softmax(
threads_per_threadblock, threadblock_id_in_cluster, tile_S, (float *)smem_O, (float *)smem_P, tid_in_threadblock,
(float *)smem_scratchpad, (float *)smem_rowmax, threads_per_threadblock, threadblock_id_in_cluster,
(float *)smem_rowsum); (float *)smem_scratchpad, (float *)smem_rowmax, (float *)smem_rowsum);
// FIXME unnecessary?
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
thread_block_gemm_single_tile(smem_P, smem_V, tid_in_threadblock,
threads_per_threadblock);
} }
int main() { int main() {