Merge branch 'kernels' of https://github.com/hansungk/vortex-private into kernels

This commit is contained in:
Richard Yan
2024-10-24 17:31:01 -07:00
14 changed files with 1856 additions and 708 deletions

View File

@@ -84,7 +84,7 @@
#endif
#ifndef NUM_CORES
#define NUM_CORES 8
#define NUM_CORES 4
#endif
#ifndef NUM_WARPS

View File

@@ -9,6 +9,8 @@
// #define SMEM_SIZE 0x4000
// 64KB
// #define SMEM_SIZE 0x10000
// 128KB
// #define SMEM_SIZE 0x20000
// 256KB
#define SMEM_SIZE 0x40000

View File

@@ -149,6 +149,7 @@ inline void vx_join(unsigned stack_ptr) {
}
// Warp Barrier
__attribute__((convergent))
inline void vx_barrier(unsigned barried_id, unsigned num_warps) {
asm volatile (".insn r %0, 4, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(barried_id), "r"(num_warps));
}

View File

@@ -18,7 +18,7 @@
#include <stdio.h>
#ifndef CORES_PER_CLUSTER
#define CORES_PER_CLUSTER 8
#define CORES_PER_CLUSTER 4
#endif
#ifdef __cplusplus

View File

@@ -80,9 +80,13 @@ if __name__ == "__main__":
fp16 = True
if fp16:
A_packed = pack_fp16_by_row(A_array)
A_swizzled = A_packed.reshape([-1, M * 2])
A_swizzled.astype('float16').tofile("input.a.row.bin")
AT_packed = A_packed.transpose([1, 0, 2])
AT_swizzled = AT_packed.reshape([-1, M * 2])
AT_swizzled.astype('float16').tofile("input.a.col.bin")
print('A:')
print(A_swizzled)
print('AT:')
print(AT_swizzled)
B_packed = pack_fp16_by_column(B_array)

View File

@@ -93,6 +93,23 @@ inline constexpr void map_c_8lanes(const int tid, int &row, int &col) {
col += ((tid % 4) / 2) * 2;
}
inline constexpr void map_c_rowmajor_8lanes(const int tid, int &row, int &col) {
const int tg = tid / 4;
// A (row major)
// row 0~ 3: threadgroup 0
// row 4~ 7: threadgroup 1
row = tid % 4;
row += tg * 4;
// B (column major)
// col 0~ 3: threadgroup 0
// col 4~ 7: threadgroup 1
col = tid % 4;
col += tg * 4;
}
void vx_wmma_load() {
int tid = vx_thread_id();
int tg = tid / 4;
@@ -174,11 +191,31 @@ void store_wmma_result() {
int row = 0;
int col = 0;
map_c_8lanes(tid, row, col);
// map_c_8lanes(tid, row, col);
map_c_rowmajor_8lanes(tid, row, col);
// store C
float *const results_wid = results + (DIM_M * DIM_N * wid);
// uncomment to have two accum buffers in rf
// asm volatile("fsw f16, %0" ::"m"(results_wid[DIM_N * 0 + col]));
// asm volatile("fsw f17, %0" ::"m"(results_wid[DIM_N * 1 + col]));
// asm volatile("fsw f18, %0" ::"m"(results_wid[DIM_N * 2 + col]));
// asm volatile("fsw f19, %0" ::"m"(results_wid[DIM_N * 3 + col]));
// asm volatile("fsw f20, %0" ::"m"(results_wid[DIM_N * 4 + col]));
// asm volatile("fsw f21, %0" ::"m"(results_wid[DIM_N * 5 + col]));
// asm volatile("fsw f22, %0" ::"m"(results_wid[DIM_N * 6 + col]));
// asm volatile("fsw f23, %0" ::"m"(results_wid[DIM_N * 7 + col]));
asm volatile("fsw f24, %0" ::"m"(results_wid[DIM_N * 0 + col]));
asm volatile("fsw f25, %0" ::"m"(results_wid[DIM_N * 1 + col]));
asm volatile("fsw f26, %0" ::"m"(results_wid[DIM_N * 2 + col]));
asm volatile("fsw f27, %0" ::"m"(results_wid[DIM_N * 3 + col]));
asm volatile("fsw f28, %0" ::"m"(results_wid[DIM_N * 4 + col]));
asm volatile("fsw f29, %0" ::"m"(results_wid[DIM_N * 5 + col]));
asm volatile("fsw f30, %0" ::"m"(results_wid[DIM_N * 6 + col]));
asm volatile("fsw f31, %0" ::"m"(results_wid[DIM_N * 7 + col]));
// 1x2 jagged mapping
// asm volatile("fsw f16, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 0)]));
// asm volatile("fsw f17, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 1)]));
// asm volatile("fsw f18, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 0)]));
@@ -187,14 +224,14 @@ void store_wmma_result() {
// asm volatile("fsw f21, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 5)]));
// asm volatile("fsw f22, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 4)]));
// asm volatile("fsw f23, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 5)]));
asm volatile("fsw f24, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 0)]));
asm volatile("fsw f25, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 1)]));
asm volatile("fsw f26, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 0)]));
asm volatile("fsw f27, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 1)]));
asm volatile("fsw f28, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 4)]));
asm volatile("fsw f29, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 5)]));
asm volatile("fsw f30, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 4)]));
asm volatile("fsw f31, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 5)]));
// asm volatile("fsw f24, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 0)]));
// asm volatile("fsw f25, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 1)]));
// asm volatile("fsw f26, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 0)]));
// asm volatile("fsw f27, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 1)]));
// asm volatile("fsw f28, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 4)]));
// asm volatile("fsw f29, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 5)]));
// asm volatile("fsw f30, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 4)]));
// asm volatile("fsw f31, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 5)]));
}
void print_wmma_result() {

View File

@@ -48,7 +48,7 @@ VX_CP = $(LLVM_VORTEX)/bin/llvm-objcopy
#VX_DP = $(RISCV_TOOLCHAIN_PATH)/bin/$(RISCV_PREFIX)-objdump
#VX_CP = $(RISCV_TOOLCHAIN_PATH)/bin/$(RISCV_PREFIX)-objcopy
VX_CFLAGS += -v -O3 -std=c++17
VX_CFLAGS += -v -Os -std=c++17
VX_CFLAGS += -mcmodel=medany -fno-rtti -fno-exceptions -nostartfiles -fdata-sections -ffunction-sections
# comment out below for regression/basic, which uses GCC that doesn't
# understand these flags

View File

@@ -2,8 +2,8 @@ PROJECT = flash_attention
SRCS = main.cpp common.h
VX_SRCS = kernel.cpp
VX_INCLUDES = ../sgemm_tcore/sgemm_impl.hpp
VX_SRCS = kernel.gemmini.cpp
VX_INCLUDES = flash_impl.hpp ../sgemm_tcore/sgemm_impl.hpp
OPTS ?= -n16

View File

@@ -0,0 +1,559 @@
#ifndef _FLASH_IMPL_H_
#define _FLASH_IMPL_H_
#include <vx_spawn.h>
#include <float.h>
#define B_ROW 64
#define B_COL 64
#define HEADDIM 64
#define ROW_REMAINDER_LOGIC
constexpr uint32_t ROWMAX_SETS = 3;
constexpr bool WARP_SPECIALIZED = true;
constexpr bool TENSOR_CORE = true;
// temporary safety stop for wrong configs
static_assert(NUM_CORES == 4);
static_assert(NUM_THREADS == 8);
static_assert(NUM_WARPS == 8);
inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock,
const uint32_t threads_per_threadblock,
float *smem_O, float *smem_rowmax,
float *smem_rowsum,
float *smem_O_row_scale) {
asm volatile("threadblock_init_sharedmem_start_%=:" ::);
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS;
static_assert((B_ROW % NUM_THREADS) == 0,
"B_ROW must be a multiple of NUM_THREADS");
static_assert(B_ROW < (NUM_THREADS * CORES_PER_CLUSTER *
(NUM_WARPS / (WARP_SPECIALIZED ? 2 : 1))),
"not enough warps to initialize rowmax/rowsum");
// each thread initializes one element in rowmax/rowsum
// multiple warps participate for the whole vector
constexpr uint32_t needed_warps = B_ROW / NUM_THREADS;
if (warp_id < needed_warps /* more warps in HW than needed? */) {
uint32_t offset = NUM_THREADS * warp_id + tid_in_warp;
#pragma GCC unroll
for (int i = 0; i < ROWMAX_SETS; i++) {
smem_rowmax[offset + i * ROWMAX_SETS] = FLT_MIN;
}
smem_rowsum[offset] = 0.0f;
smem_O_row_scale[offset] = 0.0f;
}
// each warp clears out a row of smem_O
// FIXME: dedup this pattern
#pragma GCC unroll 1
for (int row_offset = 0; row_offset < B_COL;
row_offset += warps_in_threadblock) {
const uint32_t row = row_offset + warp_id;
#ifdef ROW_REMAINDER_LOGIC
if (row >= B_ROW) {
// WARNING: the number of barrier calls have to exactly match that in the
// outside of the branch to prevent stalls!! FIXME better proof this.
continue;
}
#endif
uint32_t thread_offset = HEADDIM * row + tid_in_warp;
constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS;
const float one = 0.0f;
#pragma GCC unroll
for (int i = 0; i < per_row_iter; i++) {
smem_O[thread_offset] = 0.0f;
thread_offset += NUM_THREADS;
}
}
asm volatile("threadblock_init_sharedmem_finish_%=:" ::);
}
inline void thread_block_copy_rowmax(const float *src, float *dest,
const uint32_t tid_in_threadblock,
const uint32_t threads_per_threadblock,
const uint32_t threadblock_id_in_cluster) {
asm volatile("threadblock_copy_rowmax_start_%=:" ::);
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS;
const uint32_t warps_per_threadblock_per_core =
warps_in_threadblock / CORES_PER_CLUSTER;
// each thread copies one element in rowmax
// multiple warps participate for the whole vector
constexpr uint32_t num_warps = B_ROW / NUM_THREADS;
if (warp_id < num_warps) {
uint32_t offset = NUM_THREADS * warp_id + tid_in_warp;
dest[offset] = src[offset];
}
if constexpr (!TENSOR_CORE) {
threadblock_barrier(1, 7);
} else {
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
asm volatile("threadblock_copy_rowmax_finish_%=:" ::);
}
template <uint32_t dim_row, uint32_t dim_col, bool block_row_major = false>
inline void thread_block_copy_tile(const float *src, float *dest,
const uint32_t tid_in_threadblock,
const uint32_t threads_per_threadblock,
const uint32_t threadblock_id_in_cluster) {
asm volatile("threadblock_copy_tile_start_%=:" ::);
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS;
const uint32_t warps_per_threadblock_per_core =
warps_in_threadblock / CORES_PER_CLUSTER;
// FIXME: dedup this pattern
#pragma GCC unroll 1
for (int row_offset = 0; row_offset < dim_row;
row_offset += warps_in_threadblock) {
const uint32_t row = row_offset + warp_id;
#ifdef ROW_REMAINDER_LOGIC
if (row >= B_ROW) {
// WARNING: the number of barrier calls have to exactly match that in the
// outside of the branch to prevent stalls!! FIXME better proof this.
if constexpr (!TENSOR_CORE) {
threadblock_barrier(1, 7);
} else {
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
continue;
}
#endif
constexpr uint32_t per_row_iter = dim_col / NUM_THREADS;
#pragma GCC unroll
for (int i = 0; i < per_row_iter; i++) {
const uint32_t col_offset = NUM_THREADS * i;
const uint32_t col = col_offset + tid_in_warp;
const auto [smem_row, smem_col] =
remap_to_gemmini_dma_layout<block_row_major, B_COL>(row, col);
const uint32_t smem_offset = B_COL * smem_row + smem_col;
const uint32_t gmem_offset = B_COL * row + col;
dest[gmem_offset] = src[smem_offset];
}
if constexpr (!TENSOR_CORE) {
threadblock_barrier(1, 7);
} else {
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
}
asm volatile("threadblock_copy_tile_finish_%=:" ::);
}
template <int order>
inline float exponential_taylor_term(const float x) {
asm volatile("exponential_taylor_term_start_%=:" ::);
float res = 1.0f;
if constexpr (order == 1) {
res = x;
} else if constexpr (order == 2) {
res = x * x;
res /= 2.0f;
} else if constexpr (order == 3) {
res = x * x * x;
res /= 6.0f;
}
asm volatile("exponential_taylor_term_end_%=:" ::);
return res;
}
template <bool block_row_major = false>
__attribute__((always_inline)) inline void thread_block_online_softmax(
const float *smem_S, float *smem_P, const uint32_t tid_in_threadblock,
const uint32_t threads_per_threadblock,
const uint32_t threadblock_id_in_cluster, float *smem_scratchpad,
float *smem_rowmax, float *smem_rowsum, float *smem_O_row_scale) {
asm volatile("thread_block_online_softmax_start_%=:" ::);
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS;
const uint32_t warps_per_threadblock_per_core =
warps_in_threadblock / CORES_PER_CLUSTER;
float *smem_rowmax_this = smem_rowmax + B_ROW;
#pragma GCC unroll 1
for (int row_offset = 0; row_offset < B_ROW;
row_offset += warps_in_threadblock) {
const uint32_t row = row_offset + warp_id;
#ifdef ROW_REMAINDER_LOGIC
// if the number of warps doesn't exactly divide the number of rows,
// early-exit to prevent out-of-bounds access
if (row >= B_ROW) {
// WARNING: the number of barrier calls have to exactly match that in the
// outside of the branch to prevent stalls!! FIXME better proof this.
if constexpr (!TENSOR_CORE) {
threadblock_barrier(1, 7);
threadblock_barrier(1, 7);
threadblock_barrier(1, 7);
threadblock_barrier(1, 7);
threadblock_barrier(1, 7);
threadblock_barrier(1, 7);
} else {
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
continue;
}
#endif
const uint32_t first_thread_offset = B_COL * row;
// rowmax
//
// two-level tree reduction: reduce each row into NUM_THREADS intermediate
// maxes, then reduce it down to one row max
// one warp handles one row in tile
constexpr uint32_t per_row_iter = B_COL / NUM_THREADS;
// FIXME: threadblock_id needs to be in here too
float *warp_smem = smem_scratchpad + (warp_id * NUM_THREADS);
// #define DUMB_ROWMAX
#ifdef DUMB_ROWMAX
// FIXME remove
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
// no tree reduction; a single thread in a warp does serialized max across
// the entire row
if (tid_in_warp == 0) {
float rowmax = smem_S[first_thread_offset];
#pragma GCC unroll 16
for (int i = 0; i < B_COL; i++) {
asm volatile("fmax.s %0, %1, %2"
: "=f"(rowmax)
: "f"(rowmax), "f"(smem_S[first_thread_offset + i]));
}
smem_rowmax_this[row] = rowmax;
// update previous rowmax
// i.e. mi_new = max(mi, mij)
float prev_rowmax = smem_rowmax[row];
// stage prev rowmax in scratchpad for warp-wide broadcast
warp_smem[0] = prev_rowmax;
asm volatile("fmax.s %0, %1, %2"
: "=f"(rowmax)
: "f"(rowmax), "f"(prev_rowmax));
smem_rowmax[row] = rowmax;
}
#else
static_assert((B_COL % NUM_THREADS) == 0,
"B_COL must be a multiple of NUM_THREADS");
float per_thread_max = FLT_MIN;
#pragma GCC unroll
for (int i = 0; i < per_row_iter; i++) {
const uint32_t col_offset = NUM_THREADS * i;
const uint32_t col = col_offset + tid_in_warp;
const auto [smem_row, smem_col] =
remap_to_gemmini_dma_layout<block_row_major, B_COL>(row, col);
const uint32_t offset = B_COL * smem_row + smem_col;
const float next = smem_S[offset];
asm volatile("fmax.s %0, %1, %2"
: "=f"(per_thread_max)
: "f"(per_thread_max), "f"(next));
}
// stage per-thread max value in smem
warp_smem[tid_in_warp] = per_thread_max;
// sync writes to warp_smem
if constexpr (!TENSOR_CORE) {
threadblock_barrier(1, 7);
} else {
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
// #define PARALLEL_ROWMAX
#ifndef PARALLEL_ROWMAX
// elect 0-th thread to reduce all other thread's values in the warp
if (tid_in_warp == 0) {
float rowmax = per_thread_max;
for (int i = 1; i < NUM_THREADS; i++) {
float other = warp_smem[i];
asm volatile("fmax.s %0, %1, %2"
: "=f"(rowmax)
: "f"(rowmax), "f"(other));
}
smem_rowmax_this[row] = rowmax;
// update previous rowmax
// i.e. mi_new = max(mi, mij)
float prev_rowmax = smem_rowmax[row];
// stage prev rowmax in scratchpad for warp-wide broadcast
warp_smem[0] = prev_rowmax;
asm volatile("fmax.s %0, %1, %2"
: "=f"(rowmax)
: "f"(rowmax), "f"(prev_rowmax));
smem_rowmax[row] = rowmax;
}
#else
if (warp_id < warps_in_threadblock / NUM_THREADS) {
const uint32_t row = row_offset + NUM_THREADS * warp_id + tid_in_warp;
float *const thread_smem = smem_scratchpad + (tid_in_warp * NUM_THREADS);
float rowmax = FLT_MIN;
#pragma GCC unroll
for (int i = 0; i < NUM_THREADS; i++) {
const float f = thread_smem[i];
asm volatile("fmax.s %0, %1, %2" : "=f"(rowmax) : "f"(rowmax), "f"(f));
}
smem_rowmax_this[row] = rowmax;
// update previous rowmax
// i.e. mi_new = max(mi, mij)
float prev_rowmax = smem_rowmax[row];
// stage prev rowmax in scratchpad for warp-wide broadcast
thread_smem[0] = prev_rowmax;
asm volatile("fmax.s %0, %1, %2"
: "=f"(rowmax)
: "f"(rowmax), "f"(prev_rowmax));
smem_rowmax[row] = rowmax;
}
#endif // PARALLEL_ROWMAX
#endif // DUMB_ROWMAX
if constexpr (!TENSOR_CORE) {
threadblock_barrier(1, 7);
} else {
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
// broadcast prev rowmax to all threads in the warp
// NOTE: memory consistency is a little sketchy here
const float rowmax_prev = warp_smem[0];
const float rowmax_this = smem_rowmax_this[row];
// exponential
//
// B_ROW / (B_ROW * B_COL / (exp_elem * threads_per_threadblock))
// const uint32_t row_stride =
// (exp_elem_per_thread * threads_per_threadblock) / B_COL;
// broadcast updated rowmax to all threads in the warp
const float rowmax_new = smem_rowmax[row];
asm volatile("flashattn_exp_p_start_%=:" ::);
#pragma GCC unroll
for (int i = 0; i < per_row_iter; i++) {
const uint32_t col_offset = NUM_THREADS * i;
const uint32_t col = col_offset + tid_in_warp;
const auto [smem_row, smem_col] =
remap_to_gemmini_dma_layout<block_row_major, B_COL>(row, col);
const uint32_t offset = B_COL * smem_row + smem_col;
float f0 = smem_S[offset];
f0 -= rowmax_new;
// 2nd-order Taylor approximation
float exp = 1.0f;
exp += exponential_taylor_term<1>(f0);
exp += exponential_taylor_term<2>(f0);
// Store S transposed to the shared memory
smem_P[offset] = exp;
}
asm volatile("flashattn_exp_p_end_%=:" ::);
if constexpr (!TENSOR_CORE) {
threadblock_barrier(1, 7);
} else {
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
// rowsum
//
// two-level tree reduction, similar to rowmax
asm volatile("flashattn_rowsum_start_%=:" ::);
float per_thread_sum = 0.0f;
#pragma GCC unroll
for (int i = 0; i < per_row_iter; i++) {
const uint32_t col_offset = NUM_THREADS * i;
const uint32_t col = col_offset + tid_in_warp;
const auto [smem_row, smem_col] =
remap_to_gemmini_dma_layout<block_row_major, B_COL>(row, col);
const uint32_t offset = B_COL * smem_row + smem_col;
per_thread_sum += smem_P[offset];
}
// stage per-thread sum value in smem
// FIXME: threadblock_id needs to be in here too
warp_smem = smem_scratchpad + (warp_id * NUM_THREADS);
warp_smem[tid_in_warp] = per_thread_sum;
// sync writes to warp_smem
if constexpr (!TENSOR_CORE) {
threadblock_barrier(1, 7);
} else {
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
// 0-th thread collects all other thread's values in the warp
if (tid_in_warp == 0) {
float rowsum = per_thread_sum;
for (int iter = 1; iter < NUM_THREADS; iter++) {
float other = warp_smem[iter];
rowsum += other;
}
const float mi_prev = rowmax_prev;
const float mi_this = rowmax_this;
const float x = mi_prev - mi_this;
// 2nd-order Taylor approximation
float exp = 1.0f;
exp += exponential_taylor_term<1>(x);
exp += exponential_taylor_term<2>(x);
// update rowsum
const float rowsum_prev = smem_rowsum[row];
float rowsum_new = exp * rowsum_prev + rowsum;
smem_rowsum[row] = rowsum_new;
}
asm volatile("flashattn_rowsum_end_%=:" ::);
if constexpr (!TENSOR_CORE) {
threadblock_barrier(1, 7);
} else {
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
// compute Oi rescale factor
// FIXME: parallelize this across threads
//
asm volatile("flashattn_rescale_factor_start_%=:" ::);
#pragma GCC unroll
for (int i = 0; i < per_row_iter; i++) {
const float mi_prev = rowmax_prev;
const float mi_new = rowmax_new;
const float x = mi_prev - mi_new;
// 2nd-order Taylor approximation
float exp = 1.0f;
exp += exponential_taylor_term<1>(x);
exp += exponential_taylor_term<2>(x);
// @perf: div vs. expansion on e(-x)?
smem_O_row_scale[row] = 1.0f / exp;
}
asm volatile("flashattn_rescale_factor_end_%=:" ::);
if constexpr (!TENSOR_CORE) {
threadblock_barrier(1, 7);
} else {
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
}
asm volatile("thread_block_online_softmax_finish_%=:" ::);
}
template <bool block_row_major = false>
__attribute__((always_inline)) inline void thread_block_O_rescale(
const float *smem_O_in, float *smem_O_out, const float *smem_O_row_scale,
const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock,
const uint32_t threadblock_id_in_cluster) {
asm volatile("thread_block_O_rescale_start_%=:" ::);
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS;
const uint32_t warps_per_threadblock_per_core =
warps_in_threadblock / CORES_PER_CLUSTER;
#pragma GCC unroll 1
for (int row_offset = 0; row_offset < B_ROW;
row_offset += warps_in_threadblock) {
const uint32_t row = row_offset + warp_id;
#ifdef ROW_REMAINDER_LOGIC
if (row >= B_ROW) {
// WARNING: the number of barrier calls have to exactly match that in the
// outside of the branch to prevent stalls!! FIXME better proof this.
continue;
}
#endif
constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS;
// Oi rescale
//
#pragma GCC unroll
for (int i = 0; i < per_row_iter; i++) {
const uint32_t col_offset = NUM_THREADS * i;
const uint32_t col = col_offset + tid_in_warp;
const auto [smem_row, smem_col] =
remap_to_gemmini_dma_layout<block_row_major, HEADDIM>(row, col);
const uint32_t offset = HEADDIM * smem_row + smem_col;
const float o = smem_O_in[offset];
const float scale = smem_O_row_scale[row];
smem_O_out[offset] = (o * scale);
}
}
// reconverge after warp divergence
if constexpr (!TENSOR_CORE) {
threadblock_barrier(1, 7);
} else {
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
asm volatile("thread_block_O_rescale_finish_%=:" ::);
}
#endif

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,698 @@
#include <stdint.h>
#include <vx_intrinsics.h>
#include <vx_print.h>
#include <vx_spawn.h>
#include "common.h"
#include "sgemm_impl.hpp"
#include "include/gemmini.h"
#include "gemmini_mmio.h"
#include "flash_impl.hpp"
#define FENCE_GEMM_II
constexpr bool DEBUG = false;
static_assert(GEMMINI_DMA && !WARP_SPECIALIZED,
"GEMMINI_DMA should be set and WARP_SPECIALIZED unset");
void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// @perf: All threads are running these compute whose result is mostly same
// across the threadblock
#ifdef RADIANCE
constexpr uint32_t cores_per_cluster = CORES_PER_CLUSTER;
#else
constexpr uint32_t cores_per_cluster = 1;
#endif
// FIXME: headdim not considered
constexpr uint32_t threads_per_threadblock_theoretical =
(B_ROW * B_COL) / (ELEM_PER_THREAD);
constexpr uint32_t hw_threads_per_cluster =
CORES_PER_CLUSTER * NUM_THREADS * NUM_WARPS;
// cap maximum threadblock size to # of HW threads in cluster, to prevent
// multiple "wave" invocations which slows down the kernel
constexpr uint32_t threads_per_threadblock =
(threads_per_threadblock_theoretical > hw_threads_per_cluster)
? hw_threads_per_cluster
: threads_per_threadblock_theoretical;
constexpr uint32_t threadblocks_per_cluster =
hw_threads_per_cluster / threads_per_threadblock;
constexpr uint32_t warps_per_threadblock_per_core =
NUM_WARPS / threadblocks_per_cluster;
const uint32_t threadblock_id = task_id / threads_per_threadblock;
const uint32_t threadblock_id_in_cluster =
threadblock_id % threadblocks_per_cluster;
const uint32_t tid_in_threadblock = task_id % threads_per_threadblock;
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
constexpr uint32_t warps_in_threadblock =
threads_per_threadblock / NUM_THREADS;
// warpgroup context
constexpr uint32_t threads_per_warpgroup =
threads_per_threadblock / (WARP_SPECIALIZED ? 2 : 1);
constexpr uint32_t warpgroups_per_cluster =
threadblocks_per_cluster * (WARP_SPECIALIZED ? 2 : 1);
const uint32_t warps_per_warpgroup_per_core =
NUM_WARPS / warpgroups_per_cluster;
const uint32_t warpgroup_id = task_id / threads_per_warpgroup;
const uint32_t warpgroup_id_in_cluster =
warpgroup_id % warpgroups_per_cluster;
const uint32_t tid_in_warpgroup = tid_in_threadblock % threads_per_warpgroup;
// // warpgroup 0: warp 0
// // warpgroup 1: warp 1~7
// const uint32_t warpgroup_id = (warp_id != 0);
const uint32_t dim_seqlen = arg->dim_seqlen;
const uint32_t dim_headdim = arg->dim_headdim;
// get global memory addresses from kernel arguments
const float *gmem_Q = reinterpret_cast<float *>(arg->addr_q);
const float *gmem_K = reinterpret_cast<float *>(arg->addr_k);
const float *gmem_V = reinterpret_cast<float *>(arg->addr_v);
float *gmem_O = reinterpret_cast<float *>(arg->addr_o);
float *gmem_tmp_d0 = reinterpret_cast<float *>(0xd0000000UL);
float *gmem_tmp_d1 = reinterpret_cast<float *>(0xd1000000UL);
float *gmem_tmp_d2 = reinterpret_cast<float *>(0xd2000000UL);
float *gmem_tmp_d3 = reinterpret_cast<float *>(0xd3000000UL);
float *gmem_tmp_d4 = reinterpret_cast<float *>(0xd4000000UL);
float *gmem_tmp_d5 = reinterpret_cast<float *>(0xd5000000UL);
float *gmem_tmp_d6 = reinterpret_cast<float *>(0xd6000000UL);
float *gmem_tmp_d7 = reinterpret_cast<float *>(0xd7000000UL);
float *gmem_tmp_e0 = reinterpret_cast<float *>(0xe0000000UL);
float *gmem_tmp_e1 = reinterpret_cast<float *>(0xe1000000UL);
float *gmem_tmp_e2 = reinterpret_cast<float *>(0xe2000000UL);
float *gmem_tmp_e3 = reinterpret_cast<float *>(0xe3000000UL);
// static shared memory allocation
// these are in float elements, not bytes
constexpr uint32_t smem_Q_size = B_ROW * HEADDIM;
constexpr uint32_t smem_K_size = B_COL * HEADDIM;
constexpr uint32_t smem_QK_size = B_ROW * B_COL;
constexpr uint32_t smem_V_size = B_COL * HEADDIM;
constexpr uint32_t smem_O_size = B_COL * HEADDIM;
static_assert(
threads_per_threadblock == NUM_WARPS * NUM_THREADS * CORES_PER_CLUSTER,
"flashattention kernel assumes 1 threadblock occupancy per cluster");
uint8_t *smem_per_threadblock = reinterpret_cast<uint8_t *>(DEV_SMEM_START_ADDR);
constexpr uint32_t smem_start = DEV_SMEM_START_ADDR;
constexpr uint32_t smem_quart0 = 0 * (SMEM_SIZE / 4);
constexpr uint32_t smem_quart1 = 1 * (SMEM_SIZE / 4);
constexpr uint32_t smem_quart2 = 2 * (SMEM_SIZE / 4);
constexpr uint32_t smem_quart3 = 3 * (SMEM_SIZE / 4);
// Q/V/S in quart0/1, K/P/O in quart2/3
constexpr uint32_t smem_Q0_offset = smem_quart0;
constexpr uint32_t smem_Q1_offset = smem_quart1;
constexpr uint32_t smem_K0_offset = smem_quart2;
constexpr uint32_t smem_K1_offset = smem_quart3;
constexpr uint32_t smem_V0_offset = smem_Q0_offset + smem_Q_size * sizeof(float);
constexpr uint32_t smem_V1_offset = smem_Q1_offset + smem_Q_size * sizeof(float);
// put S1/S0 with V0/V1 so that softmax and GEMM-II doesn't cause bank
// conflicts
constexpr uint32_t smem_S0_offset = smem_V1_offset + smem_V_size * sizeof(float);
constexpr uint32_t smem_S1_offset = smem_V0_offset + smem_V_size * sizeof(float);
constexpr uint32_t smem_P0_offset = smem_K0_offset + smem_K_size * sizeof(float);
constexpr uint32_t smem_P1_offset = smem_K1_offset + smem_K_size * sizeof(float);
// reversed!
constexpr uint32_t smem_O0_offset = smem_P1_offset + smem_QK_size * sizeof(float);
constexpr uint32_t smem_O1_offset = smem_P0_offset + smem_QK_size * sizeof(float); // unused
float *smem_Q0 = reinterpret_cast<float *>(smem_start + smem_Q0_offset);
float *smem_Q1 = reinterpret_cast<float *>(smem_start + smem_Q1_offset);
float *smem_K0 = reinterpret_cast<float *>(smem_start + smem_K0_offset);
float *smem_K1 = reinterpret_cast<float *>(smem_start + smem_K1_offset);
float *smem_V0 = reinterpret_cast<float *>(smem_start + smem_V0_offset);
float *smem_V1 = reinterpret_cast<float *>(smem_start + smem_V1_offset);
float *smem_S0 = reinterpret_cast<float *>(smem_start + smem_S0_offset);
float *smem_S1 = reinterpret_cast<float *>(smem_start + smem_S1_offset);
float *smem_P0 = reinterpret_cast<float *>(smem_start + smem_P0_offset);
float *smem_P1 = reinterpret_cast<float *>(smem_start + smem_P1_offset);
float *smem_O0 = reinterpret_cast<float *>(smem_start + smem_O0_offset);
float *smem_O1 = reinterpret_cast<float *>(smem_start + smem_O1_offset);
// allocate rowmax/rowsum storage at the end of the sharedmem address space
constexpr uint32_t smem_rowmax_size = B_ROW * ROWMAX_SETS;
constexpr uint32_t smem_rowsum_size = B_ROW;
constexpr uint32_t smem_O_row_scale_size = B_ROW;
float *smem_cursor = smem_O1 + smem_O_size;
// // FIXME: dangerous
// smem_cursor = reinterpret_cast<float *>(0xff038000);
float *smem_rowmax_0 = smem_cursor;
smem_cursor += smem_rowmax_size;
float *smem_rowmax_1 = smem_cursor;
smem_cursor += smem_rowmax_size;
float *smem_rowsum_0 = smem_cursor;
smem_cursor += smem_rowsum_size;
float *smem_rowsum_1 = smem_cursor;
smem_cursor += smem_rowsum_size;
float *smem_O_row_scale_0 = smem_cursor;
smem_cursor += smem_O_row_scale_size;
float *smem_O_row_scale_1 = smem_cursor;
smem_cursor += smem_O_row_scale_size;
// sharedmem "scratchpad" area to put temporary data, e.g. for tree reduction
// in rowsum
// NOTE: out-of bounds is not checked
constexpr uint32_t smem_scratchpad_size =
threads_per_warpgroup * 2 /*arbitrary slack*/;
float *smem_scratchpad_0 = smem_cursor;
smem_cursor += smem_scratchpad_size;
float *smem_scratchpad_1 = smem_cursor;
smem_cursor += smem_scratchpad_size;
uint32_t *smem_O_flag = reinterpret_cast<uint32_t *>(smem_cursor);
smem_cursor += 1 /* 4Byte */;
static_assert(sizeof(elem_t) == sizeof(float));
constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t);
constexpr uint32_t spad_addr_Q0 = smem_Q0_offset / spad_addr_factor;
constexpr uint32_t spad_addr_Q1 = smem_Q1_offset / spad_addr_factor;
constexpr uint32_t spad_addr_K0 = smem_K0_offset / spad_addr_factor;
constexpr uint32_t spad_addr_K1 = smem_K1_offset / spad_addr_factor;
constexpr uint32_t spad_addr_V0 = smem_V0_offset / spad_addr_factor;
constexpr uint32_t spad_addr_V1 = smem_V1_offset / spad_addr_factor;
constexpr uint32_t spad_addr_S0 = smem_S0_offset / spad_addr_factor;
constexpr uint32_t spad_addr_S1 = smem_S1_offset / spad_addr_factor;
constexpr uint32_t spad_addr_P0 = smem_P0_offset / spad_addr_factor;
constexpr uint32_t spad_addr_P1 = smem_P1_offset / spad_addr_factor;
constexpr uint32_t spad_addr_O0 = smem_O0_offset / spad_addr_factor;
constexpr uint32_t spad_addr_O1 = smem_O1_offset / spad_addr_factor;
constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary
static_assert(warps_per_threadblock_per_core == NUM_WARPS);
// initialize rowmax/rowsum values in sharedmem
thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O0,
smem_rowmax_0, smem_rowsum_0, smem_O_row_scale_0);
thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O1,
smem_rowmax_1, smem_rowsum_1, smem_O_row_scale_1);
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
// skip everything except DMA in the loop FSM
constexpr uint32_t skips =
loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/0, /*skip_ldd=*/1,
/*skip_ex=*/1, /*skip_stc=*/1);
constexpr uint32_t skips_only_a =
loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/1, /*skip_ldd=*/1,
/*skip_ex=*/1, /*skip_stc=*/1);
constexpr uint32_t skips_only_b =
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/0, /*skip_ldd=*/1,
/*skip_ex=*/1, /*skip_stc=*/1);
constexpr uint32_t skips_mvout_spad =
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1,
/*skip_ex=*/1, /*skip_stc=*/0);
constexpr uint32_t skips_matmul =
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1,
/*skip_ex=*/0, /*skip_stc=*/0);
constexpr uint32_t skips_matmul_preload =
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/0,
/*skip_ex=*/0, /*skip_stc=*/1);
if (tid_in_warpgroup == 0) {
gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0);
// configure DMA with GMEM address strides
// Q matrix
gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY,
false, 0);
// K matrix
gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t),
MVIN_SCALE_IDENTITY, false, 1);
// configure DMA for Q*K store
gemmini_extended_config_st(B_COL * sizeof(elem_t), 0, MVIN_SCALE_IDENTITY);
gemmini_fence();
}
// NOTE about barriers: Placing barriers around thread-divergent branches may
// cause bugs, because the Vortex core doesn't check for tmask for barriers.
// The compiler might decide to duplicate vx_bar into both paths of a
// conditional branch, which will get evaluated twice because of the way
// branches are handled in SIMT; this might result in stalls especially when
// other warps behave differently on the branch condition.
// threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
static_assert(B_ROW == B_COL, "currently only supports square tiles");
// move Q and K into SMEM before the loop starts
//
asm volatile("dma_move_start_%=:" ::);
if (tid_in_warpgroup == 0) {
// make sure to read from the correct row of Q
const float *gmem_Q_tile = gmem_Q + HEADDIM * B_ROW * warpgroup_id;
const float *gmem_K_tile = gmem_K;
// configure the GMEM addresses for the DMA to read from
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q_tile),
(uint64_t)(gmem_K_tile), k_LOOP_WS_CONFIG_ADDRS_AB)
// configure address strides for the DMA
GEMMINI_CISC_CMD_R((dim_seqlen << 20) | (HEADDIM << 8) |
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
gemmini_fence();
// #define GEMMINI_DMA_CISC
#ifdef GEMMINI_DMA_CISC
// the target addresses of this should match with spad_addr_Q0 and
// spad_addr_K0 set in this kernel
GEMMINI_CISC_CMD_I(10);
#else
// do DMA
//
// among other things, this also configures CONFIG_BOUNDS so that the
// DMA knows the full matrix dimensions
sp_tiled_matmul_full_spad_ws(
spad_addr_Q0, spad_addr_K0,
/*spad_D=*/0, /*spad_C=*/spad_addr_S0/*bogus*/,
/*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM),
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
#endif
gemmini_fence();
// need to also move Q to spad_addr_Q1 for the next iteration
// FIXME: re-configure necessary?
gmem_K_tile = gmem_K + (B_COL * 1);
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q_tile),
(uint64_t)(gmem_K_tile), k_LOOP_WS_CONFIG_ADDRS_AB)
GEMMINI_CISC_CMD_R((dim_seqlen << 20) | (HEADDIM << 8) |
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
gemmini_fence();
#ifdef GEMMINI_DMA_CISC
// GEMMINI_CISC_CMD_I(11);
#else
sp_tiled_matmul_full_spad_ws(
spad_addr_Q1, spad_addr_K1/*bogus*/,
/*spad_D=*/0, /*spad_C=*/spad_addr_S0/*bogus*/,
/*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM),
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
#endif
gemmini_fence();
gemmini_fence();
gemmini_fence();
gemmini_fence();
// re-configure DMA for K and V load that will later happen in the loop
// GMEM addr stride for K
gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t),
MVIN_SCALE_IDENTITY, false, 0);
// GMEM addr stride for V
gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY,
false, 1);
gemmini_fence();
}
asm volatile("dma_move_end_%=:" ::);
// protect write to SMEM
// threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
// if constexpr (DEBUG) {
// thread_block_copy_tile<B_ROW, HEADDIM>(smem_Q0, gmem_tmp_d0, tid_in_warpgroup,
// threads_per_warpgroup, warpgroup_id_in_cluster);
// thread_block_copy_tile<HEADDIM, B_COL>(smem_K0, gmem_tmp_d1, tid_in_warpgroup,
// threads_per_warpgroup, warpgroup_id_in_cluster);
// threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
// }
constexpr uint32_t threads_per_warpgroup_simt =
threads_per_warpgroup -
CORES_PER_CLUSTER * NUM_THREADS /*warp 0, 4, 8, 12*/;
constexpr uint32_t warpgroup_id_simt = 1;
constexpr uint32_t barrier_id_simt = 1;
constexpr uint32_t barrier_count_simt = NUM_WARPS - 1;
const uint32_t tid_in_warpgroup_simt =
tid_in_warpgroup - (CORES_PER_CLUSTER * NUM_THREADS);
static_assert(barrier_id_simt == 1 && barrier_count_simt == 7);
asm volatile ("tile_loop_start_%=:" :: );
// "inner loop" along the columns of K^T
const uint32_t k_tiles = (dim_seqlen / B_COL);
for (uint32_t tile_k = 0;
tile_k < (4 /*for perf measurement*/ * k_tiles) + 2 /*pipeline latency*/;
tile_k++) {
if constexpr (DEBUG || true) {
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
}
// select the correct double buffer by tile iteration
// all iterations work on the same Q row tile; no ping-pong necessary
asm volatile ("dbuf_sel_start_%=:" :: );
// FIXME speedup by doing arithmetic
float *smem_Q = smem_Q0;
float *smem_K_consume = (tile_k & 1) ? smem_K1 : smem_K0;
float *smem_K_produce = (tile_k & 1) ? smem_K0 : smem_K1;
float *smem_V_consume = (tile_k & 1) ? smem_V1 : smem_V0;
float *smem_V_produce = (tile_k & 1) ? smem_V0 : smem_V1;
float *smem_S_consume = (tile_k & 1) ? smem_S1 : smem_S0;
float *smem_S_produce = (tile_k & 1) ? smem_S0 : smem_S1;
float *smem_P_consume = (tile_k & 1) ? smem_P1 : smem_P0;
float *smem_P_produce = (tile_k & 1) ? smem_P0 : smem_P1;
// O, rowmax/rowsum etc. is sequentially updated at every iteration; no
// ping-pong necessary
float *smem_O = smem_O0;
float *smem_O_row_scale = smem_O_row_scale_0;
float *smem_rowmax = smem_rowmax_0;
float *smem_rowsum = smem_rowsum_0;
float *smem_scratchpad = smem_scratchpad_0;
const auto spad_addr_Q = spad_addr_Q0;
const auto spad_addr_K_consume = (tile_k & 1) ? spad_addr_K1 : spad_addr_K0;
const auto spad_addr_K_produce = (tile_k & 1) ? spad_addr_K0 : spad_addr_K1;
const auto spad_addr_V_consume = (tile_k & 1) ? spad_addr_V1 : spad_addr_V0;
const auto spad_addr_V_produce = (tile_k & 1) ? spad_addr_V0 : spad_addr_V1;
const auto spad_addr_S_consume = (tile_k & 1) ? spad_addr_S1 : spad_addr_S0;
const auto spad_addr_S_produce = (tile_k & 1) ? spad_addr_S0 : spad_addr_S1;
const auto spad_addr_P_consume = (tile_k & 1) ? spad_addr_P1 : spad_addr_P0;
const auto spad_addr_P_produce = (tile_k & 1) ? spad_addr_P0 : spad_addr_P1;
const auto spad_addr_O = spad_addr_O0; // NOTE: there's only single O tile
asm volatile ("dbuf_sel_end_%=:" :: );
if (vx_warp_id() == 0 /* warp 0 in every core */) {
if (tile_k >= 2) // delay by 2 iters for pipelining
{
const uint32_t tile_k_ = tile_k - 2;
// GEMM II: O = O + P*V
// --------------------
// This is done *before* GEMM I in the software pipeline, working on the
// online softmax result tile from the previous iteration
asm volatile("gemm_pv_start_%=:" ::);
if (tid_in_warpgroup == 0) {
#if 0
if (tile_k_ == 0) {
gemmini_fence();
GEMMINI_CISC_CMD_I(0);
} else if (tile_k_ & 1) {
gemmini_fence();
GEMMINI_CISC_CMD_I(2);
} else {
gemmini_fence();
GEMMINI_CISC_CMD_I(1);
}
#else
// kickoff matmul
// among other things, this also configures CONFIG_BOUNDS so that the
// DMA knows the full matrix dimensions
// FIXME: perf: prevent GMEM->SMEM load for O tile
gemmini_fence();
gemmini_fence();
gemmini_fence();
gemmini_fence();
sp_tiled_matmul_full_spad_ws(
spad_addr_P_consume, spad_addr_V_consume,
/*spad_D=*/spad_addr_O, /*spad_C=*/spad_addr_O,
/*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul);
#endif
}
// // reconverge from mmio divergence
// threadblock_barrier(warpgroup_id_in_cluster,
// warps_per_warpgroup_per_core);
asm volatile("gemm_pv_finish_%=:" ::);
}
// GEMM I: S = Q*K
//
// kick off asynchronously; fence later
asm volatile("gemm_qk_start_%=:" ::);
if (tid_in_warpgroup == 0) {
// fence to GEMM II completion
gemmini_fence();
gemmini_fence();
gemmini_fence();
gemmini_fence();
#ifdef FENCE_GEMM_II
// signal that GEMM II is finished to O rescale step
*smem_O_flag = 1;
vx_fence();
#endif
// 0,2,.: opcode 0 (quartile 0/2, no accum)
// 1,3,.: opcode 3 (quartile 1/3, no accum)
// const uint32_t opcode = 3 * (tile_k & 1);
//GEMMINI_CISC_CMD_I(opcode);
sp_tiled_matmul_full_spad_ws(
spad_addr_Q, spad_addr_K_consume,
/*spad_D=*/0, /*spad_C=*/spad_addr_S_produce,
/*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM),
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul);
// gemmini_fence();
// gemmini_fence();
// gemmini_fence();
// gemmini_fence();
asm volatile("gemm_qk_finish_%=:" ::);
// data move for K and V
//
// Q stays in SMEM for the entire loop
asm volatile("move_k_v_start_%=:" ::);
// configure GMEM addresses for K and V tiles
// load K for the next iteration
const float *gmem_K_tile = gmem_K + (B_COL * (tile_k + 1 /*runahead*/));
// load V for the *previous* iteration; this will be consumed 2
// iterations later
const float *gmem_V_tile =
gmem_V + (HEADDIM * B_COL * (tile_k - 1 /*dragbehind*/));
#if 0
// fence mvout S to SMEM
gemmini_fence();
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile),
(uint64_t)(gmem_V_tile),
k_LOOP_WS_CONFIG_ADDRS_AB)
#endif
// configure address strides for the DMA
// FIXME: unnecessary?
GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) |
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
// gemmini_fence();
// do DMA
if (tile_k == 0) {
// we load (k-1)th tile for V; skip V for the 1st iteration,
// sp_tiled_matmul_full_spad_ws(
// spad_addr_K_produce, spad_addr_V_produce,
// /*spad_D=*/0, /*spad_C=*/0,
// /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
// /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
// /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
// /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a);
} else {
sp_tiled_matmul_full_spad_ws(
spad_addr_K_produce, spad_addr_V_produce,
/*spad_D=*/0, /*spad_C=*/0,
/*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
}
// fence everything before going to the next tile
gemmini_fence();
gemmini_fence();
gemmini_fence();
gemmini_fence();
}
// threadblock_barrier(warpgroup_id_in_cluster,
// warps_per_warpgroup_per_core);
asm volatile("move_k_v_finish_%=:" ::);
// NOTE: cannot put barrier here; thread 1-7 in warp 0 will skip the
// branch and call this barrier earlier than when thread 0 finishes.
// Since tmask is not considered, that will be a barrier resolve done too
// early
// threadblock_barrier(0, 1);
} else /* warp_id != 0 */ {
if (tile_k >= 1) // delay by 1 iters for pipelining
{
const uint32_t tile_k_ = tile_k - 1;
if constexpr (DEBUG) {
// verify S = Q*K before softmax
if (warpgroup_id == 0) {
if (tile_k_ == 0) {
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
smem_S_consume, gmem_tmp_d0, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
} else if (tile_k_ == 1) {
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
smem_S_consume, gmem_tmp_d1, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
}
threadblock_barrier(barrier_id_simt, barrier_count_simt);
}
}
// Online softmax
//
thread_block_online_softmax</*block_row_major=*/GEMMINI_DMA>(
smem_S_consume, smem_P_produce, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt, smem_scratchpad,
smem_rowmax, smem_rowsum, smem_O_row_scale);
threadblock_barrier(barrier_id_simt, barrier_count_simt);
if constexpr (DEBUG) {
if (warpgroup_id == 0) {
if (tile_k_ == 0) {
thread_block_copy_rowmax(
smem_rowmax, gmem_tmp_e0, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
thread_block_copy_rowmax(
smem_rowsum, gmem_tmp_e2, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
} else if (tile_k_ == 1) {
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1,
tid_in_warpgroup_simt, threads_per_warpgroup_simt,
warpgroup_id_simt);
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3,
tid_in_warpgroup_simt, threads_per_warpgroup_simt,
warpgroup_id_simt);
}
threadblock_barrier(barrier_id_simt, barrier_count_simt);
}
}
#ifdef FENCE_GEMM_II
// check flag to make sure GEMM II finished and read-after-write
// dependency on O tile is settled for rescale
if (tid_in_warpgroup_simt == 0) {
while ((*smem_O_flag) != 1)
;
// set it back to 0 for the next tile iteration
*smem_O_flag = 0;
vx_fence();
}
#endif
#if 0
if (tid_in_warpgroup == 0) {
gemmini_fence();
gemmini_fence();
gemmini_fence();
gemmini_fence();
}
// reconverge from mmio divergence
threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core);
#endif
if constexpr (DEBUG) {
if (warpgroup_id == 0) {
gemmini_fence();
gemmini_fence();
// O after PV
if (tile_k_ == 1 /*wait until GEMM II finshes */) {
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
smem_O, gmem_tmp_d6, tid_in_warpgroup_simt, threads_per_warpgroup_simt,
warpgroup_id_simt);
} else if (tile_k_ == 2) {
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
smem_O, gmem_tmp_d7, tid_in_warpgroup_simt, threads_per_warpgroup_simt,
warpgroup_id_simt);
}
threadblock_barrier(barrier_id_simt, barrier_count_simt);
}
}
// Oi rescale
thread_block_O_rescale</*block_row_major=*/GEMMINI_DMA>(
smem_O, smem_O /*in-place*/, smem_O_row_scale,
tid_in_warpgroup_simt, threads_per_warpgroup_simt,
warpgroup_id_simt);
// rescale-to-PV-GEMM barrier
threadblock_barrier(barrier_id_simt, barrier_count_simt);
if constexpr (DEBUG) {
if (warpgroup_id == 0) {
// O before PV
if (tile_k_ == 0) {
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
smem_P_produce, gmem_tmp_d2, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
smem_O, gmem_tmp_d4, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
} else if (tile_k_ == 1) {
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
smem_P_produce, gmem_tmp_d3, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
smem_O, gmem_tmp_d5, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
}
threadblock_barrier(barrier_id_simt, barrier_count_simt);
}
}
}
#if 0
// fence GEMM I after Oi rescale
if (tid_in_warpgroup == 0) {
gemmini_fence();
gemmini_fence();
gemmini_fence();
gemmini_fence();
}
// reconverge from mmio divergence
threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core);
#endif
// intra-warpgroup barrier
threadblock_barrier(barrier_id_simt, barrier_count_simt);
}
}
asm volatile ("tile_loop_finish_%=:" :: );
}
int main() {
kernel_arg_t *arg = (kernel_arg_t *)KERNEL_ARG_DEV_MEM_ADDR;
const uint32_t hw_threads_per_cluster =
CORES_PER_CLUSTER * vx_num_threads() * vx_num_warps();
// fix to 1 threadblock per cluster
const uint32_t grid_size = hw_threads_per_cluster;
#ifdef RADIANCE
vx_spawn_tasks_cluster(grid_size, (vx_spawn_tasks_cb)kernel_body, arg);
#else
// NOTE: This kernel assumes contiguous thread scheduling for efficient shared
// memory allocation, and therefore does not work with original vx_spawn_tasks
vx_spawn_tasks_contiguous(grid_size, (vx_spawn_tasks_cb)kernel_body, arg);
#endif
return 0;
}

View File

@@ -141,8 +141,10 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
}
if (tile_k == 0) {
asm volatile("cisc_start_%=:" ::);
gemmini_fence();
GEMMINI_CISC_CMD_I(0);
asm volatile("cisc_end_%=:" ::);
} else if (tile_k & 1) {
gemmini_fence();
GEMMINI_CISC_CMD_I(2);
@@ -218,4 +220,4 @@ int main() {
vx_spawn_tasks_contiguous(grid_size, (vx_spawn_tasks_cb)kernel_body, arg);
#endif
return 0;
}
}

View File

@@ -7,7 +7,7 @@
#include "include/gemmini.h"
#include "gemmini_mmio.h"
constexpr bool DEBUG = true;
constexpr bool DEBUG = false;
template <uint32_t tile_dim_row, uint32_t tile_dim_col>
inline void thread_block_copy_tile(const float *src, float *dest,
@@ -90,19 +90,29 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
thread_block_gemm<float_type, threads_per_threadblock,
/*write_to_gmem=*/true,
/*smem_a_offset=*/0,
/*smem_a_dbuf_offset=*/0,
/*smem_b_offset=*/2 * BM * BK * sizeof(float),
/*smem_b_dbuf_offset=*/2 * BM * BK * sizeof(float)>(
(const float_type *)arg->addr_a, (const float_type *)arg->addr_b,
(float *)arg->addr_c, arg->dim_m, arg->dim_n, arg->dim_k,
tid_in_threadblock, threadblocks_per_cluster, threadblock_id_in_cluster,
sharedmem_per_threadblock);
#ifdef GEMMINI_DMA
/*smem_a_dbuf_offset=*/1 * 128 * 128 * sizeof(float_type),
/*smem_b_offset=*/2 * 128 * 128 * sizeof(float_type),
/*smem_b_dbuf_offset=*/3 * 128 * 128 * sizeof(float_type)
// FIXME: above offsets are hardcoded to agree with CISC
// spadQuartile
#else
/*smem_a_dbuf_offset=*/1 * BM * BK * sizeof(float_type),
/*smem_b_offset=*/2 * BM * BK * sizeof(float_type),
/*smem_b_dbuf_offset=*/(2 * BM * BK + BK * BN) * sizeof(float_type)
#endif
>((const float_type *)arg->addr_a,
(const float_type *)arg->addr_b, (float *)arg->addr_c,
arg->dim_m, arg->dim_n, arg->dim_k, tid_in_threadblock,
threadblocks_per_cluster, threadblock_id_in_cluster,
sharedmem_per_threadblock);
float *gmem_tmp_d0 = reinterpret_cast<float *>(0xd0000000UL);
float *gmem_tmp_d1 = reinterpret_cast<float *>(0xd1000000UL);
const float *smem_A = reinterpret_cast<float *>(sharedmem_per_threadblock);
const float *smem_B = smem_A + 2 * BM * BK;
const float *smem_B = reinterpret_cast<float *>(
sharedmem_per_threadblock + 2 * BM * BK * sizeof(float_type));
if constexpr (DEBUG) {
threadblock_barrier(threadblock_id_in_cluster,

View File

@@ -29,7 +29,7 @@ using float_type = float16_t;
// (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER
// * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields
// BM <= BK*TM*TN
#define BM 128
#define BM ((NUM_CORES == 8) ? 128 : 64)
#define BN 64
#if (FP_SIZE == 32)
#define BK 64
@@ -62,8 +62,8 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER ==
#define BK_LOOP 1
// Whether to transpose smem A tile at GMEM->SMEM (produce), or SMEM->RF
// (consume). This is because the tensor core expects the A tile to be stored
// in column-major order in SMEM, whereas it will be ultimately stored in
// row-major in the RF.
// in column-major order in SMEM, so a transpose is necessary if A was stored
// row-major in GMEM.
//
// For correctness, only one of either should be 1. E.g., PRODUCE 1 CONSUME 0
// generates the NN kernel where both A and B are stored row-major in GMEM.
@@ -72,8 +72,8 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER ==
#define TRANSPOSE_AT_PRODUCE 0
#define TRANSPOSE_AT_CONSUME 0
#define GEMMINI_DMA 0
#define GEMMINI_DMA_MN_MAJOR 1
#define GEMMINI_DMA 1
#define GEMMINI_DMA_FAST 1
#if SMEM_SIZE == 0x4000
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
#define SMEM_ADDR_Q1 ((float * const) 0xff001000)
@@ -101,6 +101,7 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER ==
enum class MemLayout {
MN_major,
K_major,
block_row_major, // Gemmini DMA
};
inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) {
@@ -206,7 +207,7 @@ template <bool use_dma, uint32_t dim_col>
inline constexpr std::pair<uint32_t, uint32_t>
remap_to_gemmini_dma_layout(const uint32_t logical_row,
const uint32_t logical_col) {
static_assert(DIM == 8,
static_assert(!use_dma || DIM == 8,
"GEMMINI_DMA layout remapping code only written for DIM == 8");
if constexpr (use_dma) {
@@ -253,13 +254,11 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
constexpr int packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
const int local_k_adjusted = local_k / packed_factor;
static_assert(!GEMMINI_DMA || (layout == MemLayout::K_major) ||
GEMMINI_DMA_MN_MAJOR,
"GEMMINI_DMA only supported for K-major A tile");
static_assert((layout != MemLayout::K_major) || (FP_SIZE == 32),
"fp16 is not really tested for K-major A layout");
if constexpr (layout == MemLayout::K_major) {
if constexpr (layout == MemLayout::K_major ||
layout == MemLayout::block_row_major) {
constexpr int smem_A_cols = leading_dim;
// f8-f15 stores a single row of A
@@ -269,8 +268,9 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
// if using Gemmini DMA, remap logical row/col to Gemmini's 2-level
// block-row-major layout
const auto [smem_row, smem_col] =
remap_to_gemmini_dma_layout<GEMMINI_DMA, smem_A_cols>(smem_logical_row,
smem_logical_col);
remap_to_gemmini_dma_layout<layout == MemLayout::block_row_major,
smem_A_cols>(smem_logical_row,
smem_logical_col);
const volatile uint8_t *smem_addr;
smem_addr = reinterpret_cast<const volatile uint8_t *>(
@@ -356,8 +356,9 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k,
const int thread_in_warp) {
asm volatile ("wmma_load_b_start_%=:" :: );
static_assert(layout == MemLayout::MN_major,
"only N-major layout for the B tile is supported");
static_assert(
layout == MemLayout::MN_major || layout == MemLayout::block_row_major,
"only N-major or block-row-major layout are supported for the B tile");
const int tid = thread_in_warp;
const int tg = tid / 4;
@@ -379,8 +380,9 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k,
// if using Gemmini DMA, remap logical row/col to Gemmini's 2-level
// block-row-major layout
const auto [smem_row, smem_col] =
remap_to_gemmini_dma_layout<GEMMINI_DMA, smem_B_cols>(smem_logical_row,
smem_logical_col);
remap_to_gemmini_dma_layout<layout == MemLayout::block_row_major,
smem_B_cols>(smem_logical_row,
smem_logical_col);
const volatile uint8_t *smem_addr;
smem_addr = reinterpret_cast<const volatile uint8_t *>(
@@ -388,10 +390,10 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k,
smem_B)[smem_B_cols * smem_row + smem_col]);
// f8-f15 stores a single column of B
// threads read from different columns; no bank conflicts
if constexpr (GEMMINI_DMA) {
// for GEMMINI_DMA, moving rows for the next 7 elements in the same column
// is the same as moving DIM elements forward in the memory because of the
// block-row-major layout
if constexpr (layout == MemLayout::block_row_major) {
// for the block-row-major layout, moving rows for the next 7 elements in
// the same column is the same as moving DIM elements forward in the memory
// because of the block-row-major layout
asm volatile("flw f8, %0(%1)" :: "i"(DIM * 0 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f9, %0(%1)" :: "i"(DIM * 1 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f10, %0(%1)" :: "i"(DIM * 2 * sizeof(float)), "r"(smem_addr));
@@ -533,7 +535,9 @@ wmma_store(const int thread_in_warp, const int warp_col, const int warp_row,
asm volatile ("wmma_store_finish_%=:" :: );
}
inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count) {
__attribute__((convergent)) inline void
threadblock_barrier(const uint32_t barrier_id, const uint32_t count) {
asm volatile("" ::: "memory");
vx_fence();
vx_barrier(barrier_id, count);
}
@@ -818,6 +822,10 @@ __attribute__((always_inline)) inline void thread_block_gemm_single_tile(
if (tid_in_threadblock == 0) {
gemmini_fence();
}
// reconverge after mmio
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
if constexpr (write_to_mem) {
@@ -845,9 +853,9 @@ template <
uint32_t smem_a_dbuf_offset = 0, // byte offset of A
// double-buffer tile in shared
// memory
uint32_t smem_b_offset = sizeof(float) * BM * BK, // byte offset of B tile
uint32_t smem_b_offset = sizeof(T) * BM * BK, // byte offset of B tile
// in shared memory
uint32_t smem_b_dbuf_offset = sizeof(float) * BM *
uint32_t smem_b_dbuf_offset = sizeof(T) * BM *
BK // byte offset of B double-buffer
// tile in shared memory
>
@@ -903,6 +911,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
for (uint32_t block_m = block_m_start; block_m < block_m_end; block_m++) {
#pragma GCC unroll 1
for (uint32_t block_n = 0; (block_n * BN) < dim_n; block_n++) {
asm volatile ("loop_mn_start_%=:" :: );
// clear out accumulators
initialize_accum_regs<0>();
initialize_accum_regs<1>();
@@ -911,14 +921,13 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
// pipeline initiation
if (tid_in_threadblock == 0) {
// configure dma gmem address to load from
// FIXME: block_k is wrong
ROCC_INSTRUCTION_RS1_RS2(
XCUSTOM_ACC,
(uint64_t)(A + block_m * BM * dim_k + /*block_k:*/0 * BK),
(uint64_t)(B + /*block_k:*/0 * BK * dim_n + block_n * BN),
k_LOOP_WS_CONFIG_ADDRS_AB)
// GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB
GEMMINI_CISC_CMD_R((dim_n << 16) | (dim_k << 8) | 8);
GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8);
gemmini_fence();
GEMMINI_CISC_CMD_I(10);
@@ -949,6 +958,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
#pragma GCC unroll 1
for (uint32_t block_k = 0; (block_k * BK) < dim_k; block_k++) {
asm volatile("loop_k_start_%=:" ::);
// producer code: GMEM->SMEM memory movement
// ---------------------------------------------------------------------
@@ -958,20 +968,19 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
#if (GEMMINI_DMA == 1)
if ((tid_in_threadblock == 0) && ((block_k * BK) != (dim_k - BK))) {
// configure dma gmem address to load from
// FIXME: block_k is wrong
ROCC_INSTRUCTION_RS1_RS2(
XCUSTOM_ACC,
(uint64_t)(A + block_m * BM * dim_k + (block_k + 1/*runahead*/) * BK),
(uint64_t)(B + (block_k + 1/*runahead*/) * BK * dim_n + block_n * BN),
k_LOOP_WS_CONFIG_ADDRS_AB)
// GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB
GEMMINI_CISC_CMD_R((dim_n << 16) | (dim_k << 8) | 8);
// gemmini_fence();
GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8);
gemmini_fence();
// block_k is even: opcode 11 (write to local_a_buf)
// block_k is odd: opcode 10 (write to local_a)
const uint32_t opcode = 11 - (block_k & 1);
GEMMINI_CISC_CMD_R(opcode);
GEMMINI_CISC_CMD_I(opcode);
// // TODO: branch is probably slow
// if (block_k & 1) {
// GEMMINI_CISC_CMD_I(12);
@@ -1017,6 +1026,10 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips)
#endif
}
// reconverge after mmio divergence
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
#else
// move A
if constexpr (!TRANSPOSE_AT_PRODUCE) {
@@ -1038,10 +1051,10 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
warps_per_threadblock_per_core);
#endif
#if 0
// consumer code: SMEM->RF and compute
// ----------------------------------------------------------------------
// @perf: this loop spills to stack a lot because of all the flws in
asm volatile("dbuf_sel_start_%=:" ::);
const T *local_a_consume;
const T *local_b_consume;
if constexpr (GEMMINI_DMA) {
@@ -1056,17 +1069,29 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
// local_b_consume = reinterpret_cast<T *>(
// (mask_odd & reinterpret_cast<uintmax_t>(local_b_buf)) |
// (mask_even & reinterpret_cast<uintmax_t>(local_b)));
local_a_consume = local_a + (block_k & 1) * (BM * BK);
local_b_consume = local_b + (block_k & 1) * (BK * BN);
local_a_consume = local_a + (block_k & 1) *
(smem_a_dbuf_offset - smem_a_offset) /
sizeof(T);
local_b_consume = local_b + (block_k & 1) *
(smem_b_dbuf_offset - smem_b_offset) /
sizeof(T);
} else {
// no double-buffering without DMA
local_a_consume = local_a;
local_b_consume = local_b;
}
asm volatile("dbuf_sel_end_%=:" ::);
constexpr MemLayout layout_a =
TRANSPOSE_AT_CONSUME ? MemLayout::K_major : MemLayout::MN_major;
thread_block_gemm_single_tile<T, layout_a, MemLayout::MN_major,
GEMMINI_DMA ? (GEMMINI_DMA_FAST ? MemLayout::MN_major
: MemLayout::block_row_major)
: (TRANSPOSE_AT_CONSUME ? MemLayout::K_major
: MemLayout::MN_major);
constexpr MemLayout layout_b =
GEMMINI_DMA ? (GEMMINI_DMA_FAST ? MemLayout::MN_major
: MemLayout::block_row_major)
: MemLayout::MN_major;
thread_block_gemm_single_tile<T, layout_a, layout_b,
BM, BN, BK, 0, 0,
/*load_accum=*/false,
/*write_to_mem=*/false>(
@@ -1087,7 +1112,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
#endif
asm volatile("loop_k_end_%=:" ::);
}
if constexpr (write_to_gmem) {
@@ -1102,6 +1128,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
}
}
}
asm volatile("loop_mn_end_%=:" ::);
}
}