flash: Do GEMM II in Gemmini; verify 1st iteration

This commit is contained in:
Hansung Kim
2024-09-08 16:09:06 -07:00
parent 3f50ac57ee
commit cdb8377b62
2 changed files with 113 additions and 170 deletions

View File

@@ -95,7 +95,7 @@ inline void thread_block_copy_rowmax(const float *src, float *dest,
asm volatile("threadblock_copy_rowmax_finish_%=:" ::);
}
template <uint32_t dim_row, uint32_t dim_col>
template <uint32_t dim_row, uint32_t dim_col, bool block_row_major = false>
inline void thread_block_copy_tile(const float *src, float *dest,
const uint32_t tid_in_threadblock,
const uint32_t threads_per_threadblock,
@@ -113,14 +113,18 @@ inline void thread_block_copy_tile(const float *src, float *dest,
for (int row_offset = 0; row_offset < dim_row;
row_offset += warps_in_threadblock) {
const uint32_t row = row_offset + warp_id;
const uint32_t first_thread_offset = dim_col * row;
constexpr uint32_t per_row_iter = dim_col / NUM_THREADS;
uint32_t thread_offset = first_thread_offset + tid_in_warp;
#pragma GCC unroll
for (int i = 0; i < per_row_iter; i++) {
dest[thread_offset] = src[thread_offset];
thread_offset += NUM_THREADS;
const uint32_t col_offset = NUM_THREADS * i;
const uint32_t col = col_offset + tid_in_warp;
const auto [smem_row, smem_col] =
remap_to_gemmini_dma_layout<block_row_major, B_COL>(row, col);
const uint32_t smem_offset = B_COL * smem_row + smem_col;
const uint32_t gmem_offset = B_COL * row + col;
dest[gmem_offset] = src[smem_offset];
}
threadblock_barrier(threadblock_id_in_cluster,
@@ -415,6 +419,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
asm volatile("thread_block_online_softmax_finish_%=:" ::);
}
template <bool block_row_major = false>
__attribute__((always_inline)) inline void thread_block_O_rescale(
const float *smem_O_in, float *smem_O_out, const float *smem_O_row_scale,
const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock,
@@ -431,19 +436,21 @@ __attribute__((always_inline)) inline void thread_block_O_rescale(
for (int row_offset = 0; row_offset < B_ROW;
row_offset += warps_in_threadblock) {
const uint32_t row = row_offset + warp_id;
const uint32_t first_thread_offset = B_COL * row;
constexpr uint32_t per_row_iter = B_COL / NUM_THREADS;
uint32_t thread_offset = first_thread_offset + tid_in_warp;
constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS;
// Oi rescale
//
#pragma GCC unroll
for (int i = 0; i < per_row_iter; i++) {
const float o = smem_O_in[thread_offset];
const float scale = smem_O_row_scale[row];
smem_O_out[thread_offset] = (o * scale);
const uint32_t col_offset = NUM_THREADS * i;
const uint32_t col = col_offset + tid_in_warp;
const auto [smem_row, smem_col] =
remap_to_gemmini_dma_layout<block_row_major, HEADDIM>(row, col);
thread_offset += NUM_THREADS;
const uint32_t offset = HEADDIM * smem_row + smem_col;
const float o = smem_O_in[offset];
const float scale = smem_O_row_scale[row];
smem_O_out[offset] = (o * scale);
}
}

View File

@@ -110,8 +110,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
smem_cursor += smem_QK_size;
float *smem_S1 = smem_cursor;
smem_cursor += smem_QK_size;
float *smem_P0 = smem_S0; // in-place update
float *smem_P1 = smem_S1; // in-place update
float *smem_P0 = smem_cursor;
smem_cursor += smem_QK_size;
float *smem_P1 = smem_cursor;
smem_cursor += smem_QK_size;
float *smem_O0 = smem_cursor;
smem_cursor += smem_O_size;
float *smem_O1 = smem_cursor;
@@ -135,6 +137,14 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
spad_addr_V1 + (smem_V_size * sizeof(float) / spad_addr_factor);
constexpr uint32_t spad_addr_S1 =
spad_addr_S0 + (smem_QK_size * sizeof(float) / spad_addr_factor);
constexpr uint32_t spad_addr_P0 =
spad_addr_S1 + (smem_QK_size * sizeof(float) / spad_addr_factor);
constexpr uint32_t spad_addr_P1 =
spad_addr_P0 + (smem_QK_size * sizeof(float) / spad_addr_factor);
constexpr uint32_t spad_addr_O0 =
spad_addr_P1 + (smem_QK_size * sizeof(float) / spad_addr_factor);
constexpr uint32_t spad_addr_O1 =
spad_addr_O0 + (smem_O_size * sizeof(float) / spad_addr_factor);
// allocate rowmax/rowsum storage at the end of the sharedmem address space
constexpr uint32_t smem_rowmax_size = B_ROW * ROWMAX_SETS;
@@ -189,6 +199,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
constexpr uint32_t skips_mvout_spad =
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1,
/*skip_ex=*/1, /*skip_stc=*/0);
constexpr uint32_t skips_matmul =
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1,
/*skip_ex=*/0, /*skip_stc=*/1);
if constexpr (GEMMINI_DMA) {
if (tid_in_warpgroup == 0) {
@@ -281,12 +294,13 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
const uint32_t k_tiles = (dim_seqlen / B_COL);
for (uint32_t tile_k = 0; tile_k < k_tiles; tile_k++) {
// select the correct double buffer by tile iteration
float *smem_Q = (tile_k & 1) ? smem_Q1 : smem_Q0; // FIXME
float *smem_K = (tile_k & 1) ? smem_K1 : smem_K0; // FIXME
float *smem_V = (tile_k & 1) ? smem_V1 : smem_V0; // FIXME
float *smem_S = (tile_k & 1) ? smem_S1 : smem_S0; // FIXME
float *smem_O = (tile_k & 1) ? smem_O1 : smem_O0; // FIXME
float *smem_P = smem_S; // FIXME
// FIXME do correct double buffering
float *smem_Q = (tile_k & 1) ? smem_Q1 : smem_Q0;
float *smem_K = (tile_k & 1) ? smem_K1 : smem_K0;
float *smem_V = (tile_k & 1) ? smem_V1 : smem_V0;
float *smem_S = (tile_k & 1) ? smem_S1 : smem_S0;
float *smem_P = (tile_k & 1) ? smem_P1 : smem_P0;
float *smem_O = (tile_k & 1) ? smem_O1 : smem_O0;
float *smem_O_row_scale =
(tile_k & 1) ? smem_O_row_scale_1 : smem_O_row_scale_0;
float *smem_rowmax = (tile_k & 1) ? smem_rowmax_1 : smem_rowmax_0;
@@ -298,6 +312,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
const auto spad_addr_K = (tile_k & 1) ? spad_addr_K1 : spad_addr_K0;
const auto spad_addr_V = (tile_k & 1) ? spad_addr_V1 : spad_addr_V0;
const auto spad_addr_S = (tile_k & 1) ? spad_addr_S1 : spad_addr_S0;
const auto spad_addr_P = (tile_k & 1) ? spad_addr_P1 : spad_addr_P0;
const auto spad_addr_O = spad_addr_O0; // NOTE: there's only single O tile
// GEMM I: S = Q*K
//
@@ -320,7 +336,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
gemmini_fence();
gemmini_fence();
#if 0 // TODO
#if 0 // TODO: speed up mvout to SMEM
// loop_ws variant that skips configuring strides
#define gemmini_loop_ws(I, J, K, pad_I, pad_J, pad_K, A, B, D, C, A_stride, B_stride, D_stride, C_stride, A_transpose, B_transpose, full_C, low_D, ex_accumulate, act, a_spad_id, b_spad_id, is_resadd) \
{ \
@@ -350,7 +366,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
}
}
// thread reconvergence
// reconverge from mmio divergence
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
asm volatile("gemm_qk_finish_%=:" ::);
@@ -358,13 +374,13 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
if constexpr (DEBUG) {
if (warpgroup_id == 0) {
if (tile_k == 0) {
thread_block_copy_tile<B_ROW, B_COL>(smem_S, gmem_tmp_d0,
tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
smem_S, gmem_tmp_d0, tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
} else if (tile_k == 1) {
thread_block_copy_tile<B_ROW, B_COL>(smem_S, gmem_tmp_d1,
tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
smem_S, gmem_tmp_d1, tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
}
threadblock_barrier(warpgroup_id_in_cluster,
@@ -408,7 +424,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
}
}
#if 0
// data movement for K and V
//
// Q stays in SMEM for the entire loop
@@ -463,9 +478,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
// Oi rescale
thread_block_O_rescale(smem_O, smem_O /*in-place*/,
smem_O_row_scale, tid_in_warpgroup,
threads_per_warpgroup, warpgroup_id_in_cluster);
thread_block_O_rescale</*block_row_major=*/GEMMINI_DMA>(
smem_O, smem_O /*in-place*/, smem_O_row_scale, tid_in_warpgroup,
threads_per_warpgroup, warpgroup_id_in_cluster);
// rescale-to-PV-GEMM barrier
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
@@ -474,19 +489,19 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
if (warpgroup_id == 0) {
// O before PV
if (tile_k == 0) {
thread_block_copy_tile<B_ROW, B_COL>(smem_P, gmem_tmp_d2, tid_in_warpgroup,
threads_per_warpgroup,
warpgroup_id_in_cluster);
thread_block_copy_tile<B_ROW, HEADDIM>(smem_O, gmem_tmp_d4, tid_in_warpgroup,
threads_per_warpgroup,
warpgroup_id_in_cluster);
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
smem_P, gmem_tmp_d2, tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
smem_O, gmem_tmp_d4, tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
} else if (tile_k == 1) {
thread_block_copy_tile<B_ROW, B_COL>(smem_P, gmem_tmp_d3, tid_in_warpgroup,
threads_per_warpgroup,
warpgroup_id_in_cluster);
thread_block_copy_tile<B_ROW, HEADDIM>(smem_O, gmem_tmp_d5, tid_in_warpgroup,
threads_per_warpgroup,
warpgroup_id_in_cluster);
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
smem_P, gmem_tmp_d3, tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
smem_O, gmem_tmp_d5, tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
}
threadblock_barrier(warpgroup_id_in_cluster,
@@ -498,134 +513,54 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
asm volatile("gemm_pv_start_%=:" ::);
if constexpr (!WARP_SPECIALIZED) {
// clear out accumulators before GEMM
initialize_accum_regs<0>();
initialize_accum_regs<1>();
if constexpr (GEMMINI_DMA) {
thread_block_gemm_single_tile<
float, MemLayout::K_major /* P matrix is row-major */,
MemLayout::block_row_major, B_ROW, HEADDIM, B_COL,
/*leading_dim_a=*/0, /*leading_dim_b=*/0,
/*load_accum=*/true,
/*write_to_smem=*/true>(
smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_warpgroup,
threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
if (tid_in_warpgroup == 0) {
#if 0
if (tile_k == 0) {
gemmini_fence();
GEMMINI_CISC_CMD_I(0);
} else if (tile_k & 1) {
gemmini_fence();
GEMMINI_CISC_CMD_I(2);
} else {
thread_block_gemm_single_tile<float, MemLayout::K_major,
MemLayout::MN_major, B_ROW, HEADDIM,
B_COL,
/*leading_dim_a=*/0, /*leading_dim_b=*/0,
/*load_accum=*/true,
/*write_to_smem=*/true>(
smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_warpgroup,
threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
// FIXME: wrong but fast
// thread_block_gemm_single_tile<float, MemLayout::MN_major,
// MemLayout::MN_major,
// B_ROW, HEADDIM, B_COL,
// /*leading_dim_a=*/0,
// /*leading_dim_b=*/0,
// /*load_accum=*/true,
// /*write_to_smem=*/true>(
// smem_P, smem_V, smem_O /*load accum*/, smem_O,
// tid_in_warpgroup, threads_per_warpgroup,
// warpgroups_per_cluster, warpgroup_id_in_cluster);
gemmini_fence();
GEMMINI_CISC_CMD_I(1);
}
} else {
// when warp-specialized, there's only enough warps to do 64x32 tile
// size so we need to do 2 GEMM calls
static_assert(B_ROW / 2 == 32,
"tile size assumption for warp-specialization not met");
#else
// do matmul
// among other things, this also configures CONFIG_BOUNDS so that the
// DMA knows the full matrix dimensions
sp_tiled_matmul_full_spad_ws(
spad_addr_P, spad_addr_V,
/*spad_D=*/0, /*spad_C=*/spad_addr_O,
/*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul);
#endif
// assumes smem_P is K-major
float *smem_P_half0 = smem_P;
float *smem_P_half1 = smem_P + (B_ROW / 2) * B_COL;
float *smem_O_half0 = smem_O;
float *smem_O_half1 = smem_O + (B_ROW / 2) * HEADDIM;
gemmini_fence();
gemmini_fence();
gemmini_fence();
gemmini_fence();
// clear out accumulators before GEMM
initialize_accum_regs<0>();
initialize_accum_regs<1>();
// mvout to SMEM
// GEMMINI_CISC_CMD_I(9);
sp_tiled_matmul_full_spad_ws(
/*spad_A=*/spad_addr_P, /*spad_B=*/spad_addr_V,
/*spad_D=*/0, /*spad_C=*/spad_addr_O,
/*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL/ DIM),
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad);
gemmini_fence();
// split by rows into 2 chunks
if constexpr (GEMMINI_DMA) {
if constexpr (GEMMINI_DMA_FAST) {
thread_block_gemm_single_tile<float, MemLayout::MN_major,
MemLayout::MN_major, B_ROW / 2, HEADDIM,
B_COL,
/*leading_dim_a=*/0,
/*leading_dim_b=*/0,
/*load_accum=*/true,
/*write_to_smem=*/true>(
smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
} else {
thread_block_gemm_single_tile<
float, MemLayout::K_major /* P matrix is row-major */,
MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL,
/*leading_dim_a=*/0,
/*leading_dim_b=*/0,
/*load_accum=*/true,
/*write_to_smem=*/true>(
smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
}
} else {
thread_block_gemm_single_tile<
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
/*load_accum=*/true,
/*write_to_smem=*/true>(
smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
}
initialize_accum_regs<0>();
initialize_accum_regs<1>();
if constexpr (GEMMINI_DMA) {
if constexpr (GEMMINI_DMA_FAST) {
thread_block_gemm_single_tile<float, MemLayout::MN_major,
MemLayout::MN_major, B_ROW / 2, HEADDIM,
B_COL,
/*leading_dim_a=*/0,
/*leading_dim_b=*/0,
/*load_accum=*/true,
/*write_to_smem=*/true>(
smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
} else {
thread_block_gemm_single_tile<
float, MemLayout::K_major /* P matrix is row-major */,
MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL,
/*leading_dim_a=*/0,
/*leading_dim_b=*/0,
/*load_accum=*/true,
/*write_to_smem=*/true>(
smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
}
} else {
thread_block_gemm_single_tile<
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
/*load_accum=*/true,
/*write_to_smem=*/true>(
smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
if constexpr (DEBUG) {
// for copy-out to GMEM
gemmini_fence();
}
}
// reconverge from mmio divergence
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
asm volatile("gemm_pv_finish_%=:" ::);
@@ -634,19 +569,20 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
if (warpgroup_id == 0) {
// O after PV
if (tile_k == 0) {
thread_block_copy_tile<B_ROW, HEADDIM>(smem_O, gmem_tmp_d6, tid_in_warpgroup,
threads_per_warpgroup,
warpgroup_id_in_cluster);
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
smem_O, gmem_tmp_d6, tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
} else if (tile_k == 1) {
thread_block_copy_tile<B_ROW, HEADDIM>(smem_O, gmem_tmp_d7, tid_in_warpgroup,
threads_per_warpgroup,
warpgroup_id_in_cluster);
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
smem_O, gmem_tmp_d7, tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
}
threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core);
}
}
#if 0
#endif
}