From d3de1b674aa8585a16bfbf3d27ee4a253dc043e6 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Thu, 15 Aug 2024 22:09:13 -0700 Subject: [PATCH] flash: Compute exponents using prev/next/this rowmax values maybe there is a better way than storing all three in sharedmem? --- tests/regression/flash_attention/kernel.cpp | 44 +++++++++++++++------ 1 file changed, 32 insertions(+), 12 deletions(-) diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index dd144277..6b3bb4fb 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -31,14 +31,17 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock, constexpr uint32_t num_warps = B_ROW / NUM_THREADS; if (warp_id < num_warps) { uint32_t offset = NUM_THREADS * warp_id + tid_in_warp; + // mi, mi~, minew smem_rowmax[offset] = FLT_MIN; + smem_rowmax[offset + B_ROW] = FLT_MIN; + smem_rowmax[offset + 2 * B_ROW] = FLT_MIN; smem_rowsum[offset] = 0.0f; } + // FIXME: dedup this pattern for (int warp_offset = 0; warp_offset < B_COL; warp_offset += warps_in_threadblock) { // each warp clears out a row of smem_O - // FIXME: dedup this pattern const uint32_t row = warp_offset + warp_id; uint32_t thread_offset = HEADDIM * row + tid_in_warp; constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS; @@ -79,6 +82,10 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O, volatile float *gmem_tmp1 = reinterpret_cast(0xe0000000UL); volatile float *gmem_tmp2 = reinterpret_cast(0xf0000000UL); + float *smem_rowmax_prev = smem_rowmax; + float *smem_rowmax_new = smem_rowmax + B_ROW; + float *smem_rowmax_this = smem_rowmax + 2 * B_ROW; + for (int warp_offset = 0; warp_offset < B_ROW; warp_offset += warps_in_threadblock) { const uint32_t row = warp_offset + warp_id; @@ -136,15 +143,15 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O, : "=f"(rowmax) : "f"(rowmax), "f"(other)); } + smem_rowmax_this[row] = rowmax; // update previous rowmax // i.e. mi_new = max(mi, mij) - float prev_rowmax = smem_rowmax[row]; + float prev_rowmax = smem_rowmax_prev[row]; asm volatile("fmax.s %0, %1, %2" : "=f"(rowmax) : "f"(rowmax), "f"(prev_rowmax)); - - smem_rowmax[row] = rowmax; + smem_rowmax_new[row] = rowmax; gmem_tmp0[row] = rowmax; } #endif @@ -160,7 +167,7 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O, // (exp_elem_per_thread * threads_per_threadblock) / B_COL; // broadcast rowmax to all threads in the warp - const float row_max = smem_rowmax[row]; + const float rowmax_new = smem_rowmax_new[row]; // each thread computes two fp32 elements, downconverts it to fp16, then // packs them into one fp32 @@ -177,8 +184,8 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O, // float f1 = S[thread_offset + 1]; // FIXME: placeholder for proper exp - f0 -= row_max; - // f1 -= row_max; + f0 -= rowmax_new; + // f1 -= rowmax_new; // float16_t h0 = NN_float_to_half(f0); // float16_t h1 = NN_float_to_half(f1); @@ -217,13 +224,22 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O, // 0-th thread collects all other thread's values in the warp if (tid_in_warp == 0) { + float rowsum = per_thread_sum; for (int iter = 1; iter < NUM_THREADS; iter++) { float other = warp_smem[iter]; - per_thread_sum += other; + rowsum += other; } - // TODO: update previous rowsum here - smem_rowsum[row] = per_thread_sum; + const float mi_prev = smem_rowmax_prev[row]; + const float mi_this = smem_rowmax_this[row]; + const float mi_new = smem_rowmax_new[row]; + const float exp = mi_prev - mi_this; + + // update rowsum + const float rowsum_prev = smem_rowsum[row]; + // FIXME: placeholder for exponential + float rowsum_new = exp * rowsum_prev + rowsum; + smem_rowsum[row] = rowsum_new; } threadblock_barrier(threadblock_id_in_cluster, @@ -236,8 +252,12 @@ inline void thread_block_flashattn(float *smem_S, float *smem_O, for (int i = 0; i < per_row_iter; i++) { float fval = smem_O[thread_offset]; + const float mi_prev = smem_rowmax_prev[row]; + const float mi_new = smem_rowmax_new[row]; + const float exp = mi_prev - mi_new; + // FIXME: placeholder for proper exp - fval *= 2.0f; + fval *= exp; // update Oi in-place smem_O[thread_offset] = fval; @@ -300,7 +320,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { sizeof(float) * (smem_QK_size + smem_V_size); // allocate rowmax/rowsum storage at the end of the sharedmem address space - constexpr uint32_t smem_rowmax_size = sizeof(float) * B_ROW; + constexpr uint32_t smem_rowmax_size = sizeof(float) * B_ROW * 3 /* mi, mi~, minew */; constexpr uint32_t smem_rowsum_size = sizeof(float) * B_ROW; uint8_t *smem_rowmax = reinterpret_cast(SMEM_ADDR_END) - smem_rowmax_size;