flash: Specify leading_dim for split QK GEMM; fix uninit'd RF before GEMM

This commit is contained in:
Hansung Kim
2024-09-02 00:15:57 -07:00
parent bdd955836d
commit 8125192846

View File

@@ -15,7 +15,7 @@
constexpr uint32_t ROWMAX_SETS = 3;
constexpr bool DEBUG = true;
constexpr bool WARP_SPECIALIZED = false;
constexpr bool WARP_SPECIALIZED = true;
constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000;
@@ -630,7 +630,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
tid_in_warpgroup);
} else {
// FIXME: transpose to K-major in SMEM for correctness
load_tile_to_smem<float, MemLayout::K_major, MemLayout::K_major, B_ROW,
// load_tile_to_smem<float, MemLayout::K_major, MemLayout::K_major, B_ROW,
// HEADDIM, threads_per_warpgroup>(
// HEADDIM, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q,
// tid_in_warpgroup);
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_ROW,
HEADDIM, threads_per_warpgroup>(
dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q,
tid_in_warpgroup);
@@ -669,11 +673,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
initialize_accum_regs<0>();
initialize_accum_regs<1>();
thread_block_gemm_single_tile<float, MemLayout::MN_major,
MemLayout::MN_major, B_ROW, B_COL,
HEADDIM,
/*load_accum=*/false,
/*write_to_smem=*/true>(
thread_block_gemm_single_tile<
float, MemLayout::MN_major, MemLayout::MN_major, B_ROW, B_COL,
HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
/*load_accum=*/false,
/*write_to_smem=*/true>(
smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_warpgroup,
threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
@@ -686,7 +690,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// assumes smem_Q is K-major
// FIXME: fix this to MN-major
float *smem_Q_half0 = smem_Q;
float *smem_Q_half1 = smem_Q + (B_ROW / 2) * HEADDIM;
float *smem_Q_half1 = smem_Q + (B_ROW / 2); // MN-major
// float *smem_Q_half1 = smem_Q + (B_ROW / 2) * HEADDIM; // K-major
float *smem_S_half0 = smem_S;
float *smem_S_half1 = smem_S + (B_ROW / 2) * B_COL;
@@ -695,19 +700,23 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
initialize_accum_regs<1>();
// split by rows into 2 chunks
thread_block_gemm_single_tile<float, MemLayout::K_major /*FIXME*/,
MemLayout::MN_major, B_ROW / 2, B_COL,
HEADDIM,
/*load_accum=*/false,
/*write_to_smem=*/true>(
thread_block_gemm_single_tile<
float, MemLayout::MN_major /*FIXME*/, MemLayout::MN_major, B_ROW / 2,
B_COL, HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0,
/*load_accum=*/false,
/*write_to_smem=*/true>(
smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
thread_block_gemm_single_tile<float, MemLayout::K_major /*FIXME*/,
MemLayout::MN_major, B_ROW / 2, B_COL,
HEADDIM,
/*load_accum=*/false,
/*write_to_smem=*/true>(
initialize_accum_regs<0>();
initialize_accum_regs<1>();
thread_block_gemm_single_tile<
float, MemLayout::MN_major /*FIXME*/, MemLayout::MN_major, B_ROW / 2,
B_COL, HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0,
/*load_accum=*/false,
/*write_to_smem=*/true>(
smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
@@ -837,16 +846,18 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
thread_block_gemm_single_tile<float, MemLayout::K_major,
MemLayout::MN_major, B_ROW, HEADDIM, B_COL,
/*leading_dim_a=*/0, /*leading_dim_b=*/0,
/*load_accum=*/true,
/*write_to_smem=*/true>(
smem_P, smem_V, smem_O /*load accum*/, smem_O,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_warpgroup,
threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
// FIXME: wrong but fast
// thread_block_gemm_single_tile<float, MemLayout::MN_major,
// MemLayout::MN_major,
// B_ROW, HEADDIM, B_COL,
// /*leading_dim_a=*/0, /*leading_dim_b=*/0,
// /*load_accum=*/true,
// /*write_to_smem=*/true>(
// smem_P, smem_V, smem_O /*load accum*/, smem_O,
@@ -869,23 +880,26 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
initialize_accum_regs<1>();
// split by rows into 2 chunks
thread_block_gemm_single_tile<float, MemLayout::K_major,
MemLayout::MN_major, B_ROW / 2, HEADDIM,
B_COL,
/*load_accum=*/true,
/*write_to_smem=*/true>(
smem_P_half0, smem_V, smem_O_half0 /*load accum*/,
smem_O_half0, tid_in_warpgroup, threads_per_warpgroup,
warpgroups_per_cluster, warpgroup_id_in_cluster);
thread_block_gemm_single_tile<
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
/*load_accum=*/true,
/*write_to_smem=*/true>(
smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
thread_block_gemm_single_tile<float, MemLayout::K_major,
MemLayout::MN_major, B_ROW / 2, HEADDIM,
B_COL,
/*load_accum=*/true,
/*write_to_smem=*/true>(
smem_P_half1, smem_V, smem_O_half1 /*load accum*/,
smem_O_half1, tid_in_warpgroup, threads_per_warpgroup,
warpgroups_per_cluster, warpgroup_id_in_cluster);
initialize_accum_regs<0>();
initialize_accum_regs<1>();
thread_block_gemm_single_tile<
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
/*load_accum=*/true,
/*write_to_smem=*/true>(
smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
}
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);