flash: Update to use new CISC interface
This commit is contained in:
@@ -10,6 +10,8 @@
|
||||
|
||||
#define FENCE_GEMM_II
|
||||
|
||||
#define GEMMINI_NEW_CISC
|
||||
|
||||
constexpr bool DEBUG = false;
|
||||
|
||||
static_assert(GEMMINI_DMA && !WARP_SPECIALIZED,
|
||||
@@ -98,40 +100,43 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
"flashattention kernel assumes 1 threadblock occupancy per cluster");
|
||||
uint8_t *smem_per_threadblock = reinterpret_cast<uint8_t *>(DEV_SMEM_START_ADDR);
|
||||
constexpr uint32_t smem_start = DEV_SMEM_START_ADDR;
|
||||
constexpr uint32_t smem_quart0 = 0 * (SMEM_SIZE / 4);
|
||||
constexpr uint32_t smem_quart1 = 1 * (SMEM_SIZE / 4);
|
||||
constexpr uint32_t smem_quart2 = 2 * (SMEM_SIZE / 4);
|
||||
constexpr uint32_t smem_quart3 = 3 * (SMEM_SIZE / 4);
|
||||
constexpr uint32_t smem_hexadecile_size = (SMEM_SIZE / 16);
|
||||
// currently assumes the Q/K/V tile sizes exactly match the hexadecile size
|
||||
static_assert(smem_hexadecile_size == smem_Q_size * sizeof(float));
|
||||
static_assert(smem_hexadecile_size == smem_K_size * sizeof(float));
|
||||
static_assert(smem_hexadecile_size == smem_QK_size * sizeof(float));
|
||||
static_assert(smem_hexadecile_size == smem_V_size * sizeof(float));
|
||||
static_assert(smem_hexadecile_size == smem_O_size * sizeof(float));
|
||||
|
||||
// Q/V/S in quart0/1, K/P/O in quart2/3
|
||||
constexpr uint32_t smem_Q0_offset = smem_quart0;
|
||||
constexpr uint32_t smem_Q1_offset = smem_quart1;
|
||||
constexpr uint32_t smem_K0_offset = smem_quart2;
|
||||
constexpr uint32_t smem_K1_offset = smem_quart3;
|
||||
constexpr uint32_t smem_V0_offset = smem_Q0_offset + smem_Q_size * sizeof(float);
|
||||
constexpr uint32_t smem_V1_offset = smem_Q1_offset + smem_Q_size * sizeof(float);
|
||||
constexpr uint32_t smem_Q0_hexadecile = 4 * 0;
|
||||
constexpr uint32_t smem_Q1_hexadecile = 4 * 1;
|
||||
constexpr uint32_t smem_K0_hexadecile = 4 * 2;
|
||||
constexpr uint32_t smem_K1_hexadecile = 4 * 3;
|
||||
constexpr uint32_t smem_V0_hexadecile = smem_Q0_hexadecile + 1;
|
||||
constexpr uint32_t smem_V1_hexadecile = smem_Q1_hexadecile + 1;
|
||||
// put S1/S0 with V0/V1 so that softmax and GEMM-II doesn't cause bank
|
||||
// conflicts
|
||||
constexpr uint32_t smem_S0_offset = smem_V1_offset + smem_V_size * sizeof(float);
|
||||
constexpr uint32_t smem_S1_offset = smem_V0_offset + smem_V_size * sizeof(float);
|
||||
constexpr uint32_t smem_P0_offset = smem_K0_offset + smem_K_size * sizeof(float);
|
||||
constexpr uint32_t smem_P1_offset = smem_K1_offset + smem_K_size * sizeof(float);
|
||||
constexpr uint32_t smem_S0_hexadecile = smem_V1_hexadecile + 1;
|
||||
constexpr uint32_t smem_S1_hexadecile = smem_V0_hexadecile + 1;
|
||||
constexpr uint32_t smem_P0_hexadecile = smem_K0_hexadecile + 1;
|
||||
constexpr uint32_t smem_P1_hexadecile = smem_K1_hexadecile + 1;
|
||||
// reversed!
|
||||
constexpr uint32_t smem_O0_offset = smem_P1_offset + smem_QK_size * sizeof(float);
|
||||
constexpr uint32_t smem_O1_offset = smem_P0_offset + smem_QK_size * sizeof(float); // unused
|
||||
constexpr uint32_t smem_O0_hexadecile = smem_P1_hexadecile + 1;
|
||||
constexpr uint32_t smem_O1_hexadecile = smem_P0_hexadecile + 1; // unused
|
||||
|
||||
float *smem_Q0 = reinterpret_cast<float *>(smem_start + smem_Q0_offset);
|
||||
float *smem_Q1 = reinterpret_cast<float *>(smem_start + smem_Q1_offset);
|
||||
float *smem_K0 = reinterpret_cast<float *>(smem_start + smem_K0_offset);
|
||||
float *smem_K1 = reinterpret_cast<float *>(smem_start + smem_K1_offset);
|
||||
float *smem_V0 = reinterpret_cast<float *>(smem_start + smem_V0_offset);
|
||||
float *smem_V1 = reinterpret_cast<float *>(smem_start + smem_V1_offset);
|
||||
float *smem_S0 = reinterpret_cast<float *>(smem_start + smem_S0_offset);
|
||||
float *smem_S1 = reinterpret_cast<float *>(smem_start + smem_S1_offset);
|
||||
float *smem_P0 = reinterpret_cast<float *>(smem_start + smem_P0_offset);
|
||||
float *smem_P1 = reinterpret_cast<float *>(smem_start + smem_P1_offset);
|
||||
float *smem_O0 = reinterpret_cast<float *>(smem_start + smem_O0_offset);
|
||||
float *smem_O1 = reinterpret_cast<float *>(smem_start + smem_O1_offset);
|
||||
float *smem_Q0 = reinterpret_cast<float *>(smem_start + smem_Q0_hexadecile * smem_hexadecile_size);
|
||||
float *smem_Q1 = reinterpret_cast<float *>(smem_start + smem_Q1_hexadecile * smem_hexadecile_size);
|
||||
float *smem_K0 = reinterpret_cast<float *>(smem_start + smem_K0_hexadecile * smem_hexadecile_size);
|
||||
float *smem_K1 = reinterpret_cast<float *>(smem_start + smem_K1_hexadecile * smem_hexadecile_size);
|
||||
float *smem_V0 = reinterpret_cast<float *>(smem_start + smem_V0_hexadecile * smem_hexadecile_size);
|
||||
float *smem_V1 = reinterpret_cast<float *>(smem_start + smem_V1_hexadecile * smem_hexadecile_size);
|
||||
float *smem_S0 = reinterpret_cast<float *>(smem_start + smem_S0_hexadecile * smem_hexadecile_size);
|
||||
float *smem_S1 = reinterpret_cast<float *>(smem_start + smem_S1_hexadecile * smem_hexadecile_size);
|
||||
float *smem_P0 = reinterpret_cast<float *>(smem_start + smem_P0_hexadecile * smem_hexadecile_size);
|
||||
float *smem_P1 = reinterpret_cast<float *>(smem_start + smem_P1_hexadecile * smem_hexadecile_size);
|
||||
float *smem_O0 = reinterpret_cast<float *>(smem_start + smem_O0_hexadecile * smem_hexadecile_size);
|
||||
float *smem_O1 = reinterpret_cast<float *>(smem_start + smem_O1_hexadecile * smem_hexadecile_size);
|
||||
|
||||
// allocate rowmax/rowsum storage at the end of the sharedmem address space
|
||||
constexpr uint32_t smem_rowmax_size = B_ROW * ROWMAX_SETS;
|
||||
@@ -168,18 +173,18 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
|
||||
static_assert(sizeof(elem_t) == sizeof(float));
|
||||
constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t);
|
||||
constexpr uint32_t spad_addr_Q0 = smem_Q0_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_Q1 = smem_Q1_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_K0 = smem_K0_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_K1 = smem_K1_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_V0 = smem_V0_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_V1 = smem_V1_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_S0 = smem_S0_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_S1 = smem_S1_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_P0 = smem_P0_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_P1 = smem_P1_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_O0 = smem_O0_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_O1 = smem_O1_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_Q0 = smem_Q0_hexadecile * smem_hexadecile_size / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_Q1 = smem_Q1_hexadecile * smem_hexadecile_size / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_K0 = smem_K0_hexadecile * smem_hexadecile_size / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_K1 = smem_K1_hexadecile * smem_hexadecile_size / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_V0 = smem_V0_hexadecile * smem_hexadecile_size / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_V1 = smem_V1_hexadecile * smem_hexadecile_size / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_S0 = smem_S0_hexadecile * smem_hexadecile_size / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_S1 = smem_S1_hexadecile * smem_hexadecile_size / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_P0 = smem_P0_hexadecile * smem_hexadecile_size / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_P1 = smem_P1_hexadecile * smem_hexadecile_size / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_O0 = smem_O0_hexadecile * smem_hexadecile_size / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_O1 = smem_O1_hexadecile * smem_hexadecile_size / spad_addr_factor;
|
||||
|
||||
constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary
|
||||
static_assert(warps_per_threadblock_per_core == NUM_WARPS);
|
||||
@@ -246,22 +251,27 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// make sure to read from the correct row of Q
|
||||
const float *gmem_Q_tile = gmem_Q + HEADDIM * B_ROW * warpgroup_id;
|
||||
const float *gmem_K_tile = gmem_K;
|
||||
|
||||
// do DMA
|
||||
//
|
||||
// move Q to spad_addr_Q0 for the first iteration
|
||||
//
|
||||
#ifdef GEMMINI_NEW_CISC
|
||||
// the target addresses of this should match with spad_addr_Q0 and
|
||||
// spad_addr_K0 set in this kernel
|
||||
gemmini_tile_load_ab(gmem_Q_tile, gmem_K_tile, smem_Q0_hexadecile,
|
||||
smem_K0_hexadecile, 0 /*tile_idx_i*/, 0 /*tile_idx_j*/,
|
||||
0 /*tile_idx_k*/, dim_seqlen, dim_seqlen, HEADDIM,
|
||||
B_ROW, B_COL, HEADDIM);
|
||||
#else
|
||||
// configure the GMEM addresses for the DMA to read from
|
||||
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q_tile),
|
||||
(uint64_t)(gmem_K_tile), k_LOOP_WS_CONFIG_ADDRS_AB)
|
||||
// configure address strides for the DMA
|
||||
GEMMINI_CISC_CMD_R((dim_seqlen << 20) | (HEADDIM << 8) |
|
||||
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
|
||||
GEMMINI_CISC_SET_AB_STRIDE);
|
||||
gemmini_fence();
|
||||
|
||||
// #define GEMMINI_DMA_CISC
|
||||
#ifdef GEMMINI_DMA_CISC
|
||||
// the target addresses of this should match with spad_addr_Q0 and
|
||||
// spad_addr_K0 set in this kernel
|
||||
GEMMINI_CISC_CMD_I(10);
|
||||
#else
|
||||
// do DMA
|
||||
//
|
||||
// among other things, this also configures CONFIG_BOUNDS so that the
|
||||
// DMA knows the full matrix dimensions
|
||||
sp_tiled_matmul_full_spad_ws(
|
||||
@@ -274,7 +284,14 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
#endif
|
||||
gemmini_fence();
|
||||
|
||||
// need to also move Q to spad_addr_Q1 for the next iteration
|
||||
// also move Q to spad_addr_Q1 for the second iteration
|
||||
//
|
||||
#ifdef GEMMINI_NEW_CISC
|
||||
gemmini_tile_load_ab(gmem_Q_tile, gmem_K_tile, smem_Q1_hexadecile,
|
||||
smem_K1_hexadecile, 0 /*tile_idx_i*/, 0 /*tile_idx_j*/,
|
||||
0 /*tile_idx_k*/, dim_seqlen, dim_seqlen, HEADDIM,
|
||||
B_ROW, B_COL, HEADDIM);
|
||||
#else
|
||||
// FIXME: re-configure necessary?
|
||||
gmem_K_tile = gmem_K + (B_COL * 1);
|
||||
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q_tile),
|
||||
@@ -282,9 +299,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
GEMMINI_CISC_CMD_R((dim_seqlen << 20) | (HEADDIM << 8) |
|
||||
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
|
||||
gemmini_fence();
|
||||
#ifdef GEMMINI_DMA_CISC
|
||||
// GEMMINI_CISC_CMD_I(11);
|
||||
#else
|
||||
|
||||
sp_tiled_matmul_full_spad_ws(
|
||||
spad_addr_Q1, spad_addr_K1/*bogus*/,
|
||||
/*spad_D=*/0, /*spad_C=*/spad_addr_S0/*bogus*/,
|
||||
@@ -379,6 +394,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
const auto spad_addr_P_consume = (tile_k & 1) ? spad_addr_P1 : spad_addr_P0;
|
||||
const auto spad_addr_P_produce = (tile_k & 1) ? spad_addr_P0 : spad_addr_P1;
|
||||
const auto spad_addr_O = spad_addr_O0; // NOTE: there's only single O tile
|
||||
|
||||
const auto spad_hex_Q = smem_Q0_hexadecile;
|
||||
const auto spad_hex_K_consume = (tile_k & 1) ? smem_K1_hexadecile : smem_K0_hexadecile;
|
||||
const auto spad_hex_K_produce = (tile_k & 1) ? smem_K0_hexadecile : smem_K1_hexadecile;
|
||||
const auto spad_hex_V_consume = (tile_k & 1) ? smem_V1_hexadecile : smem_V0_hexadecile;
|
||||
const auto spad_hex_V_produce = (tile_k & 1) ? smem_V0_hexadecile : smem_V1_hexadecile;
|
||||
const auto spad_hex_S_consume = (tile_k & 1) ? smem_S1_hexadecile : smem_S0_hexadecile;
|
||||
const auto spad_hex_S_produce = (tile_k & 1) ? smem_S0_hexadecile : smem_S1_hexadecile;
|
||||
const auto spad_hex_P_consume = (tile_k & 1) ? smem_P1_hexadecile : smem_P0_hexadecile;
|
||||
const auto spad_hex_P_produce = (tile_k & 1) ? smem_P0_hexadecile : smem_P1_hexadecile;
|
||||
const auto spad_hex_O = smem_O0_hexadecile; // NOTE: there's only single O tile
|
||||
asm volatile ("dbuf_sel_end_%=:" :: );
|
||||
|
||||
if (vx_warp_id() == 0 /* warp 0 in every core */) {
|
||||
@@ -394,26 +420,19 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
asm volatile("gemm_pv_start_%=:" ::);
|
||||
|
||||
if (tid_in_warpgroup == 0) {
|
||||
#if 0
|
||||
if (tile_k_ == 0) {
|
||||
gemmini_fence();
|
||||
GEMMINI_CISC_CMD_I(0);
|
||||
} else if (tile_k_ & 1) {
|
||||
gemmini_fence();
|
||||
GEMMINI_CISC_CMD_I(2);
|
||||
} else {
|
||||
gemmini_fence();
|
||||
GEMMINI_CISC_CMD_I(1);
|
||||
}
|
||||
#else
|
||||
// kickoff matmul
|
||||
// among other things, this also configures CONFIG_BOUNDS so that the
|
||||
// DMA knows the full matrix dimensions
|
||||
|
||||
// FIXME: perf: prevent GMEM->SMEM load for O tile
|
||||
gemmini_fence();
|
||||
gemmini_fence();
|
||||
gemmini_fence();
|
||||
gemmini_fence();
|
||||
#ifdef GEMMINI_NEW_CISC
|
||||
gemmini_tile_compute</*store_to_spad=*/true>(
|
||||
spad_hex_P_consume, spad_hex_V_consume, spad_hex_O,
|
||||
0 /*accumulate.
|
||||
FIXME: Gemmini doens't support accumulation from a spad tile*/);
|
||||
#else
|
||||
sp_tiled_matmul_full_spad_ws(
|
||||
spad_addr_P_consume, spad_addr_V_consume,
|
||||
/*spad_D=*/spad_addr_O, /*spad_C=*/spad_addr_O,
|
||||
@@ -452,7 +471,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// 0,2,.: opcode 0 (quartile 0/2, no accum)
|
||||
// 1,3,.: opcode 3 (quartile 1/3, no accum)
|
||||
// const uint32_t opcode = 3 * (tile_k & 1);
|
||||
//GEMMINI_CISC_CMD_I(opcode);
|
||||
#ifdef GEMMINI_NEW_CISC
|
||||
gemmini_tile_compute</*store_to_spad=*/true>(
|
||||
spad_hex_Q, spad_hex_K_consume, spad_hex_S_produce,
|
||||
0 /*accumulate*/);
|
||||
#else
|
||||
sp_tiled_matmul_full_spad_ws(
|
||||
spad_addr_Q, spad_addr_K_consume,
|
||||
/*spad_D=*/0, /*spad_C=*/spad_addr_S_produce,
|
||||
@@ -460,6 +483,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul);
|
||||
#endif
|
||||
|
||||
// gemmini_fence();
|
||||
// gemmini_fence();
|
||||
@@ -487,14 +511,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
(uint64_t)(gmem_V_tile),
|
||||
k_LOOP_WS_CONFIG_ADDRS_AB)
|
||||
#endif
|
||||
// configure address strides for the DMA
|
||||
// FIXME: unnecessary?
|
||||
GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) |
|
||||
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
|
||||
// gemmini_fence();
|
||||
|
||||
// do DMA
|
||||
if (tile_k == 0) {
|
||||
// // configure address strides for the DMA
|
||||
// // FIXME: unnecessary?
|
||||
// GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) |
|
||||
// 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
|
||||
// gemmini_fence();
|
||||
//
|
||||
// we load (k-1)th tile for V; skip V for the 1st iteration,
|
||||
// sp_tiled_matmul_full_spad_ws(
|
||||
// spad_addr_K_produce, spad_addr_V_produce,
|
||||
@@ -504,6 +530,18 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||
// /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a);
|
||||
} else {
|
||||
#ifdef GEMMINI_NEW_CISC
|
||||
gemmini_tile_load_ab(
|
||||
gmem_K_tile, gmem_V_tile, spad_hex_K_produce, spad_hex_V_produce,
|
||||
0 /*tile_idx_i*/, 0 /*tile_idx_j*/, 0 /*tile_idx_k*/,
|
||||
HEADDIM /*dim_m of KT*/, HEADDIM /*dim_n of V*/,
|
||||
dim_seqlen /*dim_k of KT*/, B_ROW, HEADDIM, B_COL);
|
||||
#else
|
||||
// configure address strides for the DMA
|
||||
// FIXME: unnecessary?
|
||||
GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) |
|
||||
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
|
||||
gemmini_fence();
|
||||
sp_tiled_matmul_full_spad_ws(
|
||||
spad_addr_K_produce, spad_addr_V_produce,
|
||||
/*spad_D=*/0, /*spad_C=*/0,
|
||||
@@ -511,6 +549,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
|
||||
#endif
|
||||
}
|
||||
|
||||
// fence everything before going to the next tile
|
||||
|
||||
Reference in New Issue
Block a user