diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 772d4db1..75b20803 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -13,14 +13,22 @@ // FIXME #define HEADDIM B_COL +constexpr uint32_t ROWMAX_SETS = 3; constexpr bool DEBUG = true; -constexpr bool DOUBLE_BUF = false; +constexpr bool DOUBLE_BUF = true; + +// temporary safety stop for wrong configs +static_assert(NUM_CORES == 4); +static_assert(NUM_THREADS == 8); +static_assert(NUM_WARPS == 8); inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock, float *smem_O, float *smem_rowmax, float *smem_rowsum) { + asm volatile("threadblock_init_sharedmem_start_%=:" ::); + const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS; @@ -36,26 +44,30 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock, constexpr uint32_t needed_warps = B_ROW / NUM_THREADS; if (warp_id < needed_warps /* more warps in HW than needed? */) { uint32_t offset = NUM_THREADS * warp_id + tid_in_warp; - // mi, mi~, minew - smem_rowmax[offset] = FLT_MIN; - smem_rowmax[offset + B_ROW] = FLT_MIN; - smem_rowmax[offset + 2 * B_ROW] = FLT_MIN; +#pragma GCC unroll + for (int i = 0; i < ROWMAX_SETS; i++) { + smem_rowmax[offset + i * ROWMAX_SETS] = FLT_MIN; + } smem_rowsum[offset] = 0.0f; } // each warp clears out a row of smem_O // FIXME: dedup this pattern - for (int warp_offset = 0; warp_offset < B_COL; - warp_offset += warps_in_threadblock) { - const uint32_t row = warp_offset + warp_id; +#pragma GCC unroll 1 + for (int row_offset = 0; row_offset < B_COL; + row_offset += warps_in_threadblock) { + const uint32_t row = row_offset + warp_id; uint32_t thread_offset = HEADDIM * row + tid_in_warp; constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS; + const float one = 0.0f; #pragma GCC unroll for (int i = 0; i < per_row_iter; i++) { smem_O[thread_offset] = 0.0f; thread_offset += NUM_THREADS; } } + + asm volatile("threadblock_init_sharedmem_finish_%=:" ::); } inline void thread_block_copy_rowmax(const float *src, float *dest, @@ -97,9 +109,10 @@ inline void thread_block_copy_tile(const float *src, float *dest, warps_in_threadblock / CORES_PER_CLUSTER; // FIXME: dedup this pattern - for (int warp_offset = 0; warp_offset < B_ROW; - warp_offset += warps_in_threadblock) { - const uint32_t row = warp_offset + warp_id; +#pragma GCC unroll 1 + for (int row_offset = 0; row_offset < B_ROW; + row_offset += warps_in_threadblock) { + const uint32_t row = row_offset + warp_id; const uint32_t first_thread_offset = B_COL * row; constexpr uint32_t per_row_iter = B_COL / NUM_THREADS; @@ -163,6 +176,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( float *smem_rowmax_this = smem_rowmax + B_ROW; +#pragma GCC unroll 1 for (int row_offset = 0; row_offset < B_ROW; row_offset += warps_in_threadblock) { const uint32_t row = row_offset + warp_id; @@ -171,27 +185,46 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( // rowmax // // two-level tree reduction: reduce each row into NUM_THREADS intermediate - // maxes, then reduce it to one global max + // maxes, then reduce it down to one row max // one warp handles one row in tile + constexpr uint32_t per_row_iter = B_COL / NUM_THREADS; + uint32_t thread_offset = first_thread_offset + tid_in_warp; + // FIXME: threadblock_id needs to be in here too + float *warp_smem = smem_scratchpad + (warp_id * NUM_THREADS); + // #define DUMB_ROWMAX #ifdef DUMB_ROWMAX + // FIXME remove + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + + // no tree reduction; a single thread in a warp does serialized max across + // the entire row if (tid_in_warp == 0) { - float max = S[first_thread_offset]; -#pragma GCC unroll + float rowmax = smem_S[first_thread_offset]; +#pragma GCC unroll 16 for (int i = 0; i < B_COL; i++) { asm volatile("fmax.s %0, %1, %2" - : "=f"(max) - : "f"(max), "f"(S[first_thread_offset + i])); + : "=f"(rowmax) + : "f"(rowmax), "f"(smem_S[first_thread_offset + i])); } - smem_rowmax[row] = max; + smem_rowmax_this[row] = rowmax; + + // update previous rowmax + // i.e. mi_new = max(mi, mij) + float prev_rowmax = smem_rowmax[row]; + // stage prev rowmax in scratchpad for warp-wide broadcast + warp_smem[0] = prev_rowmax; + asm volatile("fmax.s %0, %1, %2" + : "=f"(rowmax) + : "f"(rowmax), "f"(prev_rowmax)); + smem_rowmax[row] = rowmax; } #else static_assert((B_COL % NUM_THREADS) == 0, "B_COL must be a multiple of NUM_THREADS"); - constexpr uint32_t per_row_iter = B_COL / NUM_THREADS; - uint32_t thread_offset = first_thread_offset + tid_in_warp; float per_thread_max = FLT_MIN; #pragma GCC unroll for (int i = 0; i < per_row_iter; i++) { @@ -202,8 +235,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( thread_offset += NUM_THREADS; } // stage per-thread max value in smem - // FIXME: threadblock_id needs to be in here too - float *warp_smem = smem_scratchpad + (warp_id * NUM_THREADS); warp_smem[tid_in_warp] = per_thread_max; // sync writes to warp_smem @@ -233,9 +264,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( : "f"(rowmax), "f"(prev_rowmax)); smem_rowmax[row] = rowmax; } - #else - if (warp_id < warps_in_threadblock / NUM_THREADS) { const uint32_t row = row_offset + NUM_THREADS * warp_id + tid_in_warp; float *const thread_smem = smem_scratchpad + (tid_in_warp * NUM_THREADS); @@ -257,8 +286,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( : "f"(rowmax), "f"(prev_rowmax)); smem_rowmax[row] = rowmax; } -#endif - +#endif // PARALLEL_ROWMAX #endif // DUMB_ROWMAX threadblock_barrier(threadblock_id_in_cluster, @@ -404,16 +432,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { #endif // FIXME: headdim not considered - uint32_t threads_per_threadblock = + constexpr uint32_t threads_per_threadblock_theoretical = (B_ROW * B_COL) / (ELEM_PER_THREAD) / (DOUBLE_BUF ? 2 : 1); - const uint32_t hw_threads_per_cluster = - cores_per_cluster * vx_num_threads() * vx_num_warps(); + constexpr uint32_t hw_threads_per_cluster = + CORES_PER_CLUSTER * NUM_THREADS * NUM_WARPS; // cap maximum threadblock size to # of HW threads in cluster, to prevent // multiple "wave" invocations which slows down the kernel - if (threads_per_threadblock > hw_threads_per_cluster) { - threads_per_threadblock = hw_threads_per_cluster; - } - const uint32_t threadblocks_per_cluster = + constexpr uint32_t threads_per_threadblock = + (threads_per_threadblock_theoretical > hw_threads_per_cluster) + ? hw_threads_per_cluster + : threads_per_threadblock_theoretical; + constexpr uint32_t threadblocks_per_cluster = hw_threads_per_cluster / threads_per_threadblock; const int threadblock_id = task_id / threads_per_threadblock; @@ -452,7 +481,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { smem_QK_size + smem_V_size; // allocate rowmax/rowsum storage at the end of the sharedmem address space - constexpr uint32_t smem_rowmax_size = B_ROW * 3 /* mi, mi~, minew */; + constexpr uint32_t smem_rowmax_size = B_ROW * ROWMAX_SETS; constexpr uint32_t smem_rowsum_size = B_ROW; float *smem_rowmax = reinterpret_cast(SMEM_ADDR_END) - smem_rowmax_size; @@ -505,16 +534,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // load Q load_tile_to_smem( + HEADDIM, threads_per_threadblock>( dim_seqlen, 0 /*FIXME: only work on first B_ROW rows of Q for now*/, 0 /* always 0 because dim_k == headdim */, gmem_Q, smem_Q, tid_in_threadblock); // load K load_tile_to_smem(dim_seqlen, tile_k, - 0 /* always 0 because dim_k == headdim */, - gmem_K, smem_K, tid_in_threadblock); + HEADDIM, threads_per_threadblock>( + dim_seqlen, tile_k, 0 /* always 0 because dim_k == headdim */, gmem_K, + smem_K, tid_in_threadblock); // GMEM->SMEM and compute barrier threadblock_barrier(threadblock_id_in_cluster, @@ -533,8 +562,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } else { // load Q*K load_tile_to_smem(dim_seqlen, 0, tile_k, gmem_Q /*=gmem_S*/, - smem_S, tid_in_threadblock); + HEADDIM, threads_per_threadblock>( + dim_seqlen, 0, tile_k, gmem_Q /*=gmem_S*/, smem_S, + tid_in_threadblock); // the above should be equivalent to: // load_tile_to_smem( + HEADDIM, threads_per_threadblock>( HEADDIM, 0 /* 0 because always reads the full N-dimension */, tile_k, gmem_V, smem_V, tid_in_threadblock);