flash: Fix DMA addr stride, stop at S=Q*K
This commit is contained in:
@@ -8,10 +8,9 @@
|
||||
#include "include/gemmini.h"
|
||||
#include "gemmini_mmio.h"
|
||||
|
||||
#define B_ROW BM
|
||||
#define B_COL BN
|
||||
// FIXME
|
||||
#define HEADDIM B_COL
|
||||
#define B_ROW 64
|
||||
#define B_COL 64
|
||||
#define HEADDIM 64
|
||||
|
||||
constexpr uint32_t ROWMAX_SETS = 3;
|
||||
constexpr bool DEBUG = true;
|
||||
@@ -19,6 +18,8 @@ constexpr bool WARP_SPECIALIZED = false;
|
||||
|
||||
constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000;
|
||||
|
||||
constexpr bool Q_IS_K_MAJOR = true;
|
||||
|
||||
// temporary safety stop for wrong configs
|
||||
static_assert(NUM_CORES == 4);
|
||||
static_assert(NUM_THREADS == 8);
|
||||
@@ -99,6 +100,7 @@ inline void thread_block_copy_rowmax(const float *src, float *dest,
|
||||
asm volatile("threadblock_copy_rowmax_finish_%=:" ::);
|
||||
}
|
||||
|
||||
template <uint32_t dim_row, uint32_t dim_col>
|
||||
inline void thread_block_copy_tile(const float *src, float *dest,
|
||||
const uint32_t tid_in_threadblock,
|
||||
const uint32_t threads_per_threadblock,
|
||||
@@ -113,12 +115,12 @@ inline void thread_block_copy_tile(const float *src, float *dest,
|
||||
|
||||
// FIXME: dedup this pattern
|
||||
#pragma GCC unroll 1
|
||||
for (int row_offset = 0; row_offset < B_ROW;
|
||||
for (int row_offset = 0; row_offset < dim_row;
|
||||
row_offset += warps_in_threadblock) {
|
||||
const uint32_t row = row_offset + warp_id;
|
||||
const uint32_t first_thread_offset = B_COL * row;
|
||||
const uint32_t first_thread_offset = dim_col * row;
|
||||
|
||||
constexpr uint32_t per_row_iter = B_COL / NUM_THREADS;
|
||||
constexpr uint32_t per_row_iter = dim_col / NUM_THREADS;
|
||||
uint32_t thread_offset = first_thread_offset + tid_in_warp;
|
||||
#pragma GCC unroll
|
||||
for (int i = 0; i < per_row_iter; i++) {
|
||||
@@ -533,12 +535,12 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
constexpr uint32_t smem_QK_size = B_ROW * B_COL;
|
||||
constexpr uint32_t smem_V_size = B_COL * HEADDIM;
|
||||
constexpr uint32_t smem_O_size = B_COL * HEADDIM;
|
||||
static_assert(
|
||||
threads_per_threadblock == NUM_WARPS * NUM_THREADS * CORES_PER_CLUSTER,
|
||||
"flashattention kernel assumes 1 threadblock occupancy per cluster");
|
||||
uint8_t *smem_per_threadblock = reinterpret_cast<uint8_t *>(
|
||||
DEV_SMEM_START_ADDR +
|
||||
sizeof(float_type) *
|
||||
(smem_QK_size + smem_V_size + smem_O_size) *
|
||||
threadblock_id_in_cluster);
|
||||
float *smem_cursor = reinterpret_cast<float *>(DEV_FAKE_SMEM_START_ADDR);
|
||||
DEV_SMEM_START_ADDR);
|
||||
float *smem_cursor = reinterpret_cast<float *>(smem_per_threadblock);
|
||||
float *smem_Q0 = smem_cursor;
|
||||
smem_cursor += smem_Q_size;
|
||||
float *smem_Q1 = smem_cursor;
|
||||
@@ -563,6 +565,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
smem_cursor += smem_O_size;
|
||||
|
||||
// NOTE: this has to match with smem_*
|
||||
static_assert(sizeof(elem_t) == sizeof(float));
|
||||
constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t);
|
||||
constexpr uint32_t spad_addr_Q0 = 0;
|
||||
constexpr uint32_t spad_addr_Q1 =
|
||||
@@ -635,15 +638,18 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
||||
// }
|
||||
|
||||
static_assert(!GEMMINI_DMA || Q_IS_K_MAJOR,
|
||||
"DMA code assumes Q matrix is stored K-major");
|
||||
|
||||
if constexpr (GEMMINI_DMA) {
|
||||
if (tid_in_warpgroup == 0) {
|
||||
gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0);
|
||||
|
||||
// configure DMA for Q tile
|
||||
// configure DMA for the full Q matrix
|
||||
gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY,
|
||||
false, 0);
|
||||
// configure DMA for K tile
|
||||
gemmini_extended3_config_ld(B_COL * sizeof(elem_t), MVIN_SCALE_IDENTITY,
|
||||
// configure DMA for the full K matrix
|
||||
gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t), MVIN_SCALE_IDENTITY,
|
||||
false, 1);
|
||||
// configure DMA for Q*K store
|
||||
gemmini_extended_config_st(B_COL * sizeof(elem_t), 0,
|
||||
@@ -652,12 +658,12 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE about barriers: placing barriers around thread-divergent branches may
|
||||
// cause bugs, since the core doesn't check tmask for barriers. The compiler
|
||||
// might decide to replicate vx_bar into both paths of a conditional branch,
|
||||
// which will get evaluated twice along the split/join process and result in
|
||||
// a different number of calls w.r.t other non-divergent warps and therefore
|
||||
// stalls.
|
||||
// NOTE about barriers: Placing barriers around thread-divergent branches may
|
||||
// cause bugs, because the Vortex core doesn't check for tmask for barriers.
|
||||
// The compiler might decide to duplicate vx_bar into both paths of a
|
||||
// conditional branch, which will get evaluated twice because of the way
|
||||
// branches are handled in SIMT; this might result in stalls especially when
|
||||
// other warps behave differently on the branch condition.
|
||||
// threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||
|
||||
// move Q and K into SMEM before the loop starts
|
||||
@@ -674,13 +680,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q),
|
||||
(uint64_t)(gmem_K), k_LOOP_WS_CONFIG_ADDRS_AB)
|
||||
// configure address strides for the DMA
|
||||
GEMMINI_CISC_CMD_R((B_COL << 16) | (HEADDIM << 8) |
|
||||
// GEMMINI_CISC_CMD_R((B_COL << 16) | (HEADDIM << 8) |
|
||||
// 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
|
||||
GEMMINI_CISC_CMD_R((dim_seqlen << 16) | (HEADDIM << 8) |
|
||||
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
|
||||
gemmini_fence();
|
||||
|
||||
#define GEMMINI_DMA_CISC
|
||||
// #define GEMMINI_DMA_CISC
|
||||
#ifdef GEMMINI_DMA_CISC
|
||||
GEMMINI_CISC_CMD_I(10);
|
||||
GEMMINI_CISC_CMD_I(9);
|
||||
gemmini_fence();
|
||||
#else
|
||||
// skip everything except DMA in the loop FSM
|
||||
@@ -693,7 +701,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
sp_tiled_matmul_full_spad_ws(
|
||||
spad_addr_Q0, spad_addr_K0,
|
||||
/*spad_D=*/0, /*spad_C=*/spad_addr_S0,
|
||||
/*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
|
||||
/*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM),
|
||||
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
|
||||
@@ -704,10 +712,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
asm volatile("dma_move_end_%=:" ::);
|
||||
} else {
|
||||
// load Q; this stays in SMEM for the entire loop
|
||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_ROW,
|
||||
HEADDIM, threads_per_warpgroup>(
|
||||
dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q,
|
||||
tid_in_warpgroup);
|
||||
if constexpr (Q_IS_K_MAJOR) {
|
||||
load_tile_to_smem<float, MemLayout::K_major, MemLayout::K_major, B_ROW,
|
||||
HEADDIM, threads_per_warpgroup>(
|
||||
HEADDIM, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q,
|
||||
tid_in_warpgroup);
|
||||
} else {
|
||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_ROW,
|
||||
HEADDIM, threads_per_warpgroup>(
|
||||
dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q,
|
||||
tid_in_warpgroup);
|
||||
}
|
||||
|
||||
// load K
|
||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
|
||||
@@ -719,14 +734,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// protect write to SMEM
|
||||
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||
|
||||
if constexpr (DEBUG) {
|
||||
thread_block_copy_tile(smem_Q, gmem_tmp_d0, tid_in_warpgroup,
|
||||
threads_per_warpgroup, warpgroup_id_in_cluster);
|
||||
// if constexpr (DEBUG) {
|
||||
// thread_block_copy_tile<B_ROW, HEADDIM>(smem_Q0, gmem_tmp_d0, tid_in_warpgroup,
|
||||
// threads_per_warpgroup, warpgroup_id_in_cluster);
|
||||
// thread_block_copy_tile<HEADDIM, B_COL>(smem_K0, gmem_tmp_d1, tid_in_warpgroup,
|
||||
// threads_per_warpgroup, warpgroup_id_in_cluster);
|
||||
|
||||
// threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||
}
|
||||
// threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||
// }
|
||||
|
||||
#if 0
|
||||
asm volatile ("tile_loop_start_%=:" :: );
|
||||
|
||||
// "inner loop" along the columns of K^T
|
||||
@@ -751,25 +767,34 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
initialize_accum_regs<0>();
|
||||
initialize_accum_regs<1>();
|
||||
|
||||
thread_block_gemm_single_tile<
|
||||
float, MemLayout::MN_major, MemLayout::MN_major, B_ROW, B_COL,
|
||||
HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||
/*load_accum=*/false,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_warpgroup,
|
||||
threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
if constexpr (Q_IS_K_MAJOR) {
|
||||
thread_block_gemm_single_tile<
|
||||
float, MemLayout::K_major, MemLayout::MN_major, B_ROW, B_COL,
|
||||
HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||
/*load_accum=*/false,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_Q, smem_K, nullptr /*ignore accum*/, smem_S,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
} else {
|
||||
thread_block_gemm_single_tile<
|
||||
float, MemLayout::MN_major, MemLayout::MN_major, B_ROW, B_COL,
|
||||
HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||
/*load_accum=*/false,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_Q, smem_K, nullptr /*ignore accum*/, smem_S,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
}
|
||||
} else {
|
||||
// when warp-specialized, there's only enough warps to do 64x32 tile
|
||||
// size so we need to do 2 GEMM calls
|
||||
static_assert(B_ROW / 2 == 32,
|
||||
"tile size assumption for warp-specialization not met");
|
||||
|
||||
// assumes smem_Q is K-major
|
||||
// FIXME: fix this to MN-major
|
||||
float *smem_Q_half0 = smem_Q;
|
||||
float *smem_Q_half1 = smem_Q + (B_ROW / 2); // MN-major
|
||||
// float *smem_Q_half1 = smem_Q + (B_ROW / 2) * HEADDIM; // K-major
|
||||
float *smem_Q_half1 = Q_IS_K_MAJOR ? smem_Q + (B_ROW / 2) * HEADDIM
|
||||
: smem_Q + (B_ROW / 2);
|
||||
float *smem_S_half0 = smem_S;
|
||||
float *smem_S_half1 = smem_S + (B_ROW / 2) * B_COL;
|
||||
|
||||
@@ -778,26 +803,48 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
initialize_accum_regs<1>();
|
||||
|
||||
// split by rows into 2 chunks
|
||||
thread_block_gemm_single_tile<
|
||||
float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2,
|
||||
B_COL, HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0,
|
||||
/*load_accum=*/false,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
if constexpr (Q_IS_K_MAJOR) {
|
||||
thread_block_gemm_single_tile<
|
||||
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL,
|
||||
HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||
/*load_accum=*/false,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
} else {
|
||||
thread_block_gemm_single_tile<
|
||||
float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2, B_COL,
|
||||
HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0,
|
||||
/*load_accum=*/false,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
}
|
||||
|
||||
initialize_accum_regs<0>();
|
||||
initialize_accum_regs<1>();
|
||||
|
||||
thread_block_gemm_single_tile<
|
||||
float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2,
|
||||
B_COL, HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0,
|
||||
/*load_accum=*/false,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
if constexpr (Q_IS_K_MAJOR) {
|
||||
thread_block_gemm_single_tile<
|
||||
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL,
|
||||
HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||
/*load_accum=*/false,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
} else {
|
||||
thread_block_gemm_single_tile<
|
||||
float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2, B_COL,
|
||||
HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0,
|
||||
/*load_accum=*/false,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// load Q*K
|
||||
@@ -813,11 +860,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
if constexpr (DEBUG) {
|
||||
if (warpgroup_id == 0) {
|
||||
if (tile_k == 0) {
|
||||
thread_block_copy_tile(smem_S, gmem_tmp_d0,
|
||||
thread_block_copy_tile<B_ROW, B_COL>(smem_S, gmem_tmp_d0,
|
||||
tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
} else if (tile_k == 1) {
|
||||
thread_block_copy_tile(smem_S, gmem_tmp_d1,
|
||||
thread_block_copy_tile<B_ROW, B_COL>(smem_S, gmem_tmp_d1,
|
||||
tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
}
|
||||
@@ -830,6 +877,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// inter-warpgroup barrier before online softmax
|
||||
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
||||
|
||||
#if 0
|
||||
// Online softmax
|
||||
//
|
||||
thread_block_online_softmax(smem_S, smem_P, tid_in_warpgroup,
|
||||
@@ -897,17 +945,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
if (warpgroup_id == 0) {
|
||||
// O before PV
|
||||
if (tile_k == 0) {
|
||||
thread_block_copy_tile(smem_P, gmem_tmp_d2, tid_in_warpgroup,
|
||||
thread_block_copy_tile<B_ROW, B_COL>(smem_P, gmem_tmp_d2, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
thread_block_copy_tile(smem_O, gmem_tmp_d4, tid_in_warpgroup,
|
||||
thread_block_copy_tile<B_ROW, HEADDIM>(smem_O, gmem_tmp_d4, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
} else if (tile_k == 1) {
|
||||
thread_block_copy_tile(smem_P, gmem_tmp_d3, tid_in_warpgroup,
|
||||
thread_block_copy_tile<B_ROW, B_COL>(smem_P, gmem_tmp_d3, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
thread_block_copy_tile(smem_O, gmem_tmp_d5, tid_in_warpgroup,
|
||||
thread_block_copy_tile<B_ROW, HEADDIM>(smem_O, gmem_tmp_d5, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
}
|
||||
@@ -986,11 +1034,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
if (warpgroup_id == 0) {
|
||||
// O after PV
|
||||
if (tile_k == 0) {
|
||||
thread_block_copy_tile(smem_O, gmem_tmp_d6, tid_in_warpgroup,
|
||||
thread_block_copy_tile<B_ROW, HEADDIM>(smem_O, gmem_tmp_d6, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
} else if (tile_k == 1) {
|
||||
thread_block_copy_tile(smem_O, gmem_tmp_d7, tid_in_warpgroup,
|
||||
thread_block_copy_tile<B_ROW, HEADDIM>(smem_O, gmem_tmp_d7, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
}
|
||||
@@ -1006,6 +1054,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// warps_per_threadblock_per_core);
|
||||
// threadblock_barrier(3, // FIXME
|
||||
// NUM_WARPS);
|
||||
#endif
|
||||
}
|
||||
|
||||
asm volatile ("tile_loop_finish_%=:" :: );
|
||||
@@ -1015,7 +1064,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
if (warpgroup_id == 0) {
|
||||
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#include "include/gemmini.h"
|
||||
#include "gemmini_mmio.h"
|
||||
|
||||
#define FP_SIZE 16
|
||||
#define FP_SIZE 32
|
||||
|
||||
// "fake" fp16 type that only has the correct data width.
|
||||
using float16_t = uint16_t;
|
||||
@@ -29,7 +29,7 @@ using float_type = float16_t;
|
||||
// (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER
|
||||
// * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields
|
||||
// BM <= BK*TM*TN
|
||||
#define BM 128
|
||||
#define BM 64
|
||||
#define BN 64
|
||||
#if (FP_SIZE == 32)
|
||||
#define BK 64
|
||||
@@ -62,18 +62,18 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER ==
|
||||
#define BK_LOOP 1
|
||||
// Whether to transpose smem A tile at GMEM->SMEM (produce), or SMEM->RF
|
||||
// (consume). This is because the tensor core expects the A tile to be stored
|
||||
// in column-major order in SMEM, whereas it will be ultimately stored in
|
||||
// row-major in the RF.
|
||||
// in column-major order in SMEM, so a transpose is necessary if A was stored
|
||||
// row-major in GMEM.
|
||||
//
|
||||
// For correctness, only one of either should be 1. E.g., PRODUCE 1 CONSUME 0
|
||||
// generates the NN kernel where both A and B are stored row-major in GMEM.
|
||||
// To model the case where the A matrix is already stored column-major in GMEM,
|
||||
// set both to 0.
|
||||
#define TRANSPOSE_AT_PRODUCE 0
|
||||
#define TRANSPOSE_AT_CONSUME 0
|
||||
#define TRANSPOSE_AT_CONSUME 1
|
||||
|
||||
#define GEMMINI_DMA 0
|
||||
#define GEMMINI_DMA_MN_MAJOR 1
|
||||
#define GEMMINI_DMA 1
|
||||
#define GEMMINI_DMA_MN_MAJOR 0
|
||||
#if SMEM_SIZE == 0x4000
|
||||
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
|
||||
#define SMEM_ADDR_Q1 ((float * const) 0xff001000)
|
||||
|
||||
Reference in New Issue
Block a user