flash: Pass smem_P arg to softmax func

This commit is contained in:
Hansung Kim
2024-08-18 15:21:05 -07:00
parent d0809d292a
commit 90f6effa97

View File

@@ -53,13 +53,11 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock,
}
}
inline void thread_block_flashattn(float *smem_S, float *smem_O,
const uint32_t tid_in_threadblock,
const uint32_t threads_per_threadblock,
const uint32_t threadblock_id_in_cluster,
float *smem_scratchpad,
float *smem_rowmax,
float *smem_rowsum) {
inline void thread_block_online_softmax(
float *smem_S, float *smem_O, float *smem_P,
const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock,
const uint32_t threadblock_id_in_cluster, float *smem_scratchpad,
float *smem_rowmax, float *smem_rowsum) {
asm volatile("thread_block_flashattn_start_%=:" ::);
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
@@ -191,7 +189,6 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O,
// Store S transposed to the shared memory
// update S in-place into P
smem_S[thread_offset] = f0;
// S[thread_offset + 1] = f1;
gmem_tmp1[thread_offset] = f0;
@@ -232,7 +229,6 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O,
const float mi_prev = smem_rowmax_prev[row];
const float mi_this = smem_rowmax_this[row];
const float mi_new = smem_rowmax_new[row];
const float exp = mi_prev - mi_this;
// update rowsum
@@ -261,7 +257,6 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O,
// update Oi in-place
smem_O[thread_offset] = fval;
gmem_tmp2[thread_offset] = fval;
thread_offset += NUM_THREADS;
}
@@ -316,8 +311,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
threadblock_id_in_cluster);
uint8_t *smem_S = smem_per_threadblock;
uint8_t *smem_O = smem_per_threadblock +
sizeof(float) * (smem_QK_size + smem_V_size);
uint8_t *smem_P = smem_S; // in-place update from S to P
uint8_t *smem_V = smem_per_threadblock + sizeof(float) * smem_QK_size;
uint8_t *smem_O =
smem_per_threadblock + sizeof(float) * (smem_QK_size + smem_V_size);
// allocate rowmax/rowsum storage at the end of the sharedmem address space
constexpr uint32_t smem_rowmax_size = sizeof(float) * B_ROW * 3 /* mi, mi~, minew */;
@@ -361,10 +358,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
float *tile_S = (float *)arg->addr_q;
#endif
thread_block_flashattn(tile_S, (float *)smem_O, tid_in_threadblock,
thread_block_online_softmax(
tile_S, (float *)smem_O, (float *)smem_P, tid_in_threadblock,
threads_per_threadblock, threadblock_id_in_cluster,
(float *)smem_scratchpad, (float *)smem_rowmax,
(float *)smem_rowsum);
(float *)smem_scratchpad, (float *)smem_rowmax, (float *)smem_rowsum);
// FIXME unnecessary?
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
thread_block_gemm_single_tile(smem_P, smem_V, tid_in_threadblock,
threads_per_threadblock);
}
int main() {