From 8125192846c3170e790507b834dbbd3b4bcfee1f Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Mon, 2 Sep 2024 00:15:57 -0700 Subject: [PATCH] flash: Specify leading_dim for split QK GEMM; fix uninit'd RF before GEMM --- tests/regression/flash_attention/kernel.cpp | 86 ++++++++++++--------- 1 file changed, 50 insertions(+), 36 deletions(-) diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 12abcd8e..d9ad63f9 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -15,7 +15,7 @@ constexpr uint32_t ROWMAX_SETS = 3; constexpr bool DEBUG = true; -constexpr bool WARP_SPECIALIZED = false; +constexpr bool WARP_SPECIALIZED = true; constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000; @@ -630,7 +630,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { tid_in_warpgroup); } else { // FIXME: transpose to K-major in SMEM for correctness - load_tile_to_smem( + // HEADDIM, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q, + // tid_in_warpgroup); + load_tile_to_smem( dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q, tid_in_warpgroup); @@ -669,11 +673,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<0>(); initialize_accum_regs<1>(); - thread_block_gemm_single_tile( + thread_block_gemm_single_tile< + float, MemLayout::MN_major, MemLayout::MN_major, B_ROW, B_COL, + HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0, + /*load_accum=*/false, + /*write_to_smem=*/true>( smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, warpgroup_id_in_cluster); @@ -686,7 +690,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // assumes smem_Q is K-major // FIXME: fix this to MN-major float *smem_Q_half0 = smem_Q; - float *smem_Q_half1 = smem_Q + (B_ROW / 2) * HEADDIM; + float *smem_Q_half1 = smem_Q + (B_ROW / 2); // MN-major + // float *smem_Q_half1 = smem_Q + (B_ROW / 2) * HEADDIM; // K-major float *smem_S_half0 = smem_S; float *smem_S_half1 = smem_S + (B_ROW / 2) * B_COL; @@ -695,19 +700,23 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<1>(); // split by rows into 2 chunks - thread_block_gemm_single_tile( + thread_block_gemm_single_tile< + float, MemLayout::MN_major /*FIXME*/, MemLayout::MN_major, B_ROW / 2, + B_COL, HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0, + /*load_accum=*/false, + /*write_to_smem=*/true>( smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0, tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, warpgroup_id_in_cluster); - thread_block_gemm_single_tile( + + initialize_accum_regs<0>(); + initialize_accum_regs<1>(); + + thread_block_gemm_single_tile< + float, MemLayout::MN_major /*FIXME*/, MemLayout::MN_major, B_ROW / 2, + B_COL, HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0, + /*load_accum=*/false, + /*write_to_smem=*/true>( smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1, tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, warpgroup_id_in_cluster); @@ -837,16 +846,18 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { thread_block_gemm_single_tile( - smem_P, smem_V, smem_O /*load accum*/, smem_O, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_warpgroup, + threads_per_warpgroup, warpgroups_per_cluster, warpgroup_id_in_cluster); // FIXME: wrong but fast // thread_block_gemm_single_tile( // smem_P, smem_V, smem_O /*load accum*/, smem_O, @@ -869,23 +880,26 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<1>(); // split by rows into 2 chunks - thread_block_gemm_single_tile( - smem_P_half0, smem_V, smem_O_half0 /*load accum*/, - smem_O_half0, tid_in_warpgroup, threads_per_warpgroup, - warpgroups_per_cluster, warpgroup_id_in_cluster); + thread_block_gemm_single_tile< + float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, + B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0, + /*load_accum=*/true, + /*write_to_smem=*/true>( + smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); - thread_block_gemm_single_tile( - smem_P_half1, smem_V, smem_O_half1 /*load accum*/, - smem_O_half1, tid_in_warpgroup, threads_per_warpgroup, - warpgroups_per_cluster, warpgroup_id_in_cluster); + initialize_accum_regs<0>(); + initialize_accum_regs<1>(); + + thread_block_gemm_single_tile< + float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, + B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0, + /*load_accum=*/true, + /*write_to_smem=*/true>( + smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); } threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);