flash: Write fast config for DMA

MAC utilization is 20-25% for the loop.
This commit is contained in:
Hansung Kim
2024-09-07 20:46:58 -07:00
parent 8d32a03d09
commit 03308f8033

View File

@@ -13,11 +13,12 @@
#define HEADDIM 64
constexpr uint32_t ROWMAX_SETS = 3;
constexpr bool DEBUG = true;
constexpr bool DEBUG = false;
constexpr bool WARP_SPECIALIZED = true;
constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000;
constexpr bool GEMMINI_DMA_FAST = true;
constexpr bool Q_IS_K_MAJOR = true;
// temporary safety stop for wrong configs
@@ -763,6 +764,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// GEMM I: S = Q*K
//
// FIXME: deduplicate this between GEMM II
asm volatile("gemm_qk_start_%=:" ::);
if constexpr (!WARP_SPECIALIZED) {
// clear out accumulators before GEMM
initialize_accum_regs<0>();
@@ -815,15 +817,27 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// split by rows into 2 chunks
if constexpr (GEMMINI_DMA) {
thread_block_gemm_single_tile<float, MemLayout::block_row_major,
MemLayout::block_row_major, B_ROW / 2,
B_COL, HEADDIM, /*leading_dim_a=*/0,
/*leading_dim_b=*/0,
/*load_accum=*/false,
/*write_to_smem=*/true>(
smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
if constexpr (GEMMINI_DMA_FAST) {
thread_block_gemm_single_tile<float, MemLayout::MN_major,
MemLayout::MN_major, B_ROW / 2,
B_COL, HEADDIM, /*leading_dim_a=*/0,
/*leading_dim_b=*/0,
/*load_accum=*/false,
/*write_to_smem=*/true>(
smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
} else {
thread_block_gemm_single_tile<float, MemLayout::block_row_major,
MemLayout::block_row_major, B_ROW / 2,
B_COL, HEADDIM, /*leading_dim_a=*/0,
/*leading_dim_b=*/0,
/*load_accum=*/false,
/*write_to_smem=*/true>(
smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
}
} else if constexpr (Q_IS_K_MAJOR) {
thread_block_gemm_single_tile<
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL,
@@ -848,15 +862,27 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
initialize_accum_regs<1>();
if constexpr (GEMMINI_DMA) {
thread_block_gemm_single_tile<float, MemLayout::block_row_major,
MemLayout::block_row_major, B_ROW / 2,
B_COL, HEADDIM, /*leading_dim_a=*/0,
/*leading_dim_b=*/0,
/*load_accum=*/false,
/*write_to_smem=*/true>(
smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
if constexpr (GEMMINI_DMA_FAST) {
thread_block_gemm_single_tile<float, MemLayout::MN_major,
MemLayout::MN_major, B_ROW / 2,
B_COL, HEADDIM, /*leading_dim_a=*/0,
/*leading_dim_b=*/0,
/*load_accum=*/false,
/*write_to_smem=*/true>(
smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
} else {
thread_block_gemm_single_tile<float, MemLayout::block_row_major,
MemLayout::block_row_major, B_ROW / 2,
B_COL, HEADDIM, /*leading_dim_a=*/0,
/*leading_dim_b=*/0,
/*load_accum=*/false,
/*write_to_smem=*/true>(
smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
}
} else if constexpr (Q_IS_K_MAJOR) {
thread_block_gemm_single_tile<
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL,
@@ -888,6 +914,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// protect write to SMEM (smem_S) before softmax
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
asm volatile("gemm_qk_finish_%=:" ::);
if constexpr (DEBUG) {
if (warpgroup_id == 0) {
if (tile_k == 0) {
@@ -921,6 +949,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// data movement for K and V
//
// Q stays in SMEM for the entire loop
asm volatile("move_k_v_start_%=:" ::);
if constexpr (GEMMINI_DMA) {
// NOTE: Beware of race conditions; with warp specialization, we need to
// make sure below command code to DMA is not executed simultaneously
@@ -965,6 +994,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
HEADDIM, 0 /* full N-dimension */, tile_k, gmem_V, smem_V,
tid_in_warpgroup);
}
asm volatile("move_k_v_finish_%=:" ::);
// protect write to SMEM
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
@@ -995,8 +1025,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// inter-warpgroup barrier before GEMM II
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
// GEMM II: O = O + P*V
// Oi rescale
thread_block_O_rescale(smem_O, smem_O /*in-place*/,
smem_O_row_scale, tid_in_warpgroup,
@@ -1029,6 +1057,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
}
}
// GEMM II: O = O + P*V
asm volatile("gemm_pv_start_%=:" ::);
if constexpr (!WARP_SPECIALIZED) {
// clear out accumulators before GEMM
initialize_accum_regs<0>();
@@ -1084,16 +1116,29 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// split by rows into 2 chunks
if constexpr (GEMMINI_DMA) {
thread_block_gemm_single_tile<
float, MemLayout::K_major /* P matrix is row-major */,
MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL,
/*leading_dim_a=*/0,
/*leading_dim_b=*/0,
/*load_accum=*/true,
/*write_to_smem=*/true>(
smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
if constexpr (GEMMINI_DMA_FAST) {
thread_block_gemm_single_tile<float, MemLayout::MN_major,
MemLayout::MN_major, B_ROW / 2, HEADDIM,
B_COL,
/*leading_dim_a=*/0,
/*leading_dim_b=*/0,
/*load_accum=*/true,
/*write_to_smem=*/true>(
smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
} else {
thread_block_gemm_single_tile<
float, MemLayout::K_major /* P matrix is row-major */,
MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL,
/*leading_dim_a=*/0,
/*leading_dim_b=*/0,
/*load_accum=*/true,
/*write_to_smem=*/true>(
smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
}
} else {
thread_block_gemm_single_tile<
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
@@ -1109,16 +1154,29 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
initialize_accum_regs<1>();
if constexpr (GEMMINI_DMA) {
thread_block_gemm_single_tile<
float, MemLayout::K_major /* P matrix is row-major */,
MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL,
/*leading_dim_a=*/0,
/*leading_dim_b=*/0,
/*load_accum=*/true,
/*write_to_smem=*/true>(
smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
if constexpr (GEMMINI_DMA_FAST) {
thread_block_gemm_single_tile<float, MemLayout::MN_major,
MemLayout::MN_major, B_ROW / 2, HEADDIM,
B_COL,
/*leading_dim_a=*/0,
/*leading_dim_b=*/0,
/*load_accum=*/true,
/*write_to_smem=*/true>(
smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
} else {
thread_block_gemm_single_tile<
float, MemLayout::K_major /* P matrix is row-major */,
MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL,
/*leading_dim_a=*/0,
/*leading_dim_b=*/0,
/*load_accum=*/true,
/*write_to_smem=*/true>(
smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
}
} else {
thread_block_gemm_single_tile<
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
@@ -1133,6 +1191,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
asm volatile("gemm_pv_finish_%=:" ::);
if constexpr (DEBUG) {
if (warpgroup_id == 0) {
// O after PV