diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index a7d21591..dd6cb4f3 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -73,6 +73,37 @@ inline float exponential_taylor_term(const float x) { return res; } +inline void thread_block_copy_data(const float *src, float *dest, + const uint32_t tid_in_threadblock, + const uint32_t threads_per_threadblock, + const uint32_t threadblocks_per_cluster, + const uint32_t threadblock_id_in_cluster) { + const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; + const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; + const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS; + const uint32_t warps_per_threadblock_per_core = + NUM_WARPS / threadblocks_per_cluster; + + for (int warp_offset = 0; warp_offset < B_ROW; + warp_offset += warps_in_threadblock) { + const uint32_t row = warp_offset + warp_id; + const uint32_t first_thread_offset = B_COL * row; + + constexpr uint32_t per_row_iter = B_COL / NUM_THREADS; + uint32_t thread_offset = first_thread_offset + tid_in_warp; + float per_thread_max = FLT_MIN; +#pragma GCC unroll + for (int i = 0; i < per_row_iter; i++) { + const float f = src[thread_offset]; + dest[thread_offset] = f; + thread_offset += NUM_THREADS; + } + + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } +} + inline void thread_block_online_softmax( const float *smem_S, float *smem_O, float *smem_P, const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock, @@ -97,9 +128,6 @@ inline void thread_block_online_softmax( // asm volatile("fmv.s %0, f22" : "=f"(ft[6])); // asm volatile("fmv.s %0, f23" : "=f"(ft[7])); - volatile float *gmem_tmp0 = reinterpret_cast(0xd0000000UL); - volatile float *gmem_tmp1 = reinterpret_cast(0xe0000000UL); - float *smem_rowmax_prev = smem_rowmax; float *smem_rowmax_new = smem_rowmax + B_ROW; float *smem_rowmax_this = smem_rowmax + 2 * B_ROW; @@ -201,9 +229,6 @@ inline void thread_block_online_softmax( for (int i = 0; i < exp_per_row_iter; i++) { float f0 = smem_S[thread_offset]; - // check Q*K result - gmem_tmp0[thread_offset] = f0; - f0 -= rowmax_new; // 2nd-order Taylor approximation @@ -214,7 +239,6 @@ inline void thread_block_online_softmax( // Store S transposed to the shared memory smem_P[thread_offset] = exp; - gmem_tmp1[thread_offset] = exp; thread_offset += NUM_THREADS; } @@ -389,85 +413,97 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const float *gmem_K = reinterpret_cast(arg->addr_k); const float *gmem_V = reinterpret_cast(arg->addr_v); float *gmem_O = reinterpret_cast(arg->addr_o); - float *gmem_tmp0 = reinterpret_cast(0xd0000000UL); + + // "inner loop" along the columns of K^T + for (uint32_t tile_k = 0; tile_k < (dim_seqlen / B_COL); tile_k++) { // #define SKIP_GEMM #ifndef SKIP_GEMM -#if 0 - thread_block_gemm( - (const float_type *)arg->addr_q, (const float_type *)arg->addr_k, - (float *)smem_S /*write result to SMEM */, B_ROW, B_COL, - HEADDIM, tid_in_threadblock, threads_per_threadblock, - threadblocks_per_cluster, threadblock_id_in_cluster, - smem_per_threadblock); + // clear out accumulators + initialize_accum_regs<0>(); + initialize_accum_regs<1>(); + static_assert(B_ROW == B_COL, "currently only supports square tiles"); + + // load Q + load_tile_to_smem( + dim_seqlen, 0 /*FIXME: only work on first B_ROW rows of Q for now*/, + 0 /* always 0 because dim_k == headdim */, gmem_Q, smem_Q, + tid_in_threadblock); + + // load K + load_tile_to_smem(dim_seqlen, tile_k, + 0 /* always 0 because dim_k == headdim */, + gmem_K, smem_K, tid_in_threadblock); + + // GMEM->SMEM and compute barrier + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + + // GEMM I: S = Q*K + thread_block_gemm_single_tile( + smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_threadblock, + threads_per_threadblock, threadblocks_per_cluster, + threadblock_id_in_cluster); + + // protect GEMM result writes (smem_S) before softmax + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + + const float *tile_S = (float *)smem_S; #else - - // clear out accumulators - initialize_accum_regs<0>(); - initialize_accum_regs<1>(); - - // load Q - static_assert(B_ROW == B_COL, "currently only supports square tiles"); - load_tile_to_smem(B_ROW, 0, 0, gmem_Q, smem_Q, tid_in_threadblock); - // load K - load_tile_to_smem(B_COL, 0, 0, gmem_K, smem_K, tid_in_threadblock); - - // GMEM->SMEM and compute barrier - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); - - // GEMM I: S = Q*K - thread_block_gemm_single_tile( - smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_threadblock, - threads_per_threadblock, threadblocks_per_cluster, - threadblock_id_in_cluster); + float *tile_S = (float *)arg->addr_q; #endif - // protect GEMM result writes (smem_S) before softmax - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); + thread_block_online_softmax( + tile_S, smem_O, smem_P, tid_in_threadblock, threads_per_threadblock, + threadblocks_per_cluster, threadblock_id_in_cluster, smem_scratchpad, + smem_rowmax, smem_rowsum); - const float *tile_S = (float *)smem_S; -#else - float *tile_S = (float *)arg->addr_q; -#endif + // FIXME unnecessary? + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); - thread_block_online_softmax(tile_S, smem_O, smem_P, tid_in_threadblock, - threads_per_threadblock, threadblocks_per_cluster, - threadblock_id_in_cluster, smem_scratchpad, - smem_rowmax, smem_rowsum); + // GEMM II: O = O + P*V - // FIXME unnecessary? - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); + // clear out accumulators + initialize_accum_regs<0>(); + initialize_accum_regs<1>(); - // GEMM II: O = O + P*V + load_tile_to_smem( + B_COL, 0, 0, gmem_V, smem_V, tid_in_threadblock); - // clear out accumulators - initialize_accum_regs<0>(); - initialize_accum_regs<1>(); + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); - load_tile_to_smem( - B_COL, 0, 0, gmem_V, smem_V, tid_in_threadblock); + // FIXME: support MN_major for A for ideal performance + thread_block_gemm_single_tile( + smem_P, smem_V, smem_O /*load accum*/, smem_O, + tid_in_threadblock, threads_per_threadblock, threadblocks_per_cluster, + threadblock_id_in_cluster); - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } - // FIXME: support MN_major for A for ideal performance - thread_block_gemm_single_tile( - smem_P, smem_V, smem_O, gmem_O /*smem_O*/, - tid_in_threadblock, threads_per_threadblock, threadblocks_per_cluster, - threadblock_id_in_cluster); + float *gmem_tmp0 = reinterpret_cast(0xd0000000UL); + float *gmem_tmp1 = reinterpret_cast(0xe0000000UL); - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); + // copy out tile data to GMEM for debugging + thread_block_copy_data(smem_P, gmem_tmp0, tid_in_threadblock, + threads_per_threadblock, threadblocks_per_cluster, + threadblock_id_in_cluster); + thread_block_copy_data(smem_O, gmem_tmp1, tid_in_threadblock, + threads_per_threadblock, threadblocks_per_cluster, + threadblock_id_in_cluster); } int main() {