flash: Pass smem_P arg to softmax func
This commit is contained in:
@@ -53,13 +53,11 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock,
|
||||
}
|
||||
}
|
||||
|
||||
inline void thread_block_flashattn(float *smem_S, float *smem_O,
|
||||
const uint32_t tid_in_threadblock,
|
||||
const uint32_t threads_per_threadblock,
|
||||
const uint32_t threadblock_id_in_cluster,
|
||||
float *smem_scratchpad,
|
||||
float *smem_rowmax,
|
||||
float *smem_rowsum) {
|
||||
inline void thread_block_online_softmax(
|
||||
float *smem_S, float *smem_O, float *smem_P,
|
||||
const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock,
|
||||
const uint32_t threadblock_id_in_cluster, float *smem_scratchpad,
|
||||
float *smem_rowmax, float *smem_rowsum) {
|
||||
asm volatile("thread_block_flashattn_start_%=:" ::);
|
||||
|
||||
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
||||
@@ -191,7 +189,6 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O,
|
||||
|
||||
// Store S transposed to the shared memory
|
||||
|
||||
// update S in-place into P
|
||||
smem_S[thread_offset] = f0;
|
||||
// S[thread_offset + 1] = f1;
|
||||
gmem_tmp1[thread_offset] = f0;
|
||||
@@ -232,7 +229,6 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O,
|
||||
|
||||
const float mi_prev = smem_rowmax_prev[row];
|
||||
const float mi_this = smem_rowmax_this[row];
|
||||
const float mi_new = smem_rowmax_new[row];
|
||||
const float exp = mi_prev - mi_this;
|
||||
|
||||
// update rowsum
|
||||
@@ -261,7 +257,6 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O,
|
||||
|
||||
// update Oi in-place
|
||||
smem_O[thread_offset] = fval;
|
||||
gmem_tmp2[thread_offset] = fval;
|
||||
|
||||
thread_offset += NUM_THREADS;
|
||||
}
|
||||
@@ -316,8 +311,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
threadblock_id_in_cluster);
|
||||
|
||||
uint8_t *smem_S = smem_per_threadblock;
|
||||
uint8_t *smem_O = smem_per_threadblock +
|
||||
sizeof(float) * (smem_QK_size + smem_V_size);
|
||||
uint8_t *smem_P = smem_S; // in-place update from S to P
|
||||
uint8_t *smem_V = smem_per_threadblock + sizeof(float) * smem_QK_size;
|
||||
uint8_t *smem_O =
|
||||
smem_per_threadblock + sizeof(float) * (smem_QK_size + smem_V_size);
|
||||
|
||||
// allocate rowmax/rowsum storage at the end of the sharedmem address space
|
||||
constexpr uint32_t smem_rowmax_size = sizeof(float) * B_ROW * 3 /* mi, mi~, minew */;
|
||||
@@ -361,10 +358,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
float *tile_S = (float *)arg->addr_q;
|
||||
#endif
|
||||
|
||||
thread_block_flashattn(tile_S, (float *)smem_O, tid_in_threadblock,
|
||||
threads_per_threadblock, threadblock_id_in_cluster,
|
||||
(float *)smem_scratchpad, (float *)smem_rowmax,
|
||||
(float *)smem_rowsum);
|
||||
thread_block_online_softmax(
|
||||
tile_S, (float *)smem_O, (float *)smem_P, tid_in_threadblock,
|
||||
threads_per_threadblock, threadblock_id_in_cluster,
|
||||
(float *)smem_scratchpad, (float *)smem_rowmax, (float *)smem_rowsum);
|
||||
|
||||
// FIXME unnecessary?
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
|
||||
thread_block_gemm_single_tile(smem_P, smem_V, tid_in_threadblock,
|
||||
threads_per_threadblock);
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
Reference in New Issue
Block a user