diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index d9ad63f9..61c73cc5 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -623,22 +623,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { static_assert(B_ROW == B_COL, "currently only supports square tiles"); // load Q; this stays in SMEM for the entire loop - if constexpr (!WARP_SPECIALIZED) { - load_tile_to_smem( - dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q, - tid_in_warpgroup); - } else { - // FIXME: transpose to K-major in SMEM for correctness - // load_tile_to_smem( - // HEADDIM, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q, - // tid_in_warpgroup); - load_tile_to_smem( - dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q, - tid_in_warpgroup); - } + load_tile_to_smem( + dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q, + tid_in_warpgroup); // load K load_tile_to_smem( @@ -713,7 +701,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<1>(); thread_block_gemm_single_tile< - float, MemLayout::MN_major /*FIXME*/, MemLayout::MN_major, B_ROW / 2, + float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2, B_COL, HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0, /*load_accum=*/false, /*write_to_smem=*/true>(