flash: Pass smem_P arg to softmax func
This commit is contained in:
@@ -53,13 +53,11 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void thread_block_flashattn(float *smem_S, float *smem_O,
|
inline void thread_block_online_softmax(
|
||||||
const uint32_t tid_in_threadblock,
|
float *smem_S, float *smem_O, float *smem_P,
|
||||||
const uint32_t threads_per_threadblock,
|
const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock,
|
||||||
const uint32_t threadblock_id_in_cluster,
|
const uint32_t threadblock_id_in_cluster, float *smem_scratchpad,
|
||||||
float *smem_scratchpad,
|
float *smem_rowmax, float *smem_rowsum) {
|
||||||
float *smem_rowmax,
|
|
||||||
float *smem_rowsum) {
|
|
||||||
asm volatile("thread_block_flashattn_start_%=:" ::);
|
asm volatile("thread_block_flashattn_start_%=:" ::);
|
||||||
|
|
||||||
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
||||||
@@ -191,7 +189,6 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O,
|
|||||||
|
|
||||||
// Store S transposed to the shared memory
|
// Store S transposed to the shared memory
|
||||||
|
|
||||||
// update S in-place into P
|
|
||||||
smem_S[thread_offset] = f0;
|
smem_S[thread_offset] = f0;
|
||||||
// S[thread_offset + 1] = f1;
|
// S[thread_offset + 1] = f1;
|
||||||
gmem_tmp1[thread_offset] = f0;
|
gmem_tmp1[thread_offset] = f0;
|
||||||
@@ -232,7 +229,6 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O,
|
|||||||
|
|
||||||
const float mi_prev = smem_rowmax_prev[row];
|
const float mi_prev = smem_rowmax_prev[row];
|
||||||
const float mi_this = smem_rowmax_this[row];
|
const float mi_this = smem_rowmax_this[row];
|
||||||
const float mi_new = smem_rowmax_new[row];
|
|
||||||
const float exp = mi_prev - mi_this;
|
const float exp = mi_prev - mi_this;
|
||||||
|
|
||||||
// update rowsum
|
// update rowsum
|
||||||
@@ -261,7 +257,6 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O,
|
|||||||
|
|
||||||
// update Oi in-place
|
// update Oi in-place
|
||||||
smem_O[thread_offset] = fval;
|
smem_O[thread_offset] = fval;
|
||||||
gmem_tmp2[thread_offset] = fval;
|
|
||||||
|
|
||||||
thread_offset += NUM_THREADS;
|
thread_offset += NUM_THREADS;
|
||||||
}
|
}
|
||||||
@@ -316,8 +311,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
threadblock_id_in_cluster);
|
threadblock_id_in_cluster);
|
||||||
|
|
||||||
uint8_t *smem_S = smem_per_threadblock;
|
uint8_t *smem_S = smem_per_threadblock;
|
||||||
uint8_t *smem_O = smem_per_threadblock +
|
uint8_t *smem_P = smem_S; // in-place update from S to P
|
||||||
sizeof(float) * (smem_QK_size + smem_V_size);
|
uint8_t *smem_V = smem_per_threadblock + sizeof(float) * smem_QK_size;
|
||||||
|
uint8_t *smem_O =
|
||||||
|
smem_per_threadblock + sizeof(float) * (smem_QK_size + smem_V_size);
|
||||||
|
|
||||||
// allocate rowmax/rowsum storage at the end of the sharedmem address space
|
// allocate rowmax/rowsum storage at the end of the sharedmem address space
|
||||||
constexpr uint32_t smem_rowmax_size = sizeof(float) * B_ROW * 3 /* mi, mi~, minew */;
|
constexpr uint32_t smem_rowmax_size = sizeof(float) * B_ROW * 3 /* mi, mi~, minew */;
|
||||||
@@ -361,10 +358,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
float *tile_S = (float *)arg->addr_q;
|
float *tile_S = (float *)arg->addr_q;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
thread_block_flashattn(tile_S, (float *)smem_O, tid_in_threadblock,
|
thread_block_online_softmax(
|
||||||
threads_per_threadblock, threadblock_id_in_cluster,
|
tile_S, (float *)smem_O, (float *)smem_P, tid_in_threadblock,
|
||||||
(float *)smem_scratchpad, (float *)smem_rowmax,
|
threads_per_threadblock, threadblock_id_in_cluster,
|
||||||
(float *)smem_rowsum);
|
(float *)smem_scratchpad, (float *)smem_rowmax, (float *)smem_rowsum);
|
||||||
|
|
||||||
|
// FIXME unnecessary?
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
|
||||||
|
thread_block_gemm_single_tile(smem_P, smem_V, tid_in_threadblock,
|
||||||
|
threads_per_threadblock);
|
||||||
}
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
|
|||||||
Reference in New Issue
Block a user