flash: Enable skipping Q*K for larger dimensions
This commit is contained in:
@@ -465,53 +465,65 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
float *gmem_tmp_e3 = reinterpret_cast<float *>(0xe3000000UL);
|
float *gmem_tmp_e3 = reinterpret_cast<float *>(0xe3000000UL);
|
||||||
|
|
||||||
// "inner loop" along the columns of K^T
|
// "inner loop" along the columns of K^T
|
||||||
for (uint32_t tile_k = 0; tile_k < (dim_seqlen / B_COL); tile_k++) {
|
const uint32_t k_tiles = (dim_seqlen / B_COL);
|
||||||
|
for (uint32_t tile_k = 0; tile_k < k_tiles; tile_k++) {
|
||||||
|
|
||||||
// #define SKIP_GEMM
|
const float *tile_S = nullptr;
|
||||||
#ifndef SKIP_GEMM
|
|
||||||
// clear out accumulators
|
|
||||||
initialize_accum_regs<0>();
|
|
||||||
initialize_accum_regs<1>();
|
|
||||||
|
|
||||||
static_assert(B_ROW == B_COL, "currently only supports square tiles");
|
constexpr bool skip_gemm_qk = true;
|
||||||
|
if constexpr (!skip_gemm_qk) {
|
||||||
|
// clear out accumulators
|
||||||
|
initialize_accum_regs<0>();
|
||||||
|
initialize_accum_regs<1>();
|
||||||
|
|
||||||
// load Q
|
static_assert(B_ROW == B_COL, "currently only supports square tiles");
|
||||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_ROW,
|
|
||||||
HEADDIM>(
|
|
||||||
dim_seqlen, 0 /*FIXME: only work on first B_ROW rows of Q for now*/,
|
|
||||||
0 /* always 0 because dim_k == headdim */, gmem_Q, smem_Q,
|
|
||||||
tid_in_threadblock);
|
|
||||||
|
|
||||||
// load K
|
// load Q
|
||||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
|
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_ROW,
|
||||||
HEADDIM>(dim_seqlen, tile_k,
|
HEADDIM>(
|
||||||
0 /* always 0 because dim_k == headdim */,
|
dim_seqlen, 0 /*FIXME: only work on first B_ROW rows of Q for now*/,
|
||||||
gmem_K, smem_K, tid_in_threadblock);
|
0 /* always 0 because dim_k == headdim */, gmem_Q, smem_Q,
|
||||||
|
tid_in_threadblock);
|
||||||
|
|
||||||
// GMEM->SMEM and compute barrier
|
// load K
|
||||||
threadblock_barrier(threadblock_id_in_cluster,
|
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
|
||||||
warps_per_threadblock_per_core);
|
HEADDIM>(dim_seqlen, tile_k,
|
||||||
|
0 /* always 0 because dim_k == headdim */,
|
||||||
|
gmem_K, smem_K, tid_in_threadblock);
|
||||||
|
|
||||||
// GEMM I: S = Q*K
|
// GMEM->SMEM and compute barrier
|
||||||
thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
MemLayout::MN_major,
|
warps_per_threadblock_per_core);
|
||||||
/*load_accum=*/false,
|
|
||||||
/*write_to_smem=*/true>(
|
// GEMM I: S = Q*K
|
||||||
smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_threadblock,
|
thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
||||||
threads_per_threadblock, threadblocks_per_cluster,
|
MemLayout::MN_major,
|
||||||
threadblock_id_in_cluster);
|
/*load_accum=*/false,
|
||||||
|
/*write_to_smem=*/true>(
|
||||||
|
smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_threadblock,
|
||||||
|
threads_per_threadblock, threadblocks_per_cluster,
|
||||||
|
threadblock_id_in_cluster);
|
||||||
|
|
||||||
|
// tile_S = smem_S;
|
||||||
|
} else {
|
||||||
|
// load Q*K
|
||||||
|
load_tile_to_smem<float, MemLayout::K_major, MemLayout::K_major, B_COL,
|
||||||
|
HEADDIM>(dim_seqlen, 0, tile_k, gmem_Q /*=gmem_S*/,
|
||||||
|
smem_S, tid_in_threadblock);
|
||||||
|
// the above should be equivalent to:
|
||||||
|
// load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
|
||||||
|
// HEADDIM>(dim_seqlen, tile_k, 0, gmem_Q /*=gmem_S*/,
|
||||||
|
// smem_S, tid_in_threadblock);
|
||||||
|
|
||||||
|
// tile_S = reinterpret_cast<float *>(arg->addr_q);
|
||||||
|
}
|
||||||
|
|
||||||
// protect GEMM result writes (smem_S) before softmax
|
// protect GEMM result writes (smem_S) before softmax
|
||||||
threadblock_barrier(threadblock_id_in_cluster,
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
warps_per_threadblock_per_core);
|
warps_per_threadblock_per_core);
|
||||||
|
|
||||||
const float *tile_S = (float *)smem_S;
|
|
||||||
#else
|
|
||||||
float *tile_S = (float *)arg->addr_q;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
thread_block_online_softmax(
|
thread_block_online_softmax(
|
||||||
tile_S, smem_O, smem_P, tid_in_threadblock, threads_per_threadblock,
|
smem_S, smem_O, smem_P, tid_in_threadblock, threads_per_threadblock,
|
||||||
threadblocks_per_cluster, threadblock_id_in_cluster, smem_scratchpad,
|
threadblocks_per_cluster, threadblock_id_in_cluster, smem_scratchpad,
|
||||||
smem_rowmax, smem_rowsum);
|
smem_rowmax, smem_rowsum);
|
||||||
|
|
||||||
@@ -535,7 +547,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
threads_per_threadblock,
|
threads_per_threadblock,
|
||||||
threadblocks_per_cluster,
|
threadblocks_per_cluster,
|
||||||
threadblock_id_in_cluster);
|
threadblock_id_in_cluster);
|
||||||
} else if (tile_k == 1) {
|
} else if (tile_k == k_tiles - 1) {
|
||||||
thread_block_copy_tile(
|
thread_block_copy_tile(
|
||||||
smem_P, gmem_tmp_d1, tid_in_threadblock, threads_per_threadblock,
|
smem_P, gmem_tmp_d1, tid_in_threadblock, threads_per_threadblock,
|
||||||
threadblocks_per_cluster, threadblock_id_in_cluster);
|
threadblocks_per_cluster, threadblock_id_in_cluster);
|
||||||
@@ -565,8 +577,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
// V dimension is [seqlen, headdim], stored N(headdim)-major
|
// V dimension is [seqlen, headdim], stored N(headdim)-major
|
||||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
|
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
|
||||||
HEADDIM>(
|
HEADDIM>(
|
||||||
HEADDIM, 0 /* 0 because always reads the full N-dimension */,
|
HEADDIM, 0 /* 0 because always reads the full N-dimension */, tile_k,
|
||||||
tile_k * B_COL, gmem_V, smem_V, tid_in_threadblock);
|
gmem_V, smem_V, tid_in_threadblock);
|
||||||
|
|
||||||
threadblock_barrier(threadblock_id_in_cluster,
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
warps_per_threadblock_per_core);
|
warps_per_threadblock_per_core);
|
||||||
@@ -588,7 +600,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
thread_block_copy_tile(
|
thread_block_copy_tile(
|
||||||
smem_O, gmem_tmp_d4, tid_in_threadblock, threads_per_threadblock,
|
smem_O, gmem_tmp_d4, tid_in_threadblock, threads_per_threadblock,
|
||||||
threadblocks_per_cluster, threadblock_id_in_cluster);
|
threadblocks_per_cluster, threadblock_id_in_cluster);
|
||||||
} else if (tile_k == 1) {
|
} else if (tile_k == k_tiles - 1) {
|
||||||
thread_block_copy_tile(
|
thread_block_copy_tile(
|
||||||
smem_O, gmem_tmp_d5, tid_in_threadblock, threads_per_threadblock,
|
smem_O, gmem_tmp_d5, tid_in_threadblock, threads_per_threadblock,
|
||||||
threadblocks_per_cluster, threadblock_id_in_cluster);
|
threadblocks_per_cluster, threadblock_id_in_cluster);
|
||||||
|
|||||||
Reference in New Issue
Block a user