flash: Enable skipping Q*K for larger dimensions
This commit is contained in:
@@ -465,53 +465,65 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
float *gmem_tmp_e3 = reinterpret_cast<float *>(0xe3000000UL);
|
||||
|
||||
// "inner loop" along the columns of K^T
|
||||
for (uint32_t tile_k = 0; tile_k < (dim_seqlen / B_COL); tile_k++) {
|
||||
const uint32_t k_tiles = (dim_seqlen / B_COL);
|
||||
for (uint32_t tile_k = 0; tile_k < k_tiles; tile_k++) {
|
||||
|
||||
// #define SKIP_GEMM
|
||||
#ifndef SKIP_GEMM
|
||||
// clear out accumulators
|
||||
initialize_accum_regs<0>();
|
||||
initialize_accum_regs<1>();
|
||||
const float *tile_S = nullptr;
|
||||
|
||||
static_assert(B_ROW == B_COL, "currently only supports square tiles");
|
||||
constexpr bool skip_gemm_qk = true;
|
||||
if constexpr (!skip_gemm_qk) {
|
||||
// clear out accumulators
|
||||
initialize_accum_regs<0>();
|
||||
initialize_accum_regs<1>();
|
||||
|
||||
// load Q
|
||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_ROW,
|
||||
HEADDIM>(
|
||||
dim_seqlen, 0 /*FIXME: only work on first B_ROW rows of Q for now*/,
|
||||
0 /* always 0 because dim_k == headdim */, gmem_Q, smem_Q,
|
||||
tid_in_threadblock);
|
||||
static_assert(B_ROW == B_COL, "currently only supports square tiles");
|
||||
|
||||
// load K
|
||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
|
||||
HEADDIM>(dim_seqlen, tile_k,
|
||||
0 /* always 0 because dim_k == headdim */,
|
||||
gmem_K, smem_K, tid_in_threadblock);
|
||||
// load Q
|
||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_ROW,
|
||||
HEADDIM>(
|
||||
dim_seqlen, 0 /*FIXME: only work on first B_ROW rows of Q for now*/,
|
||||
0 /* always 0 because dim_k == headdim */, gmem_Q, smem_Q,
|
||||
tid_in_threadblock);
|
||||
|
||||
// GMEM->SMEM and compute barrier
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
// load K
|
||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
|
||||
HEADDIM>(dim_seqlen, tile_k,
|
||||
0 /* always 0 because dim_k == headdim */,
|
||||
gmem_K, smem_K, tid_in_threadblock);
|
||||
|
||||
// GEMM I: S = Q*K
|
||||
thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
||||
MemLayout::MN_major,
|
||||
/*load_accum=*/false,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_threadblock,
|
||||
threads_per_threadblock, threadblocks_per_cluster,
|
||||
threadblock_id_in_cluster);
|
||||
// GMEM->SMEM and compute barrier
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
|
||||
// GEMM I: S = Q*K
|
||||
thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
||||
MemLayout::MN_major,
|
||||
/*load_accum=*/false,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_threadblock,
|
||||
threads_per_threadblock, threadblocks_per_cluster,
|
||||
threadblock_id_in_cluster);
|
||||
|
||||
// tile_S = smem_S;
|
||||
} else {
|
||||
// load Q*K
|
||||
load_tile_to_smem<float, MemLayout::K_major, MemLayout::K_major, B_COL,
|
||||
HEADDIM>(dim_seqlen, 0, tile_k, gmem_Q /*=gmem_S*/,
|
||||
smem_S, tid_in_threadblock);
|
||||
// the above should be equivalent to:
|
||||
// load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
|
||||
// HEADDIM>(dim_seqlen, tile_k, 0, gmem_Q /*=gmem_S*/,
|
||||
// smem_S, tid_in_threadblock);
|
||||
|
||||
// tile_S = reinterpret_cast<float *>(arg->addr_q);
|
||||
}
|
||||
|
||||
// protect GEMM result writes (smem_S) before softmax
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
|
||||
const float *tile_S = (float *)smem_S;
|
||||
#else
|
||||
float *tile_S = (float *)arg->addr_q;
|
||||
#endif
|
||||
|
||||
thread_block_online_softmax(
|
||||
tile_S, smem_O, smem_P, tid_in_threadblock, threads_per_threadblock,
|
||||
smem_S, smem_O, smem_P, tid_in_threadblock, threads_per_threadblock,
|
||||
threadblocks_per_cluster, threadblock_id_in_cluster, smem_scratchpad,
|
||||
smem_rowmax, smem_rowsum);
|
||||
|
||||
@@ -535,7 +547,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
threads_per_threadblock,
|
||||
threadblocks_per_cluster,
|
||||
threadblock_id_in_cluster);
|
||||
} else if (tile_k == 1) {
|
||||
} else if (tile_k == k_tiles - 1) {
|
||||
thread_block_copy_tile(
|
||||
smem_P, gmem_tmp_d1, tid_in_threadblock, threads_per_threadblock,
|
||||
threadblocks_per_cluster, threadblock_id_in_cluster);
|
||||
@@ -565,8 +577,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// V dimension is [seqlen, headdim], stored N(headdim)-major
|
||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
|
||||
HEADDIM>(
|
||||
HEADDIM, 0 /* 0 because always reads the full N-dimension */,
|
||||
tile_k * B_COL, gmem_V, smem_V, tid_in_threadblock);
|
||||
HEADDIM, 0 /* 0 because always reads the full N-dimension */, tile_k,
|
||||
gmem_V, smem_V, tid_in_threadblock);
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
@@ -588,7 +600,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
thread_block_copy_tile(
|
||||
smem_O, gmem_tmp_d4, tid_in_threadblock, threads_per_threadblock,
|
||||
threadblocks_per_cluster, threadblock_id_in_cluster);
|
||||
} else if (tile_k == 1) {
|
||||
} else if (tile_k == k_tiles - 1) {
|
||||
thread_block_copy_tile(
|
||||
smem_O, gmem_tmp_d5, tid_in_threadblock, threads_per_threadblock,
|
||||
threadblocks_per_cluster, threadblock_id_in_cluster);
|
||||
|
||||
Reference in New Issue
Block a user