flash: Supply correct tile dims to single_tile

This commit is contained in:
Hansung Kim
2024-08-20 19:50:45 -07:00
parent 091f40c365
commit 3f20dd59c0

View File

@@ -141,7 +141,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
const uint32_t threadblocks_per_cluster,
const uint32_t threadblock_id_in_cluster, float *smem_scratchpad,
float *smem_rowmax, float *smem_rowsum) {
asm volatile("thread_block_flashattn_start_%=:" ::);
asm volatile("thread_block_online_softmax_start_%=:" ::);
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
@@ -250,20 +250,11 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
// broadcast updated rowmax to all threads in the warp
const float rowmax_new = smem_rowmax[row];
// each thread computes two fp32 elements, downconverts it to fp16, then
// packs them into one fp32
constexpr uint32_t elem_per_thread = 1;
static_assert((B_COL % (elem_per_thread * NUM_THREADS)) == 0,
"B_COL condition not met for P compute");
thread_offset = first_thread_offset + (elem_per_thread * tid_in_warp);
constexpr uint32_t exp_per_row_iter =
B_COL / (elem_per_thread * NUM_THREADS);
asm volatile("flashattn_exp_p_start_%=:" ::);
thread_offset = first_thread_offset + tid_in_warp;
#pragma GCC unroll
for (int i = 0; i < exp_per_row_iter; i++) {
for (int i = 0; i < per_row_iter; i++) {
float f0 = smem_S[thread_offset];
f0 -= rowmax_new;
@@ -292,8 +283,9 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
asm volatile("flashattn_rowsum_start_%=:" ::);
thread_offset = first_thread_offset + tid_in_warp;
float per_thread_sum = 0.0f;
thread_offset = first_thread_offset + tid_in_warp;
#pragma GCC unroll
for (int i = 0; i < per_row_iter; i++) {
per_thread_sum += smem_P[thread_offset];
@@ -317,7 +309,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
}
const float mi_prev = rowmax_prev;
// TODO: replace this with a register?
const float mi_this = rowmax_this;
const float x = mi_prev - mi_this;
@@ -371,7 +362,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
warps_per_threadblock_per_core);
}
asm volatile("thread_block_flashattn_finish_%=:" ::);
asm volatile("thread_block_online_softmax_finish_%=:" ::);
}
void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
@@ -497,7 +488,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// GEMM I: S = Q*K
thread_block_gemm_single_tile<float, MemLayout::MN_major,
MemLayout::MN_major,
MemLayout::MN_major, B_ROW, B_COL, HEADDIM,
/*load_accum=*/false,
/*write_to_smem=*/true>(
smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_threadblock,
@@ -583,14 +574,23 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
// FIXME: support MN_major for A for ideal performance
thread_block_gemm_single_tile<float, MemLayout::K_major,
MemLayout::MN_major,
B_ROW, HEADDIM, B_COL,
/*load_accum=*/true,
/*write_to_smem=*/true>(
smem_P, smem_V, smem_O /*load accum*/, smem_O,
tid_in_threadblock, threads_per_threadblock, threadblocks_per_cluster,
threadblock_id_in_cluster);
// FIXME: wrong but fast
// thread_block_gemm_single_tile<float, MemLayout::MN_major,
// MemLayout::MN_major,
// B_ROW, HEADDIM, B_COL,
// /*load_accum=*/true,
// /*write_to_smem=*/true>(
// smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_threadblock,
// threads_per_threadblock, threadblocks_per_cluster,
// threadblock_id_in_cluster);
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);