From 615d36a5c29d6d4a894d5c90e69fab8632fd215e Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Tue, 20 Aug 2024 14:34:45 -0700 Subject: [PATCH] flash: Reduce smem use for rowmax; verify result --- tests/regression/flash_attention/kernel.cpp | 201 ++++++++++++++------ 1 file changed, 146 insertions(+), 55 deletions(-) diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index b1086e18..ccd20fbd 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -13,6 +13,8 @@ // FIXME #define HEADDIM B_COL +constexpr bool DEBUG = true; + inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock, float *smem_O, @@ -53,6 +55,66 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock, } } +inline void thread_block_copy_rowmax(const float *src, float *dest, + const uint32_t tid_in_threadblock, + const uint32_t threads_per_threadblock, + const uint32_t threadblocks_per_cluster, + const uint32_t threadblock_id_in_cluster) { + asm volatile("threadblock_copy_rowmax_start_%=:" ::); + + const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; + const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; + const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS; + const uint32_t warps_per_threadblock_per_core = + NUM_WARPS / threadblocks_per_cluster; + + constexpr uint32_t num_warps = B_ROW / NUM_THREADS; + if (warp_id < num_warps) { + uint32_t offset = NUM_THREADS * warp_id + tid_in_warp; + dest[offset] = src[offset]; + } + + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + + asm volatile("threadblock_copy_rowmax_finish_%=:" ::); +} + +inline void thread_block_copy_tile(const float *src, float *dest, + const uint32_t tid_in_threadblock, + const uint32_t threads_per_threadblock, + const uint32_t threadblocks_per_cluster, + const uint32_t threadblock_id_in_cluster) { + asm volatile("threadblock_copy_tile_start_%=:" ::); + + const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; + const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; + const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS; + const uint32_t warps_per_threadblock_per_core = + NUM_WARPS / threadblocks_per_cluster; + + // FIXME: dedup this pattern + for (int warp_offset = 0; warp_offset < B_ROW; + warp_offset += warps_in_threadblock) { + const uint32_t row = warp_offset + warp_id; + const uint32_t first_thread_offset = B_COL * row; + + constexpr uint32_t per_row_iter = B_COL / NUM_THREADS; + uint32_t thread_offset = first_thread_offset + tid_in_warp; + float per_thread_max = FLT_MIN; +#pragma GCC unroll + for (int i = 0; i < per_row_iter; i++) { + dest[thread_offset] = src[thread_offset]; + thread_offset += NUM_THREADS; + } + + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } + + asm volatile("threadblock_copy_tile_finish_%=:" ::); +} + template inline float exponential_taylor_term(const float x) { asm volatile("exponential_taylor_term_start_%=:" ::); @@ -73,38 +135,7 @@ inline float exponential_taylor_term(const float x) { return res; } -inline void thread_block_copy_data(const float *src, float *dest, - const uint32_t tid_in_threadblock, - const uint32_t threads_per_threadblock, - const uint32_t threadblocks_per_cluster, - const uint32_t threadblock_id_in_cluster) { - const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; - const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; - const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS; - const uint32_t warps_per_threadblock_per_core = - NUM_WARPS / threadblocks_per_cluster; - - for (int warp_offset = 0; warp_offset < B_ROW; - warp_offset += warps_in_threadblock) { - const uint32_t row = warp_offset + warp_id; - const uint32_t first_thread_offset = B_COL * row; - - constexpr uint32_t per_row_iter = B_COL / NUM_THREADS; - uint32_t thread_offset = first_thread_offset + tid_in_warp; - float per_thread_max = FLT_MIN; -#pragma GCC unroll - for (int i = 0; i < per_row_iter; i++) { - const float f = src[thread_offset]; - dest[thread_offset] = f; - thread_offset += NUM_THREADS; - } - - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); - } -} - -inline void thread_block_online_softmax( +__attribute__((always_inline)) inline void thread_block_online_softmax( const float *smem_S, float *smem_O, float *smem_P, const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock, const uint32_t threadblocks_per_cluster, @@ -128,9 +159,7 @@ inline void thread_block_online_softmax( // asm volatile("fmv.s %0, f22" : "=f"(ft[6])); // asm volatile("fmv.s %0, f23" : "=f"(ft[7])); - float *smem_rowmax_prev = smem_rowmax; - float *smem_rowmax_new = smem_rowmax + B_ROW; - float *smem_rowmax_this = smem_rowmax + 2 * B_ROW; + float *smem_rowmax_this = smem_rowmax + B_ROW; for (int warp_offset = 0; warp_offset < B_ROW; warp_offset += warps_in_threadblock) { @@ -192,26 +221,34 @@ inline void thread_block_online_softmax( // update previous rowmax // i.e. mi_new = max(mi, mij) - float prev_rowmax = smem_rowmax_prev[row]; + float prev_rowmax = smem_rowmax[row]; + // stage prev rowmax in scratchpad for warp-wide broadcast + warp_smem[0] = prev_rowmax; asm volatile("fmax.s %0, %1, %2" : "=f"(rowmax) : "f"(rowmax), "f"(prev_rowmax)); - smem_rowmax_new[row] = rowmax; + smem_rowmax[row] = rowmax; } + #endif // FIXME: unnecessary? threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core); + // broadcast prev rowmax to all threads in the warp + // NOTE: memory consistency is a little sketchy here + const float rowmax_prev = warp_smem[0]; + const float rowmax_this = smem_rowmax_this[row]; + // exponential // // B_ROW / (B_ROW * B_COL / (exp_elem * threads_per_threadblock)) // const uint32_t row_stride = // (exp_elem_per_thread * threads_per_threadblock) / B_COL; - // broadcast rowmax to all threads in the warp - const float rowmax_new = smem_rowmax_new[row]; + // broadcast updated rowmax to all threads in the warp + const float rowmax_new = smem_rowmax[row]; // each thread computes two fp32 elements, downconverts it to fp16, then // packs them into one fp32 @@ -279,8 +316,9 @@ inline void thread_block_online_softmax( rowsum += other; } - const float mi_prev = smem_rowmax_prev[row]; - const float mi_this = smem_rowmax_this[row]; + const float mi_prev = rowmax_prev; + // TODO: replace this with a register? + const float mi_this = rowmax_this; const float x = mi_prev - mi_this; // 2nd-order Taylor approximation @@ -309,8 +347,8 @@ inline void thread_block_online_softmax( for (int i = 0; i < per_row_iter; i++) { float o = smem_O[thread_offset]; - const float mi_prev = smem_rowmax_prev[row]; - const float mi_new = smem_rowmax_new[row]; + const float mi_prev = rowmax_prev; + const float mi_new = rowmax_new; const float x = mi_prev - mi_new; // 2nd-order Taylor approximation @@ -398,9 +436,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // sharedmem "scratchpad" area to put temporary data, e.g. for tree reduction // in rowsum // NOTE: out-of bounds is not checked + // TODO: reduce this from B_ROW to NUM_WARPS constexpr uint32_t smem_scratchpad_size = B_ROW * NUM_THREADS * 2 /*arbitrary slack*/; - float *smem_scratchpad = smem_rowmax - smem_scratchpad_size; + float *smem_scratchpad = smem_rowsum - smem_scratchpad_size; const uint32_t warps_per_threadblock_per_core = NUM_WARPS / threadblocks_per_cluster; @@ -414,6 +453,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const float *gmem_V = reinterpret_cast(arg->addr_v); float *gmem_O = reinterpret_cast(arg->addr_o); + float *gmem_tmp_d0 = reinterpret_cast(0xd0000000UL); + float *gmem_tmp_d1 = reinterpret_cast(0xd1000000UL); + float *gmem_tmp_d2 = reinterpret_cast(0xd2000000UL); + float *gmem_tmp_d3 = reinterpret_cast(0xd3000000UL); + float *gmem_tmp_d4 = reinterpret_cast(0xd4000000UL); + float *gmem_tmp_d5 = reinterpret_cast(0xd5000000UL); + float *gmem_tmp_e0 = reinterpret_cast(0xe0000000UL); + float *gmem_tmp_e1 = reinterpret_cast(0xe1000000UL); + float *gmem_tmp_e2 = reinterpret_cast(0xe2000000UL); + float *gmem_tmp_e3 = reinterpret_cast(0xe3000000UL); + // "inner loop" along the columns of K^T for (uint32_t tile_k = 0; tile_k < (dim_seqlen / B_COL); tile_k++) { @@ -469,6 +519,43 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core); + if constexpr (DEBUG) { + if (tile_k == 0) { + thread_block_copy_tile( + smem_P, gmem_tmp_d0, tid_in_threadblock, threads_per_threadblock, + threadblocks_per_cluster, threadblock_id_in_cluster); + thread_block_copy_tile( + smem_O, gmem_tmp_d2, tid_in_threadblock, threads_per_threadblock, + threadblocks_per_cluster, threadblock_id_in_cluster); + thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_threadblock, + threads_per_threadblock, + threadblocks_per_cluster, + threadblock_id_in_cluster); + thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_threadblock, + threads_per_threadblock, + threadblocks_per_cluster, + threadblock_id_in_cluster); + } else if (tile_k == 1) { + thread_block_copy_tile( + smem_P, gmem_tmp_d1, tid_in_threadblock, threads_per_threadblock, + threadblocks_per_cluster, threadblock_id_in_cluster); + thread_block_copy_tile( + smem_O, gmem_tmp_d3, tid_in_threadblock, threads_per_threadblock, + threadblocks_per_cluster, threadblock_id_in_cluster); + thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_threadblock, + threads_per_threadblock, + threadblocks_per_cluster, + threadblock_id_in_cluster); + thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, tid_in_threadblock, + threads_per_threadblock, + threadblocks_per_cluster, + threadblock_id_in_cluster); + } + + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } + // GEMM II: O = O + P*V // clear out accumulators @@ -495,18 +582,22 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core); + + if constexpr (DEBUG) { + if (tile_k == 0) { + thread_block_copy_tile( + smem_O, gmem_tmp_d4, tid_in_threadblock, threads_per_threadblock, + threadblocks_per_cluster, threadblock_id_in_cluster); + } else if (tile_k == 1) { + thread_block_copy_tile( + smem_O, gmem_tmp_d5, tid_in_threadblock, threads_per_threadblock, + threadblocks_per_cluster, threadblock_id_in_cluster); + } + + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } } - - float *gmem_tmp0 = reinterpret_cast(0xd0000000UL); - float *gmem_tmp1 = reinterpret_cast(0xe0000000UL); - - // copy out tile data to GMEM for debugging - thread_block_copy_data(smem_P, gmem_tmp0, tid_in_threadblock, - threads_per_threadblock, threadblocks_per_cluster, - threadblock_id_in_cluster); - thread_block_copy_data(smem_O, gmem_tmp1, tid_in_threadblock, - threads_per_threadblock, threadblocks_per_cluster, - threadblock_id_in_cluster); } int main() {