flash: Update tcore kernel to use new CISC
This commit is contained in:
@@ -8,6 +8,8 @@
|
||||
#include "gemmini_mmio.h"
|
||||
#include "flash_impl.hpp"
|
||||
|
||||
#define GEMMINI_NEW_CISC 1
|
||||
|
||||
constexpr bool DEBUG = false;
|
||||
constexpr bool Q_IS_K_MAJOR = true;
|
||||
|
||||
@@ -94,6 +96,14 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
uint8_t *smem_per_threadblock = reinterpret_cast<uint8_t *>(
|
||||
DEV_SMEM_START_ADDR);
|
||||
constexpr uint32_t smem_start = DEV_SMEM_START_ADDR;
|
||||
constexpr uint32_t smem_hexadecile_size = (SMEM_SIZE / 16);
|
||||
// currently assumes the Q/K/V tile sizes exactly match the hexadecile size
|
||||
static_assert(smem_hexadecile_size == smem_Q_size * sizeof(float));
|
||||
static_assert(smem_hexadecile_size == smem_K_size * sizeof(float));
|
||||
static_assert(smem_hexadecile_size == smem_QK_size * sizeof(float));
|
||||
static_assert(smem_hexadecile_size == smem_V_size * sizeof(float));
|
||||
static_assert(smem_hexadecile_size == smem_O_size * sizeof(float));
|
||||
|
||||
constexpr uint32_t smem_octet0 = 0 * (SMEM_SIZE / 8);
|
||||
constexpr uint32_t smem_octet1 = 1 * (SMEM_SIZE / 8);
|
||||
constexpr uint32_t smem_octet2 = 2 * (SMEM_SIZE / 8);
|
||||
@@ -108,31 +118,31 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// half
|
||||
// at the same time, make sure Q and K are in different banks so that they
|
||||
// can be accessed in parallel for GEMM; same for P and V
|
||||
constexpr uint32_t smem_Q0_offset = smem_octet0;
|
||||
constexpr uint32_t smem_Q1_offset = smem_octet4;
|
||||
constexpr uint32_t smem_K0_offset = smem_octet1;
|
||||
constexpr uint32_t smem_K1_offset = smem_octet5;
|
||||
constexpr uint32_t smem_V0_offset = smem_K0_offset + smem_K_size * sizeof(float);
|
||||
constexpr uint32_t smem_V1_offset = smem_K1_offset + smem_K_size * sizeof(float);
|
||||
constexpr uint32_t smem_S0_offset = smem_octet2;
|
||||
constexpr uint32_t smem_S1_offset = smem_octet6;
|
||||
constexpr uint32_t smem_P0_offset = smem_Q0_offset + smem_Q_size * sizeof(float);
|
||||
constexpr uint32_t smem_P1_offset = smem_Q1_offset + smem_Q_size * sizeof(float);
|
||||
constexpr uint32_t smem_O0_offset = smem_octet3;
|
||||
constexpr uint32_t smem_O1_offset = smem_octet7;
|
||||
constexpr uint32_t smem_Q0_hexadecile = 2 * 0; // octet0
|
||||
constexpr uint32_t smem_Q1_hexadecile = 2 * 4; // octet4
|
||||
constexpr uint32_t smem_K0_hexadecile = 2 * 1; // octet1
|
||||
constexpr uint32_t smem_K1_hexadecile = 2 * 5; // octet5
|
||||
constexpr uint32_t smem_V0_hexadecile = smem_K0_hexadecile + 1;
|
||||
constexpr uint32_t smem_V1_hexadecile = smem_K1_hexadecile + 1;
|
||||
constexpr uint32_t smem_S0_hexadecile = 2 * 2; // octet2
|
||||
constexpr uint32_t smem_S1_hexadecile = 2 * 6; // octet6
|
||||
constexpr uint32_t smem_P0_hexadecile = smem_Q0_hexadecile + 1;
|
||||
constexpr uint32_t smem_P1_hexadecile = smem_Q1_hexadecile + 1;
|
||||
constexpr uint32_t smem_O0_hexadecile = 2 * 3; // octet3
|
||||
constexpr uint32_t smem_O1_hexadecile = 2 * 7; // octet7
|
||||
|
||||
float *smem_Q0 = reinterpret_cast<float *>(smem_start + smem_Q0_offset);
|
||||
float *smem_Q1 = reinterpret_cast<float *>(smem_start + smem_Q1_offset);
|
||||
float *smem_K0 = reinterpret_cast<float *>(smem_start + smem_K0_offset);
|
||||
float *smem_K1 = reinterpret_cast<float *>(smem_start + smem_K1_offset);
|
||||
float *smem_V0 = reinterpret_cast<float *>(smem_start + smem_V0_offset);
|
||||
float *smem_V1 = reinterpret_cast<float *>(smem_start + smem_V1_offset);
|
||||
float *smem_S0 = reinterpret_cast<float *>(smem_start + smem_S0_offset);
|
||||
float *smem_S1 = reinterpret_cast<float *>(smem_start + smem_S1_offset);
|
||||
float *smem_P0 = reinterpret_cast<float *>(smem_start + smem_P0_offset);
|
||||
float *smem_P1 = reinterpret_cast<float *>(smem_start + smem_P1_offset);
|
||||
float *smem_O0 = reinterpret_cast<float *>(smem_start + smem_O0_offset);
|
||||
float *smem_O1 = reinterpret_cast<float *>(smem_start + smem_O1_offset);
|
||||
float *smem_Q0 = reinterpret_cast<float *>(smem_start + smem_Q0_hexadecile * smem_hexadecile_size);
|
||||
float *smem_Q1 = reinterpret_cast<float *>(smem_start + smem_Q1_hexadecile * smem_hexadecile_size);
|
||||
float *smem_K0 = reinterpret_cast<float *>(smem_start + smem_K0_hexadecile * smem_hexadecile_size);
|
||||
float *smem_K1 = reinterpret_cast<float *>(smem_start + smem_K1_hexadecile * smem_hexadecile_size);
|
||||
float *smem_V0 = reinterpret_cast<float *>(smem_start + smem_V0_hexadecile * smem_hexadecile_size);
|
||||
float *smem_V1 = reinterpret_cast<float *>(smem_start + smem_V1_hexadecile * smem_hexadecile_size);
|
||||
float *smem_S0 = reinterpret_cast<float *>(smem_start + smem_S0_hexadecile * smem_hexadecile_size);
|
||||
float *smem_S1 = reinterpret_cast<float *>(smem_start + smem_S1_hexadecile * smem_hexadecile_size);
|
||||
float *smem_P0 = reinterpret_cast<float *>(smem_start + smem_P0_hexadecile * smem_hexadecile_size);
|
||||
float *smem_P1 = reinterpret_cast<float *>(smem_start + smem_P1_hexadecile * smem_hexadecile_size);
|
||||
float *smem_O0 = reinterpret_cast<float *>(smem_start + smem_O0_hexadecile * smem_hexadecile_size);
|
||||
float *smem_O1 = reinterpret_cast<float *>(smem_start + smem_O1_hexadecile * smem_hexadecile_size);
|
||||
|
||||
// allocate rowmax/rowsum storage at the end of the sharedmem address space
|
||||
constexpr uint32_t smem_rowmax_size = B_ROW * ROWMAX_SETS;
|
||||
@@ -180,25 +190,32 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
float *smem_scratchpad =
|
||||
(warpgroup_id % 2) ? smem_scratchpad_1 : smem_scratchpad_0;
|
||||
|
||||
#ifdef GEMMINI_NEW_CISC
|
||||
const auto spad_hex_Q = (warpgroup_id % 2) ? smem_Q1_hexadecile : smem_Q0_hexadecile;
|
||||
const auto spad_hex_K = (warpgroup_id % 2) ? smem_K1_hexadecile : smem_K0_hexadecile;
|
||||
const auto spad_hex_V = (warpgroup_id % 2) ? smem_V1_hexadecile : smem_V0_hexadecile;
|
||||
const auto spad_hex_S = (warpgroup_id % 2) ? smem_S1_hexadecile : smem_S0_hexadecile;
|
||||
#else
|
||||
static_assert(sizeof(elem_t) == sizeof(float));
|
||||
constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t);
|
||||
constexpr uint32_t spad_addr_Q0 = smem_Q0_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_Q1 = smem_Q1_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_K0 = smem_K0_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_K1 = smem_K1_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_V0 = smem_V0_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_V1 = smem_V1_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_S0 = smem_S0_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_S1 = smem_S1_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_P0 = smem_P0_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_P1 = smem_P1_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_O0 = smem_O0_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_O1 = smem_O1_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_Q0 = smem_Q0_hexadecile * smem_hexadecile_size / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_Q1 = smem_Q1_hexadecile * smem_hexadecile_size / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_K0 = smem_K0_hexadecile * smem_hexadecile_size / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_K1 = smem_K1_hexadecile * smem_hexadecile_size / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_V0 = smem_V0_hexadecile * smem_hexadecile_size / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_V1 = smem_V1_hexadecile * smem_hexadecile_size / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_S0 = smem_S0_hexadecile * smem_hexadecile_size / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_S1 = smem_S1_hexadecile * smem_hexadecile_size / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_P0 = smem_P0_hexadecile * smem_hexadecile_size / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_P1 = smem_P1_hexadecile * smem_hexadecile_size / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_O0 = smem_O0_hexadecile * smem_hexadecile_size / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_O1 = smem_O1_hexadecile * smem_hexadecile_size / spad_addr_factor;
|
||||
|
||||
const auto spad_addr_Q = (warpgroup_id % 2) ? spad_addr_Q1 : spad_addr_Q0;
|
||||
const auto spad_addr_K = (warpgroup_id % 2) ? spad_addr_K1 : spad_addr_K0;
|
||||
const auto spad_addr_V = (warpgroup_id % 2) ? spad_addr_V1 : spad_addr_V0;
|
||||
const auto spad_addr_S = (warpgroup_id % 2) ? spad_addr_S1 : spad_addr_S0;
|
||||
#endif
|
||||
|
||||
// initialize rowmax/rowsum values in sharedmem
|
||||
thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O,
|
||||
@@ -246,8 +263,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// other warps behave differently on the branch condition.
|
||||
// threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||
|
||||
// move Q and K into SMEM before the loop starts
|
||||
//
|
||||
static_assert(B_ROW == B_COL, "currently only supports square tiles");
|
||||
|
||||
if constexpr (GEMMINI_DMA) {
|
||||
@@ -256,6 +271,19 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
if (tid_in_warpgroup == 0) {
|
||||
const float *gmem_Q_tile = gmem_Q + HEADDIM * B_ROW * warpgroup_id;
|
||||
const float *gmem_K_tile = gmem_K;
|
||||
// do DMA
|
||||
//
|
||||
// move Q and K into SMEM before the loop starts. Note this will be done
|
||||
// separately for the two warpgroups
|
||||
//
|
||||
#ifdef GEMMINI_NEW_CISC
|
||||
// the target addresses of this should match with spad_addr_Q0 and
|
||||
// spad_addr_K0 set in this kernel
|
||||
gemmini_tile_load_ab(gmem_Q_tile, gmem_K_tile, spad_hex_Q,
|
||||
spad_hex_K, 0 /*tile_idx_i*/,
|
||||
0 /*tile_idx_j*/, 0 /*tile_idx_k*/, dim_seqlen,
|
||||
dim_seqlen, HEADDIM, B_ROW, B_COL, HEADDIM);
|
||||
#else
|
||||
// configure the GMEM addresses for the DMA to read from
|
||||
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q_tile),
|
||||
(uint64_t)(gmem_K_tile),
|
||||
@@ -265,13 +293,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
|
||||
gemmini_fence();
|
||||
|
||||
// #define GEMMINI_DMA_CISC
|
||||
#ifdef GEMMINI_DMA_CISC
|
||||
GEMMINI_CISC_CMD_I(9);
|
||||
gemmini_fence();
|
||||
#else
|
||||
// do DMA
|
||||
//
|
||||
// among other things, this also configures CONFIG_BOUNDS so that the
|
||||
// DMA knows the full matrix dimensions
|
||||
sp_tiled_matmul_full_spad_ws(
|
||||
@@ -281,9 +302,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
|
||||
gemmini_fence();
|
||||
#endif
|
||||
|
||||
// block until DMA complete
|
||||
gemmini_fence();
|
||||
|
||||
// re-configure DMA for K and V load that will later happen in the loop
|
||||
// GMEM addr stride for K
|
||||
gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t), MVIN_SCALE_IDENTITY,
|
||||
@@ -549,13 +572,20 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile),
|
||||
(uint64_t)(gmem_V_tile),
|
||||
k_LOOP_WS_CONFIG_ADDRS_AB)
|
||||
// do DMA
|
||||
#ifdef GEMMINI_NEW_CISC
|
||||
gemmini_tile_load_ab(gmem_K_tile, gmem_V_tile, spad_hex_K, spad_hex_V,
|
||||
0 /*tile_idx_i*/, 0 /*tile_idx_j*/,
|
||||
0 /*tile_idx_k*/, HEADDIM /*dim_m of KT*/,
|
||||
HEADDIM /*dim_n of V*/, dim_seqlen /*dim_k of KT*/,
|
||||
B_ROW, HEADDIM, B_COL);
|
||||
#else
|
||||
// configure address strides for the DMA
|
||||
// FIXME: unnecessary?
|
||||
GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) |
|
||||
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
|
||||
gemmini_fence();
|
||||
|
||||
// do DMA
|
||||
sp_tiled_matmul_full_spad_ws(
|
||||
spad_addr_K, spad_addr_V,
|
||||
/*spad_D=*/0, /*spad_C=*/spad_addr_S,
|
||||
@@ -563,6 +593,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
|
||||
#endif
|
||||
// FIXME: necessary?
|
||||
gemmini_fence();
|
||||
}
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user