flash: Do GEMM II in Gemmini; verify 1st iteration
This commit is contained in:
@@ -95,7 +95,7 @@ inline void thread_block_copy_rowmax(const float *src, float *dest,
|
||||
asm volatile("threadblock_copy_rowmax_finish_%=:" ::);
|
||||
}
|
||||
|
||||
template <uint32_t dim_row, uint32_t dim_col>
|
||||
template <uint32_t dim_row, uint32_t dim_col, bool block_row_major = false>
|
||||
inline void thread_block_copy_tile(const float *src, float *dest,
|
||||
const uint32_t tid_in_threadblock,
|
||||
const uint32_t threads_per_threadblock,
|
||||
@@ -113,14 +113,18 @@ inline void thread_block_copy_tile(const float *src, float *dest,
|
||||
for (int row_offset = 0; row_offset < dim_row;
|
||||
row_offset += warps_in_threadblock) {
|
||||
const uint32_t row = row_offset + warp_id;
|
||||
const uint32_t first_thread_offset = dim_col * row;
|
||||
|
||||
constexpr uint32_t per_row_iter = dim_col / NUM_THREADS;
|
||||
uint32_t thread_offset = first_thread_offset + tid_in_warp;
|
||||
#pragma GCC unroll
|
||||
for (int i = 0; i < per_row_iter; i++) {
|
||||
dest[thread_offset] = src[thread_offset];
|
||||
thread_offset += NUM_THREADS;
|
||||
const uint32_t col_offset = NUM_THREADS * i;
|
||||
const uint32_t col = col_offset + tid_in_warp;
|
||||
const auto [smem_row, smem_col] =
|
||||
remap_to_gemmini_dma_layout<block_row_major, B_COL>(row, col);
|
||||
const uint32_t smem_offset = B_COL * smem_row + smem_col;
|
||||
const uint32_t gmem_offset = B_COL * row + col;
|
||||
|
||||
dest[gmem_offset] = src[smem_offset];
|
||||
}
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
@@ -415,6 +419,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
asm volatile("thread_block_online_softmax_finish_%=:" ::);
|
||||
}
|
||||
|
||||
template <bool block_row_major = false>
|
||||
__attribute__((always_inline)) inline void thread_block_O_rescale(
|
||||
const float *smem_O_in, float *smem_O_out, const float *smem_O_row_scale,
|
||||
const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock,
|
||||
@@ -431,19 +436,21 @@ __attribute__((always_inline)) inline void thread_block_O_rescale(
|
||||
for (int row_offset = 0; row_offset < B_ROW;
|
||||
row_offset += warps_in_threadblock) {
|
||||
const uint32_t row = row_offset + warp_id;
|
||||
const uint32_t first_thread_offset = B_COL * row;
|
||||
constexpr uint32_t per_row_iter = B_COL / NUM_THREADS;
|
||||
uint32_t thread_offset = first_thread_offset + tid_in_warp;
|
||||
constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS;
|
||||
|
||||
// Oi rescale
|
||||
//
|
||||
#pragma GCC unroll
|
||||
for (int i = 0; i < per_row_iter; i++) {
|
||||
const float o = smem_O_in[thread_offset];
|
||||
const float scale = smem_O_row_scale[row];
|
||||
smem_O_out[thread_offset] = (o * scale);
|
||||
const uint32_t col_offset = NUM_THREADS * i;
|
||||
const uint32_t col = col_offset + tid_in_warp;
|
||||
const auto [smem_row, smem_col] =
|
||||
remap_to_gemmini_dma_layout<block_row_major, HEADDIM>(row, col);
|
||||
|
||||
thread_offset += NUM_THREADS;
|
||||
const uint32_t offset = HEADDIM * smem_row + smem_col;
|
||||
const float o = smem_O_in[offset];
|
||||
const float scale = smem_O_row_scale[row];
|
||||
smem_O_out[offset] = (o * scale);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -110,8 +110,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
smem_cursor += smem_QK_size;
|
||||
float *smem_S1 = smem_cursor;
|
||||
smem_cursor += smem_QK_size;
|
||||
float *smem_P0 = smem_S0; // in-place update
|
||||
float *smem_P1 = smem_S1; // in-place update
|
||||
float *smem_P0 = smem_cursor;
|
||||
smem_cursor += smem_QK_size;
|
||||
float *smem_P1 = smem_cursor;
|
||||
smem_cursor += smem_QK_size;
|
||||
float *smem_O0 = smem_cursor;
|
||||
smem_cursor += smem_O_size;
|
||||
float *smem_O1 = smem_cursor;
|
||||
@@ -135,6 +137,14 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
spad_addr_V1 + (smem_V_size * sizeof(float) / spad_addr_factor);
|
||||
constexpr uint32_t spad_addr_S1 =
|
||||
spad_addr_S0 + (smem_QK_size * sizeof(float) / spad_addr_factor);
|
||||
constexpr uint32_t spad_addr_P0 =
|
||||
spad_addr_S1 + (smem_QK_size * sizeof(float) / spad_addr_factor);
|
||||
constexpr uint32_t spad_addr_P1 =
|
||||
spad_addr_P0 + (smem_QK_size * sizeof(float) / spad_addr_factor);
|
||||
constexpr uint32_t spad_addr_O0 =
|
||||
spad_addr_P1 + (smem_QK_size * sizeof(float) / spad_addr_factor);
|
||||
constexpr uint32_t spad_addr_O1 =
|
||||
spad_addr_O0 + (smem_O_size * sizeof(float) / spad_addr_factor);
|
||||
|
||||
// allocate rowmax/rowsum storage at the end of the sharedmem address space
|
||||
constexpr uint32_t smem_rowmax_size = B_ROW * ROWMAX_SETS;
|
||||
@@ -189,6 +199,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
constexpr uint32_t skips_mvout_spad =
|
||||
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1,
|
||||
/*skip_ex=*/1, /*skip_stc=*/0);
|
||||
constexpr uint32_t skips_matmul =
|
||||
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1,
|
||||
/*skip_ex=*/0, /*skip_stc=*/1);
|
||||
|
||||
if constexpr (GEMMINI_DMA) {
|
||||
if (tid_in_warpgroup == 0) {
|
||||
@@ -281,12 +294,13 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
const uint32_t k_tiles = (dim_seqlen / B_COL);
|
||||
for (uint32_t tile_k = 0; tile_k < k_tiles; tile_k++) {
|
||||
// select the correct double buffer by tile iteration
|
||||
float *smem_Q = (tile_k & 1) ? smem_Q1 : smem_Q0; // FIXME
|
||||
float *smem_K = (tile_k & 1) ? smem_K1 : smem_K0; // FIXME
|
||||
float *smem_V = (tile_k & 1) ? smem_V1 : smem_V0; // FIXME
|
||||
float *smem_S = (tile_k & 1) ? smem_S1 : smem_S0; // FIXME
|
||||
float *smem_O = (tile_k & 1) ? smem_O1 : smem_O0; // FIXME
|
||||
float *smem_P = smem_S; // FIXME
|
||||
// FIXME do correct double buffering
|
||||
float *smem_Q = (tile_k & 1) ? smem_Q1 : smem_Q0;
|
||||
float *smem_K = (tile_k & 1) ? smem_K1 : smem_K0;
|
||||
float *smem_V = (tile_k & 1) ? smem_V1 : smem_V0;
|
||||
float *smem_S = (tile_k & 1) ? smem_S1 : smem_S0;
|
||||
float *smem_P = (tile_k & 1) ? smem_P1 : smem_P0;
|
||||
float *smem_O = (tile_k & 1) ? smem_O1 : smem_O0;
|
||||
float *smem_O_row_scale =
|
||||
(tile_k & 1) ? smem_O_row_scale_1 : smem_O_row_scale_0;
|
||||
float *smem_rowmax = (tile_k & 1) ? smem_rowmax_1 : smem_rowmax_0;
|
||||
@@ -298,6 +312,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
const auto spad_addr_K = (tile_k & 1) ? spad_addr_K1 : spad_addr_K0;
|
||||
const auto spad_addr_V = (tile_k & 1) ? spad_addr_V1 : spad_addr_V0;
|
||||
const auto spad_addr_S = (tile_k & 1) ? spad_addr_S1 : spad_addr_S0;
|
||||
const auto spad_addr_P = (tile_k & 1) ? spad_addr_P1 : spad_addr_P0;
|
||||
const auto spad_addr_O = spad_addr_O0; // NOTE: there's only single O tile
|
||||
|
||||
// GEMM I: S = Q*K
|
||||
//
|
||||
@@ -320,7 +336,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
gemmini_fence();
|
||||
gemmini_fence();
|
||||
|
||||
#if 0 // TODO
|
||||
#if 0 // TODO: speed up mvout to SMEM
|
||||
// loop_ws variant that skips configuring strides
|
||||
#define gemmini_loop_ws(I, J, K, pad_I, pad_J, pad_K, A, B, D, C, A_stride, B_stride, D_stride, C_stride, A_transpose, B_transpose, full_C, low_D, ex_accumulate, act, a_spad_id, b_spad_id, is_resadd) \
|
||||
{ \
|
||||
@@ -350,7 +366,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
}
|
||||
}
|
||||
|
||||
// thread reconvergence
|
||||
// reconverge from mmio divergence
|
||||
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||
|
||||
asm volatile("gemm_qk_finish_%=:" ::);
|
||||
@@ -358,13 +374,13 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
if constexpr (DEBUG) {
|
||||
if (warpgroup_id == 0) {
|
||||
if (tile_k == 0) {
|
||||
thread_block_copy_tile<B_ROW, B_COL>(smem_S, gmem_tmp_d0,
|
||||
tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
||||
smem_S, gmem_tmp_d0, tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
} else if (tile_k == 1) {
|
||||
thread_block_copy_tile<B_ROW, B_COL>(smem_S, gmem_tmp_d1,
|
||||
tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
||||
smem_S, gmem_tmp_d1, tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
}
|
||||
|
||||
threadblock_barrier(warpgroup_id_in_cluster,
|
||||
@@ -408,7 +424,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
}
|
||||
}
|
||||
|
||||
#if 0
|
||||
// data movement for K and V
|
||||
//
|
||||
// Q stays in SMEM for the entire loop
|
||||
@@ -463,9 +478,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
||||
|
||||
// Oi rescale
|
||||
thread_block_O_rescale(smem_O, smem_O /*in-place*/,
|
||||
smem_O_row_scale, tid_in_warpgroup,
|
||||
threads_per_warpgroup, warpgroup_id_in_cluster);
|
||||
thread_block_O_rescale</*block_row_major=*/GEMMINI_DMA>(
|
||||
smem_O, smem_O /*in-place*/, smem_O_row_scale, tid_in_warpgroup,
|
||||
threads_per_warpgroup, warpgroup_id_in_cluster);
|
||||
|
||||
// rescale-to-PV-GEMM barrier
|
||||
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||
@@ -474,19 +489,19 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
if (warpgroup_id == 0) {
|
||||
// O before PV
|
||||
if (tile_k == 0) {
|
||||
thread_block_copy_tile<B_ROW, B_COL>(smem_P, gmem_tmp_d2, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
thread_block_copy_tile<B_ROW, HEADDIM>(smem_O, gmem_tmp_d4, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
||||
smem_P, gmem_tmp_d2, tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
|
||||
smem_O, gmem_tmp_d4, tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
} else if (tile_k == 1) {
|
||||
thread_block_copy_tile<B_ROW, B_COL>(smem_P, gmem_tmp_d3, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
thread_block_copy_tile<B_ROW, HEADDIM>(smem_O, gmem_tmp_d5, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
||||
smem_P, gmem_tmp_d3, tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
|
||||
smem_O, gmem_tmp_d5, tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
}
|
||||
|
||||
threadblock_barrier(warpgroup_id_in_cluster,
|
||||
@@ -498,134 +513,54 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
|
||||
asm volatile("gemm_pv_start_%=:" ::);
|
||||
|
||||
if constexpr (!WARP_SPECIALIZED) {
|
||||
// clear out accumulators before GEMM
|
||||
initialize_accum_regs<0>();
|
||||
initialize_accum_regs<1>();
|
||||
|
||||
if constexpr (GEMMINI_DMA) {
|
||||
thread_block_gemm_single_tile<
|
||||
float, MemLayout::K_major /* P matrix is row-major */,
|
||||
MemLayout::block_row_major, B_ROW, HEADDIM, B_COL,
|
||||
/*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||
/*load_accum=*/true,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_warpgroup,
|
||||
threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
if (tid_in_warpgroup == 0) {
|
||||
#if 0
|
||||
if (tile_k == 0) {
|
||||
gemmini_fence();
|
||||
GEMMINI_CISC_CMD_I(0);
|
||||
} else if (tile_k & 1) {
|
||||
gemmini_fence();
|
||||
GEMMINI_CISC_CMD_I(2);
|
||||
} else {
|
||||
thread_block_gemm_single_tile<float, MemLayout::K_major,
|
||||
MemLayout::MN_major, B_ROW, HEADDIM,
|
||||
B_COL,
|
||||
/*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||
/*load_accum=*/true,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_warpgroup,
|
||||
threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
// FIXME: wrong but fast
|
||||
// thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
||||
// MemLayout::MN_major,
|
||||
// B_ROW, HEADDIM, B_COL,
|
||||
// /*leading_dim_a=*/0,
|
||||
// /*leading_dim_b=*/0,
|
||||
// /*load_accum=*/true,
|
||||
// /*write_to_smem=*/true>(
|
||||
// smem_P, smem_V, smem_O /*load accum*/, smem_O,
|
||||
// tid_in_warpgroup, threads_per_warpgroup,
|
||||
// warpgroups_per_cluster, warpgroup_id_in_cluster);
|
||||
gemmini_fence();
|
||||
GEMMINI_CISC_CMD_I(1);
|
||||
}
|
||||
} else {
|
||||
// when warp-specialized, there's only enough warps to do 64x32 tile
|
||||
// size so we need to do 2 GEMM calls
|
||||
static_assert(B_ROW / 2 == 32,
|
||||
"tile size assumption for warp-specialization not met");
|
||||
#else
|
||||
// do matmul
|
||||
// among other things, this also configures CONFIG_BOUNDS so that the
|
||||
// DMA knows the full matrix dimensions
|
||||
sp_tiled_matmul_full_spad_ws(
|
||||
spad_addr_P, spad_addr_V,
|
||||
/*spad_D=*/0, /*spad_C=*/spad_addr_O,
|
||||
/*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
|
||||
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul);
|
||||
#endif
|
||||
|
||||
// assumes smem_P is K-major
|
||||
float *smem_P_half0 = smem_P;
|
||||
float *smem_P_half1 = smem_P + (B_ROW / 2) * B_COL;
|
||||
float *smem_O_half0 = smem_O;
|
||||
float *smem_O_half1 = smem_O + (B_ROW / 2) * HEADDIM;
|
||||
gemmini_fence();
|
||||
gemmini_fence();
|
||||
gemmini_fence();
|
||||
gemmini_fence();
|
||||
|
||||
// clear out accumulators before GEMM
|
||||
initialize_accum_regs<0>();
|
||||
initialize_accum_regs<1>();
|
||||
// mvout to SMEM
|
||||
// GEMMINI_CISC_CMD_I(9);
|
||||
sp_tiled_matmul_full_spad_ws(
|
||||
/*spad_A=*/spad_addr_P, /*spad_B=*/spad_addr_V,
|
||||
/*spad_D=*/0, /*spad_C=*/spad_addr_O,
|
||||
/*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL/ DIM),
|
||||
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad);
|
||||
gemmini_fence();
|
||||
|
||||
// split by rows into 2 chunks
|
||||
if constexpr (GEMMINI_DMA) {
|
||||
if constexpr (GEMMINI_DMA_FAST) {
|
||||
thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
||||
MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
||||
B_COL,
|
||||
/*leading_dim_a=*/0,
|
||||
/*leading_dim_b=*/0,
|
||||
/*load_accum=*/true,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
} else {
|
||||
thread_block_gemm_single_tile<
|
||||
float, MemLayout::K_major /* P matrix is row-major */,
|
||||
MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL,
|
||||
/*leading_dim_a=*/0,
|
||||
/*leading_dim_b=*/0,
|
||||
/*load_accum=*/true,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
}
|
||||
} else {
|
||||
thread_block_gemm_single_tile<
|
||||
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
||||
B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||
/*load_accum=*/true,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
}
|
||||
|
||||
initialize_accum_regs<0>();
|
||||
initialize_accum_regs<1>();
|
||||
|
||||
if constexpr (GEMMINI_DMA) {
|
||||
if constexpr (GEMMINI_DMA_FAST) {
|
||||
thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
||||
MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
||||
B_COL,
|
||||
/*leading_dim_a=*/0,
|
||||
/*leading_dim_b=*/0,
|
||||
/*load_accum=*/true,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
} else {
|
||||
thread_block_gemm_single_tile<
|
||||
float, MemLayout::K_major /* P matrix is row-major */,
|
||||
MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL,
|
||||
/*leading_dim_a=*/0,
|
||||
/*leading_dim_b=*/0,
|
||||
/*load_accum=*/true,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
}
|
||||
} else {
|
||||
thread_block_gemm_single_tile<
|
||||
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
||||
B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||
/*load_accum=*/true,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
if constexpr (DEBUG) {
|
||||
// for copy-out to GMEM
|
||||
gemmini_fence();
|
||||
}
|
||||
}
|
||||
|
||||
// reconverge from mmio divergence
|
||||
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||
|
||||
asm volatile("gemm_pv_finish_%=:" ::);
|
||||
@@ -634,19 +569,20 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
if (warpgroup_id == 0) {
|
||||
// O after PV
|
||||
if (tile_k == 0) {
|
||||
thread_block_copy_tile<B_ROW, HEADDIM>(smem_O, gmem_tmp_d6, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
|
||||
smem_O, gmem_tmp_d6, tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
} else if (tile_k == 1) {
|
||||
thread_block_copy_tile<B_ROW, HEADDIM>(smem_O, gmem_tmp_d7, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
|
||||
smem_O, gmem_tmp_d7, tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
}
|
||||
|
||||
threadblock_barrier(warpgroup_id_in_cluster,
|
||||
warps_per_warpgroup_per_core);
|
||||
}
|
||||
}
|
||||
#if 0
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user