flash: Fix DMA for up to GEMM II
yeah
This commit is contained in:
@@ -168,16 +168,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
|||||||
const uint32_t warps_per_threadblock_per_core =
|
const uint32_t warps_per_threadblock_per_core =
|
||||||
warps_in_threadblock / CORES_PER_CLUSTER;
|
warps_in_threadblock / CORES_PER_CLUSTER;
|
||||||
|
|
||||||
// float ft[8];
|
|
||||||
// asm volatile("fmv.s %0, f16" : "=f"(ft[0]));
|
|
||||||
// asm volatile("fmv.s %0, f17" : "=f"(ft[1]));
|
|
||||||
// asm volatile("fmv.s %0, f18" : "=f"(ft[2]));
|
|
||||||
// asm volatile("fmv.s %0, f19" : "=f"(ft[3]));
|
|
||||||
// asm volatile("fmv.s %0, f20" : "=f"(ft[4]));
|
|
||||||
// asm volatile("fmv.s %0, f21" : "=f"(ft[5]));
|
|
||||||
// asm volatile("fmv.s %0, f22" : "=f"(ft[6]));
|
|
||||||
// asm volatile("fmv.s %0, f23" : "=f"(ft[7]));
|
|
||||||
|
|
||||||
float *smem_rowmax_this = smem_rowmax + B_ROW;
|
float *smem_rowmax_this = smem_rowmax + B_ROW;
|
||||||
|
|
||||||
#pragma GCC unroll 1
|
#pragma GCC unroll 1
|
||||||
@@ -541,6 +531,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
uint8_t *smem_per_threadblock = reinterpret_cast<uint8_t *>(
|
uint8_t *smem_per_threadblock = reinterpret_cast<uint8_t *>(
|
||||||
DEV_SMEM_START_ADDR);
|
DEV_SMEM_START_ADDR);
|
||||||
float *smem_cursor = reinterpret_cast<float *>(smem_per_threadblock);
|
float *smem_cursor = reinterpret_cast<float *>(smem_per_threadblock);
|
||||||
|
// float *smem_cursor = reinterpret_cast<float *>(DEV_FAKE_SMEM_START_ADDR);
|
||||||
float *smem_Q0 = smem_cursor;
|
float *smem_Q0 = smem_cursor;
|
||||||
smem_cursor += smem_Q_size;
|
smem_cursor += smem_Q_size;
|
||||||
float *smem_Q1 = smem_cursor;
|
float *smem_Q1 = smem_cursor;
|
||||||
@@ -587,31 +578,33 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
constexpr uint32_t smem_rowmax_size = B_ROW * ROWMAX_SETS;
|
constexpr uint32_t smem_rowmax_size = B_ROW * ROWMAX_SETS;
|
||||||
constexpr uint32_t smem_rowsum_size = B_ROW;
|
constexpr uint32_t smem_rowsum_size = B_ROW;
|
||||||
constexpr uint32_t smem_O_row_scale_size = B_ROW;
|
constexpr uint32_t smem_O_row_scale_size = B_ROW;
|
||||||
smem_cursor = reinterpret_cast<float *>(SMEM_ADDR_END);
|
// smem_cursor = reinterpret_cast<float *>(DEV_FAKE_SMEM_START_ADDR + SMEM_SIZE);
|
||||||
|
smem_cursor = reinterpret_cast<float *>(0xff038000);
|
||||||
|
|
||||||
smem_cursor -= smem_rowmax_size;
|
|
||||||
float *smem_rowmax_0 = smem_cursor;
|
float *smem_rowmax_0 = smem_cursor;
|
||||||
smem_cursor -= smem_rowmax_size;
|
smem_cursor += smem_rowmax_size;
|
||||||
float *smem_rowmax_1 = smem_cursor;
|
float *smem_rowmax_1 = smem_cursor;
|
||||||
smem_cursor -= smem_rowsum_size;
|
smem_cursor += smem_rowmax_size;
|
||||||
float *smem_rowsum_0 = smem_cursor;
|
float *smem_rowsum_0 = smem_cursor;
|
||||||
smem_cursor -= smem_rowsum_size;
|
smem_cursor += smem_rowsum_size;
|
||||||
float *smem_rowsum_1 = smem_cursor;
|
float *smem_rowsum_1 = smem_cursor;
|
||||||
smem_cursor -= smem_O_row_scale_size;
|
smem_cursor += smem_rowsum_size;
|
||||||
float *smem_O_row_scale_0 = smem_cursor;
|
float *smem_O_row_scale_0 = smem_cursor;
|
||||||
smem_cursor -= smem_O_row_scale_size;
|
smem_cursor += smem_O_row_scale_size;
|
||||||
float *smem_O_row_scale_1 = smem_cursor;
|
float *smem_O_row_scale_1 = smem_cursor;
|
||||||
|
smem_cursor += smem_O_row_scale_size;
|
||||||
|
|
||||||
// sharedmem "scratchpad" area to put temporary data, e.g. for tree reduction
|
// sharedmem "scratchpad" area to put temporary data, e.g. for tree reduction
|
||||||
// in rowsum
|
// in rowsum
|
||||||
// NOTE: out-of bounds is not checked
|
// NOTE: out-of bounds is not checked
|
||||||
// TODO: reduce this from B_ROW to NUM_WARPS
|
// TODO: reduce this from B_ROW to NUM_WARPS
|
||||||
constexpr uint32_t smem_scratchpad_size =
|
constexpr uint32_t smem_scratchpad_size =
|
||||||
threads_per_warpgroup * 2 /*arbitrary slack*/;
|
B_ROW * NUM_THREADS * 2 /*arbitrary slack*/;
|
||||||
smem_cursor -= smem_scratchpad_size;
|
// threads_per_warpgroup * 2 /*arbitrary slack*/;
|
||||||
float *smem_scratchpad_0 = smem_cursor;
|
float *smem_scratchpad_0 = smem_cursor;
|
||||||
smem_cursor -= smem_scratchpad_size;
|
smem_cursor += smem_scratchpad_size;
|
||||||
float *smem_scratchpad_1 = smem_cursor;
|
float *smem_scratchpad_1 = smem_cursor;
|
||||||
|
smem_cursor += smem_scratchpad_size;
|
||||||
|
|
||||||
// select the correct buffer by warpgroup
|
// select the correct buffer by warpgroup
|
||||||
float *smem_Q = (warpgroup_id % 2) ? smem_Q1 : smem_Q0;
|
float *smem_Q = (warpgroup_id % 2) ? smem_Q1 : smem_Q0;
|
||||||
@@ -628,19 +621,24 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
(warpgroup_id % 2) ? smem_scratchpad_1 : smem_scratchpad_0;
|
(warpgroup_id % 2) ? smem_scratchpad_1 : smem_scratchpad_0;
|
||||||
|
|
||||||
// initialize rowmax/rowsum values in sharedmem
|
// initialize rowmax/rowsum values in sharedmem
|
||||||
// thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O,
|
thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O,
|
||||||
// smem_rowmax, smem_rowsum, smem_O_row_scale);
|
smem_rowmax, smem_rowsum, smem_O_row_scale);
|
||||||
|
|
||||||
constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary
|
constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary
|
||||||
|
|
||||||
// delay warpgroup 0 by 1 iteration to do ping-pong scheduling
|
// delay warpgroup 0 by 1 iteration to do ping-pong scheduling
|
||||||
// if (warpgroup_id == 1) {
|
if (warpgroup_id == 1) {
|
||||||
// threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
||||||
// }
|
}
|
||||||
|
|
||||||
static_assert(!GEMMINI_DMA || Q_IS_K_MAJOR,
|
static_assert(!GEMMINI_DMA || Q_IS_K_MAJOR,
|
||||||
"DMA code assumes Q matrix is stored K-major");
|
"DMA code assumes Q matrix is stored K-major");
|
||||||
|
|
||||||
|
// skip everything except DMA in the loop FSM
|
||||||
|
constexpr uint32_t skips =
|
||||||
|
loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/0, /*skip_ldd=*/1,
|
||||||
|
/*skip_ex=*/1, /*skip_stc=*/1);
|
||||||
|
|
||||||
if constexpr (GEMMINI_DMA) {
|
if constexpr (GEMMINI_DMA) {
|
||||||
if (tid_in_warpgroup == 0) {
|
if (tid_in_warpgroup == 0) {
|
||||||
gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0);
|
gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0);
|
||||||
@@ -680,8 +678,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q),
|
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q),
|
||||||
(uint64_t)(gmem_K), k_LOOP_WS_CONFIG_ADDRS_AB)
|
(uint64_t)(gmem_K), k_LOOP_WS_CONFIG_ADDRS_AB)
|
||||||
// configure address strides for the DMA
|
// configure address strides for the DMA
|
||||||
// GEMMINI_CISC_CMD_R((B_COL << 16) | (HEADDIM << 8) |
|
|
||||||
// 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
|
|
||||||
GEMMINI_CISC_CMD_R((dim_seqlen << 16) | (HEADDIM << 8) |
|
GEMMINI_CISC_CMD_R((dim_seqlen << 16) | (HEADDIM << 8) |
|
||||||
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
|
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
@@ -691,11 +687,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
GEMMINI_CISC_CMD_I(9);
|
GEMMINI_CISC_CMD_I(9);
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
#else
|
#else
|
||||||
// skip everything except DMA in the loop FSM
|
// do DMA
|
||||||
constexpr uint32_t skips =
|
//
|
||||||
loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/0, /*skip_ldd=*/1,
|
|
||||||
/*skip_ex=*/1, /*skip_stc=*/1);
|
|
||||||
|
|
||||||
// among other things, this also configures CONFIG_BOUNDS so that the
|
// among other things, this also configures CONFIG_BOUNDS so that the
|
||||||
// DMA knows the full matrix dimensions
|
// DMA knows the full matrix dimensions
|
||||||
sp_tiled_matmul_full_spad_ws(
|
sp_tiled_matmul_full_spad_ws(
|
||||||
@@ -707,6 +700,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
|
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// re-configure DMA for K and V load that will later happen in the loop
|
||||||
|
// GMEM addr stride for K
|
||||||
|
gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t), MVIN_SCALE_IDENTITY,
|
||||||
|
false, 0);
|
||||||
|
// GMEM addr stride for V
|
||||||
|
gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY,
|
||||||
|
false, 1);
|
||||||
|
gemmini_fence();
|
||||||
}
|
}
|
||||||
|
|
||||||
asm volatile("dma_move_end_%=:" ::);
|
asm volatile("dma_move_end_%=:" ::);
|
||||||
@@ -767,7 +769,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
initialize_accum_regs<0>();
|
initialize_accum_regs<0>();
|
||||||
initialize_accum_regs<1>();
|
initialize_accum_regs<1>();
|
||||||
|
|
||||||
if constexpr (Q_IS_K_MAJOR) {
|
if constexpr (GEMMINI_DMA) {
|
||||||
|
thread_block_gemm_single_tile<
|
||||||
|
float, MemLayout::block_row_major, MemLayout::block_row_major,
|
||||||
|
B_ROW, B_COL, HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||||
|
/*load_accum=*/false,
|
||||||
|
/*write_to_smem=*/true>(
|
||||||
|
smem_Q, smem_K, nullptr /*ignore accum*/, smem_S,
|
||||||
|
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
|
warpgroup_id_in_cluster);
|
||||||
|
} else if constexpr (Q_IS_K_MAJOR) {
|
||||||
thread_block_gemm_single_tile<
|
thread_block_gemm_single_tile<
|
||||||
float, MemLayout::K_major, MemLayout::MN_major, B_ROW, B_COL,
|
float, MemLayout::K_major, MemLayout::MN_major, B_ROW, B_COL,
|
||||||
HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||||
@@ -803,6 +814,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
initialize_accum_regs<1>();
|
initialize_accum_regs<1>();
|
||||||
|
|
||||||
// split by rows into 2 chunks
|
// split by rows into 2 chunks
|
||||||
|
// TODO: GEMMINI_DMA
|
||||||
if constexpr (Q_IS_K_MAJOR) {
|
if constexpr (Q_IS_K_MAJOR) {
|
||||||
thread_block_gemm_single_tile<
|
thread_block_gemm_single_tile<
|
||||||
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL,
|
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL,
|
||||||
@@ -826,6 +838,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
initialize_accum_regs<0>();
|
initialize_accum_regs<0>();
|
||||||
initialize_accum_regs<1>();
|
initialize_accum_regs<1>();
|
||||||
|
|
||||||
|
// TODO: GEMMINI_DMA
|
||||||
if constexpr (Q_IS_K_MAJOR) {
|
if constexpr (Q_IS_K_MAJOR) {
|
||||||
thread_block_gemm_single_tile<
|
thread_block_gemm_single_tile<
|
||||||
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL,
|
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL,
|
||||||
@@ -877,7 +890,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
// inter-warpgroup barrier before online softmax
|
// inter-warpgroup barrier before online softmax
|
||||||
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
||||||
|
|
||||||
#if 0
|
|
||||||
// Online softmax
|
// Online softmax
|
||||||
//
|
//
|
||||||
thread_block_online_softmax(smem_S, smem_P, tid_in_warpgroup,
|
thread_block_online_softmax(smem_S, smem_P, tid_in_warpgroup,
|
||||||
@@ -885,10 +897,39 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
smem_scratchpad, smem_rowmax, smem_rowsum,
|
smem_scratchpad, smem_rowmax, smem_rowsum,
|
||||||
smem_O_row_scale);
|
smem_O_row_scale);
|
||||||
|
|
||||||
|
// FIXME: unnecessary?
|
||||||
|
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||||
|
|
||||||
// data movement for K and V
|
// data movement for K and V
|
||||||
//
|
//
|
||||||
// Q stays in SMEM for the entire loop
|
// Q stays in SMEM for the entire loop
|
||||||
//
|
if constexpr (GEMMINI_DMA) {
|
||||||
|
if (tid_in_threadblock == 0) {
|
||||||
|
// configure GMEM addresses for K and V tiles
|
||||||
|
// load K for the next iteration
|
||||||
|
const float *gmem_K_tile = gmem_K + (B_COL * (tile_k + 1));
|
||||||
|
// load V for the current iteration
|
||||||
|
const float *gmem_V_tile = gmem_V + (HEADDIM * B_COL * tile_k);
|
||||||
|
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile),
|
||||||
|
(uint64_t)(gmem_V_tile),
|
||||||
|
k_LOOP_WS_CONFIG_ADDRS_AB)
|
||||||
|
// configure address strides for the DMA
|
||||||
|
// FIXME: unnecessary?
|
||||||
|
GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 16) | (dim_seqlen /*KT*/ << 8) |
|
||||||
|
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
|
||||||
|
gemmini_fence();
|
||||||
|
|
||||||
|
// do DMA
|
||||||
|
sp_tiled_matmul_full_spad_ws(
|
||||||
|
spad_addr_K0, spad_addr_V0,
|
||||||
|
/*spad_D=*/0, /*spad_C=*/spad_addr_S0,
|
||||||
|
/*I=*/(HEADDIM / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
|
||||||
|
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||||
|
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||||
|
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
|
||||||
|
gemmini_fence();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
// load K for the next iteration
|
// load K for the next iteration
|
||||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
|
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
|
||||||
HEADDIM, threads_per_warpgroup>(
|
HEADDIM, threads_per_warpgroup>(
|
||||||
@@ -901,6 +942,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
HEADDIM, threads_per_warpgroup>(
|
HEADDIM, threads_per_warpgroup>(
|
||||||
HEADDIM, 0 /* full N-dimension */, tile_k, gmem_V, smem_V,
|
HEADDIM, 0 /* full N-dimension */, tile_k, gmem_V, smem_V,
|
||||||
tid_in_warpgroup);
|
tid_in_warpgroup);
|
||||||
|
}
|
||||||
|
|
||||||
// protect write to SMEM
|
// protect write to SMEM
|
||||||
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||||
@@ -970,25 +1012,38 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
initialize_accum_regs<0>();
|
initialize_accum_regs<0>();
|
||||||
initialize_accum_regs<1>();
|
initialize_accum_regs<1>();
|
||||||
|
|
||||||
thread_block_gemm_single_tile<float, MemLayout::K_major,
|
if constexpr (GEMMINI_DMA) {
|
||||||
MemLayout::MN_major, B_ROW, HEADDIM, B_COL,
|
thread_block_gemm_single_tile<float, MemLayout::block_row_major,
|
||||||
|
MemLayout::block_row_major, B_ROW,
|
||||||
|
HEADDIM, B_COL,
|
||||||
|
/*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||||
|
/*load_accum=*/true,
|
||||||
|
/*write_to_smem=*/true>(
|
||||||
|
smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_warpgroup,
|
||||||
|
threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
|
warpgroup_id_in_cluster);
|
||||||
|
} else {
|
||||||
|
thread_block_gemm_single_tile<float, MemLayout::K_major,
|
||||||
|
MemLayout::MN_major, B_ROW, HEADDIM,
|
||||||
|
B_COL,
|
||||||
/*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
/*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||||
/*load_accum=*/true,
|
/*load_accum=*/true,
|
||||||
/*write_to_smem=*/true>(
|
/*write_to_smem=*/true>(
|
||||||
smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_warpgroup,
|
smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_warpgroup,
|
||||||
threads_per_warpgroup, warpgroups_per_cluster,
|
threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
warpgroup_id_in_cluster);
|
warpgroup_id_in_cluster);
|
||||||
|
|
||||||
// FIXME: wrong but fast
|
// FIXME: wrong but fast
|
||||||
// thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
// thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
||||||
// MemLayout::MN_major,
|
// MemLayout::MN_major,
|
||||||
// B_ROW, HEADDIM, B_COL,
|
// B_ROW, HEADDIM, B_COL,
|
||||||
// /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
// /*leading_dim_a=*/0,
|
||||||
|
// /*leading_dim_b=*/0,
|
||||||
// /*load_accum=*/true,
|
// /*load_accum=*/true,
|
||||||
// /*write_to_smem=*/true>(
|
// /*write_to_smem=*/true>(
|
||||||
// smem_P, smem_V, smem_O /*load accum*/, smem_O,
|
// smem_P, smem_V, smem_O /*load accum*/, smem_O,
|
||||||
// tid_in_warpgroup, threads_per_warpgroup,
|
// tid_in_warpgroup, threads_per_warpgroup,
|
||||||
// warpgroups_per_cluster, warpgroup_id_in_cluster);
|
// warpgroups_per_cluster, warpgroup_id_in_cluster);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// when warp-specialized, there's only enough warps to do 64x32 tile
|
// when warp-specialized, there's only enough warps to do 64x32 tile
|
||||||
// size so we need to do 2 GEMM calls
|
// size so we need to do 2 GEMM calls
|
||||||
@@ -1006,6 +1061,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
initialize_accum_regs<1>();
|
initialize_accum_regs<1>();
|
||||||
|
|
||||||
// split by rows into 2 chunks
|
// split by rows into 2 chunks
|
||||||
|
// TODO: GEMMINI_DMA
|
||||||
thread_block_gemm_single_tile<
|
thread_block_gemm_single_tile<
|
||||||
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
||||||
B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||||
@@ -1047,13 +1103,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
warps_per_warpgroup_per_core);
|
warps_per_warpgroup_per_core);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#if 0
|
||||||
tile_iter_end:
|
|
||||||
// synchronize progress of two warpgroups
|
|
||||||
// threadblock_barrier(threadblock_id_in_cluster,
|
|
||||||
// warps_per_threadblock_per_core);
|
|
||||||
// threadblock_barrier(3, // FIXME
|
|
||||||
// NUM_WARPS);
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user