flash: Do accumulation of PV into O using the single_tile API

This commit is contained in:
Hansung Kim
2024-08-19 18:03:06 -07:00
parent 03c61d72ff
commit 64e48de8af

View File

@@ -371,10 +371,13 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
// GEMM I: S = Q*K
thread_block_gemm_single_tile<float, MemLayout::MN_major, MemLayout::MN_major,
/*load_accum=*/false,
/*write_to_smem=*/true>(
smem_Q, smem_K, smem_S, tid_in_threadblock, threads_per_threadblock,
threadblocks_per_cluster, threadblock_id_in_cluster);
smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_threadblock,
threads_per_threadblock, threadblocks_per_cluster,
threadblock_id_in_cluster);
#endif
// protect GEMM result writes (smem_S) before softmax
@@ -395,21 +398,24 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
// GEMM II: O = O + P*V
// clear out accumulators
initialize_accum_regs<0>();
initialize_accum_regs<1>();
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, BN, BK>(
B_COL, 0 /*FIXME*/, 0 /*FIXME*/, gmem_V, smem_V, tid_in_threadblock);
B_COL, 0, 0, gmem_V, smem_V, tid_in_threadblock);
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
// FIXME: support MN_major for A for ideal performance
thread_block_gemm_single_tile<float, MemLayout::K_major, MemLayout::MN_major,
/*load_accum=*/false,
/*write_to_smem=*/true>(
smem_P, smem_V, gmem_O /*smem_O*/, tid_in_threadblock,
threads_per_threadblock, threadblocks_per_cluster,
smem_P, smem_V, smem_O, gmem_O /*smem_O*/,
tid_in_threadblock, threads_per_threadblock, threadblocks_per_cluster,
threadblock_id_in_cluster);
threadblock_barrier(threadblock_id_in_cluster,