flash: Write DMA code for warp-specialized
TODO: result unverified
This commit is contained in:
@@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
constexpr uint32_t ROWMAX_SETS = 3;
|
constexpr uint32_t ROWMAX_SETS = 3;
|
||||||
constexpr bool DEBUG = true;
|
constexpr bool DEBUG = true;
|
||||||
constexpr bool WARP_SPECIALIZED = false;
|
constexpr bool WARP_SPECIALIZED = true;
|
||||||
|
|
||||||
constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000;
|
constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000;
|
||||||
|
|
||||||
@@ -492,11 +492,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
warpgroup_id % warpgroups_per_cluster;
|
warpgroup_id % warpgroups_per_cluster;
|
||||||
const uint32_t tid_in_warpgroup = tid_in_threadblock % threads_per_warpgroup;
|
const uint32_t tid_in_warpgroup = tid_in_threadblock % threads_per_warpgroup;
|
||||||
|
|
||||||
// FIXME do proper software pipelining
|
|
||||||
// if (WARP_SPECIALIZED && warpgroup_id_in_cluster != 1) {
|
|
||||||
// return;
|
|
||||||
// }
|
|
||||||
|
|
||||||
const uint32_t dim_seqlen = arg->dim_seqlen;
|
const uint32_t dim_seqlen = arg->dim_seqlen;
|
||||||
const uint32_t dim_headdim = arg->dim_headdim;
|
const uint32_t dim_headdim = arg->dim_headdim;
|
||||||
|
|
||||||
@@ -597,7 +592,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
// sharedmem "scratchpad" area to put temporary data, e.g. for tree reduction
|
// sharedmem "scratchpad" area to put temporary data, e.g. for tree reduction
|
||||||
// in rowsum
|
// in rowsum
|
||||||
// NOTE: out-of bounds is not checked
|
// NOTE: out-of bounds is not checked
|
||||||
// TODO: reduce this from B_ROW to NUM_WARPS
|
|
||||||
constexpr uint32_t smem_scratchpad_size =
|
constexpr uint32_t smem_scratchpad_size =
|
||||||
threads_per_warpgroup * 2 /*arbitrary slack*/;
|
threads_per_warpgroup * 2 /*arbitrary slack*/;
|
||||||
float *smem_scratchpad_0 = smem_cursor;
|
float *smem_scratchpad_0 = smem_cursor;
|
||||||
@@ -619,6 +613,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
float *smem_scratchpad =
|
float *smem_scratchpad =
|
||||||
(warpgroup_id % 2) ? smem_scratchpad_1 : smem_scratchpad_0;
|
(warpgroup_id % 2) ? smem_scratchpad_1 : smem_scratchpad_0;
|
||||||
|
|
||||||
|
const auto spad_addr_Q = (warpgroup_id % 2) ? spad_addr_Q1 : spad_addr_Q0;
|
||||||
|
const auto spad_addr_K = (warpgroup_id % 2) ? spad_addr_K1 : spad_addr_K0;
|
||||||
|
const auto spad_addr_V = (warpgroup_id % 2) ? spad_addr_V1 : spad_addr_V0;
|
||||||
|
const auto spad_addr_S = (warpgroup_id % 2) ? spad_addr_S1 : spad_addr_S0;
|
||||||
|
|
||||||
// initialize rowmax/rowsum values in sharedmem
|
// initialize rowmax/rowsum values in sharedmem
|
||||||
thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O,
|
thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O,
|
||||||
smem_rowmax, smem_rowsum, smem_O_row_scale);
|
smem_rowmax, smem_rowsum, smem_O_row_scale);
|
||||||
@@ -626,7 +625,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary
|
constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary
|
||||||
|
|
||||||
// delay warpgroup 0 by 1 iteration to do ping-pong scheduling
|
// delay warpgroup 0 by 1 iteration to do ping-pong scheduling
|
||||||
if (warpgroup_id == 1) {
|
if (WARP_SPECIALIZED && warpgroup_id == 1) {
|
||||||
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -667,15 +666,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
//
|
//
|
||||||
static_assert(B_ROW == B_COL, "currently only supports square tiles");
|
static_assert(B_ROW == B_COL, "currently only supports square tiles");
|
||||||
|
|
||||||
static_assert(warps_per_warpgroup_per_core == 8); // FIXME nocheckin
|
|
||||||
|
|
||||||
if constexpr (GEMMINI_DMA) {
|
if constexpr (GEMMINI_DMA) {
|
||||||
asm volatile("dma_move_start_%=:" ::);
|
asm volatile("dma_move_start_%=:" ::);
|
||||||
|
|
||||||
if (tid_in_threadblock == 0) {
|
if (tid_in_warpgroup == 0) {
|
||||||
|
const float *gmem_Q_tile = gmem_Q + HEADDIM * B_ROW * warpgroup_id;
|
||||||
|
const float *gmem_K_tile = gmem_K;
|
||||||
// configure the GMEM addresses for the DMA to read from
|
// configure the GMEM addresses for the DMA to read from
|
||||||
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q),
|
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q_tile),
|
||||||
(uint64_t)(gmem_K), k_LOOP_WS_CONFIG_ADDRS_AB)
|
(uint64_t)(gmem_K_tile),
|
||||||
|
k_LOOP_WS_CONFIG_ADDRS_AB)
|
||||||
// configure address strides for the DMA
|
// configure address strides for the DMA
|
||||||
GEMMINI_CISC_CMD_R((dim_seqlen << 16) | (HEADDIM << 8) |
|
GEMMINI_CISC_CMD_R((dim_seqlen << 16) | (HEADDIM << 8) |
|
||||||
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
|
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
|
||||||
@@ -691,8 +691,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
// among other things, this also configures CONFIG_BOUNDS so that the
|
// among other things, this also configures CONFIG_BOUNDS so that the
|
||||||
// DMA knows the full matrix dimensions
|
// DMA knows the full matrix dimensions
|
||||||
sp_tiled_matmul_full_spad_ws(
|
sp_tiled_matmul_full_spad_ws(
|
||||||
spad_addr_Q0, spad_addr_K0,
|
spad_addr_Q, spad_addr_K,
|
||||||
/*spad_D=*/0, /*spad_C=*/spad_addr_S0,
|
/*spad_D=*/0, /*spad_C=*/spad_addr_S,
|
||||||
/*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM),
|
/*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM),
|
||||||
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||||
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||||
@@ -803,8 +803,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
"tile size assumption for warp-specialization not met");
|
"tile size assumption for warp-specialization not met");
|
||||||
|
|
||||||
float *smem_Q_half0 = smem_Q;
|
float *smem_Q_half0 = smem_Q;
|
||||||
float *smem_Q_half1 = Q_IS_K_MAJOR ? smem_Q + (B_ROW / 2) * HEADDIM
|
float *smem_Q_half1 = (Q_IS_K_MAJOR || GEMMINI_DMA)
|
||||||
: smem_Q + (B_ROW / 2);
|
? smem_Q + (B_ROW / 2) * HEADDIM
|
||||||
|
: smem_Q + (B_ROW / 2);
|
||||||
float *smem_S_half0 = smem_S;
|
float *smem_S_half0 = smem_S;
|
||||||
float *smem_S_half1 = smem_S + (B_ROW / 2) * B_COL;
|
float *smem_S_half1 = smem_S + (B_ROW / 2) * B_COL;
|
||||||
|
|
||||||
@@ -813,8 +814,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
initialize_accum_regs<1>();
|
initialize_accum_regs<1>();
|
||||||
|
|
||||||
// split by rows into 2 chunks
|
// split by rows into 2 chunks
|
||||||
// TODO: GEMMINI_DMA
|
if constexpr (GEMMINI_DMA) {
|
||||||
if constexpr (Q_IS_K_MAJOR) {
|
thread_block_gemm_single_tile<float, MemLayout::block_row_major,
|
||||||
|
MemLayout::block_row_major, B_ROW / 2,
|
||||||
|
B_COL, HEADDIM, /*leading_dim_a=*/0,
|
||||||
|
/*leading_dim_b=*/0,
|
||||||
|
/*load_accum=*/false,
|
||||||
|
/*write_to_smem=*/true>(
|
||||||
|
smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0,
|
||||||
|
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
|
warpgroup_id_in_cluster);
|
||||||
|
} else if constexpr (Q_IS_K_MAJOR) {
|
||||||
thread_block_gemm_single_tile<
|
thread_block_gemm_single_tile<
|
||||||
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL,
|
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL,
|
||||||
HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||||
@@ -837,8 +847,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
initialize_accum_regs<0>();
|
initialize_accum_regs<0>();
|
||||||
initialize_accum_regs<1>();
|
initialize_accum_regs<1>();
|
||||||
|
|
||||||
// TODO: GEMMINI_DMA
|
if constexpr (GEMMINI_DMA) {
|
||||||
if constexpr (Q_IS_K_MAJOR) {
|
thread_block_gemm_single_tile<float, MemLayout::block_row_major,
|
||||||
|
MemLayout::block_row_major, B_ROW / 2,
|
||||||
|
B_COL, HEADDIM, /*leading_dim_a=*/0,
|
||||||
|
/*leading_dim_b=*/0,
|
||||||
|
/*load_accum=*/false,
|
||||||
|
/*write_to_smem=*/true>(
|
||||||
|
smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1,
|
||||||
|
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
|
warpgroup_id_in_cluster);
|
||||||
|
} else if constexpr (Q_IS_K_MAJOR) {
|
||||||
thread_block_gemm_single_tile<
|
thread_block_gemm_single_tile<
|
||||||
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL,
|
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL,
|
||||||
HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||||
@@ -903,7 +922,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
//
|
//
|
||||||
// Q stays in SMEM for the entire loop
|
// Q stays in SMEM for the entire loop
|
||||||
if constexpr (GEMMINI_DMA) {
|
if constexpr (GEMMINI_DMA) {
|
||||||
if (tid_in_threadblock == 0) {
|
// NOTE: Beware of race conditions; with warp specialization, we need to
|
||||||
|
// make sure below command code to DMA is not executed simultaneously
|
||||||
|
// from the two warpgroups (which will result in hardware fault).
|
||||||
|
// Currently the ping-pong scheduling scheme prevents that.
|
||||||
|
if (tid_in_warpgroup == 0) {
|
||||||
// configure GMEM addresses for K and V tiles
|
// configure GMEM addresses for K and V tiles
|
||||||
// load K for the next iteration
|
// load K for the next iteration
|
||||||
const float *gmem_K_tile = gmem_K + (B_COL * (tile_k + 1));
|
const float *gmem_K_tile = gmem_K + (B_COL * (tile_k + 1));
|
||||||
@@ -920,8 +943,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
|
|
||||||
// do DMA
|
// do DMA
|
||||||
sp_tiled_matmul_full_spad_ws(
|
sp_tiled_matmul_full_spad_ws(
|
||||||
spad_addr_K0, spad_addr_V0,
|
spad_addr_K, spad_addr_V,
|
||||||
/*spad_D=*/0, /*spad_C=*/spad_addr_S0,
|
/*spad_D=*/0, /*spad_C=*/spad_addr_S,
|
||||||
/*I=*/(HEADDIM / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
|
/*I=*/(HEADDIM / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
|
||||||
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||||
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||||
@@ -1044,9 +1067,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
// warpgroups_per_cluster, warpgroup_id_in_cluster);
|
// warpgroups_per_cluster, warpgroup_id_in_cluster);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
static_assert(!WARP_SPECIALIZED || !GEMMINI_DMA,
|
|
||||||
"warp specialization unimplemented for dma");
|
|
||||||
|
|
||||||
// when warp-specialized, there's only enough warps to do 64x32 tile
|
// when warp-specialized, there's only enough warps to do 64x32 tile
|
||||||
// size so we need to do 2 GEMM calls
|
// size so we need to do 2 GEMM calls
|
||||||
static_assert(B_ROW / 2 == 32,
|
static_assert(B_ROW / 2 == 32,
|
||||||
@@ -1063,27 +1083,52 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
initialize_accum_regs<1>();
|
initialize_accum_regs<1>();
|
||||||
|
|
||||||
// split by rows into 2 chunks
|
// split by rows into 2 chunks
|
||||||
// TODO: GEMMINI_DMA
|
if constexpr (GEMMINI_DMA) {
|
||||||
thread_block_gemm_single_tile<
|
thread_block_gemm_single_tile<
|
||||||
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
float, MemLayout::K_major /* P matrix is row-major */,
|
||||||
B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL,
|
||||||
/*load_accum=*/true,
|
/*leading_dim_a=*/0,
|
||||||
/*write_to_smem=*/true>(
|
/*leading_dim_b=*/0,
|
||||||
smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0,
|
/*load_accum=*/true,
|
||||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
/*write_to_smem=*/true>(
|
||||||
warpgroup_id_in_cluster);
|
smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0,
|
||||||
|
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
|
warpgroup_id_in_cluster);
|
||||||
|
} else {
|
||||||
|
thread_block_gemm_single_tile<
|
||||||
|
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
||||||
|
B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||||
|
/*load_accum=*/true,
|
||||||
|
/*write_to_smem=*/true>(
|
||||||
|
smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0,
|
||||||
|
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
|
warpgroup_id_in_cluster);
|
||||||
|
}
|
||||||
|
|
||||||
initialize_accum_regs<0>();
|
initialize_accum_regs<0>();
|
||||||
initialize_accum_regs<1>();
|
initialize_accum_regs<1>();
|
||||||
|
|
||||||
thread_block_gemm_single_tile<
|
if constexpr (GEMMINI_DMA) {
|
||||||
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
thread_block_gemm_single_tile<
|
||||||
B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
float, MemLayout::K_major /* P matrix is row-major */,
|
||||||
/*load_accum=*/true,
|
MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL,
|
||||||
/*write_to_smem=*/true>(
|
/*leading_dim_a=*/0,
|
||||||
smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1,
|
/*leading_dim_b=*/0,
|
||||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
/*load_accum=*/true,
|
||||||
warpgroup_id_in_cluster);
|
/*write_to_smem=*/true>(
|
||||||
|
smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1,
|
||||||
|
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
|
warpgroup_id_in_cluster);
|
||||||
|
} else {
|
||||||
|
thread_block_gemm_single_tile<
|
||||||
|
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
||||||
|
B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||||
|
/*load_accum=*/true,
|
||||||
|
/*write_to_smem=*/true>(
|
||||||
|
smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1,
|
||||||
|
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
|
warpgroup_id_in_cluster);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||||
|
|||||||
Reference in New Issue
Block a user