flash: Do accumulation of PV into O using the single_tile API
This commit is contained in:
@@ -371,10 +371,13 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
|
||||
// GEMM I: S = Q*K
|
||||
thread_block_gemm_single_tile<float, MemLayout::MN_major, MemLayout::MN_major,
|
||||
/*load_accum=*/false,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_Q, smem_K, smem_S, tid_in_threadblock, threads_per_threadblock,
|
||||
threadblocks_per_cluster, threadblock_id_in_cluster);
|
||||
smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_threadblock,
|
||||
threads_per_threadblock, threadblocks_per_cluster,
|
||||
threadblock_id_in_cluster);
|
||||
#endif
|
||||
|
||||
// protect GEMM result writes (smem_S) before softmax
|
||||
@@ -395,21 +398,24 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
|
||||
// GEMM II: O = O + P*V
|
||||
|
||||
// clear out accumulators
|
||||
initialize_accum_regs<0>();
|
||||
initialize_accum_regs<1>();
|
||||
|
||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, BN, BK>(
|
||||
B_COL, 0 /*FIXME*/, 0 /*FIXME*/, gmem_V, smem_V, tid_in_threadblock);
|
||||
B_COL, 0, 0, gmem_V, smem_V, tid_in_threadblock);
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
|
||||
// FIXME: support MN_major for A for ideal performance
|
||||
thread_block_gemm_single_tile<float, MemLayout::K_major, MemLayout::MN_major,
|
||||
/*load_accum=*/false,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_P, smem_V, gmem_O /*smem_O*/, tid_in_threadblock,
|
||||
threads_per_threadblock, threadblocks_per_cluster,
|
||||
smem_P, smem_V, smem_O, gmem_O /*smem_O*/,
|
||||
tid_in_threadblock, threads_per_threadblock, threadblocks_per_cluster,
|
||||
threadblock_id_in_cluster);
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
|
||||
Reference in New Issue
Block a user