flash: Fix DMA addr stride, stop at S=Q*K
This commit is contained in:
@@ -8,10 +8,9 @@
|
|||||||
#include "include/gemmini.h"
|
#include "include/gemmini.h"
|
||||||
#include "gemmini_mmio.h"
|
#include "gemmini_mmio.h"
|
||||||
|
|
||||||
#define B_ROW BM
|
#define B_ROW 64
|
||||||
#define B_COL BN
|
#define B_COL 64
|
||||||
// FIXME
|
#define HEADDIM 64
|
||||||
#define HEADDIM B_COL
|
|
||||||
|
|
||||||
constexpr uint32_t ROWMAX_SETS = 3;
|
constexpr uint32_t ROWMAX_SETS = 3;
|
||||||
constexpr bool DEBUG = true;
|
constexpr bool DEBUG = true;
|
||||||
@@ -19,6 +18,8 @@ constexpr bool WARP_SPECIALIZED = false;
|
|||||||
|
|
||||||
constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000;
|
constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000;
|
||||||
|
|
||||||
|
constexpr bool Q_IS_K_MAJOR = true;
|
||||||
|
|
||||||
// temporary safety stop for wrong configs
|
// temporary safety stop for wrong configs
|
||||||
static_assert(NUM_CORES == 4);
|
static_assert(NUM_CORES == 4);
|
||||||
static_assert(NUM_THREADS == 8);
|
static_assert(NUM_THREADS == 8);
|
||||||
@@ -99,6 +100,7 @@ inline void thread_block_copy_rowmax(const float *src, float *dest,
|
|||||||
asm volatile("threadblock_copy_rowmax_finish_%=:" ::);
|
asm volatile("threadblock_copy_rowmax_finish_%=:" ::);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <uint32_t dim_row, uint32_t dim_col>
|
||||||
inline void thread_block_copy_tile(const float *src, float *dest,
|
inline void thread_block_copy_tile(const float *src, float *dest,
|
||||||
const uint32_t tid_in_threadblock,
|
const uint32_t tid_in_threadblock,
|
||||||
const uint32_t threads_per_threadblock,
|
const uint32_t threads_per_threadblock,
|
||||||
@@ -113,12 +115,12 @@ inline void thread_block_copy_tile(const float *src, float *dest,
|
|||||||
|
|
||||||
// FIXME: dedup this pattern
|
// FIXME: dedup this pattern
|
||||||
#pragma GCC unroll 1
|
#pragma GCC unroll 1
|
||||||
for (int row_offset = 0; row_offset < B_ROW;
|
for (int row_offset = 0; row_offset < dim_row;
|
||||||
row_offset += warps_in_threadblock) {
|
row_offset += warps_in_threadblock) {
|
||||||
const uint32_t row = row_offset + warp_id;
|
const uint32_t row = row_offset + warp_id;
|
||||||
const uint32_t first_thread_offset = B_COL * row;
|
const uint32_t first_thread_offset = dim_col * row;
|
||||||
|
|
||||||
constexpr uint32_t per_row_iter = B_COL / NUM_THREADS;
|
constexpr uint32_t per_row_iter = dim_col / NUM_THREADS;
|
||||||
uint32_t thread_offset = first_thread_offset + tid_in_warp;
|
uint32_t thread_offset = first_thread_offset + tid_in_warp;
|
||||||
#pragma GCC unroll
|
#pragma GCC unroll
|
||||||
for (int i = 0; i < per_row_iter; i++) {
|
for (int i = 0; i < per_row_iter; i++) {
|
||||||
@@ -533,12 +535,12 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
constexpr uint32_t smem_QK_size = B_ROW * B_COL;
|
constexpr uint32_t smem_QK_size = B_ROW * B_COL;
|
||||||
constexpr uint32_t smem_V_size = B_COL * HEADDIM;
|
constexpr uint32_t smem_V_size = B_COL * HEADDIM;
|
||||||
constexpr uint32_t smem_O_size = B_COL * HEADDIM;
|
constexpr uint32_t smem_O_size = B_COL * HEADDIM;
|
||||||
|
static_assert(
|
||||||
|
threads_per_threadblock == NUM_WARPS * NUM_THREADS * CORES_PER_CLUSTER,
|
||||||
|
"flashattention kernel assumes 1 threadblock occupancy per cluster");
|
||||||
uint8_t *smem_per_threadblock = reinterpret_cast<uint8_t *>(
|
uint8_t *smem_per_threadblock = reinterpret_cast<uint8_t *>(
|
||||||
DEV_SMEM_START_ADDR +
|
DEV_SMEM_START_ADDR);
|
||||||
sizeof(float_type) *
|
float *smem_cursor = reinterpret_cast<float *>(smem_per_threadblock);
|
||||||
(smem_QK_size + smem_V_size + smem_O_size) *
|
|
||||||
threadblock_id_in_cluster);
|
|
||||||
float *smem_cursor = reinterpret_cast<float *>(DEV_FAKE_SMEM_START_ADDR);
|
|
||||||
float *smem_Q0 = smem_cursor;
|
float *smem_Q0 = smem_cursor;
|
||||||
smem_cursor += smem_Q_size;
|
smem_cursor += smem_Q_size;
|
||||||
float *smem_Q1 = smem_cursor;
|
float *smem_Q1 = smem_cursor;
|
||||||
@@ -563,6 +565,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
smem_cursor += smem_O_size;
|
smem_cursor += smem_O_size;
|
||||||
|
|
||||||
// NOTE: this has to match with smem_*
|
// NOTE: this has to match with smem_*
|
||||||
|
static_assert(sizeof(elem_t) == sizeof(float));
|
||||||
constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t);
|
constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t);
|
||||||
constexpr uint32_t spad_addr_Q0 = 0;
|
constexpr uint32_t spad_addr_Q0 = 0;
|
||||||
constexpr uint32_t spad_addr_Q1 =
|
constexpr uint32_t spad_addr_Q1 =
|
||||||
@@ -635,15 +638,18 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
// threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
// threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
||||||
// }
|
// }
|
||||||
|
|
||||||
|
static_assert(!GEMMINI_DMA || Q_IS_K_MAJOR,
|
||||||
|
"DMA code assumes Q matrix is stored K-major");
|
||||||
|
|
||||||
if constexpr (GEMMINI_DMA) {
|
if constexpr (GEMMINI_DMA) {
|
||||||
if (tid_in_warpgroup == 0) {
|
if (tid_in_warpgroup == 0) {
|
||||||
gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0);
|
gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0);
|
||||||
|
|
||||||
// configure DMA for Q tile
|
// configure DMA for the full Q matrix
|
||||||
gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY,
|
gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY,
|
||||||
false, 0);
|
false, 0);
|
||||||
// configure DMA for K tile
|
// configure DMA for the full K matrix
|
||||||
gemmini_extended3_config_ld(B_COL * sizeof(elem_t), MVIN_SCALE_IDENTITY,
|
gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t), MVIN_SCALE_IDENTITY,
|
||||||
false, 1);
|
false, 1);
|
||||||
// configure DMA for Q*K store
|
// configure DMA for Q*K store
|
||||||
gemmini_extended_config_st(B_COL * sizeof(elem_t), 0,
|
gemmini_extended_config_st(B_COL * sizeof(elem_t), 0,
|
||||||
@@ -652,12 +658,12 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE about barriers: placing barriers around thread-divergent branches may
|
// NOTE about barriers: Placing barriers around thread-divergent branches may
|
||||||
// cause bugs, since the core doesn't check tmask for barriers. The compiler
|
// cause bugs, because the Vortex core doesn't check for tmask for barriers.
|
||||||
// might decide to replicate vx_bar into both paths of a conditional branch,
|
// The compiler might decide to duplicate vx_bar into both paths of a
|
||||||
// which will get evaluated twice along the split/join process and result in
|
// conditional branch, which will get evaluated twice because of the way
|
||||||
// a different number of calls w.r.t other non-divergent warps and therefore
|
// branches are handled in SIMT; this might result in stalls especially when
|
||||||
// stalls.
|
// other warps behave differently on the branch condition.
|
||||||
// threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
// threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||||
|
|
||||||
// move Q and K into SMEM before the loop starts
|
// move Q and K into SMEM before the loop starts
|
||||||
@@ -674,13 +680,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q),
|
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q),
|
||||||
(uint64_t)(gmem_K), k_LOOP_WS_CONFIG_ADDRS_AB)
|
(uint64_t)(gmem_K), k_LOOP_WS_CONFIG_ADDRS_AB)
|
||||||
// configure address strides for the DMA
|
// configure address strides for the DMA
|
||||||
GEMMINI_CISC_CMD_R((B_COL << 16) | (HEADDIM << 8) |
|
// GEMMINI_CISC_CMD_R((B_COL << 16) | (HEADDIM << 8) |
|
||||||
|
// 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
|
||||||
|
GEMMINI_CISC_CMD_R((dim_seqlen << 16) | (HEADDIM << 8) |
|
||||||
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
|
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
|
|
||||||
#define GEMMINI_DMA_CISC
|
// #define GEMMINI_DMA_CISC
|
||||||
#ifdef GEMMINI_DMA_CISC
|
#ifdef GEMMINI_DMA_CISC
|
||||||
GEMMINI_CISC_CMD_I(10);
|
GEMMINI_CISC_CMD_I(9);
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
#else
|
#else
|
||||||
// skip everything except DMA in the loop FSM
|
// skip everything except DMA in the loop FSM
|
||||||
@@ -693,7 +701,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
sp_tiled_matmul_full_spad_ws(
|
sp_tiled_matmul_full_spad_ws(
|
||||||
spad_addr_Q0, spad_addr_K0,
|
spad_addr_Q0, spad_addr_K0,
|
||||||
/*spad_D=*/0, /*spad_C=*/spad_addr_S0,
|
/*spad_D=*/0, /*spad_C=*/spad_addr_S0,
|
||||||
/*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
|
/*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM),
|
||||||
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||||
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
|
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
|
||||||
@@ -704,10 +712,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
asm volatile("dma_move_end_%=:" ::);
|
asm volatile("dma_move_end_%=:" ::);
|
||||||
} else {
|
} else {
|
||||||
// load Q; this stays in SMEM for the entire loop
|
// load Q; this stays in SMEM for the entire loop
|
||||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_ROW,
|
if constexpr (Q_IS_K_MAJOR) {
|
||||||
HEADDIM, threads_per_warpgroup>(
|
load_tile_to_smem<float, MemLayout::K_major, MemLayout::K_major, B_ROW,
|
||||||
dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q,
|
HEADDIM, threads_per_warpgroup>(
|
||||||
tid_in_warpgroup);
|
HEADDIM, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q,
|
||||||
|
tid_in_warpgroup);
|
||||||
|
} else {
|
||||||
|
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_ROW,
|
||||||
|
HEADDIM, threads_per_warpgroup>(
|
||||||
|
dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q,
|
||||||
|
tid_in_warpgroup);
|
||||||
|
}
|
||||||
|
|
||||||
// load K
|
// load K
|
||||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
|
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
|
||||||
@@ -719,14 +734,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
// protect write to SMEM
|
// protect write to SMEM
|
||||||
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||||
|
|
||||||
if constexpr (DEBUG) {
|
// if constexpr (DEBUG) {
|
||||||
thread_block_copy_tile(smem_Q, gmem_tmp_d0, tid_in_warpgroup,
|
// thread_block_copy_tile<B_ROW, HEADDIM>(smem_Q0, gmem_tmp_d0, tid_in_warpgroup,
|
||||||
threads_per_warpgroup, warpgroup_id_in_cluster);
|
// threads_per_warpgroup, warpgroup_id_in_cluster);
|
||||||
|
// thread_block_copy_tile<HEADDIM, B_COL>(smem_K0, gmem_tmp_d1, tid_in_warpgroup,
|
||||||
|
// threads_per_warpgroup, warpgroup_id_in_cluster);
|
||||||
|
|
||||||
// threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
// threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||||
}
|
// }
|
||||||
|
|
||||||
#if 0
|
|
||||||
asm volatile ("tile_loop_start_%=:" :: );
|
asm volatile ("tile_loop_start_%=:" :: );
|
||||||
|
|
||||||
// "inner loop" along the columns of K^T
|
// "inner loop" along the columns of K^T
|
||||||
@@ -751,25 +767,34 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
initialize_accum_regs<0>();
|
initialize_accum_regs<0>();
|
||||||
initialize_accum_regs<1>();
|
initialize_accum_regs<1>();
|
||||||
|
|
||||||
thread_block_gemm_single_tile<
|
if constexpr (Q_IS_K_MAJOR) {
|
||||||
float, MemLayout::MN_major, MemLayout::MN_major, B_ROW, B_COL,
|
thread_block_gemm_single_tile<
|
||||||
HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
float, MemLayout::K_major, MemLayout::MN_major, B_ROW, B_COL,
|
||||||
/*load_accum=*/false,
|
HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||||
/*write_to_smem=*/true>(
|
/*load_accum=*/false,
|
||||||
smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_warpgroup,
|
/*write_to_smem=*/true>(
|
||||||
threads_per_warpgroup, warpgroups_per_cluster,
|
smem_Q, smem_K, nullptr /*ignore accum*/, smem_S,
|
||||||
warpgroup_id_in_cluster);
|
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
|
warpgroup_id_in_cluster);
|
||||||
|
} else {
|
||||||
|
thread_block_gemm_single_tile<
|
||||||
|
float, MemLayout::MN_major, MemLayout::MN_major, B_ROW, B_COL,
|
||||||
|
HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||||
|
/*load_accum=*/false,
|
||||||
|
/*write_to_smem=*/true>(
|
||||||
|
smem_Q, smem_K, nullptr /*ignore accum*/, smem_S,
|
||||||
|
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
|
warpgroup_id_in_cluster);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// when warp-specialized, there's only enough warps to do 64x32 tile
|
// when warp-specialized, there's only enough warps to do 64x32 tile
|
||||||
// size so we need to do 2 GEMM calls
|
// size so we need to do 2 GEMM calls
|
||||||
static_assert(B_ROW / 2 == 32,
|
static_assert(B_ROW / 2 == 32,
|
||||||
"tile size assumption for warp-specialization not met");
|
"tile size assumption for warp-specialization not met");
|
||||||
|
|
||||||
// assumes smem_Q is K-major
|
|
||||||
// FIXME: fix this to MN-major
|
|
||||||
float *smem_Q_half0 = smem_Q;
|
float *smem_Q_half0 = smem_Q;
|
||||||
float *smem_Q_half1 = smem_Q + (B_ROW / 2); // MN-major
|
float *smem_Q_half1 = Q_IS_K_MAJOR ? smem_Q + (B_ROW / 2) * HEADDIM
|
||||||
// float *smem_Q_half1 = smem_Q + (B_ROW / 2) * HEADDIM; // K-major
|
: smem_Q + (B_ROW / 2);
|
||||||
float *smem_S_half0 = smem_S;
|
float *smem_S_half0 = smem_S;
|
||||||
float *smem_S_half1 = smem_S + (B_ROW / 2) * B_COL;
|
float *smem_S_half1 = smem_S + (B_ROW / 2) * B_COL;
|
||||||
|
|
||||||
@@ -778,26 +803,48 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
initialize_accum_regs<1>();
|
initialize_accum_regs<1>();
|
||||||
|
|
||||||
// split by rows into 2 chunks
|
// split by rows into 2 chunks
|
||||||
thread_block_gemm_single_tile<
|
if constexpr (Q_IS_K_MAJOR) {
|
||||||
float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2,
|
thread_block_gemm_single_tile<
|
||||||
B_COL, HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0,
|
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL,
|
||||||
/*load_accum=*/false,
|
HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||||
/*write_to_smem=*/true>(
|
/*load_accum=*/false,
|
||||||
smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0,
|
/*write_to_smem=*/true>(
|
||||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0,
|
||||||
warpgroup_id_in_cluster);
|
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
|
warpgroup_id_in_cluster);
|
||||||
|
} else {
|
||||||
|
thread_block_gemm_single_tile<
|
||||||
|
float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2, B_COL,
|
||||||
|
HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0,
|
||||||
|
/*load_accum=*/false,
|
||||||
|
/*write_to_smem=*/true>(
|
||||||
|
smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0,
|
||||||
|
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
|
warpgroup_id_in_cluster);
|
||||||
|
}
|
||||||
|
|
||||||
initialize_accum_regs<0>();
|
initialize_accum_regs<0>();
|
||||||
initialize_accum_regs<1>();
|
initialize_accum_regs<1>();
|
||||||
|
|
||||||
thread_block_gemm_single_tile<
|
if constexpr (Q_IS_K_MAJOR) {
|
||||||
float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2,
|
thread_block_gemm_single_tile<
|
||||||
B_COL, HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0,
|
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL,
|
||||||
/*load_accum=*/false,
|
HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||||
/*write_to_smem=*/true>(
|
/*load_accum=*/false,
|
||||||
smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1,
|
/*write_to_smem=*/true>(
|
||||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1,
|
||||||
warpgroup_id_in_cluster);
|
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
|
warpgroup_id_in_cluster);
|
||||||
|
} else {
|
||||||
|
thread_block_gemm_single_tile<
|
||||||
|
float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2, B_COL,
|
||||||
|
HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0,
|
||||||
|
/*load_accum=*/false,
|
||||||
|
/*write_to_smem=*/true>(
|
||||||
|
smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1,
|
||||||
|
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
|
warpgroup_id_in_cluster);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// load Q*K
|
// load Q*K
|
||||||
@@ -813,11 +860,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
if constexpr (DEBUG) {
|
if constexpr (DEBUG) {
|
||||||
if (warpgroup_id == 0) {
|
if (warpgroup_id == 0) {
|
||||||
if (tile_k == 0) {
|
if (tile_k == 0) {
|
||||||
thread_block_copy_tile(smem_S, gmem_tmp_d0,
|
thread_block_copy_tile<B_ROW, B_COL>(smem_S, gmem_tmp_d0,
|
||||||
tid_in_warpgroup, threads_per_warpgroup,
|
tid_in_warpgroup, threads_per_warpgroup,
|
||||||
warpgroup_id_in_cluster);
|
warpgroup_id_in_cluster);
|
||||||
} else if (tile_k == 1) {
|
} else if (tile_k == 1) {
|
||||||
thread_block_copy_tile(smem_S, gmem_tmp_d1,
|
thread_block_copy_tile<B_ROW, B_COL>(smem_S, gmem_tmp_d1,
|
||||||
tid_in_warpgroup, threads_per_warpgroup,
|
tid_in_warpgroup, threads_per_warpgroup,
|
||||||
warpgroup_id_in_cluster);
|
warpgroup_id_in_cluster);
|
||||||
}
|
}
|
||||||
@@ -830,6 +877,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
// inter-warpgroup barrier before online softmax
|
// inter-warpgroup barrier before online softmax
|
||||||
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
||||||
|
|
||||||
|
#if 0
|
||||||
// Online softmax
|
// Online softmax
|
||||||
//
|
//
|
||||||
thread_block_online_softmax(smem_S, smem_P, tid_in_warpgroup,
|
thread_block_online_softmax(smem_S, smem_P, tid_in_warpgroup,
|
||||||
@@ -897,17 +945,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
if (warpgroup_id == 0) {
|
if (warpgroup_id == 0) {
|
||||||
// O before PV
|
// O before PV
|
||||||
if (tile_k == 0) {
|
if (tile_k == 0) {
|
||||||
thread_block_copy_tile(smem_P, gmem_tmp_d2, tid_in_warpgroup,
|
thread_block_copy_tile<B_ROW, B_COL>(smem_P, gmem_tmp_d2, tid_in_warpgroup,
|
||||||
threads_per_warpgroup,
|
threads_per_warpgroup,
|
||||||
warpgroup_id_in_cluster);
|
warpgroup_id_in_cluster);
|
||||||
thread_block_copy_tile(smem_O, gmem_tmp_d4, tid_in_warpgroup,
|
thread_block_copy_tile<B_ROW, HEADDIM>(smem_O, gmem_tmp_d4, tid_in_warpgroup,
|
||||||
threads_per_warpgroup,
|
threads_per_warpgroup,
|
||||||
warpgroup_id_in_cluster);
|
warpgroup_id_in_cluster);
|
||||||
} else if (tile_k == 1) {
|
} else if (tile_k == 1) {
|
||||||
thread_block_copy_tile(smem_P, gmem_tmp_d3, tid_in_warpgroup,
|
thread_block_copy_tile<B_ROW, B_COL>(smem_P, gmem_tmp_d3, tid_in_warpgroup,
|
||||||
threads_per_warpgroup,
|
threads_per_warpgroup,
|
||||||
warpgroup_id_in_cluster);
|
warpgroup_id_in_cluster);
|
||||||
thread_block_copy_tile(smem_O, gmem_tmp_d5, tid_in_warpgroup,
|
thread_block_copy_tile<B_ROW, HEADDIM>(smem_O, gmem_tmp_d5, tid_in_warpgroup,
|
||||||
threads_per_warpgroup,
|
threads_per_warpgroup,
|
||||||
warpgroup_id_in_cluster);
|
warpgroup_id_in_cluster);
|
||||||
}
|
}
|
||||||
@@ -986,11 +1034,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
if (warpgroup_id == 0) {
|
if (warpgroup_id == 0) {
|
||||||
// O after PV
|
// O after PV
|
||||||
if (tile_k == 0) {
|
if (tile_k == 0) {
|
||||||
thread_block_copy_tile(smem_O, gmem_tmp_d6, tid_in_warpgroup,
|
thread_block_copy_tile<B_ROW, HEADDIM>(smem_O, gmem_tmp_d6, tid_in_warpgroup,
|
||||||
threads_per_warpgroup,
|
threads_per_warpgroup,
|
||||||
warpgroup_id_in_cluster);
|
warpgroup_id_in_cluster);
|
||||||
} else if (tile_k == 1) {
|
} else if (tile_k == 1) {
|
||||||
thread_block_copy_tile(smem_O, gmem_tmp_d7, tid_in_warpgroup,
|
thread_block_copy_tile<B_ROW, HEADDIM>(smem_O, gmem_tmp_d7, tid_in_warpgroup,
|
||||||
threads_per_warpgroup,
|
threads_per_warpgroup,
|
||||||
warpgroup_id_in_cluster);
|
warpgroup_id_in_cluster);
|
||||||
}
|
}
|
||||||
@@ -1006,6 +1054,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
// warps_per_threadblock_per_core);
|
// warps_per_threadblock_per_core);
|
||||||
// threadblock_barrier(3, // FIXME
|
// threadblock_barrier(3, // FIXME
|
||||||
// NUM_WARPS);
|
// NUM_WARPS);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
asm volatile ("tile_loop_finish_%=:" :: );
|
asm volatile ("tile_loop_finish_%=:" :: );
|
||||||
@@ -1015,7 +1064,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
if (warpgroup_id == 0) {
|
if (warpgroup_id == 0) {
|
||||||
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
#include "include/gemmini.h"
|
#include "include/gemmini.h"
|
||||||
#include "gemmini_mmio.h"
|
#include "gemmini_mmio.h"
|
||||||
|
|
||||||
#define FP_SIZE 16
|
#define FP_SIZE 32
|
||||||
|
|
||||||
// "fake" fp16 type that only has the correct data width.
|
// "fake" fp16 type that only has the correct data width.
|
||||||
using float16_t = uint16_t;
|
using float16_t = uint16_t;
|
||||||
@@ -29,7 +29,7 @@ using float_type = float16_t;
|
|||||||
// (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER
|
// (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER
|
||||||
// * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields
|
// * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields
|
||||||
// BM <= BK*TM*TN
|
// BM <= BK*TM*TN
|
||||||
#define BM 128
|
#define BM 64
|
||||||
#define BN 64
|
#define BN 64
|
||||||
#if (FP_SIZE == 32)
|
#if (FP_SIZE == 32)
|
||||||
#define BK 64
|
#define BK 64
|
||||||
@@ -62,18 +62,18 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER ==
|
|||||||
#define BK_LOOP 1
|
#define BK_LOOP 1
|
||||||
// Whether to transpose smem A tile at GMEM->SMEM (produce), or SMEM->RF
|
// Whether to transpose smem A tile at GMEM->SMEM (produce), or SMEM->RF
|
||||||
// (consume). This is because the tensor core expects the A tile to be stored
|
// (consume). This is because the tensor core expects the A tile to be stored
|
||||||
// in column-major order in SMEM, whereas it will be ultimately stored in
|
// in column-major order in SMEM, so a transpose is necessary if A was stored
|
||||||
// row-major in the RF.
|
// row-major in GMEM.
|
||||||
//
|
//
|
||||||
// For correctness, only one of either should be 1. E.g., PRODUCE 1 CONSUME 0
|
// For correctness, only one of either should be 1. E.g., PRODUCE 1 CONSUME 0
|
||||||
// generates the NN kernel where both A and B are stored row-major in GMEM.
|
// generates the NN kernel where both A and B are stored row-major in GMEM.
|
||||||
// To model the case where the A matrix is already stored column-major in GMEM,
|
// To model the case where the A matrix is already stored column-major in GMEM,
|
||||||
// set both to 0.
|
// set both to 0.
|
||||||
#define TRANSPOSE_AT_PRODUCE 0
|
#define TRANSPOSE_AT_PRODUCE 0
|
||||||
#define TRANSPOSE_AT_CONSUME 0
|
#define TRANSPOSE_AT_CONSUME 1
|
||||||
|
|
||||||
#define GEMMINI_DMA 0
|
#define GEMMINI_DMA 1
|
||||||
#define GEMMINI_DMA_MN_MAJOR 1
|
#define GEMMINI_DMA_MN_MAJOR 0
|
||||||
#if SMEM_SIZE == 0x4000
|
#if SMEM_SIZE == 0x4000
|
||||||
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
|
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
|
||||||
#define SMEM_ADDR_Q1 ((float * const) 0xff001000)
|
#define SMEM_ADDR_Q1 ((float * const) 0xff001000)
|
||||||
|
|||||||
Reference in New Issue
Block a user