flash: Write fast config for DMA
MAC utilization is 20-25% for the loop.
This commit is contained in:
@@ -13,11 +13,12 @@
|
||||
#define HEADDIM 64
|
||||
|
||||
constexpr uint32_t ROWMAX_SETS = 3;
|
||||
constexpr bool DEBUG = true;
|
||||
constexpr bool DEBUG = false;
|
||||
constexpr bool WARP_SPECIALIZED = true;
|
||||
|
||||
constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000;
|
||||
|
||||
constexpr bool GEMMINI_DMA_FAST = true;
|
||||
constexpr bool Q_IS_K_MAJOR = true;
|
||||
|
||||
// temporary safety stop for wrong configs
|
||||
@@ -763,6 +764,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// GEMM I: S = Q*K
|
||||
//
|
||||
// FIXME: deduplicate this between GEMM II
|
||||
asm volatile("gemm_qk_start_%=:" ::);
|
||||
if constexpr (!WARP_SPECIALIZED) {
|
||||
// clear out accumulators before GEMM
|
||||
initialize_accum_regs<0>();
|
||||
@@ -815,6 +817,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
|
||||
// split by rows into 2 chunks
|
||||
if constexpr (GEMMINI_DMA) {
|
||||
if constexpr (GEMMINI_DMA_FAST) {
|
||||
thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
||||
MemLayout::MN_major, B_ROW / 2,
|
||||
B_COL, HEADDIM, /*leading_dim_a=*/0,
|
||||
/*leading_dim_b=*/0,
|
||||
/*load_accum=*/false,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
} else {
|
||||
thread_block_gemm_single_tile<float, MemLayout::block_row_major,
|
||||
MemLayout::block_row_major, B_ROW / 2,
|
||||
B_COL, HEADDIM, /*leading_dim_a=*/0,
|
||||
@@ -824,6 +837,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
}
|
||||
} else if constexpr (Q_IS_K_MAJOR) {
|
||||
thread_block_gemm_single_tile<
|
||||
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL,
|
||||
@@ -848,6 +862,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
initialize_accum_regs<1>();
|
||||
|
||||
if constexpr (GEMMINI_DMA) {
|
||||
if constexpr (GEMMINI_DMA_FAST) {
|
||||
thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
||||
MemLayout::MN_major, B_ROW / 2,
|
||||
B_COL, HEADDIM, /*leading_dim_a=*/0,
|
||||
/*leading_dim_b=*/0,
|
||||
/*load_accum=*/false,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
} else {
|
||||
thread_block_gemm_single_tile<float, MemLayout::block_row_major,
|
||||
MemLayout::block_row_major, B_ROW / 2,
|
||||
B_COL, HEADDIM, /*leading_dim_a=*/0,
|
||||
@@ -857,6 +882,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
}
|
||||
} else if constexpr (Q_IS_K_MAJOR) {
|
||||
thread_block_gemm_single_tile<
|
||||
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL,
|
||||
@@ -888,6 +914,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// protect write to SMEM (smem_S) before softmax
|
||||
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||
|
||||
asm volatile("gemm_qk_finish_%=:" ::);
|
||||
|
||||
if constexpr (DEBUG) {
|
||||
if (warpgroup_id == 0) {
|
||||
if (tile_k == 0) {
|
||||
@@ -921,6 +949,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// data movement for K and V
|
||||
//
|
||||
// Q stays in SMEM for the entire loop
|
||||
asm volatile("move_k_v_start_%=:" ::);
|
||||
if constexpr (GEMMINI_DMA) {
|
||||
// NOTE: Beware of race conditions; with warp specialization, we need to
|
||||
// make sure below command code to DMA is not executed simultaneously
|
||||
@@ -965,6 +994,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
HEADDIM, 0 /* full N-dimension */, tile_k, gmem_V, smem_V,
|
||||
tid_in_warpgroup);
|
||||
}
|
||||
asm volatile("move_k_v_finish_%=:" ::);
|
||||
|
||||
// protect write to SMEM
|
||||
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||
@@ -995,8 +1025,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// inter-warpgroup barrier before GEMM II
|
||||
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
||||
|
||||
// GEMM II: O = O + P*V
|
||||
|
||||
// Oi rescale
|
||||
thread_block_O_rescale(smem_O, smem_O /*in-place*/,
|
||||
smem_O_row_scale, tid_in_warpgroup,
|
||||
@@ -1029,6 +1057,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
}
|
||||
}
|
||||
|
||||
// GEMM II: O = O + P*V
|
||||
|
||||
asm volatile("gemm_pv_start_%=:" ::);
|
||||
|
||||
if constexpr (!WARP_SPECIALIZED) {
|
||||
// clear out accumulators before GEMM
|
||||
initialize_accum_regs<0>();
|
||||
@@ -1084,6 +1116,18 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
|
||||
// split by rows into 2 chunks
|
||||
if constexpr (GEMMINI_DMA) {
|
||||
if constexpr (GEMMINI_DMA_FAST) {
|
||||
thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
||||
MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
||||
B_COL,
|
||||
/*leading_dim_a=*/0,
|
||||
/*leading_dim_b=*/0,
|
||||
/*load_accum=*/true,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
} else {
|
||||
thread_block_gemm_single_tile<
|
||||
float, MemLayout::K_major /* P matrix is row-major */,
|
||||
MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL,
|
||||
@@ -1094,6 +1138,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
}
|
||||
} else {
|
||||
thread_block_gemm_single_tile<
|
||||
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
||||
@@ -1109,6 +1154,18 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
initialize_accum_regs<1>();
|
||||
|
||||
if constexpr (GEMMINI_DMA) {
|
||||
if constexpr (GEMMINI_DMA_FAST) {
|
||||
thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
||||
MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
||||
B_COL,
|
||||
/*leading_dim_a=*/0,
|
||||
/*leading_dim_b=*/0,
|
||||
/*load_accum=*/true,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
} else {
|
||||
thread_block_gemm_single_tile<
|
||||
float, MemLayout::K_major /* P matrix is row-major */,
|
||||
MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL,
|
||||
@@ -1119,6 +1176,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
}
|
||||
} else {
|
||||
thread_block_gemm_single_tile<
|
||||
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
||||
@@ -1133,6 +1191,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
|
||||
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||
|
||||
asm volatile("gemm_pv_finish_%=:" ::);
|
||||
|
||||
if constexpr (DEBUG) {
|
||||
if (warpgroup_id == 0) {
|
||||
// O after PV
|
||||
|
||||
Reference in New Issue
Block a user