flash: Cleanup debug code

This commit is contained in:
Hansung Kim
2024-09-02 00:40:05 -07:00
parent 8125192846
commit 70273fd00d

View File

@@ -623,22 +623,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
static_assert(B_ROW == B_COL, "currently only supports square tiles");
// load Q; this stays in SMEM for the entire loop
if constexpr (!WARP_SPECIALIZED) {
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_ROW,
HEADDIM, threads_per_warpgroup>(
dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q,
tid_in_warpgroup);
} else {
// FIXME: transpose to K-major in SMEM for correctness
// load_tile_to_smem<float, MemLayout::K_major, MemLayout::K_major, B_ROW,
// HEADDIM, threads_per_warpgroup>(
// HEADDIM, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q,
// tid_in_warpgroup);
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_ROW,
HEADDIM, threads_per_warpgroup>(
dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q,
tid_in_warpgroup);
}
// load K
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
@@ -701,7 +689,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// split by rows into 2 chunks
thread_block_gemm_single_tile<
float, MemLayout::MN_major /*FIXME*/, MemLayout::MN_major, B_ROW / 2,
float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2,
B_COL, HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0,
/*load_accum=*/false,
/*write_to_smem=*/true>(
@@ -713,7 +701,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
initialize_accum_regs<1>();
thread_block_gemm_single_tile<
float, MemLayout::MN_major /*FIXME*/, MemLayout::MN_major, B_ROW / 2,
float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2,
B_COL, HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0,
/*load_accum=*/false,
/*write_to_smem=*/true>(