From f844d96eea4385428ae5e1b7d8a29b9f60c9eb25 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Thu, 15 Aug 2024 17:28:36 -0700 Subject: [PATCH] flash: Initialize rowmax/rowsum cache in sharedmem --- tests/regression/flash_attention/kernel.cpp | 72 ++++++++++++++++----- 1 file changed, 57 insertions(+), 15 deletions(-) diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 522dcb0b..918f3607 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -14,12 +14,34 @@ using float_type = float16_t; #define B_ROW BM #define B_COL BN -inline void thread_block_flashattn(float *S, - const uint32_t tid_in_threadblock, +inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock, + const uint32_t threads_per_threadblock, + float *sharedmem_scratchpad, + float *sharedmem_rowmax, + float *sharedmem_rowsum) { + const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; + const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; + + static_assert((B_ROW % NUM_THREADS) == 0, + "B_ROW must be a multiple of NUM_THREADS"); + // FIXME: this shouldn't be necessary + static_assert(B_ROW < (NUM_THREADS * CORES_PER_CLUSTER * NUM_WARPS), + "Not enough warps to initialize rowmax/rowsum"); + + constexpr uint32_t num_warps = B_ROW / NUM_THREADS; + if (warp_id < num_warps) { + uint32_t offset = NUM_THREADS * warp_id + tid_in_warp; + sharedmem_rowmax[offset] = FLT_MIN; + sharedmem_rowsum[offset] = 0.0f; + } +} + +inline void thread_block_flashattn(float *S, const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock, const uint32_t threadblock_id_in_cluster, float *sharedmem_scratchpad, - float *sharedmem_row_max_sum) { + float *sharedmem_rowmax, + float *sharedmem_rowsum) { asm volatile("thread_block_flashattn_start_%=:" ::); const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; @@ -63,7 +85,7 @@ inline void thread_block_flashattn(float *S, : "=f"(max) : "f"(max), "f"(S[first_thread_offset + i])); } - sharedmem_row_max_sum[row] = max; + sharedmem_rowmax[row] = max; gmem_tmp0[row] = max; } @@ -92,14 +114,23 @@ inline void thread_block_flashattn(float *S, // elect 0-th thread to reduce all other thread's values in the warp if (tid_in_warp == 0) { + float rowmax = per_thread_max; for (int iter = 1; iter < NUM_THREADS; iter++) { float other = warp_smem[iter]; asm volatile("fmax.s %0, %1, %2" - : "=f"(per_thread_max) - : "f"(per_thread_max), "f"(other)); + : "=f"(rowmax) + : "f"(rowmax), "f"(other)); } - sharedmem_row_max_sum[row] = per_thread_max; - gmem_tmp0[row] = per_thread_max; + + // update previous rowsum + // i.e. mi_new = max(mi, mij) + float prev_rowmax = sharedmem_rowmax[row]; + asm volatile("fmax.s %0, %1, %2" + : "=f"(rowmax) + : "f"(rowmax), "f"(prev_rowmax)); + + sharedmem_rowmax[row] = rowmax; + gmem_tmp0[row] = rowmax; } #endif @@ -114,7 +145,7 @@ inline void thread_block_flashattn(float *S, // (exp_elem_per_thread * threads_per_threadblock) / B_COL; // broadcast rowmax to all threads in the warp - const float row_max = sharedmem_row_max_sum[row]; + const float row_max = sharedmem_rowmax[row]; thread_offset = first_thread_offset + tid_in_warp; #pragma GCC unroll @@ -160,7 +191,9 @@ inline void thread_block_flashattn(float *S, float other = warp_smem[iter]; per_thread_sum += other; } - sharedmem_row_max_sum[row] = per_thread_sum; + + // TODO: update previous rowsum here + sharedmem_rowsum[row] = per_thread_sum; gmem_tmp2[row] = per_thread_sum; } @@ -212,17 +245,26 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { (2 * BM * BK) * threadblock_id_in_cluster); uint8_t *smem_S = sharedmem_per_threadblock; - constexpr uint32_t sharedmem_row_max_sum_size = 2 * sizeof(float) * B_ROW; + constexpr uint32_t sharedmem_rowmax_size = sizeof(float) * B_ROW; + constexpr uint32_t sharedmem_rowsum_size = sizeof(float) * B_ROW; // sharedmem area to store rowmax/rowsum values in softmax - uint8_t *sharedmem_row_max_sum = - reinterpret_cast(SMEM_ADDR_END) - sharedmem_row_max_sum_size; + uint8_t *sharedmem_rowmax = + reinterpret_cast(SMEM_ADDR_END) - sharedmem_rowmax_size; + uint8_t *sharedmem_rowsum = sharedmem_rowmax - sharedmem_rowsum_size; + // sharedmem "scratchpad" area to put temporary data, e.g. for tree reduction // in rowsum // NOTE: out-of bounds is not checked constexpr uint32_t sharedmem_scratchpad_size = sizeof(float) * B_ROW * NUM_THREADS * 2 /*arbitrary slack*/; uint8_t *sharedmem_scratchpad = - sharedmem_row_max_sum - sharedmem_scratchpad_size; + sharedmem_rowmax - sharedmem_scratchpad_size; + + // initialize rowmax/rowsum values in sharedmem + thread_block_init_sharedmem(tid_in_threadblock, threads_per_threadblock, + (float *)sharedmem_scratchpad, + (float *)sharedmem_rowmax, + (float *)sharedmem_rowsum); // thread_block_gemm( // (const float_type *)arg->addr_a, (const float_type *)arg->addr_b, @@ -240,7 +282,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { thread_block_flashattn((float *)arg->addr_a /* smem_S, */, tid_in_threadblock, threads_per_threadblock, threadblock_id_in_cluster, (float *)sharedmem_scratchpad, - (float *)sharedmem_row_max_sum); + (float *)sharedmem_rowmax, (float *)sharedmem_rowsum); } int main() {