From 03308f8033e2deb75ff686952d2741fbfde6e0cc Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sat, 7 Sep 2024 20:46:58 -0700 Subject: [PATCH] flash: Write fast config for DMA MAC utilization is 20-25% for the loop. --- tests/regression/flash_attention/kernel.cpp | 142 ++++++++++++++------ 1 file changed, 101 insertions(+), 41 deletions(-) diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 13d743ea..15538ba2 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -13,11 +13,12 @@ #define HEADDIM 64 constexpr uint32_t ROWMAX_SETS = 3; -constexpr bool DEBUG = true; +constexpr bool DEBUG = false; constexpr bool WARP_SPECIALIZED = true; constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000; +constexpr bool GEMMINI_DMA_FAST = true; constexpr bool Q_IS_K_MAJOR = true; // temporary safety stop for wrong configs @@ -763,6 +764,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // GEMM I: S = Q*K // // FIXME: deduplicate this between GEMM II + asm volatile("gemm_qk_start_%=:" ::); if constexpr (!WARP_SPECIALIZED) { // clear out accumulators before GEMM initialize_accum_regs<0>(); @@ -815,15 +817,27 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // split by rows into 2 chunks if constexpr (GEMMINI_DMA) { - thread_block_gemm_single_tile( - smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); + if constexpr (GEMMINI_DMA_FAST) { + thread_block_gemm_single_tile( + smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile( + smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } } else if constexpr (Q_IS_K_MAJOR) { thread_block_gemm_single_tile< float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL, @@ -848,15 +862,27 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<1>(); if constexpr (GEMMINI_DMA) { - thread_block_gemm_single_tile( - smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); + if constexpr (GEMMINI_DMA_FAST) { + thread_block_gemm_single_tile( + smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile( + smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } } else if constexpr (Q_IS_K_MAJOR) { thread_block_gemm_single_tile< float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL, @@ -888,6 +914,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // protect write to SMEM (smem_S) before softmax threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + asm volatile("gemm_qk_finish_%=:" ::); + if constexpr (DEBUG) { if (warpgroup_id == 0) { if (tile_k == 0) { @@ -921,6 +949,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // data movement for K and V // // Q stays in SMEM for the entire loop + asm volatile("move_k_v_start_%=:" ::); if constexpr (GEMMINI_DMA) { // NOTE: Beware of race conditions; with warp specialization, we need to // make sure below command code to DMA is not executed simultaneously @@ -965,6 +994,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { HEADDIM, 0 /* full N-dimension */, tile_k, gmem_V, smem_V, tid_in_warpgroup); } + asm volatile("move_k_v_finish_%=:" ::); // protect write to SMEM threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); @@ -995,8 +1025,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // inter-warpgroup barrier before GEMM II threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); - // GEMM II: O = O + P*V - // Oi rescale thread_block_O_rescale(smem_O, smem_O /*in-place*/, smem_O_row_scale, tid_in_warpgroup, @@ -1029,6 +1057,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } } + // GEMM II: O = O + P*V + + asm volatile("gemm_pv_start_%=:" ::); + if constexpr (!WARP_SPECIALIZED) { // clear out accumulators before GEMM initialize_accum_regs<0>(); @@ -1084,16 +1116,29 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // split by rows into 2 chunks if constexpr (GEMMINI_DMA) { - thread_block_gemm_single_tile< - float, MemLayout::K_major /* P matrix is row-major */, - MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL, - /*leading_dim_a=*/0, - /*leading_dim_b=*/0, - /*load_accum=*/true, - /*write_to_smem=*/true>( - smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); + if constexpr (GEMMINI_DMA_FAST) { + thread_block_gemm_single_tile( + smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile< + float, MemLayout::K_major /* P matrix is row-major */, + MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL, + /*leading_dim_a=*/0, + /*leading_dim_b=*/0, + /*load_accum=*/true, + /*write_to_smem=*/true>( + smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } } else { thread_block_gemm_single_tile< float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, @@ -1109,16 +1154,29 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<1>(); if constexpr (GEMMINI_DMA) { - thread_block_gemm_single_tile< - float, MemLayout::K_major /* P matrix is row-major */, - MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL, - /*leading_dim_a=*/0, - /*leading_dim_b=*/0, - /*load_accum=*/true, - /*write_to_smem=*/true>( - smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); + if constexpr (GEMMINI_DMA_FAST) { + thread_block_gemm_single_tile( + smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile< + float, MemLayout::K_major /* P matrix is row-major */, + MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL, + /*leading_dim_a=*/0, + /*leading_dim_b=*/0, + /*load_accum=*/true, + /*write_to_smem=*/true>( + smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } } else { thread_block_gemm_single_tile< float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, @@ -1133,6 +1191,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + asm volatile("gemm_pv_finish_%=:" ::); + if constexpr (DEBUG) { if (warpgroup_id == 0) { // O after PV