flash: Restructure for full software pipelining

Verified up to P and O before PV; need to fix iteration for V load.
This commit is contained in:
Hansung Kim
2024-09-08 18:45:32 -07:00
parent cdb8377b62
commit 8efa6868ea
2 changed files with 239 additions and 206 deletions

View File

@@ -184,10 +184,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary
// delay warpgroup 0 by 1 iteration to do ping-pong scheduling
if (WARP_SPECIALIZED && warpgroup_id == 1) {
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
}
// // delay warpgroup 0 by 1 iteration to do ping-pong scheduling
// if (WARP_SPECIALIZED && warpgroup_id == 1) {
// threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
// }
static_assert(!GEMMINI_DMA || Q_IS_K_MAJOR,
"DMA code assumes Q matrix is stored K-major");
@@ -196,6 +196,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
constexpr uint32_t skips =
loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/0, /*skip_ldd=*/1,
/*skip_ex=*/1, /*skip_stc=*/1);
constexpr uint32_t skips_only_a =
loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/1, /*skip_ldd=*/1,
/*skip_ex=*/1, /*skip_stc=*/1);
constexpr uint32_t skips_mvout_spad =
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1,
/*skip_ex=*/1, /*skip_stc=*/0);
@@ -248,6 +251,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
#define GEMMINI_DMA_CISC
#ifdef GEMMINI_DMA_CISC
// the target addresses of this should match with spad_addr_Q0 and
// spad_addr_K0 set in this kernel
GEMMINI_CISC_CMD_I(10);
gemmini_fence();
#else
@@ -292,15 +297,30 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// "inner loop" along the columns of K^T
const uint32_t k_tiles = (dim_seqlen / B_COL);
for (uint32_t tile_k = 0; tile_k < k_tiles; tile_k++) {
for (uint32_t tile_k = 0; tile_k < k_tiles + 2 /*pipeline latency*/;
tile_k++) {
if constexpr (DEBUG) {
// barrier for debugging
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
}
// select the correct double buffer by tile iteration
// FIXME do correct double buffering
float *smem_Q = (tile_k & 1) ? smem_Q1 : smem_Q0;
float *smem_K = (tile_k & 1) ? smem_K1 : smem_K0;
float *smem_V = (tile_k & 1) ? smem_V1 : smem_V0;
float *smem_S = (tile_k & 1) ? smem_S1 : smem_S0;
float *smem_P = (tile_k & 1) ? smem_P1 : smem_P0;
float *smem_O = (tile_k & 1) ? smem_O1 : smem_O0;
// all iterations work on the same Q row tile; no ping-pong necessary
asm volatile ("dbuf_sel_start_%=:" :: );
// FIXME speedup by doing arithmetic
float *smem_Q = smem_Q0;
float *smem_K_consume = (tile_k & 1) ? smem_K1 : smem_K0;
float *smem_K_produce = (tile_k & 1) ? smem_K0 : smem_K1;
float *smem_V_consume = (tile_k & 1) ? smem_V1 : smem_V0;
float *smem_V_produce = (tile_k & 1) ? smem_V0 : smem_V1;
float *smem_S_consume = (tile_k & 1) ? smem_S1 : smem_S0;
float *smem_S_produce = (tile_k & 1) ? smem_S0 : smem_S1;
float *smem_P_consume = (tile_k & 1) ? smem_P1 : smem_P0;
float *smem_P_produce = (tile_k & 1) ? smem_P0 : smem_P1;
// O tile is sequentially updated at every iteration; no ping-pong
// necessary
float *smem_O = smem_O0;
// FIXME: O_row_scale/rowmax/rowsum/spad shouldn't really need ping-pong
float *smem_O_row_scale =
(tile_k & 1) ? smem_O_row_scale_1 : smem_O_row_scale_0;
float *smem_rowmax = (tile_k & 1) ? smem_rowmax_1 : smem_rowmax_0;
@@ -308,28 +328,111 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
float *smem_scratchpad =
(tile_k & 1) ? smem_scratchpad_1 : smem_scratchpad_0;
const auto spad_addr_Q = (tile_k & 1) ? spad_addr_Q1 : spad_addr_Q0;
const auto spad_addr_K = (tile_k & 1) ? spad_addr_K1 : spad_addr_K0;
const auto spad_addr_V = (tile_k & 1) ? spad_addr_V1 : spad_addr_V0;
const auto spad_addr_S = (tile_k & 1) ? spad_addr_S1 : spad_addr_S0;
const auto spad_addr_P = (tile_k & 1) ? spad_addr_P1 : spad_addr_P0;
const auto spad_addr_Q = spad_addr_Q0;
const auto spad_addr_K_consume = (tile_k & 1) ? spad_addr_K1 : spad_addr_K0;
const auto spad_addr_K_produce = (tile_k & 1) ? spad_addr_K0 : spad_addr_K1;
const auto spad_addr_V_consume = (tile_k & 1) ? spad_addr_V1 : spad_addr_V0;
const auto spad_addr_V_produce = (tile_k & 1) ? spad_addr_V0 : spad_addr_V1;
const auto spad_addr_S_consume = (tile_k & 1) ? spad_addr_S1 : spad_addr_S0;
const auto spad_addr_S_produce = (tile_k & 1) ? spad_addr_S0 : spad_addr_S1;
const auto spad_addr_P_consume = (tile_k & 1) ? spad_addr_P1 : spad_addr_P0;
const auto spad_addr_P_produce = (tile_k & 1) ? spad_addr_P0 : spad_addr_P1;
const auto spad_addr_O = spad_addr_O0; // NOTE: there's only single O tile
asm volatile ("dbuf_sel_end_%=:" :: );
// GEMM I: S = Q*K
//
asm volatile("gemm_qk_start_%=:" ::);
// GEMM II: O = O + P*V
// --------------------
// This is done *before* GEMM I in the software pipeline, working on the
// online softmax result tile from the previous iteration
if (tid_in_warpgroup == 0) {
if (tile_k == 0) {
if (tile_k >= 2) // delay by 2 iters for pipelining
{
const uint32_t tile_k_ = tile_k - 2;
asm volatile("gemm_pv_start_%=:" ::);
if (tid_in_warpgroup == 0) {
#if 0
if (tile_k_ == 0) {
gemmini_fence();
GEMMINI_CISC_CMD_I(0);
} else if (tile_k & 1) {
} else if (tile_k_ & 1) {
gemmini_fence();
GEMMINI_CISC_CMD_I(2);
} else {
gemmini_fence();
GEMMINI_CISC_CMD_I(1);
}
#else
// do matmul
// among other things, this also configures CONFIG_BOUNDS so that the
// DMA knows the full matrix dimensions
sp_tiled_matmul_full_spad_ws(
spad_addr_P_consume, spad_addr_V_consume,
/*spad_D=*/0, /*spad_C=*/spad_addr_O,
/*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul);
#endif
gemmini_fence();
gemmini_fence();
gemmini_fence();
gemmini_fence();
// mvout to SMEM
// GEMMINI_CISC_CMD_I(9);
sp_tiled_matmul_full_spad_ws(
/*spad_A=*/spad_addr_P_consume, /*spad_B=*/spad_addr_V_consume,
/*spad_D=*/0, /*spad_C=*/spad_addr_O,
/*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad);
gemmini_fence();
if constexpr (DEBUG) {
// for copy-out to GMEM
gemmini_fence();
}
}
// reconverge from mmio divergence
threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core);
asm volatile("gemm_pv_finish_%=:" ::);
if constexpr (DEBUG) {
if (warpgroup_id == 0) {
// O after PV
if (tile_k_ == 0) {
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
smem_O, gmem_tmp_d6, tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
} else if (tile_k_ == 1) {
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
smem_O, gmem_tmp_d7, tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
}
threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core);
}
}
}
// GEMM I: S = Q*K
//
asm volatile("gemm_qk_start_%=:" ::);
if (tid_in_warpgroup == 0) {
// 0,2,.: opcode 0 (quartile 0/2, no accum)
// 1,3,.: opcode 3 (quartile 1/3, no accum)
const uint32_t opcode = 3 * (tile_k & 1);
gemmini_fence();
GEMMINI_CISC_CMD_I(opcode);
gemmini_fence();
gemmini_fence();
@@ -352,8 +455,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// mvout to SMEM
// GEMMINI_CISC_CMD_I(9);
sp_tiled_matmul_full_spad_ws(
/*spad_A=*/spad_addr_Q, /*spad_B=*/spad_addr_K,
/*spad_D=*/0, /*spad_C=*/spad_addr_S,
/*spad_A=*/spad_addr_Q, /*spad_B=*/spad_addr_K_consume,
/*spad_D=*/0, /*spad_C=*/spad_addr_S_produce,
/*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM),
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
@@ -375,11 +478,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
if (warpgroup_id == 0) {
if (tile_k == 0) {
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
smem_S, gmem_tmp_d0, tid_in_warpgroup, threads_per_warpgroup,
smem_S_produce, gmem_tmp_d0, tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
} else if (tile_k == 1) {
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
smem_S, gmem_tmp_d1, tid_in_warpgroup, threads_per_warpgroup,
smem_S_produce, gmem_tmp_d1, tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
}
@@ -388,39 +491,76 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
}
}
// inter-warpgroup barrier before online softmax
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
if (tile_k >= 1) // delay by 1 iters for pipelining
{
const uint32_t tile_k_ = tile_k - 1;
// Online softmax
//
thread_block_online_softmax</*block_row_major=*/GEMMINI_DMA>(
smem_S, smem_P, tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster, smem_scratchpad, smem_rowmax, smem_rowsum,
smem_O_row_scale);
// Online softmax
//
thread_block_online_softmax</*block_row_major=*/GEMMINI_DMA>(
smem_S_consume, smem_P_produce, tid_in_warpgroup,
threads_per_warpgroup, warpgroup_id_in_cluster, smem_scratchpad,
smem_rowmax, smem_rowsum, smem_O_row_scale);
// FIXME: unnecessary?
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core);
if constexpr (DEBUG) {
if (warpgroup_id == 0) {
if (tile_k == 0) {
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_warpgroup,
threads_per_warpgroup,
warpgroup_id_in_cluster);
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_warpgroup,
threads_per_warpgroup,
warpgroup_id_in_cluster);
} else if (tile_k == 1) {
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_warpgroup,
threads_per_warpgroup,
warpgroup_id_in_cluster);
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, tid_in_warpgroup,
threads_per_warpgroup,
warpgroup_id_in_cluster);
if constexpr (DEBUG) {
if (warpgroup_id == 0) {
if (tile_k_ == 0) {
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_warpgroup,
threads_per_warpgroup,
warpgroup_id_in_cluster);
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_warpgroup,
threads_per_warpgroup,
warpgroup_id_in_cluster);
} else if (tile_k_ == 1) {
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_warpgroup,
threads_per_warpgroup,
warpgroup_id_in_cluster);
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, tid_in_warpgroup,
threads_per_warpgroup,
warpgroup_id_in_cluster);
}
threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core);
}
}
threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core);
// TODO: put a synchronization here with GEMM-II
// Oi rescale
thread_block_O_rescale</*block_row_major=*/GEMMINI_DMA>(
smem_O, smem_O /*in-place*/, smem_O_row_scale, tid_in_warpgroup,
threads_per_warpgroup, warpgroup_id_in_cluster);
// rescale-to-PV-GEMM barrier
threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core);
if constexpr (DEBUG) {
if (warpgroup_id == 0) {
// O before PV
if (tile_k_ == 0) {
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
smem_P_produce, gmem_tmp_d2, tid_in_warpgroup,
threads_per_warpgroup, warpgroup_id_in_cluster);
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
smem_O, gmem_tmp_d4, tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
} else if (tile_k_ == 1) {
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
smem_P_produce, gmem_tmp_d3, tid_in_warpgroup,
threads_per_warpgroup, warpgroup_id_in_cluster);
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
smem_O, gmem_tmp_d5, tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
}
threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core);
}
}
}
@@ -428,171 +568,64 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
//
// Q stays in SMEM for the entire loop
asm volatile("move_k_v_start_%=:" ::);
if constexpr (GEMMINI_DMA) {
// NOTE: Beware of race conditions; with warp specialization, we need to
// make sure below command code to DMA is not executed simultaneously
// from the two warpgroups (which will result in hardware fault).
// Currently the ping-pong scheduling scheme prevents that.
if (tid_in_warpgroup == 0) {
// configure GMEM addresses for K and V tiles
// load K for the next iteration
const float *gmem_K_tile = gmem_K + (B_COL * (tile_k + 1));
// load V for the current iteration
const float *gmem_V_tile = gmem_V + (HEADDIM * B_COL * tile_k);
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile),
(uint64_t)(gmem_V_tile),
k_LOOP_WS_CONFIG_ADDRS_AB)
// configure address strides for the DMA
// FIXME: unnecessary?
GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) |
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
gemmini_fence();
// do DMA
// NOTE: Beware of race conditions; with warp specialization, we need to
// make sure below command code to DMA is not executed simultaneously
// from the two warpgroups (which will result in hardware fault).
// Currently the ping-pong scheduling scheme prevents that.
if (tid_in_warpgroup == 0) {
// configure GMEM addresses for K and V tiles
// load K for the next iteration
const float *gmem_K_tile = gmem_K + (B_COL * (tile_k + 1 /*runahead*/));
// load V for the *previous* iteration; this will be consumed 2
// iterations later
const float *gmem_V_tile =
gmem_V + (HEADDIM * B_COL * (tile_k - 1 /*dragbehind*/));
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile),
(uint64_t)(gmem_V_tile),
k_LOOP_WS_CONFIG_ADDRS_AB)
// configure address strides for the DMA
// FIXME: unnecessary?
GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) |
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
gemmini_fence();
// do DMA
if (tile_k == 0) {
// we load (k-1)th tile for V; skip V for the 1st iteration,
sp_tiled_matmul_full_spad_ws(
spad_addr_K, spad_addr_V,
/*spad_D=*/0, /*spad_C=*/spad_addr_S,
spad_addr_K_produce, spad_addr_V_produce,
/*spad_D=*/0, /*spad_C=*/spad_addr_S_produce /*FIXME:bogus*/,
/*I=*/(HEADDIM / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a);
} else {
sp_tiled_matmul_full_spad_ws(
spad_addr_K_produce, spad_addr_V_produce,
/*spad_D=*/0, /*spad_C=*/spad_addr_S_produce /*FIXME:bogus*/,
/*I=*/(HEADDIM / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
gemmini_fence();
}
} else {
// load K for the next iteration
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
HEADDIM, threads_per_warpgroup>(
dim_seqlen, tile_k + 1, 0 /* dim_k == headdim */, gmem_K, smem_K,
tid_in_warpgroup);
// load V for the current iteration
// V dimension is [seqlen, headdim], stored N(headdim)-major
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
HEADDIM, threads_per_warpgroup>(
HEADDIM, 0 /* full N-dimension */, tile_k, gmem_V, smem_V,
tid_in_warpgroup);
gemmini_fence();
}
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
asm volatile("move_k_v_finish_%=:" ::);
// inter-warpgroup barrier before GEMM II
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
// Oi rescale
thread_block_O_rescale</*block_row_major=*/GEMMINI_DMA>(
smem_O, smem_O /*in-place*/, smem_O_row_scale, tid_in_warpgroup,
threads_per_warpgroup, warpgroup_id_in_cluster);
// rescale-to-PV-GEMM barrier
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
if constexpr (DEBUG) {
if (warpgroup_id == 0) {
// O before PV
if (tile_k == 0) {
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
smem_P, gmem_tmp_d2, tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
smem_O, gmem_tmp_d4, tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
} else if (tile_k == 1) {
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
smem_P, gmem_tmp_d3, tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
smem_O, gmem_tmp_d5, tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
}
threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core);
}
}
// GEMM II: O = O + P*V
asm volatile("gemm_pv_start_%=:" ::);
if (tid_in_warpgroup == 0) {
#if 0
if (tile_k == 0) {
gemmini_fence();
GEMMINI_CISC_CMD_I(0);
} else if (tile_k & 1) {
gemmini_fence();
GEMMINI_CISC_CMD_I(2);
} else {
gemmini_fence();
GEMMINI_CISC_CMD_I(1);
}
#else
// do matmul
// among other things, this also configures CONFIG_BOUNDS so that the
// DMA knows the full matrix dimensions
sp_tiled_matmul_full_spad_ws(
spad_addr_P, spad_addr_V,
/*spad_D=*/0, /*spad_C=*/spad_addr_O,
/*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul);
#endif
gemmini_fence();
gemmini_fence();
gemmini_fence();
gemmini_fence();
// mvout to SMEM
// GEMMINI_CISC_CMD_I(9);
sp_tiled_matmul_full_spad_ws(
/*spad_A=*/spad_addr_P, /*spad_B=*/spad_addr_V,
/*spad_D=*/0, /*spad_C=*/spad_addr_O,
/*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL/ DIM),
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad);
gemmini_fence();
if constexpr (DEBUG) {
// for copy-out to GMEM
gemmini_fence();
}
}
// reconverge from mmio divergence
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
asm volatile("gemm_pv_finish_%=:" ::);
if constexpr (DEBUG) {
if (warpgroup_id == 0) {
// O after PV
if (tile_k == 0) {
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
smem_O, gmem_tmp_d6, tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
} else if (tile_k == 1) {
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
smem_O, gmem_tmp_d7, tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
}
threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core);
}
}
#if 0
#endif
}
asm volatile ("tile_loop_finish_%=:" :: );
// wait for warpgroup 1 to finish, which called the global barrier before
// entering the loop
if (warpgroup_id == 0) {
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
}
// // wait for warpgroup 1 to finish, which called the global barrier before
// // entering the loop
// if (warpgroup_id == 0) {
// threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
// }
}
int main() {

View File

@@ -73,7 +73,7 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER ==
#define TRANSPOSE_AT_CONSUME 0
#define GEMMINI_DMA 1
#define GEMMINI_DMA_FAST 1
#define GEMMINI_DMA_FAST 0
#define GEMMINI_DMA_FLEXIBLE_LAYOUT 1
#if SMEM_SIZE == 0x4000
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)