From 3f20dd59c0459979933b2bc54f5652ebc84501c0 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Tue, 20 Aug 2024 19:50:45 -0700 Subject: [PATCH] flash: Supply correct tile dims to single_tile --- tests/regression/flash_attention/kernel.cpp | 34 ++++++++++----------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 64fcf201..52947c73 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -141,7 +141,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( const uint32_t threadblocks_per_cluster, const uint32_t threadblock_id_in_cluster, float *smem_scratchpad, float *smem_rowmax, float *smem_rowsum) { - asm volatile("thread_block_flashattn_start_%=:" ::); + asm volatile("thread_block_online_softmax_start_%=:" ::); const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; @@ -250,20 +250,11 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( // broadcast updated rowmax to all threads in the warp const float rowmax_new = smem_rowmax[row]; - // each thread computes two fp32 elements, downconverts it to fp16, then - // packs them into one fp32 - constexpr uint32_t elem_per_thread = 1; - static_assert((B_COL % (elem_per_thread * NUM_THREADS)) == 0, - "B_COL condition not met for P compute"); - - thread_offset = first_thread_offset + (elem_per_thread * tid_in_warp); - constexpr uint32_t exp_per_row_iter = - B_COL / (elem_per_thread * NUM_THREADS); - asm volatile("flashattn_exp_p_start_%=:" ::); + thread_offset = first_thread_offset + tid_in_warp; #pragma GCC unroll - for (int i = 0; i < exp_per_row_iter; i++) { + for (int i = 0; i < per_row_iter; i++) { float f0 = smem_S[thread_offset]; f0 -= rowmax_new; @@ -292,8 +283,9 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( asm volatile("flashattn_rowsum_start_%=:" ::); - thread_offset = first_thread_offset + tid_in_warp; float per_thread_sum = 0.0f; + + thread_offset = first_thread_offset + tid_in_warp; #pragma GCC unroll for (int i = 0; i < per_row_iter; i++) { per_thread_sum += smem_P[thread_offset]; @@ -317,7 +309,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( } const float mi_prev = rowmax_prev; - // TODO: replace this with a register? const float mi_this = rowmax_this; const float x = mi_prev - mi_this; @@ -371,7 +362,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( warps_per_threadblock_per_core); } - asm volatile("thread_block_flashattn_finish_%=:" ::); + asm volatile("thread_block_online_softmax_finish_%=:" ::); } void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { @@ -497,7 +488,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // GEMM I: S = Q*K thread_block_gemm_single_tile( smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_threadblock, @@ -583,14 +574,23 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core); - // FIXME: support MN_major for A for ideal performance thread_block_gemm_single_tile( smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_threadblock, threads_per_threadblock, threadblocks_per_cluster, threadblock_id_in_cluster); + // FIXME: wrong but fast + // thread_block_gemm_single_tile( + // smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_threadblock, + // threads_per_threadblock, threadblocks_per_cluster, + // threadblock_id_in_cluster); threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core);