diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index ffd3e495..dd144277 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -10,35 +10,53 @@ #define B_ROW BM #define B_COL BN +// FIXME +#define HEADDIM B_COL inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock, - float *sharedmem_scratchpad, - float *sharedmem_rowmax, - float *sharedmem_rowsum) { + float *smem_O, + float *smem_rowmax, + float *smem_rowsum) { const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; + const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS; static_assert((B_ROW % NUM_THREADS) == 0, "B_ROW must be a multiple of NUM_THREADS"); // FIXME: this shouldn't be necessary static_assert(B_ROW < (NUM_THREADS * CORES_PER_CLUSTER * NUM_WARPS), - "Not enough warps to initialize rowmax/rowsum"); + "not enough warps to initialize rowmax/rowsum"); constexpr uint32_t num_warps = B_ROW / NUM_THREADS; if (warp_id < num_warps) { uint32_t offset = NUM_THREADS * warp_id + tid_in_warp; - sharedmem_rowmax[offset] = FLT_MIN; - sharedmem_rowsum[offset] = 0.0f; + smem_rowmax[offset] = FLT_MIN; + smem_rowsum[offset] = 0.0f; + } + + for (int warp_offset = 0; warp_offset < B_COL; + warp_offset += warps_in_threadblock) { + // each warp clears out a row of smem_O + // FIXME: dedup this pattern + const uint32_t row = warp_offset + warp_id; + uint32_t thread_offset = HEADDIM * row + tid_in_warp; + constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS; +#pragma GCC unroll + for (int i = 0; i < per_row_iter; i++) { + smem_O[thread_offset] = 0.0f; + thread_offset += NUM_THREADS; + } } } -inline void thread_block_flashattn(float *S, const uint32_t tid_in_threadblock, +inline void thread_block_flashattn(float *smem_S, float *smem_O, + const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock, const uint32_t threadblock_id_in_cluster, - float *sharedmem_scratchpad, - float *sharedmem_rowmax, - float *sharedmem_rowsum) { + float *smem_scratchpad, + float *smem_rowmax, + float *smem_rowsum) { asm volatile("thread_block_flashattn_start_%=:" ::); const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; @@ -82,7 +100,7 @@ inline void thread_block_flashattn(float *S, const uint32_t tid_in_threadblock, : "=f"(max) : "f"(max), "f"(S[first_thread_offset + i])); } - sharedmem_rowmax[row] = max; + smem_rowmax[row] = max; gmem_tmp0[row] = max; } @@ -94,7 +112,7 @@ inline void thread_block_flashattn(float *S, const uint32_t tid_in_threadblock, float per_thread_max = FLT_MIN; #pragma GCC unroll for (int i = 0; i < per_row_iter; i++) { - const float next = S[thread_offset]; + const float next = smem_S[thread_offset]; asm volatile("fmax.s %0, %1, %2" : "=f"(per_thread_max) : "f"(per_thread_max), "f"(next)); @@ -102,7 +120,7 @@ inline void thread_block_flashattn(float *S, const uint32_t tid_in_threadblock, } // stage per-thread max value in smem // FIXME: threadblock_id needs to be in here too - float *warp_smem = sharedmem_scratchpad + (warp_id * NUM_THREADS); + float *warp_smem = smem_scratchpad + (warp_id * NUM_THREADS); warp_smem[tid_in_warp] = per_thread_max; // sync writes to warp_smem @@ -121,12 +139,12 @@ inline void thread_block_flashattn(float *S, const uint32_t tid_in_threadblock, // update previous rowmax // i.e. mi_new = max(mi, mij) - float prev_rowmax = sharedmem_rowmax[row]; + float prev_rowmax = smem_rowmax[row]; asm volatile("fmax.s %0, %1, %2" : "=f"(rowmax) : "f"(rowmax), "f"(prev_rowmax)); - sharedmem_rowmax[row] = rowmax; + smem_rowmax[row] = rowmax; gmem_tmp0[row] = rowmax; } #endif @@ -142,7 +160,7 @@ inline void thread_block_flashattn(float *S, const uint32_t tid_in_threadblock, // (exp_elem_per_thread * threads_per_threadblock) / B_COL; // broadcast rowmax to all threads in the warp - const float row_max = sharedmem_rowmax[row]; + const float row_max = smem_rowmax[row]; // each thread computes two fp32 elements, downconverts it to fp16, then // packs them into one fp32 @@ -155,7 +173,7 @@ inline void thread_block_flashattn(float *S, const uint32_t tid_in_threadblock, B_COL / (elem_per_thread * NUM_THREADS); #pragma GCC unroll for (int i = 0; i < exp_per_row_iter; i++) { - float f0 = S[thread_offset]; + float f0 = smem_S[thread_offset]; // float f1 = S[thread_offset + 1]; // FIXME: placeholder for proper exp @@ -167,7 +185,7 @@ inline void thread_block_flashattn(float *S, const uint32_t tid_in_threadblock, // Store S transposed to the shared memory // update S in-place into P - S[thread_offset] = f0; + smem_S[thread_offset] = f0; // S[thread_offset + 1] = f1; gmem_tmp1[thread_offset] = f0; @@ -185,12 +203,12 @@ inline void thread_block_flashattn(float *S, const uint32_t tid_in_threadblock, float per_thread_sum = 0.0f; #pragma GCC unroll for (int i = 0; i < per_row_iter; i++) { - per_thread_sum += S[thread_offset]; + per_thread_sum += smem_S[thread_offset]; thread_offset += NUM_THREADS; } // stage per-thread sum value in smem // FIXME: threadblock_id needs to be in here too - warp_smem = sharedmem_scratchpad + (warp_id * NUM_THREADS); + warp_smem = smem_scratchpad + (warp_id * NUM_THREADS); warp_smem[tid_in_warp] = per_thread_sum; // sync writes to warp_smem @@ -205,12 +223,32 @@ inline void thread_block_flashattn(float *S, const uint32_t tid_in_threadblock, } // TODO: update previous rowsum here - sharedmem_rowsum[row] = per_thread_sum; - gmem_tmp2[row] = per_thread_sum; + smem_rowsum[row] = per_thread_sum; } threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core); + + // Oi rescale + // + thread_offset = first_thread_offset + tid_in_warp; +#pragma GCC unroll + for (int i = 0; i < per_row_iter; i++) { + float fval = smem_O[thread_offset]; + + // FIXME: placeholder for proper exp + fval *= 2.0f; + + // update Oi in-place + smem_O[thread_offset] = fval; + gmem_tmp2[thread_offset] = fval; + + thread_offset += NUM_THREADS; + } + + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } asm volatile("thread_block_flashattn_finish_%=:" ::); @@ -226,7 +264,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { constexpr uint32_t cores_per_cluster = 1; #endif - uint32_t threads_per_threadblock = (BM * BN) / (ELEM_PER_THREAD); + // FIXME: headdim not considered + uint32_t threads_per_threadblock = (B_ROW * B_COL) / (ELEM_PER_THREAD); const uint32_t hw_threads_per_cluster = cores_per_cluster * vx_num_threads() * vx_num_warps(); // cap maximum threadblock size to # of HW threads in cluster, to prevent @@ -245,36 +284,44 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const uint32_t dim_seqlen = arg->dim_seqlen; const uint32_t dim_headdim = arg->dim_headdim; - // "static" shared memory allocation. This would determine threadblock - // occupancy of a single cluster - uint8_t *sharedmem_per_threadblock = reinterpret_cast( - DEV_SMEM_START_ADDR + sizeof(float_type) * 2 /*overkill for non-dma*/ * - (2 * BM * BK) * threadblock_id_in_cluster); + // "static" shared memory allocation. This would determine maximum + // threadblock occupancy in a cluster + const uint32_t smem_QK_size = B_ROW * B_COL; + const uint32_t smem_V_size = B_COL * HEADDIM; + const uint32_t smem_O_size = B_COL * HEADDIM; + uint8_t *smem_per_threadblock = reinterpret_cast( + DEV_SMEM_START_ADDR + + sizeof(float_type) * + (smem_QK_size + smem_V_size + smem_O_size) * + threadblock_id_in_cluster); - uint8_t *smem_S = sharedmem_per_threadblock; - constexpr uint32_t sharedmem_rowmax_size = sizeof(float) * B_ROW; - constexpr uint32_t sharedmem_rowsum_size = sizeof(float) * B_ROW; - // sharedmem area to store rowmax/rowsum values in softmax - uint8_t *sharedmem_rowmax = - reinterpret_cast(SMEM_ADDR_END) - sharedmem_rowmax_size; - uint8_t *sharedmem_rowsum = sharedmem_rowmax - sharedmem_rowsum_size; + uint8_t *smem_S = smem_per_threadblock; + uint8_t *smem_O = smem_per_threadblock + + sizeof(float) * (smem_QK_size + smem_V_size); + + // allocate rowmax/rowsum storage at the end of the sharedmem address space + constexpr uint32_t smem_rowmax_size = sizeof(float) * B_ROW; + constexpr uint32_t smem_rowsum_size = sizeof(float) * B_ROW; + uint8_t *smem_rowmax = + reinterpret_cast(SMEM_ADDR_END) - smem_rowmax_size; + uint8_t *smem_rowsum = smem_rowmax - smem_rowsum_size; // sharedmem "scratchpad" area to put temporary data, e.g. for tree reduction // in rowsum // NOTE: out-of bounds is not checked - constexpr uint32_t sharedmem_scratchpad_size = + constexpr uint32_t smem_scratchpad_size = sizeof(float) * B_ROW * NUM_THREADS * 2 /*arbitrary slack*/; - uint8_t *sharedmem_scratchpad = - sharedmem_rowmax - sharedmem_scratchpad_size; + uint8_t *smem_scratchpad = + smem_rowmax - smem_scratchpad_size; const uint32_t warps_per_threadblock_per_core = NUM_WARPS / threads_per_threadblock; // initialize rowmax/rowsum values in sharedmem thread_block_init_sharedmem(tid_in_threadblock, threads_per_threadblock, - (float *)sharedmem_scratchpad, - (float *)sharedmem_rowmax, - (float *)sharedmem_rowsum); + (float *)smem_O, + (float *)smem_rowmax, + (float *)smem_rowsum); #define SKIP_GEMM #ifndef SKIP_GEMM @@ -283,7 +330,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { (float *)smem_S /*write result to SMEM */, arg->dim_m, arg->dim_n, arg->dim_k, tid_in_threadblock, threads_per_threadblock, threadblocks_per_cluster, threadblock_id_in_cluster, - sharedmem_per_threadblock); + smem_per_threadblock); // protect writes of GEMM results before softmax threadblock_barrier(threadblock_id_in_cluster, @@ -294,10 +341,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { float *tile_S = (float *)arg->addr_q; #endif - thread_block_flashattn(tile_S, tid_in_threadblock, + thread_block_flashattn(tile_S, (float *)smem_O, tid_in_threadblock, threads_per_threadblock, threadblock_id_in_cluster, - (float *)sharedmem_scratchpad, - (float *)sharedmem_rowmax, (float *)sharedmem_rowsum); + (float *)smem_scratchpad, (float *)smem_rowmax, + (float *)smem_rowsum); } int main() {