flash: Update tcore kernel to use new CISC

This commit is contained in:
Hansung Kim
2024-11-09 19:49:20 -08:00
parent 76a6aaf085
commit 8fe6d918f2

View File

@@ -8,6 +8,8 @@
#include "gemmini_mmio.h"
#include "flash_impl.hpp"
#define GEMMINI_NEW_CISC 1
constexpr bool DEBUG = false;
constexpr bool Q_IS_K_MAJOR = true;
@@ -94,6 +96,14 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
uint8_t *smem_per_threadblock = reinterpret_cast<uint8_t *>(
DEV_SMEM_START_ADDR);
constexpr uint32_t smem_start = DEV_SMEM_START_ADDR;
constexpr uint32_t smem_hexadecile_size = (SMEM_SIZE / 16);
// currently assumes the Q/K/V tile sizes exactly match the hexadecile size
static_assert(smem_hexadecile_size == smem_Q_size * sizeof(float));
static_assert(smem_hexadecile_size == smem_K_size * sizeof(float));
static_assert(smem_hexadecile_size == smem_QK_size * sizeof(float));
static_assert(smem_hexadecile_size == smem_V_size * sizeof(float));
static_assert(smem_hexadecile_size == smem_O_size * sizeof(float));
constexpr uint32_t smem_octet0 = 0 * (SMEM_SIZE / 8);
constexpr uint32_t smem_octet1 = 1 * (SMEM_SIZE / 8);
constexpr uint32_t smem_octet2 = 2 * (SMEM_SIZE / 8);
@@ -108,31 +118,31 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// half
// at the same time, make sure Q and K are in different banks so that they
// can be accessed in parallel for GEMM; same for P and V
constexpr uint32_t smem_Q0_offset = smem_octet0;
constexpr uint32_t smem_Q1_offset = smem_octet4;
constexpr uint32_t smem_K0_offset = smem_octet1;
constexpr uint32_t smem_K1_offset = smem_octet5;
constexpr uint32_t smem_V0_offset = smem_K0_offset + smem_K_size * sizeof(float);
constexpr uint32_t smem_V1_offset = smem_K1_offset + smem_K_size * sizeof(float);
constexpr uint32_t smem_S0_offset = smem_octet2;
constexpr uint32_t smem_S1_offset = smem_octet6;
constexpr uint32_t smem_P0_offset = smem_Q0_offset + smem_Q_size * sizeof(float);
constexpr uint32_t smem_P1_offset = smem_Q1_offset + smem_Q_size * sizeof(float);
constexpr uint32_t smem_O0_offset = smem_octet3;
constexpr uint32_t smem_O1_offset = smem_octet7;
constexpr uint32_t smem_Q0_hexadecile = 2 * 0; // octet0
constexpr uint32_t smem_Q1_hexadecile = 2 * 4; // octet4
constexpr uint32_t smem_K0_hexadecile = 2 * 1; // octet1
constexpr uint32_t smem_K1_hexadecile = 2 * 5; // octet5
constexpr uint32_t smem_V0_hexadecile = smem_K0_hexadecile + 1;
constexpr uint32_t smem_V1_hexadecile = smem_K1_hexadecile + 1;
constexpr uint32_t smem_S0_hexadecile = 2 * 2; // octet2
constexpr uint32_t smem_S1_hexadecile = 2 * 6; // octet6
constexpr uint32_t smem_P0_hexadecile = smem_Q0_hexadecile + 1;
constexpr uint32_t smem_P1_hexadecile = smem_Q1_hexadecile + 1;
constexpr uint32_t smem_O0_hexadecile = 2 * 3; // octet3
constexpr uint32_t smem_O1_hexadecile = 2 * 7; // octet7
float *smem_Q0 = reinterpret_cast<float *>(smem_start + smem_Q0_offset);
float *smem_Q1 = reinterpret_cast<float *>(smem_start + smem_Q1_offset);
float *smem_K0 = reinterpret_cast<float *>(smem_start + smem_K0_offset);
float *smem_K1 = reinterpret_cast<float *>(smem_start + smem_K1_offset);
float *smem_V0 = reinterpret_cast<float *>(smem_start + smem_V0_offset);
float *smem_V1 = reinterpret_cast<float *>(smem_start + smem_V1_offset);
float *smem_S0 = reinterpret_cast<float *>(smem_start + smem_S0_offset);
float *smem_S1 = reinterpret_cast<float *>(smem_start + smem_S1_offset);
float *smem_P0 = reinterpret_cast<float *>(smem_start + smem_P0_offset);
float *smem_P1 = reinterpret_cast<float *>(smem_start + smem_P1_offset);
float *smem_O0 = reinterpret_cast<float *>(smem_start + smem_O0_offset);
float *smem_O1 = reinterpret_cast<float *>(smem_start + smem_O1_offset);
float *smem_Q0 = reinterpret_cast<float *>(smem_start + smem_Q0_hexadecile * smem_hexadecile_size);
float *smem_Q1 = reinterpret_cast<float *>(smem_start + smem_Q1_hexadecile * smem_hexadecile_size);
float *smem_K0 = reinterpret_cast<float *>(smem_start + smem_K0_hexadecile * smem_hexadecile_size);
float *smem_K1 = reinterpret_cast<float *>(smem_start + smem_K1_hexadecile * smem_hexadecile_size);
float *smem_V0 = reinterpret_cast<float *>(smem_start + smem_V0_hexadecile * smem_hexadecile_size);
float *smem_V1 = reinterpret_cast<float *>(smem_start + smem_V1_hexadecile * smem_hexadecile_size);
float *smem_S0 = reinterpret_cast<float *>(smem_start + smem_S0_hexadecile * smem_hexadecile_size);
float *smem_S1 = reinterpret_cast<float *>(smem_start + smem_S1_hexadecile * smem_hexadecile_size);
float *smem_P0 = reinterpret_cast<float *>(smem_start + smem_P0_hexadecile * smem_hexadecile_size);
float *smem_P1 = reinterpret_cast<float *>(smem_start + smem_P1_hexadecile * smem_hexadecile_size);
float *smem_O0 = reinterpret_cast<float *>(smem_start + smem_O0_hexadecile * smem_hexadecile_size);
float *smem_O1 = reinterpret_cast<float *>(smem_start + smem_O1_hexadecile * smem_hexadecile_size);
// allocate rowmax/rowsum storage at the end of the sharedmem address space
constexpr uint32_t smem_rowmax_size = B_ROW * ROWMAX_SETS;
@@ -180,25 +190,32 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
float *smem_scratchpad =
(warpgroup_id % 2) ? smem_scratchpad_1 : smem_scratchpad_0;
#ifdef GEMMINI_NEW_CISC
const auto spad_hex_Q = (warpgroup_id % 2) ? smem_Q1_hexadecile : smem_Q0_hexadecile;
const auto spad_hex_K = (warpgroup_id % 2) ? smem_K1_hexadecile : smem_K0_hexadecile;
const auto spad_hex_V = (warpgroup_id % 2) ? smem_V1_hexadecile : smem_V0_hexadecile;
const auto spad_hex_S = (warpgroup_id % 2) ? smem_S1_hexadecile : smem_S0_hexadecile;
#else
static_assert(sizeof(elem_t) == sizeof(float));
constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t);
constexpr uint32_t spad_addr_Q0 = smem_Q0_offset / spad_addr_factor;
constexpr uint32_t spad_addr_Q1 = smem_Q1_offset / spad_addr_factor;
constexpr uint32_t spad_addr_K0 = smem_K0_offset / spad_addr_factor;
constexpr uint32_t spad_addr_K1 = smem_K1_offset / spad_addr_factor;
constexpr uint32_t spad_addr_V0 = smem_V0_offset / spad_addr_factor;
constexpr uint32_t spad_addr_V1 = smem_V1_offset / spad_addr_factor;
constexpr uint32_t spad_addr_S0 = smem_S0_offset / spad_addr_factor;
constexpr uint32_t spad_addr_S1 = smem_S1_offset / spad_addr_factor;
constexpr uint32_t spad_addr_P0 = smem_P0_offset / spad_addr_factor;
constexpr uint32_t spad_addr_P1 = smem_P1_offset / spad_addr_factor;
constexpr uint32_t spad_addr_O0 = smem_O0_offset / spad_addr_factor;
constexpr uint32_t spad_addr_O1 = smem_O1_offset / spad_addr_factor;
constexpr uint32_t spad_addr_Q0 = smem_Q0_hexadecile * smem_hexadecile_size / spad_addr_factor;
constexpr uint32_t spad_addr_Q1 = smem_Q1_hexadecile * smem_hexadecile_size / spad_addr_factor;
constexpr uint32_t spad_addr_K0 = smem_K0_hexadecile * smem_hexadecile_size / spad_addr_factor;
constexpr uint32_t spad_addr_K1 = smem_K1_hexadecile * smem_hexadecile_size / spad_addr_factor;
constexpr uint32_t spad_addr_V0 = smem_V0_hexadecile * smem_hexadecile_size / spad_addr_factor;
constexpr uint32_t spad_addr_V1 = smem_V1_hexadecile * smem_hexadecile_size / spad_addr_factor;
constexpr uint32_t spad_addr_S0 = smem_S0_hexadecile * smem_hexadecile_size / spad_addr_factor;
constexpr uint32_t spad_addr_S1 = smem_S1_hexadecile * smem_hexadecile_size / spad_addr_factor;
constexpr uint32_t spad_addr_P0 = smem_P0_hexadecile * smem_hexadecile_size / spad_addr_factor;
constexpr uint32_t spad_addr_P1 = smem_P1_hexadecile * smem_hexadecile_size / spad_addr_factor;
constexpr uint32_t spad_addr_O0 = smem_O0_hexadecile * smem_hexadecile_size / spad_addr_factor;
constexpr uint32_t spad_addr_O1 = smem_O1_hexadecile * smem_hexadecile_size / spad_addr_factor;
const auto spad_addr_Q = (warpgroup_id % 2) ? spad_addr_Q1 : spad_addr_Q0;
const auto spad_addr_K = (warpgroup_id % 2) ? spad_addr_K1 : spad_addr_K0;
const auto spad_addr_V = (warpgroup_id % 2) ? spad_addr_V1 : spad_addr_V0;
const auto spad_addr_S = (warpgroup_id % 2) ? spad_addr_S1 : spad_addr_S0;
#endif
// initialize rowmax/rowsum values in sharedmem
thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O,
@@ -246,8 +263,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// other warps behave differently on the branch condition.
// threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
// move Q and K into SMEM before the loop starts
//
static_assert(B_ROW == B_COL, "currently only supports square tiles");
if constexpr (GEMMINI_DMA) {
@@ -256,6 +271,19 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
if (tid_in_warpgroup == 0) {
const float *gmem_Q_tile = gmem_Q + HEADDIM * B_ROW * warpgroup_id;
const float *gmem_K_tile = gmem_K;
// do DMA
//
// move Q and K into SMEM before the loop starts. Note this will be done
// separately for the two warpgroups
//
#ifdef GEMMINI_NEW_CISC
// the target addresses of this should match with spad_addr_Q0 and
// spad_addr_K0 set in this kernel
gemmini_tile_load_ab(gmem_Q_tile, gmem_K_tile, spad_hex_Q,
spad_hex_K, 0 /*tile_idx_i*/,
0 /*tile_idx_j*/, 0 /*tile_idx_k*/, dim_seqlen,
dim_seqlen, HEADDIM, B_ROW, B_COL, HEADDIM);
#else
// configure the GMEM addresses for the DMA to read from
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q_tile),
(uint64_t)(gmem_K_tile),
@@ -265,13 +293,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
gemmini_fence();
// #define GEMMINI_DMA_CISC
#ifdef GEMMINI_DMA_CISC
GEMMINI_CISC_CMD_I(9);
gemmini_fence();
#else
// do DMA
//
// among other things, this also configures CONFIG_BOUNDS so that the
// DMA knows the full matrix dimensions
sp_tiled_matmul_full_spad_ws(
@@ -281,9 +302,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
gemmini_fence();
#endif
// block until DMA complete
gemmini_fence();
// re-configure DMA for K and V load that will later happen in the loop
// GMEM addr stride for K
gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t), MVIN_SCALE_IDENTITY,
@@ -549,13 +572,20 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile),
(uint64_t)(gmem_V_tile),
k_LOOP_WS_CONFIG_ADDRS_AB)
// do DMA
#ifdef GEMMINI_NEW_CISC
gemmini_tile_load_ab(gmem_K_tile, gmem_V_tile, spad_hex_K, spad_hex_V,
0 /*tile_idx_i*/, 0 /*tile_idx_j*/,
0 /*tile_idx_k*/, HEADDIM /*dim_m of KT*/,
HEADDIM /*dim_n of V*/, dim_seqlen /*dim_k of KT*/,
B_ROW, HEADDIM, B_COL);
#else
// configure address strides for the DMA
// FIXME: unnecessary?
GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) |
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
gemmini_fence();
// do DMA
sp_tiled_matmul_full_spad_ws(
spad_addr_K, spad_addr_V,
/*spad_D=*/0, /*spad_C=*/spad_addr_S,
@@ -563,6 +593,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
#endif
// FIXME: necessary?
gemmini_fence();
}
} else {