diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 54cc749a..00a5323a 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -17,6 +17,8 @@ constexpr uint32_t ROWMAX_SETS = 3; constexpr bool DEBUG = true; constexpr bool DOUBLE_BUF = true; +constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000; + // temporary safety stop for wrong configs static_assert(NUM_CORES == 4); static_assert(NUM_THREADS == 8); @@ -154,8 +156,8 @@ inline float exponential_taylor_term(const float x) { } __attribute__((always_inline)) inline void thread_block_online_softmax( - const float *smem_S, float *smem_O, float *smem_P, - const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock, + const float *smem_S, float *smem_P, const uint32_t tid_in_threadblock, + const uint32_t threads_per_threadblock, const uint32_t threadblock_id_in_cluster, float *smem_scratchpad, float *smem_rowmax, float *smem_rowsum, float *smem_O_row_scale) { asm volatile("thread_block_online_softmax_start_%=:" ::); @@ -466,7 +468,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // FIXME: headdim not considered constexpr uint32_t threads_per_threadblock_theoretical = - (B_ROW * B_COL) / (ELEM_PER_THREAD) / (DOUBLE_BUF ? 2 : 1); + (B_ROW * B_COL) / (ELEM_PER_THREAD); constexpr uint32_t hw_threads_per_cluster = CORES_PER_CLUSTER * NUM_THREADS * NUM_WARPS; // cap maximum threadblock size to # of HW threads in cluster, to prevent @@ -477,14 +479,29 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { : threads_per_threadblock_theoretical; constexpr uint32_t threadblocks_per_cluster = hw_threads_per_cluster / threads_per_threadblock; + constexpr uint32_t warps_per_threadblock_per_core = + NUM_WARPS / threadblocks_per_cluster; - const int threadblock_id = task_id / threads_per_threadblock; - const int threadblock_id_in_cluster = + const uint32_t threadblock_id = task_id / threads_per_threadblock; + const uint32_t threadblock_id_in_cluster = threadblock_id % threadblocks_per_cluster; - const int tid_in_threadblock = task_id % threads_per_threadblock; + const uint32_t tid_in_threadblock = task_id % threads_per_threadblock; + const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; + constexpr uint32_t warps_in_threadblock = + threads_per_threadblock / NUM_THREADS; + + // warpgroup context + constexpr uint32_t threads_per_warpgroup = threads_per_threadblock / 2; + constexpr uint32_t warpgroups_per_cluster = threadblocks_per_cluster * 2; + const uint32_t warps_per_warpgroup_per_core = + NUM_WARPS / warpgroups_per_cluster; + const uint32_t warpgroup_id = task_id / threads_per_warpgroup; + const uint32_t warpgroup_id_in_cluster = + warpgroup_id % warpgroups_per_cluster; + const uint32_t tid_in_warpgroup = tid_in_threadblock % threads_per_warpgroup; // FIXME do proper software pipelining - if (DOUBLE_BUF && threadblock_id != 0) { + if (DOUBLE_BUF && warpgroup_id_in_cluster != 1) { return; } @@ -507,11 +524,12 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { float *smem_K = smem_Q + smem_Q_size; // in-place multiplication of QK into Q float *smem_S = reinterpret_cast(smem_per_threadblock); - float *smem_P = smem_S; // in-place update from S to P - float *smem_V = - reinterpret_cast(smem_per_threadblock) + smem_QK_size; - float *smem_O = reinterpret_cast(smem_per_threadblock) + - smem_QK_size + smem_V_size; + float *smem_P0 = smem_S; // in-place update from S to P + float *smem_P1 = smem_P0 + smem_QK_size; + float *smem_O = smem_P1 + smem_QK_size; + float *smem_V0 = + reinterpret_cast(DEV_FAKE_SMEM_START_ADDR) + smem_QK_size; + float *smem_V1 = smem_V0 + smem_QK_size; // allocate rowmax/rowsum storage at the end of the sharedmem address space constexpr uint32_t smem_rowmax_size = B_ROW * ROWMAX_SETS; @@ -528,13 +546,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // TODO: reduce this from B_ROW to NUM_WARPS constexpr uint32_t smem_scratchpad_size = B_ROW * NUM_THREADS * 2 /*arbitrary slack*/; - float *smem_scratchpad = smem_rowsum - smem_scratchpad_size; - - const uint32_t warps_per_threadblock_per_core = - NUM_WARPS / threadblocks_per_cluster; + float *smem_scratchpad = smem_O_row_scale - smem_scratchpad_size; // initialize rowmax/rowsum values in sharedmem - thread_block_init_sharedmem(tid_in_threadblock, threads_per_threadblock, + thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O, smem_rowmax, smem_rowsum, smem_O_row_scale); @@ -554,212 +569,243 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { float *gmem_tmp_e2 = reinterpret_cast(0xe2000000UL); float *gmem_tmp_e3 = reinterpret_cast(0xe3000000UL); + asm volatile ("tile_loop_start_%=:" :: ); + // "inner loop" along the columns of K^T const uint32_t k_tiles = (dim_seqlen / B_COL); for (uint32_t tile_k = 0; tile_k < k_tiles; tile_k++) { + // float *smem_P_produce = (tile_k % 2) ? smem_P0 : smem_P1; + // float *smem_P_consume = (tile_k % 2) ? smem_P1 : smem_P0; + // float *smem_V_produce = (tile_k % 2) ? smem_V0 : smem_V1; + // float *smem_V_consume = (tile_k % 2) ? smem_V1 : smem_V0; + float *smem_P_produce = smem_P0; + float *smem_P_consume = smem_P0; + float *smem_V_produce = smem_V0; + float *smem_V_consume = smem_V0; - const float *tile_S = nullptr; + // if (warpgroup_id == 0) { + { + // Pipeline stage 1 + // + // skip pipeline drain + // if (tile_k == k_tiles) { + // continue; + // } + const uint32_t tile_k_ = tile_k; + + constexpr bool skip_gemm_qk = true; + if constexpr (!skip_gemm_qk) { + // clear out accumulators + initialize_accum_regs<0>(); + initialize_accum_regs<1>(); + + static_assert(B_ROW == B_COL, "currently only supports square tiles"); + + // load Q + load_tile_to_smem( + dim_seqlen, 0 /*FIXME: only work on first B_ROW rows of Q for now*/, + 0 /* always 0 because dim_k == headdim */, gmem_Q, smem_Q, + tid_in_warpgroup); + + // load K + load_tile_to_smem( + dim_seqlen, tile_k_, 0 /* always 0 because dim_k == headdim */, + gmem_K, smem_K, tid_in_warpgroup); + + // GMEM->SMEM and compute barrier + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); + + // GEMM I: S = Q*K + thread_block_gemm_single_tile( + smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, + tid_in_warpgroup, threads_per_warpgroup, + warpgroups_per_cluster, warpgroup_id_in_cluster); + } else { + // load Q*K + load_tile_to_smem( + dim_seqlen, 0, tile_k_, gmem_Q /*=gmem_S*/, smem_S, + tid_in_warpgroup); + // the above should be equivalent to: + // load_tile_to_smem(dim_seqlen, tile_k_, 0, gmem_Q + // /*=gmem_S*/, + // smem_S, tid_in_warpgroup); + } + + // protect GEMM result writes (smem_S) before softmax + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); + + thread_block_online_softmax(smem_S, smem_P_produce, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster, smem_scratchpad, + smem_rowmax, smem_rowsum, smem_O_row_scale); + + // FIXME unnecessary? + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); + + if constexpr (DEBUG) { + if (tile_k_ == 0) { + thread_block_copy_tile(smem_P_produce, gmem_tmp_d0, + tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); + thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + } else if (tile_k_ == k_tiles - 1) { + thread_block_copy_tile(smem_P_produce, gmem_tmp_d1, + tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); + thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + } + + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); + } + // } else if (warpgroup_id == 1) { + } + { + // Pipeline stage 2 + // + // skip pipeline start + // if (tile_k == 0) { + // continue; + // } + // const uint32_t tile_k_ = tile_k - 1; + const uint32_t tile_k_ = tile_k; + + // GEMM II: O = O + P*V - constexpr bool skip_gemm_qk = true; - if constexpr (!skip_gemm_qk) { // clear out accumulators initialize_accum_regs<0>(); initialize_accum_regs<1>(); - static_assert(B_ROW == B_COL, "currently only supports square tiles"); - - // load Q - load_tile_to_smem( - dim_seqlen, 0 /*FIXME: only work on first B_ROW rows of Q for now*/, - 0 /* always 0 because dim_k == headdim */, gmem_Q, smem_Q, - tid_in_threadblock); - - // load K + // V dimension is [seqlen, headdim], stored N(headdim)-major load_tile_to_smem( - dim_seqlen, tile_k, 0 /* always 0 because dim_k == headdim */, gmem_K, - smem_K, tid_in_threadblock); + HEADDIM, threads_per_warpgroup>( + HEADDIM, 0 /* 0 because always reads the full N-dimension */, tile_k_, + gmem_V, smem_V_consume, tid_in_warpgroup); - // GMEM->SMEM and compute barrier - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); + // FIXME: should be removable + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); - // GEMM I: S = Q*K - thread_block_gemm_single_tile( - smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_threadblock, - threads_per_threadblock, threadblocks_per_cluster, - threadblock_id_in_cluster); + // Oi rescale + thread_block_O_rescale(smem_O, smem_O /*in-place*/, smem_O_row_scale, + tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); - // tile_S = smem_S; - } else { - // load Q*K - load_tile_to_smem( - dim_seqlen, 0, tile_k, gmem_Q /*=gmem_S*/, smem_S, - tid_in_threadblock); - // the above should be equivalent to: - // load_tile_to_smem(dim_seqlen, tile_k, 0, gmem_Q /*=gmem_S*/, - // smem_S, tid_in_threadblock); + // rescale-to-PV-GEMM barrier + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); - // tile_S = reinterpret_cast(arg->addr_q); - } + if constexpr (DEBUG) { + // O before PV + if (tile_k_ == 0) { + thread_block_copy_tile(smem_O, gmem_tmp_d2, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + } else if (tile_k_ == k_tiles - 1) { + thread_block_copy_tile(smem_O, gmem_tmp_d3, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + } - // protect GEMM result writes (smem_S) before softmax - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); - - thread_block_online_softmax( - smem_S, smem_O, smem_P, tid_in_threadblock, threads_per_threadblock, - threadblock_id_in_cluster, smem_scratchpad, - smem_rowmax, smem_rowsum, smem_O_row_scale); - - // FIXME unnecessary? - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); - - if constexpr (DEBUG) { - if (tile_k == 0) { - thread_block_copy_tile(smem_P, gmem_tmp_d0, tid_in_threadblock, - threads_per_threadblock, - threadblock_id_in_cluster); - thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_threadblock, - threads_per_threadblock, - threadblock_id_in_cluster); - thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_threadblock, - threads_per_threadblock, - threadblock_id_in_cluster); - } else if (tile_k == k_tiles - 1) { - thread_block_copy_tile(smem_P, gmem_tmp_d1, tid_in_threadblock, - threads_per_threadblock, - threadblock_id_in_cluster); - thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_threadblock, - threads_per_threadblock, - threadblock_id_in_cluster); - thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, tid_in_threadblock, - threads_per_threadblock, - threadblock_id_in_cluster); + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); } - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); - } + if constexpr (!DOUBLE_BUF) { + thread_block_gemm_single_tile( + smem_P_consume, smem_V_consume, smem_O /*load accum*/, smem_O, + tid_in_warpgroup, threads_per_warpgroup, + warpgroups_per_cluster, warpgroup_id_in_cluster); - // GEMM II: O = O + P*V + // FIXME: wrong but fast + // thread_block_gemm_single_tile( + // smem_P_consume, smem_V_consume, smem_O /*load accum*/, smem_O, + // tid_in_warpgroup, threads_per_warpgroup, + // warpgroups_per_cluster, warpgroup_id_in_cluster); + } else { + // when warp-specialized, there's only enough warps to do 64x32 tile + // size so we need to do 2 GEMM calls + static_assert(B_ROW / 2 == 32, + "tile size assumption for warp-specialization not met"); - // clear out accumulators - initialize_accum_regs<0>(); - initialize_accum_regs<1>(); + // assumes smem_P is K-major + float *smem_P_half0 = smem_P_consume; + float *smem_P_half1 = smem_P_consume + (B_ROW / 2) * B_COL; + float *smem_O_half0 = smem_O; + float *smem_O_half1 = smem_O + (B_ROW / 2) * HEADDIM; - // V dimension is [seqlen, headdim], stored N(headdim)-major - load_tile_to_smem( - HEADDIM, 0 /* 0 because always reads the full N-dimension */, tile_k, - gmem_V, smem_V, tid_in_threadblock); + // split by rows into 2 chunks + thread_block_gemm_single_tile( + smem_P_half0, smem_V_consume, smem_O_half0 /*load accum*/, + smem_O_half0, tid_in_warpgroup, threads_per_warpgroup, + warpgroups_per_cluster, warpgroup_id_in_cluster); - // FIXME: should be removable - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); - - // Oi rescale - thread_block_O_rescale(smem_O, smem_O /*in-place*/, smem_O_row_scale, - tid_in_threadblock, threads_per_threadblock, - threadblock_id_in_cluster); - - // rescale-to-PV-GEMM barrier - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); - - if constexpr (DEBUG) { - // O before PV - if (tile_k == 0) { - thread_block_copy_tile(smem_O, gmem_tmp_d2, tid_in_threadblock, - threads_per_threadblock, - threadblock_id_in_cluster); - } else if (tile_k == k_tiles - 1) { - thread_block_copy_tile(smem_O, gmem_tmp_d3, tid_in_threadblock, - threads_per_threadblock, - threadblock_id_in_cluster); + thread_block_gemm_single_tile( + smem_P_half1, smem_V_consume, smem_O_half1 /*load accum*/, + smem_O_half1, tid_in_warpgroup, threads_per_warpgroup, + warpgroups_per_cluster, warpgroup_id_in_cluster); } - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); - } + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); - if constexpr (!DOUBLE_BUF) { - thread_block_gemm_single_tile( - smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_threadblock, - threads_per_threadblock, threadblocks_per_cluster, - threadblock_id_in_cluster); + if constexpr (DEBUG) { + // O after PV + if (tile_k_ == 0) { + thread_block_copy_tile(smem_O, gmem_tmp_d4, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + } else if (tile_k_ == k_tiles - 1) { + thread_block_copy_tile(smem_O, gmem_tmp_d5, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + } - // FIXME: wrong but fast - // thread_block_gemm_single_tile( - // smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_threadblock, - // threads_per_threadblock, threadblocks_per_cluster, - // threadblock_id_in_cluster); - } else { - // when warp-specialized, there's only enough warps to do 64x32 tile size - // so we need to do 2 GEMM calls - static_assert(B_ROW / 2 == 32, - "tile size assumption for warp-specialization not met"); - - // assumes smem_P is K-major - float *smem_P0 = smem_P; - float *smem_P1 = smem_P + (B_ROW / 2) * B_COL; - float *smem_O0 = smem_O; - float *smem_O1 = smem_O + (B_ROW / 2) * HEADDIM; - - // split by rows into 2 chunks - thread_block_gemm_single_tile( - smem_P0, smem_V, smem_O0 /*load accum*/, smem_O0, tid_in_threadblock, - threads_per_threadblock, threadblocks_per_cluster, - threadblock_id_in_cluster); - - thread_block_gemm_single_tile( - smem_P1, smem_V, smem_O1 /*load accum*/, smem_O1, tid_in_threadblock, - threads_per_threadblock, threadblocks_per_cluster, - threadblock_id_in_cluster); - } - - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); - - if constexpr (DEBUG) { - // O after PV - if (tile_k == 0) { - thread_block_copy_tile(smem_O, gmem_tmp_d4, tid_in_threadblock, - threads_per_threadblock, - threadblock_id_in_cluster); - } else if (tile_k == k_tiles - 1) { - thread_block_copy_tile(smem_O, gmem_tmp_d5, tid_in_threadblock, - threads_per_threadblock, - threadblock_id_in_cluster); + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); } - - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); } } + + asm volatile ("tile_loop_finish_%=:" :: ); } int main() {