flash: Cleanup debug code
This commit is contained in:
@@ -623,22 +623,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
static_assert(B_ROW == B_COL, "currently only supports square tiles");
|
||||
|
||||
// load Q; this stays in SMEM for the entire loop
|
||||
if constexpr (!WARP_SPECIALIZED) {
|
||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_ROW,
|
||||
HEADDIM, threads_per_warpgroup>(
|
||||
dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q,
|
||||
tid_in_warpgroup);
|
||||
} else {
|
||||
// FIXME: transpose to K-major in SMEM for correctness
|
||||
// load_tile_to_smem<float, MemLayout::K_major, MemLayout::K_major, B_ROW,
|
||||
// HEADDIM, threads_per_warpgroup>(
|
||||
// HEADDIM, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q,
|
||||
// tid_in_warpgroup);
|
||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_ROW,
|
||||
HEADDIM, threads_per_warpgroup>(
|
||||
dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q,
|
||||
tid_in_warpgroup);
|
||||
}
|
||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_ROW,
|
||||
HEADDIM, threads_per_warpgroup>(
|
||||
dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q,
|
||||
tid_in_warpgroup);
|
||||
|
||||
// load K
|
||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
|
||||
@@ -701,7 +689,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
|
||||
// split by rows into 2 chunks
|
||||
thread_block_gemm_single_tile<
|
||||
float, MemLayout::MN_major /*FIXME*/, MemLayout::MN_major, B_ROW / 2,
|
||||
float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2,
|
||||
B_COL, HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0,
|
||||
/*load_accum=*/false,
|
||||
/*write_to_smem=*/true>(
|
||||
@@ -713,7 +701,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
initialize_accum_regs<1>();
|
||||
|
||||
thread_block_gemm_single_tile<
|
||||
float, MemLayout::MN_major /*FIXME*/, MemLayout::MN_major, B_ROW / 2,
|
||||
float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2,
|
||||
B_COL, HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0,
|
||||
/*load_accum=*/false,
|
||||
/*write_to_smem=*/true>(
|
||||
|
||||
Reference in New Issue
Block a user