flash: Enable skipping Q*K for larger dimensions

This commit is contained in:
Hansung Kim
2024-08-20 19:15:16 -07:00
parent 526c2bd334
commit dde0372769

View File

@@ -465,53 +465,65 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
float *gmem_tmp_e3 = reinterpret_cast<float *>(0xe3000000UL);
// "inner loop" along the columns of K^T
for (uint32_t tile_k = 0; tile_k < (dim_seqlen / B_COL); tile_k++) {
const uint32_t k_tiles = (dim_seqlen / B_COL);
for (uint32_t tile_k = 0; tile_k < k_tiles; tile_k++) {
// #define SKIP_GEMM
#ifndef SKIP_GEMM
// clear out accumulators
initialize_accum_regs<0>();
initialize_accum_regs<1>();
const float *tile_S = nullptr;
static_assert(B_ROW == B_COL, "currently only supports square tiles");
constexpr bool skip_gemm_qk = true;
if constexpr (!skip_gemm_qk) {
// clear out accumulators
initialize_accum_regs<0>();
initialize_accum_regs<1>();
// load Q
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_ROW,
HEADDIM>(
dim_seqlen, 0 /*FIXME: only work on first B_ROW rows of Q for now*/,
0 /* always 0 because dim_k == headdim */, gmem_Q, smem_Q,
tid_in_threadblock);
static_assert(B_ROW == B_COL, "currently only supports square tiles");
// load K
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
HEADDIM>(dim_seqlen, tile_k,
0 /* always 0 because dim_k == headdim */,
gmem_K, smem_K, tid_in_threadblock);
// load Q
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_ROW,
HEADDIM>(
dim_seqlen, 0 /*FIXME: only work on first B_ROW rows of Q for now*/,
0 /* always 0 because dim_k == headdim */, gmem_Q, smem_Q,
tid_in_threadblock);
// GMEM->SMEM and compute barrier
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
// load K
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
HEADDIM>(dim_seqlen, tile_k,
0 /* always 0 because dim_k == headdim */,
gmem_K, smem_K, tid_in_threadblock);
// GEMM I: S = Q*K
thread_block_gemm_single_tile<float, MemLayout::MN_major,
MemLayout::MN_major,
/*load_accum=*/false,
/*write_to_smem=*/true>(
smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_threadblock,
threads_per_threadblock, threadblocks_per_cluster,
threadblock_id_in_cluster);
// GMEM->SMEM and compute barrier
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
// GEMM I: S = Q*K
thread_block_gemm_single_tile<float, MemLayout::MN_major,
MemLayout::MN_major,
/*load_accum=*/false,
/*write_to_smem=*/true>(
smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_threadblock,
threads_per_threadblock, threadblocks_per_cluster,
threadblock_id_in_cluster);
// tile_S = smem_S;
} else {
// load Q*K
load_tile_to_smem<float, MemLayout::K_major, MemLayout::K_major, B_COL,
HEADDIM>(dim_seqlen, 0, tile_k, gmem_Q /*=gmem_S*/,
smem_S, tid_in_threadblock);
// the above should be equivalent to:
// load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
// HEADDIM>(dim_seqlen, tile_k, 0, gmem_Q /*=gmem_S*/,
// smem_S, tid_in_threadblock);
// tile_S = reinterpret_cast<float *>(arg->addr_q);
}
// protect GEMM result writes (smem_S) before softmax
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
const float *tile_S = (float *)smem_S;
#else
float *tile_S = (float *)arg->addr_q;
#endif
thread_block_online_softmax(
tile_S, smem_O, smem_P, tid_in_threadblock, threads_per_threadblock,
smem_S, smem_O, smem_P, tid_in_threadblock, threads_per_threadblock,
threadblocks_per_cluster, threadblock_id_in_cluster, smem_scratchpad,
smem_rowmax, smem_rowsum);
@@ -535,7 +547,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
threads_per_threadblock,
threadblocks_per_cluster,
threadblock_id_in_cluster);
} else if (tile_k == 1) {
} else if (tile_k == k_tiles - 1) {
thread_block_copy_tile(
smem_P, gmem_tmp_d1, tid_in_threadblock, threads_per_threadblock,
threadblocks_per_cluster, threadblock_id_in_cluster);
@@ -565,8 +577,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// V dimension is [seqlen, headdim], stored N(headdim)-major
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
HEADDIM>(
HEADDIM, 0 /* 0 because always reads the full N-dimension */,
tile_k * B_COL, gmem_V, smem_V, tid_in_threadblock);
HEADDIM, 0 /* 0 because always reads the full N-dimension */, tile_k,
gmem_V, smem_V, tid_in_threadblock);
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
@@ -588,7 +600,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
thread_block_copy_tile(
smem_O, gmem_tmp_d4, tid_in_threadblock, threads_per_threadblock,
threadblocks_per_cluster, threadblock_id_in_cluster);
} else if (tile_k == 1) {
} else if (tile_k == k_tiles - 1) {
thread_block_copy_tile(
smem_O, gmem_tmp_d5, tid_in_threadblock, threads_per_threadblock,
threadblocks_per_cluster, threadblock_id_in_cluster);