flash: Specify leading_dim for split QK GEMM; fix uninit'd RF before GEMM
This commit is contained in:
@@ -15,7 +15,7 @@
|
||||
|
||||
constexpr uint32_t ROWMAX_SETS = 3;
|
||||
constexpr bool DEBUG = true;
|
||||
constexpr bool WARP_SPECIALIZED = false;
|
||||
constexpr bool WARP_SPECIALIZED = true;
|
||||
|
||||
constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000;
|
||||
|
||||
@@ -630,7 +630,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
tid_in_warpgroup);
|
||||
} else {
|
||||
// FIXME: transpose to K-major in SMEM for correctness
|
||||
load_tile_to_smem<float, MemLayout::K_major, MemLayout::K_major, B_ROW,
|
||||
// load_tile_to_smem<float, MemLayout::K_major, MemLayout::K_major, B_ROW,
|
||||
// HEADDIM, threads_per_warpgroup>(
|
||||
// HEADDIM, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q,
|
||||
// tid_in_warpgroup);
|
||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_ROW,
|
||||
HEADDIM, threads_per_warpgroup>(
|
||||
dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q,
|
||||
tid_in_warpgroup);
|
||||
@@ -669,11 +673,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
initialize_accum_regs<0>();
|
||||
initialize_accum_regs<1>();
|
||||
|
||||
thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
||||
MemLayout::MN_major, B_ROW, B_COL,
|
||||
HEADDIM,
|
||||
/*load_accum=*/false,
|
||||
/*write_to_smem=*/true>(
|
||||
thread_block_gemm_single_tile<
|
||||
float, MemLayout::MN_major, MemLayout::MN_major, B_ROW, B_COL,
|
||||
HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||
/*load_accum=*/false,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_warpgroup,
|
||||
threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
@@ -686,7 +690,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// assumes smem_Q is K-major
|
||||
// FIXME: fix this to MN-major
|
||||
float *smem_Q_half0 = smem_Q;
|
||||
float *smem_Q_half1 = smem_Q + (B_ROW / 2) * HEADDIM;
|
||||
float *smem_Q_half1 = smem_Q + (B_ROW / 2); // MN-major
|
||||
// float *smem_Q_half1 = smem_Q + (B_ROW / 2) * HEADDIM; // K-major
|
||||
float *smem_S_half0 = smem_S;
|
||||
float *smem_S_half1 = smem_S + (B_ROW / 2) * B_COL;
|
||||
|
||||
@@ -695,19 +700,23 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
initialize_accum_regs<1>();
|
||||
|
||||
// split by rows into 2 chunks
|
||||
thread_block_gemm_single_tile<float, MemLayout::K_major /*FIXME*/,
|
||||
MemLayout::MN_major, B_ROW / 2, B_COL,
|
||||
HEADDIM,
|
||||
/*load_accum=*/false,
|
||||
/*write_to_smem=*/true>(
|
||||
thread_block_gemm_single_tile<
|
||||
float, MemLayout::MN_major /*FIXME*/, MemLayout::MN_major, B_ROW / 2,
|
||||
B_COL, HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0,
|
||||
/*load_accum=*/false,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
thread_block_gemm_single_tile<float, MemLayout::K_major /*FIXME*/,
|
||||
MemLayout::MN_major, B_ROW / 2, B_COL,
|
||||
HEADDIM,
|
||||
/*load_accum=*/false,
|
||||
/*write_to_smem=*/true>(
|
||||
|
||||
initialize_accum_regs<0>();
|
||||
initialize_accum_regs<1>();
|
||||
|
||||
thread_block_gemm_single_tile<
|
||||
float, MemLayout::MN_major /*FIXME*/, MemLayout::MN_major, B_ROW / 2,
|
||||
B_COL, HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0,
|
||||
/*load_accum=*/false,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
@@ -837,16 +846,18 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
|
||||
thread_block_gemm_single_tile<float, MemLayout::K_major,
|
||||
MemLayout::MN_major, B_ROW, HEADDIM, B_COL,
|
||||
/*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||
/*load_accum=*/true,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_P, smem_V, smem_O /*load accum*/, smem_O,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_warpgroup,
|
||||
threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
|
||||
// FIXME: wrong but fast
|
||||
// thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
||||
// MemLayout::MN_major,
|
||||
// B_ROW, HEADDIM, B_COL,
|
||||
// /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||
// /*load_accum=*/true,
|
||||
// /*write_to_smem=*/true>(
|
||||
// smem_P, smem_V, smem_O /*load accum*/, smem_O,
|
||||
@@ -869,23 +880,26 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
initialize_accum_regs<1>();
|
||||
|
||||
// split by rows into 2 chunks
|
||||
thread_block_gemm_single_tile<float, MemLayout::K_major,
|
||||
MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
||||
B_COL,
|
||||
/*load_accum=*/true,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_P_half0, smem_V, smem_O_half0 /*load accum*/,
|
||||
smem_O_half0, tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroups_per_cluster, warpgroup_id_in_cluster);
|
||||
thread_block_gemm_single_tile<
|
||||
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
||||
B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||
/*load_accum=*/true,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
|
||||
thread_block_gemm_single_tile<float, MemLayout::K_major,
|
||||
MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
||||
B_COL,
|
||||
/*load_accum=*/true,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_P_half1, smem_V, smem_O_half1 /*load accum*/,
|
||||
smem_O_half1, tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroups_per_cluster, warpgroup_id_in_cluster);
|
||||
initialize_accum_regs<0>();
|
||||
initialize_accum_regs<1>();
|
||||
|
||||
thread_block_gemm_single_tile<
|
||||
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
||||
B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||
/*load_accum=*/true,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
}
|
||||
|
||||
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||
|
||||
Reference in New Issue
Block a user