flash: Fix DMA addr stride, stop at S=Q*K

This commit is contained in:
Hansung Kim
2024-09-07 15:48:37 -07:00
parent 9f067acdb9
commit d2f086344d
2 changed files with 127 additions and 79 deletions

View File

@@ -8,10 +8,9 @@
#include "include/gemmini.h" #include "include/gemmini.h"
#include "gemmini_mmio.h" #include "gemmini_mmio.h"
#define B_ROW BM #define B_ROW 64
#define B_COL BN #define B_COL 64
// FIXME #define HEADDIM 64
#define HEADDIM B_COL
constexpr uint32_t ROWMAX_SETS = 3; constexpr uint32_t ROWMAX_SETS = 3;
constexpr bool DEBUG = true; constexpr bool DEBUG = true;
@@ -19,6 +18,8 @@ constexpr bool WARP_SPECIALIZED = false;
constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000; constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000;
constexpr bool Q_IS_K_MAJOR = true;
// temporary safety stop for wrong configs // temporary safety stop for wrong configs
static_assert(NUM_CORES == 4); static_assert(NUM_CORES == 4);
static_assert(NUM_THREADS == 8); static_assert(NUM_THREADS == 8);
@@ -99,6 +100,7 @@ inline void thread_block_copy_rowmax(const float *src, float *dest,
asm volatile("threadblock_copy_rowmax_finish_%=:" ::); asm volatile("threadblock_copy_rowmax_finish_%=:" ::);
} }
template <uint32_t dim_row, uint32_t dim_col>
inline void thread_block_copy_tile(const float *src, float *dest, inline void thread_block_copy_tile(const float *src, float *dest,
const uint32_t tid_in_threadblock, const uint32_t tid_in_threadblock,
const uint32_t threads_per_threadblock, const uint32_t threads_per_threadblock,
@@ -113,12 +115,12 @@ inline void thread_block_copy_tile(const float *src, float *dest,
// FIXME: dedup this pattern // FIXME: dedup this pattern
#pragma GCC unroll 1 #pragma GCC unroll 1
for (int row_offset = 0; row_offset < B_ROW; for (int row_offset = 0; row_offset < dim_row;
row_offset += warps_in_threadblock) { row_offset += warps_in_threadblock) {
const uint32_t row = row_offset + warp_id; const uint32_t row = row_offset + warp_id;
const uint32_t first_thread_offset = B_COL * row; const uint32_t first_thread_offset = dim_col * row;
constexpr uint32_t per_row_iter = B_COL / NUM_THREADS; constexpr uint32_t per_row_iter = dim_col / NUM_THREADS;
uint32_t thread_offset = first_thread_offset + tid_in_warp; uint32_t thread_offset = first_thread_offset + tid_in_warp;
#pragma GCC unroll #pragma GCC unroll
for (int i = 0; i < per_row_iter; i++) { for (int i = 0; i < per_row_iter; i++) {
@@ -533,12 +535,12 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
constexpr uint32_t smem_QK_size = B_ROW * B_COL; constexpr uint32_t smem_QK_size = B_ROW * B_COL;
constexpr uint32_t smem_V_size = B_COL * HEADDIM; constexpr uint32_t smem_V_size = B_COL * HEADDIM;
constexpr uint32_t smem_O_size = B_COL * HEADDIM; constexpr uint32_t smem_O_size = B_COL * HEADDIM;
static_assert(
threads_per_threadblock == NUM_WARPS * NUM_THREADS * CORES_PER_CLUSTER,
"flashattention kernel assumes 1 threadblock occupancy per cluster");
uint8_t *smem_per_threadblock = reinterpret_cast<uint8_t *>( uint8_t *smem_per_threadblock = reinterpret_cast<uint8_t *>(
DEV_SMEM_START_ADDR + DEV_SMEM_START_ADDR);
sizeof(float_type) * float *smem_cursor = reinterpret_cast<float *>(smem_per_threadblock);
(smem_QK_size + smem_V_size + smem_O_size) *
threadblock_id_in_cluster);
float *smem_cursor = reinterpret_cast<float *>(DEV_FAKE_SMEM_START_ADDR);
float *smem_Q0 = smem_cursor; float *smem_Q0 = smem_cursor;
smem_cursor += smem_Q_size; smem_cursor += smem_Q_size;
float *smem_Q1 = smem_cursor; float *smem_Q1 = smem_cursor;
@@ -563,6 +565,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
smem_cursor += smem_O_size; smem_cursor += smem_O_size;
// NOTE: this has to match with smem_* // NOTE: this has to match with smem_*
static_assert(sizeof(elem_t) == sizeof(float));
constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t); constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t);
constexpr uint32_t spad_addr_Q0 = 0; constexpr uint32_t spad_addr_Q0 = 0;
constexpr uint32_t spad_addr_Q1 = constexpr uint32_t spad_addr_Q1 =
@@ -635,15 +638,18 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); // threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
// } // }
static_assert(!GEMMINI_DMA || Q_IS_K_MAJOR,
"DMA code assumes Q matrix is stored K-major");
if constexpr (GEMMINI_DMA) { if constexpr (GEMMINI_DMA) {
if (tid_in_warpgroup == 0) { if (tid_in_warpgroup == 0) {
gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0); gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0);
// configure DMA for Q tile // configure DMA for the full Q matrix
gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY, gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY,
false, 0); false, 0);
// configure DMA for K tile // configure DMA for the full K matrix
gemmini_extended3_config_ld(B_COL * sizeof(elem_t), MVIN_SCALE_IDENTITY, gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t), MVIN_SCALE_IDENTITY,
false, 1); false, 1);
// configure DMA for Q*K store // configure DMA for Q*K store
gemmini_extended_config_st(B_COL * sizeof(elem_t), 0, gemmini_extended_config_st(B_COL * sizeof(elem_t), 0,
@@ -652,12 +658,12 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
} }
} }
// NOTE about barriers: placing barriers around thread-divergent branches may // NOTE about barriers: Placing barriers around thread-divergent branches may
// cause bugs, since the core doesn't check tmask for barriers. The compiler // cause bugs, because the Vortex core doesn't check for tmask for barriers.
// might decide to replicate vx_bar into both paths of a conditional branch, // The compiler might decide to duplicate vx_bar into both paths of a
// which will get evaluated twice along the split/join process and result in // conditional branch, which will get evaluated twice because of the way
// a different number of calls w.r.t other non-divergent warps and therefore // branches are handled in SIMT; this might result in stalls especially when
// stalls. // other warps behave differently on the branch condition.
// threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
// move Q and K into SMEM before the loop starts // move Q and K into SMEM before the loop starts
@@ -674,13 +680,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q), ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q),
(uint64_t)(gmem_K), k_LOOP_WS_CONFIG_ADDRS_AB) (uint64_t)(gmem_K), k_LOOP_WS_CONFIG_ADDRS_AB)
// configure address strides for the DMA // configure address strides for the DMA
GEMMINI_CISC_CMD_R((B_COL << 16) | (HEADDIM << 8) | // GEMMINI_CISC_CMD_R((B_COL << 16) | (HEADDIM << 8) |
// 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
GEMMINI_CISC_CMD_R((dim_seqlen << 16) | (HEADDIM << 8) |
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
gemmini_fence(); gemmini_fence();
#define GEMMINI_DMA_CISC // #define GEMMINI_DMA_CISC
#ifdef GEMMINI_DMA_CISC #ifdef GEMMINI_DMA_CISC
GEMMINI_CISC_CMD_I(10); GEMMINI_CISC_CMD_I(9);
gemmini_fence(); gemmini_fence();
#else #else
// skip everything except DMA in the loop FSM // skip everything except DMA in the loop FSM
@@ -693,7 +701,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
sp_tiled_matmul_full_spad_ws( sp_tiled_matmul_full_spad_ws(
spad_addr_Q0, spad_addr_K0, spad_addr_Q0, spad_addr_K0,
/*spad_D=*/0, /*spad_C=*/spad_addr_S0, /*spad_D=*/0, /*spad_C=*/spad_addr_S0,
/*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM),
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
@@ -704,10 +712,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
asm volatile("dma_move_end_%=:" ::); asm volatile("dma_move_end_%=:" ::);
} else { } else {
// load Q; this stays in SMEM for the entire loop // load Q; this stays in SMEM for the entire loop
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_ROW, if constexpr (Q_IS_K_MAJOR) {
HEADDIM, threads_per_warpgroup>( load_tile_to_smem<float, MemLayout::K_major, MemLayout::K_major, B_ROW,
dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q, HEADDIM, threads_per_warpgroup>(
tid_in_warpgroup); HEADDIM, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q,
tid_in_warpgroup);
} else {
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_ROW,
HEADDIM, threads_per_warpgroup>(
dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q,
tid_in_warpgroup);
}
// load K // load K
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL, load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
@@ -719,14 +734,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// protect write to SMEM // protect write to SMEM
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
if constexpr (DEBUG) { // if constexpr (DEBUG) {
thread_block_copy_tile(smem_Q, gmem_tmp_d0, tid_in_warpgroup, // thread_block_copy_tile<B_ROW, HEADDIM>(smem_Q0, gmem_tmp_d0, tid_in_warpgroup,
threads_per_warpgroup, warpgroup_id_in_cluster); // threads_per_warpgroup, warpgroup_id_in_cluster);
// thread_block_copy_tile<HEADDIM, B_COL>(smem_K0, gmem_tmp_d1, tid_in_warpgroup,
// threads_per_warpgroup, warpgroup_id_in_cluster);
// threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
} // }
#if 0
asm volatile ("tile_loop_start_%=:" :: ); asm volatile ("tile_loop_start_%=:" :: );
// "inner loop" along the columns of K^T // "inner loop" along the columns of K^T
@@ -751,25 +767,34 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
initialize_accum_regs<0>(); initialize_accum_regs<0>();
initialize_accum_regs<1>(); initialize_accum_regs<1>();
thread_block_gemm_single_tile< if constexpr (Q_IS_K_MAJOR) {
float, MemLayout::MN_major, MemLayout::MN_major, B_ROW, B_COL, thread_block_gemm_single_tile<
HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0, float, MemLayout::K_major, MemLayout::MN_major, B_ROW, B_COL,
/*load_accum=*/false, HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
/*write_to_smem=*/true>( /*load_accum=*/false,
smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_warpgroup, /*write_to_smem=*/true>(
threads_per_warpgroup, warpgroups_per_cluster, smem_Q, smem_K, nullptr /*ignore accum*/, smem_S,
warpgroup_id_in_cluster); tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
} else {
thread_block_gemm_single_tile<
float, MemLayout::MN_major, MemLayout::MN_major, B_ROW, B_COL,
HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
/*load_accum=*/false,
/*write_to_smem=*/true>(
smem_Q, smem_K, nullptr /*ignore accum*/, smem_S,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
}
} else { } else {
// when warp-specialized, there's only enough warps to do 64x32 tile // when warp-specialized, there's only enough warps to do 64x32 tile
// size so we need to do 2 GEMM calls // size so we need to do 2 GEMM calls
static_assert(B_ROW / 2 == 32, static_assert(B_ROW / 2 == 32,
"tile size assumption for warp-specialization not met"); "tile size assumption for warp-specialization not met");
// assumes smem_Q is K-major
// FIXME: fix this to MN-major
float *smem_Q_half0 = smem_Q; float *smem_Q_half0 = smem_Q;
float *smem_Q_half1 = smem_Q + (B_ROW / 2); // MN-major float *smem_Q_half1 = Q_IS_K_MAJOR ? smem_Q + (B_ROW / 2) * HEADDIM
// float *smem_Q_half1 = smem_Q + (B_ROW / 2) * HEADDIM; // K-major : smem_Q + (B_ROW / 2);
float *smem_S_half0 = smem_S; float *smem_S_half0 = smem_S;
float *smem_S_half1 = smem_S + (B_ROW / 2) * B_COL; float *smem_S_half1 = smem_S + (B_ROW / 2) * B_COL;
@@ -778,26 +803,48 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
initialize_accum_regs<1>(); initialize_accum_regs<1>();
// split by rows into 2 chunks // split by rows into 2 chunks
thread_block_gemm_single_tile< if constexpr (Q_IS_K_MAJOR) {
float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2, thread_block_gemm_single_tile<
B_COL, HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0, float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL,
/*load_accum=*/false, HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
/*write_to_smem=*/true>( /*load_accum=*/false,
smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0, /*write_to_smem=*/true>(
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0,
warpgroup_id_in_cluster); tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
} else {
thread_block_gemm_single_tile<
float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2, B_COL,
HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0,
/*load_accum=*/false,
/*write_to_smem=*/true>(
smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
}
initialize_accum_regs<0>(); initialize_accum_regs<0>();
initialize_accum_regs<1>(); initialize_accum_regs<1>();
thread_block_gemm_single_tile< if constexpr (Q_IS_K_MAJOR) {
float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2, thread_block_gemm_single_tile<
B_COL, HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0, float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL,
/*load_accum=*/false, HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
/*write_to_smem=*/true>( /*load_accum=*/false,
smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1, /*write_to_smem=*/true>(
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1,
warpgroup_id_in_cluster); tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
} else {
thread_block_gemm_single_tile<
float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2, B_COL,
HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0,
/*load_accum=*/false,
/*write_to_smem=*/true>(
smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
}
} }
} else { } else {
// load Q*K // load Q*K
@@ -813,11 +860,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
if constexpr (DEBUG) { if constexpr (DEBUG) {
if (warpgroup_id == 0) { if (warpgroup_id == 0) {
if (tile_k == 0) { if (tile_k == 0) {
thread_block_copy_tile(smem_S, gmem_tmp_d0, thread_block_copy_tile<B_ROW, B_COL>(smem_S, gmem_tmp_d0,
tid_in_warpgroup, threads_per_warpgroup, tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster); warpgroup_id_in_cluster);
} else if (tile_k == 1) { } else if (tile_k == 1) {
thread_block_copy_tile(smem_S, gmem_tmp_d1, thread_block_copy_tile<B_ROW, B_COL>(smem_S, gmem_tmp_d1,
tid_in_warpgroup, threads_per_warpgroup, tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster); warpgroup_id_in_cluster);
} }
@@ -830,6 +877,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// inter-warpgroup barrier before online softmax // inter-warpgroup barrier before online softmax
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
#if 0
// Online softmax // Online softmax
// //
thread_block_online_softmax(smem_S, smem_P, tid_in_warpgroup, thread_block_online_softmax(smem_S, smem_P, tid_in_warpgroup,
@@ -897,17 +945,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
if (warpgroup_id == 0) { if (warpgroup_id == 0) {
// O before PV // O before PV
if (tile_k == 0) { if (tile_k == 0) {
thread_block_copy_tile(smem_P, gmem_tmp_d2, tid_in_warpgroup, thread_block_copy_tile<B_ROW, B_COL>(smem_P, gmem_tmp_d2, tid_in_warpgroup,
threads_per_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster); warpgroup_id_in_cluster);
thread_block_copy_tile(smem_O, gmem_tmp_d4, tid_in_warpgroup, thread_block_copy_tile<B_ROW, HEADDIM>(smem_O, gmem_tmp_d4, tid_in_warpgroup,
threads_per_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster); warpgroup_id_in_cluster);
} else if (tile_k == 1) { } else if (tile_k == 1) {
thread_block_copy_tile(smem_P, gmem_tmp_d3, tid_in_warpgroup, thread_block_copy_tile<B_ROW, B_COL>(smem_P, gmem_tmp_d3, tid_in_warpgroup,
threads_per_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster); warpgroup_id_in_cluster);
thread_block_copy_tile(smem_O, gmem_tmp_d5, tid_in_warpgroup, thread_block_copy_tile<B_ROW, HEADDIM>(smem_O, gmem_tmp_d5, tid_in_warpgroup,
threads_per_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster); warpgroup_id_in_cluster);
} }
@@ -986,11 +1034,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
if (warpgroup_id == 0) { if (warpgroup_id == 0) {
// O after PV // O after PV
if (tile_k == 0) { if (tile_k == 0) {
thread_block_copy_tile(smem_O, gmem_tmp_d6, tid_in_warpgroup, thread_block_copy_tile<B_ROW, HEADDIM>(smem_O, gmem_tmp_d6, tid_in_warpgroup,
threads_per_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster); warpgroup_id_in_cluster);
} else if (tile_k == 1) { } else if (tile_k == 1) {
thread_block_copy_tile(smem_O, gmem_tmp_d7, tid_in_warpgroup, thread_block_copy_tile<B_ROW, HEADDIM>(smem_O, gmem_tmp_d7, tid_in_warpgroup,
threads_per_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster); warpgroup_id_in_cluster);
} }
@@ -1006,6 +1054,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// warps_per_threadblock_per_core); // warps_per_threadblock_per_core);
// threadblock_barrier(3, // FIXME // threadblock_barrier(3, // FIXME
// NUM_WARPS); // NUM_WARPS);
#endif
} }
asm volatile ("tile_loop_finish_%=:" :: ); asm volatile ("tile_loop_finish_%=:" :: );
@@ -1015,7 +1064,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
if (warpgroup_id == 0) { if (warpgroup_id == 0) {
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
} }
#endif
} }
int main() { int main() {

View File

@@ -6,7 +6,7 @@
#include "include/gemmini.h" #include "include/gemmini.h"
#include "gemmini_mmio.h" #include "gemmini_mmio.h"
#define FP_SIZE 16 #define FP_SIZE 32
// "fake" fp16 type that only has the correct data width. // "fake" fp16 type that only has the correct data width.
using float16_t = uint16_t; using float16_t = uint16_t;
@@ -29,7 +29,7 @@ using float_type = float16_t;
// (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER // (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER
// * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields // * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields
// BM <= BK*TM*TN // BM <= BK*TM*TN
#define BM 128 #define BM 64
#define BN 64 #define BN 64
#if (FP_SIZE == 32) #if (FP_SIZE == 32)
#define BK 64 #define BK 64
@@ -62,18 +62,18 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER ==
#define BK_LOOP 1 #define BK_LOOP 1
// Whether to transpose smem A tile at GMEM->SMEM (produce), or SMEM->RF // Whether to transpose smem A tile at GMEM->SMEM (produce), or SMEM->RF
// (consume). This is because the tensor core expects the A tile to be stored // (consume). This is because the tensor core expects the A tile to be stored
// in column-major order in SMEM, whereas it will be ultimately stored in // in column-major order in SMEM, so a transpose is necessary if A was stored
// row-major in the RF. // row-major in GMEM.
// //
// For correctness, only one of either should be 1. E.g., PRODUCE 1 CONSUME 0 // For correctness, only one of either should be 1. E.g., PRODUCE 1 CONSUME 0
// generates the NN kernel where both A and B are stored row-major in GMEM. // generates the NN kernel where both A and B are stored row-major in GMEM.
// To model the case where the A matrix is already stored column-major in GMEM, // To model the case where the A matrix is already stored column-major in GMEM,
// set both to 0. // set both to 0.
#define TRANSPOSE_AT_PRODUCE 0 #define TRANSPOSE_AT_PRODUCE 0
#define TRANSPOSE_AT_CONSUME 0 #define TRANSPOSE_AT_CONSUME 1
#define GEMMINI_DMA 0 #define GEMMINI_DMA 1
#define GEMMINI_DMA_MN_MAJOR 1 #define GEMMINI_DMA_MN_MAJOR 0
#if SMEM_SIZE == 0x4000 #if SMEM_SIZE == 0x4000
#define SMEM_ADDR_Q0 ((float * const) 0xff000000) #define SMEM_ADDR_Q0 ((float * const) 0xff000000)
#define SMEM_ADDR_Q1 ((float * const) 0xff001000) #define SMEM_ADDR_Q1 ((float * const) 0xff001000)