Merge branch 'kernels' of https://github.com/hansungk/vortex-private into kernels
This commit is contained in:
@@ -84,7 +84,7 @@
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifndef NUM_CORES
|
#ifndef NUM_CORES
|
||||||
#define NUM_CORES 8
|
#define NUM_CORES 4
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifndef NUM_WARPS
|
#ifndef NUM_WARPS
|
||||||
|
|||||||
@@ -9,6 +9,8 @@
|
|||||||
// #define SMEM_SIZE 0x4000
|
// #define SMEM_SIZE 0x4000
|
||||||
// 64KB
|
// 64KB
|
||||||
// #define SMEM_SIZE 0x10000
|
// #define SMEM_SIZE 0x10000
|
||||||
|
// 128KB
|
||||||
|
// #define SMEM_SIZE 0x20000
|
||||||
// 256KB
|
// 256KB
|
||||||
#define SMEM_SIZE 0x40000
|
#define SMEM_SIZE 0x40000
|
||||||
|
|
||||||
|
|||||||
@@ -149,6 +149,7 @@ inline void vx_join(unsigned stack_ptr) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Warp Barrier
|
// Warp Barrier
|
||||||
|
__attribute__((convergent))
|
||||||
inline void vx_barrier(unsigned barried_id, unsigned num_warps) {
|
inline void vx_barrier(unsigned barried_id, unsigned num_warps) {
|
||||||
asm volatile (".insn r %0, 4, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(barried_id), "r"(num_warps));
|
asm volatile (".insn r %0, 4, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(barried_id), "r"(num_warps));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,7 +18,7 @@
|
|||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#ifndef CORES_PER_CLUSTER
|
#ifndef CORES_PER_CLUSTER
|
||||||
#define CORES_PER_CLUSTER 8
|
#define CORES_PER_CLUSTER 4
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
|||||||
@@ -80,9 +80,13 @@ if __name__ == "__main__":
|
|||||||
fp16 = True
|
fp16 = True
|
||||||
if fp16:
|
if fp16:
|
||||||
A_packed = pack_fp16_by_row(A_array)
|
A_packed = pack_fp16_by_row(A_array)
|
||||||
|
A_swizzled = A_packed.reshape([-1, M * 2])
|
||||||
|
A_swizzled.astype('float16').tofile("input.a.row.bin")
|
||||||
AT_packed = A_packed.transpose([1, 0, 2])
|
AT_packed = A_packed.transpose([1, 0, 2])
|
||||||
AT_swizzled = AT_packed.reshape([-1, M * 2])
|
AT_swizzled = AT_packed.reshape([-1, M * 2])
|
||||||
AT_swizzled.astype('float16').tofile("input.a.col.bin")
|
AT_swizzled.astype('float16').tofile("input.a.col.bin")
|
||||||
|
print('A:')
|
||||||
|
print(A_swizzled)
|
||||||
print('AT:')
|
print('AT:')
|
||||||
print(AT_swizzled)
|
print(AT_swizzled)
|
||||||
B_packed = pack_fp16_by_column(B_array)
|
B_packed = pack_fp16_by_column(B_array)
|
||||||
|
|||||||
@@ -93,6 +93,23 @@ inline constexpr void map_c_8lanes(const int tid, int &row, int &col) {
|
|||||||
col += ((tid % 4) / 2) * 2;
|
col += ((tid % 4) / 2) * 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline constexpr void map_c_rowmajor_8lanes(const int tid, int &row, int &col) {
|
||||||
|
const int tg = tid / 4;
|
||||||
|
|
||||||
|
// A (row major)
|
||||||
|
// row 0~ 3: threadgroup 0
|
||||||
|
// row 4~ 7: threadgroup 1
|
||||||
|
row = tid % 4;
|
||||||
|
row += tg * 4;
|
||||||
|
|
||||||
|
// B (column major)
|
||||||
|
// col 0~ 3: threadgroup 0
|
||||||
|
// col 4~ 7: threadgroup 1
|
||||||
|
col = tid % 4;
|
||||||
|
col += tg * 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
void vx_wmma_load() {
|
void vx_wmma_load() {
|
||||||
int tid = vx_thread_id();
|
int tid = vx_thread_id();
|
||||||
int tg = tid / 4;
|
int tg = tid / 4;
|
||||||
@@ -174,11 +191,31 @@ void store_wmma_result() {
|
|||||||
int row = 0;
|
int row = 0;
|
||||||
int col = 0;
|
int col = 0;
|
||||||
|
|
||||||
map_c_8lanes(tid, row, col);
|
// map_c_8lanes(tid, row, col);
|
||||||
|
map_c_rowmajor_8lanes(tid, row, col);
|
||||||
|
|
||||||
// store C
|
// store C
|
||||||
float *const results_wid = results + (DIM_M * DIM_N * wid);
|
float *const results_wid = results + (DIM_M * DIM_N * wid);
|
||||||
// uncomment to have two accum buffers in rf
|
|
||||||
|
// asm volatile("fsw f16, %0" ::"m"(results_wid[DIM_N * 0 + col]));
|
||||||
|
// asm volatile("fsw f17, %0" ::"m"(results_wid[DIM_N * 1 + col]));
|
||||||
|
// asm volatile("fsw f18, %0" ::"m"(results_wid[DIM_N * 2 + col]));
|
||||||
|
// asm volatile("fsw f19, %0" ::"m"(results_wid[DIM_N * 3 + col]));
|
||||||
|
// asm volatile("fsw f20, %0" ::"m"(results_wid[DIM_N * 4 + col]));
|
||||||
|
// asm volatile("fsw f21, %0" ::"m"(results_wid[DIM_N * 5 + col]));
|
||||||
|
// asm volatile("fsw f22, %0" ::"m"(results_wid[DIM_N * 6 + col]));
|
||||||
|
// asm volatile("fsw f23, %0" ::"m"(results_wid[DIM_N * 7 + col]));
|
||||||
|
asm volatile("fsw f24, %0" ::"m"(results_wid[DIM_N * 0 + col]));
|
||||||
|
asm volatile("fsw f25, %0" ::"m"(results_wid[DIM_N * 1 + col]));
|
||||||
|
asm volatile("fsw f26, %0" ::"m"(results_wid[DIM_N * 2 + col]));
|
||||||
|
asm volatile("fsw f27, %0" ::"m"(results_wid[DIM_N * 3 + col]));
|
||||||
|
asm volatile("fsw f28, %0" ::"m"(results_wid[DIM_N * 4 + col]));
|
||||||
|
asm volatile("fsw f29, %0" ::"m"(results_wid[DIM_N * 5 + col]));
|
||||||
|
asm volatile("fsw f30, %0" ::"m"(results_wid[DIM_N * 6 + col]));
|
||||||
|
asm volatile("fsw f31, %0" ::"m"(results_wid[DIM_N * 7 + col]));
|
||||||
|
|
||||||
|
|
||||||
|
// 1x2 jagged mapping
|
||||||
// asm volatile("fsw f16, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 0)]));
|
// asm volatile("fsw f16, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 0)]));
|
||||||
// asm volatile("fsw f17, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 1)]));
|
// asm volatile("fsw f17, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 1)]));
|
||||||
// asm volatile("fsw f18, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 0)]));
|
// asm volatile("fsw f18, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 0)]));
|
||||||
@@ -187,14 +224,14 @@ void store_wmma_result() {
|
|||||||
// asm volatile("fsw f21, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 5)]));
|
// asm volatile("fsw f21, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 5)]));
|
||||||
// asm volatile("fsw f22, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 4)]));
|
// asm volatile("fsw f22, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 4)]));
|
||||||
// asm volatile("fsw f23, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 5)]));
|
// asm volatile("fsw f23, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 5)]));
|
||||||
asm volatile("fsw f24, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 0)]));
|
// asm volatile("fsw f24, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 0)]));
|
||||||
asm volatile("fsw f25, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 1)]));
|
// asm volatile("fsw f25, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 1)]));
|
||||||
asm volatile("fsw f26, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 0)]));
|
// asm volatile("fsw f26, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 0)]));
|
||||||
asm volatile("fsw f27, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 1)]));
|
// asm volatile("fsw f27, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 1)]));
|
||||||
asm volatile("fsw f28, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 4)]));
|
// asm volatile("fsw f28, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 4)]));
|
||||||
asm volatile("fsw f29, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 5)]));
|
// asm volatile("fsw f29, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 5)]));
|
||||||
asm volatile("fsw f30, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 4)]));
|
// asm volatile("fsw f30, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 4)]));
|
||||||
asm volatile("fsw f31, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 5)]));
|
// asm volatile("fsw f31, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 5)]));
|
||||||
}
|
}
|
||||||
|
|
||||||
void print_wmma_result() {
|
void print_wmma_result() {
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ VX_CP = $(LLVM_VORTEX)/bin/llvm-objcopy
|
|||||||
#VX_DP = $(RISCV_TOOLCHAIN_PATH)/bin/$(RISCV_PREFIX)-objdump
|
#VX_DP = $(RISCV_TOOLCHAIN_PATH)/bin/$(RISCV_PREFIX)-objdump
|
||||||
#VX_CP = $(RISCV_TOOLCHAIN_PATH)/bin/$(RISCV_PREFIX)-objcopy
|
#VX_CP = $(RISCV_TOOLCHAIN_PATH)/bin/$(RISCV_PREFIX)-objcopy
|
||||||
|
|
||||||
VX_CFLAGS += -v -O3 -std=c++17
|
VX_CFLAGS += -v -Os -std=c++17
|
||||||
VX_CFLAGS += -mcmodel=medany -fno-rtti -fno-exceptions -nostartfiles -fdata-sections -ffunction-sections
|
VX_CFLAGS += -mcmodel=medany -fno-rtti -fno-exceptions -nostartfiles -fdata-sections -ffunction-sections
|
||||||
# comment out below for regression/basic, which uses GCC that doesn't
|
# comment out below for regression/basic, which uses GCC that doesn't
|
||||||
# understand these flags
|
# understand these flags
|
||||||
|
|||||||
@@ -2,8 +2,8 @@ PROJECT = flash_attention
|
|||||||
|
|
||||||
SRCS = main.cpp common.h
|
SRCS = main.cpp common.h
|
||||||
|
|
||||||
VX_SRCS = kernel.cpp
|
VX_SRCS = kernel.gemmini.cpp
|
||||||
VX_INCLUDES = ../sgemm_tcore/sgemm_impl.hpp
|
VX_INCLUDES = flash_impl.hpp ../sgemm_tcore/sgemm_impl.hpp
|
||||||
|
|
||||||
OPTS ?= -n16
|
OPTS ?= -n16
|
||||||
|
|
||||||
|
|||||||
559
tests/regression/flash_attention/flash_impl.hpp
Normal file
559
tests/regression/flash_attention/flash_impl.hpp
Normal file
@@ -0,0 +1,559 @@
|
|||||||
|
#ifndef _FLASH_IMPL_H_
|
||||||
|
#define _FLASH_IMPL_H_
|
||||||
|
|
||||||
|
#include <vx_spawn.h>
|
||||||
|
#include <float.h>
|
||||||
|
|
||||||
|
#define B_ROW 64
|
||||||
|
#define B_COL 64
|
||||||
|
#define HEADDIM 64
|
||||||
|
|
||||||
|
#define ROW_REMAINDER_LOGIC
|
||||||
|
|
||||||
|
constexpr uint32_t ROWMAX_SETS = 3;
|
||||||
|
constexpr bool WARP_SPECIALIZED = true;
|
||||||
|
constexpr bool TENSOR_CORE = true;
|
||||||
|
|
||||||
|
// temporary safety stop for wrong configs
|
||||||
|
static_assert(NUM_CORES == 4);
|
||||||
|
static_assert(NUM_THREADS == 8);
|
||||||
|
static_assert(NUM_WARPS == 8);
|
||||||
|
|
||||||
|
inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock,
|
||||||
|
const uint32_t threads_per_threadblock,
|
||||||
|
float *smem_O, float *smem_rowmax,
|
||||||
|
float *smem_rowsum,
|
||||||
|
float *smem_O_row_scale) {
|
||||||
|
asm volatile("threadblock_init_sharedmem_start_%=:" ::);
|
||||||
|
|
||||||
|
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
||||||
|
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
|
||||||
|
const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS;
|
||||||
|
|
||||||
|
static_assert((B_ROW % NUM_THREADS) == 0,
|
||||||
|
"B_ROW must be a multiple of NUM_THREADS");
|
||||||
|
static_assert(B_ROW < (NUM_THREADS * CORES_PER_CLUSTER *
|
||||||
|
(NUM_WARPS / (WARP_SPECIALIZED ? 2 : 1))),
|
||||||
|
"not enough warps to initialize rowmax/rowsum");
|
||||||
|
|
||||||
|
// each thread initializes one element in rowmax/rowsum
|
||||||
|
// multiple warps participate for the whole vector
|
||||||
|
constexpr uint32_t needed_warps = B_ROW / NUM_THREADS;
|
||||||
|
if (warp_id < needed_warps /* more warps in HW than needed? */) {
|
||||||
|
uint32_t offset = NUM_THREADS * warp_id + tid_in_warp;
|
||||||
|
#pragma GCC unroll
|
||||||
|
for (int i = 0; i < ROWMAX_SETS; i++) {
|
||||||
|
smem_rowmax[offset + i * ROWMAX_SETS] = FLT_MIN;
|
||||||
|
}
|
||||||
|
smem_rowsum[offset] = 0.0f;
|
||||||
|
smem_O_row_scale[offset] = 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
// each warp clears out a row of smem_O
|
||||||
|
// FIXME: dedup this pattern
|
||||||
|
#pragma GCC unroll 1
|
||||||
|
for (int row_offset = 0; row_offset < B_COL;
|
||||||
|
row_offset += warps_in_threadblock) {
|
||||||
|
const uint32_t row = row_offset + warp_id;
|
||||||
|
#ifdef ROW_REMAINDER_LOGIC
|
||||||
|
if (row >= B_ROW) {
|
||||||
|
// WARNING: the number of barrier calls have to exactly match that in the
|
||||||
|
// outside of the branch to prevent stalls!! FIXME better proof this.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
uint32_t thread_offset = HEADDIM * row + tid_in_warp;
|
||||||
|
constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS;
|
||||||
|
const float one = 0.0f;
|
||||||
|
#pragma GCC unroll
|
||||||
|
for (int i = 0; i < per_row_iter; i++) {
|
||||||
|
smem_O[thread_offset] = 0.0f;
|
||||||
|
thread_offset += NUM_THREADS;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
asm volatile("threadblock_init_sharedmem_finish_%=:" ::);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void thread_block_copy_rowmax(const float *src, float *dest,
|
||||||
|
const uint32_t tid_in_threadblock,
|
||||||
|
const uint32_t threads_per_threadblock,
|
||||||
|
const uint32_t threadblock_id_in_cluster) {
|
||||||
|
asm volatile("threadblock_copy_rowmax_start_%=:" ::);
|
||||||
|
|
||||||
|
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
||||||
|
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
|
||||||
|
const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS;
|
||||||
|
const uint32_t warps_per_threadblock_per_core =
|
||||||
|
warps_in_threadblock / CORES_PER_CLUSTER;
|
||||||
|
|
||||||
|
// each thread copies one element in rowmax
|
||||||
|
// multiple warps participate for the whole vector
|
||||||
|
constexpr uint32_t num_warps = B_ROW / NUM_THREADS;
|
||||||
|
if (warp_id < num_warps) {
|
||||||
|
uint32_t offset = NUM_THREADS * warp_id + tid_in_warp;
|
||||||
|
dest[offset] = src[offset];
|
||||||
|
}
|
||||||
|
|
||||||
|
if constexpr (!TENSOR_CORE) {
|
||||||
|
threadblock_barrier(1, 7);
|
||||||
|
} else {
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
}
|
||||||
|
|
||||||
|
asm volatile("threadblock_copy_rowmax_finish_%=:" ::);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <uint32_t dim_row, uint32_t dim_col, bool block_row_major = false>
|
||||||
|
inline void thread_block_copy_tile(const float *src, float *dest,
|
||||||
|
const uint32_t tid_in_threadblock,
|
||||||
|
const uint32_t threads_per_threadblock,
|
||||||
|
const uint32_t threadblock_id_in_cluster) {
|
||||||
|
asm volatile("threadblock_copy_tile_start_%=:" ::);
|
||||||
|
|
||||||
|
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
||||||
|
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
|
||||||
|
const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS;
|
||||||
|
const uint32_t warps_per_threadblock_per_core =
|
||||||
|
warps_in_threadblock / CORES_PER_CLUSTER;
|
||||||
|
|
||||||
|
// FIXME: dedup this pattern
|
||||||
|
#pragma GCC unroll 1
|
||||||
|
for (int row_offset = 0; row_offset < dim_row;
|
||||||
|
row_offset += warps_in_threadblock) {
|
||||||
|
const uint32_t row = row_offset + warp_id;
|
||||||
|
#ifdef ROW_REMAINDER_LOGIC
|
||||||
|
if (row >= B_ROW) {
|
||||||
|
// WARNING: the number of barrier calls have to exactly match that in the
|
||||||
|
// outside of the branch to prevent stalls!! FIXME better proof this.
|
||||||
|
if constexpr (!TENSOR_CORE) {
|
||||||
|
threadblock_barrier(1, 7);
|
||||||
|
} else {
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
constexpr uint32_t per_row_iter = dim_col / NUM_THREADS;
|
||||||
|
#pragma GCC unroll
|
||||||
|
for (int i = 0; i < per_row_iter; i++) {
|
||||||
|
const uint32_t col_offset = NUM_THREADS * i;
|
||||||
|
const uint32_t col = col_offset + tid_in_warp;
|
||||||
|
const auto [smem_row, smem_col] =
|
||||||
|
remap_to_gemmini_dma_layout<block_row_major, B_COL>(row, col);
|
||||||
|
const uint32_t smem_offset = B_COL * smem_row + smem_col;
|
||||||
|
const uint32_t gmem_offset = B_COL * row + col;
|
||||||
|
|
||||||
|
dest[gmem_offset] = src[smem_offset];
|
||||||
|
}
|
||||||
|
|
||||||
|
if constexpr (!TENSOR_CORE) {
|
||||||
|
threadblock_barrier(1, 7);
|
||||||
|
} else {
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
asm volatile("threadblock_copy_tile_finish_%=:" ::);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int order>
|
||||||
|
inline float exponential_taylor_term(const float x) {
|
||||||
|
asm volatile("exponential_taylor_term_start_%=:" ::);
|
||||||
|
|
||||||
|
float res = 1.0f;
|
||||||
|
|
||||||
|
if constexpr (order == 1) {
|
||||||
|
res = x;
|
||||||
|
} else if constexpr (order == 2) {
|
||||||
|
res = x * x;
|
||||||
|
res /= 2.0f;
|
||||||
|
} else if constexpr (order == 3) {
|
||||||
|
res = x * x * x;
|
||||||
|
res /= 6.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
asm volatile("exponential_taylor_term_end_%=:" ::);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool block_row_major = false>
|
||||||
|
__attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||||
|
const float *smem_S, float *smem_P, const uint32_t tid_in_threadblock,
|
||||||
|
const uint32_t threads_per_threadblock,
|
||||||
|
const uint32_t threadblock_id_in_cluster, float *smem_scratchpad,
|
||||||
|
float *smem_rowmax, float *smem_rowsum, float *smem_O_row_scale) {
|
||||||
|
asm volatile("thread_block_online_softmax_start_%=:" ::);
|
||||||
|
|
||||||
|
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
||||||
|
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
|
||||||
|
const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS;
|
||||||
|
const uint32_t warps_per_threadblock_per_core =
|
||||||
|
warps_in_threadblock / CORES_PER_CLUSTER;
|
||||||
|
|
||||||
|
float *smem_rowmax_this = smem_rowmax + B_ROW;
|
||||||
|
|
||||||
|
#pragma GCC unroll 1
|
||||||
|
for (int row_offset = 0; row_offset < B_ROW;
|
||||||
|
row_offset += warps_in_threadblock) {
|
||||||
|
const uint32_t row = row_offset + warp_id;
|
||||||
|
#ifdef ROW_REMAINDER_LOGIC
|
||||||
|
// if the number of warps doesn't exactly divide the number of rows,
|
||||||
|
// early-exit to prevent out-of-bounds access
|
||||||
|
if (row >= B_ROW) {
|
||||||
|
// WARNING: the number of barrier calls have to exactly match that in the
|
||||||
|
// outside of the branch to prevent stalls!! FIXME better proof this.
|
||||||
|
if constexpr (!TENSOR_CORE) {
|
||||||
|
threadblock_barrier(1, 7);
|
||||||
|
threadblock_barrier(1, 7);
|
||||||
|
threadblock_barrier(1, 7);
|
||||||
|
threadblock_barrier(1, 7);
|
||||||
|
threadblock_barrier(1, 7);
|
||||||
|
threadblock_barrier(1, 7);
|
||||||
|
} else {
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
}
|
||||||
|
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
const uint32_t first_thread_offset = B_COL * row;
|
||||||
|
|
||||||
|
// rowmax
|
||||||
|
//
|
||||||
|
// two-level tree reduction: reduce each row into NUM_THREADS intermediate
|
||||||
|
// maxes, then reduce it down to one row max
|
||||||
|
// one warp handles one row in tile
|
||||||
|
|
||||||
|
constexpr uint32_t per_row_iter = B_COL / NUM_THREADS;
|
||||||
|
// FIXME: threadblock_id needs to be in here too
|
||||||
|
float *warp_smem = smem_scratchpad + (warp_id * NUM_THREADS);
|
||||||
|
|
||||||
|
// #define DUMB_ROWMAX
|
||||||
|
#ifdef DUMB_ROWMAX
|
||||||
|
// FIXME remove
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
|
||||||
|
// no tree reduction; a single thread in a warp does serialized max across
|
||||||
|
// the entire row
|
||||||
|
if (tid_in_warp == 0) {
|
||||||
|
float rowmax = smem_S[first_thread_offset];
|
||||||
|
#pragma GCC unroll 16
|
||||||
|
for (int i = 0; i < B_COL; i++) {
|
||||||
|
asm volatile("fmax.s %0, %1, %2"
|
||||||
|
: "=f"(rowmax)
|
||||||
|
: "f"(rowmax), "f"(smem_S[first_thread_offset + i]));
|
||||||
|
}
|
||||||
|
smem_rowmax_this[row] = rowmax;
|
||||||
|
|
||||||
|
// update previous rowmax
|
||||||
|
// i.e. mi_new = max(mi, mij)
|
||||||
|
float prev_rowmax = smem_rowmax[row];
|
||||||
|
// stage prev rowmax in scratchpad for warp-wide broadcast
|
||||||
|
warp_smem[0] = prev_rowmax;
|
||||||
|
asm volatile("fmax.s %0, %1, %2"
|
||||||
|
: "=f"(rowmax)
|
||||||
|
: "f"(rowmax), "f"(prev_rowmax));
|
||||||
|
smem_rowmax[row] = rowmax;
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
static_assert((B_COL % NUM_THREADS) == 0,
|
||||||
|
"B_COL must be a multiple of NUM_THREADS");
|
||||||
|
float per_thread_max = FLT_MIN;
|
||||||
|
#pragma GCC unroll
|
||||||
|
for (int i = 0; i < per_row_iter; i++) {
|
||||||
|
const uint32_t col_offset = NUM_THREADS * i;
|
||||||
|
const uint32_t col = col_offset + tid_in_warp;
|
||||||
|
const auto [smem_row, smem_col] =
|
||||||
|
remap_to_gemmini_dma_layout<block_row_major, B_COL>(row, col);
|
||||||
|
const uint32_t offset = B_COL * smem_row + smem_col;
|
||||||
|
|
||||||
|
const float next = smem_S[offset];
|
||||||
|
asm volatile("fmax.s %0, %1, %2"
|
||||||
|
: "=f"(per_thread_max)
|
||||||
|
: "f"(per_thread_max), "f"(next));
|
||||||
|
}
|
||||||
|
// stage per-thread max value in smem
|
||||||
|
warp_smem[tid_in_warp] = per_thread_max;
|
||||||
|
|
||||||
|
// sync writes to warp_smem
|
||||||
|
if constexpr (!TENSOR_CORE) {
|
||||||
|
threadblock_barrier(1, 7);
|
||||||
|
} else {
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
}
|
||||||
|
|
||||||
|
// #define PARALLEL_ROWMAX
|
||||||
|
#ifndef PARALLEL_ROWMAX
|
||||||
|
// elect 0-th thread to reduce all other thread's values in the warp
|
||||||
|
if (tid_in_warp == 0) {
|
||||||
|
float rowmax = per_thread_max;
|
||||||
|
for (int i = 1; i < NUM_THREADS; i++) {
|
||||||
|
float other = warp_smem[i];
|
||||||
|
asm volatile("fmax.s %0, %1, %2"
|
||||||
|
: "=f"(rowmax)
|
||||||
|
: "f"(rowmax), "f"(other));
|
||||||
|
}
|
||||||
|
smem_rowmax_this[row] = rowmax;
|
||||||
|
|
||||||
|
// update previous rowmax
|
||||||
|
// i.e. mi_new = max(mi, mij)
|
||||||
|
float prev_rowmax = smem_rowmax[row];
|
||||||
|
// stage prev rowmax in scratchpad for warp-wide broadcast
|
||||||
|
warp_smem[0] = prev_rowmax;
|
||||||
|
asm volatile("fmax.s %0, %1, %2"
|
||||||
|
: "=f"(rowmax)
|
||||||
|
: "f"(rowmax), "f"(prev_rowmax));
|
||||||
|
smem_rowmax[row] = rowmax;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
if (warp_id < warps_in_threadblock / NUM_THREADS) {
|
||||||
|
const uint32_t row = row_offset + NUM_THREADS * warp_id + tid_in_warp;
|
||||||
|
float *const thread_smem = smem_scratchpad + (tid_in_warp * NUM_THREADS);
|
||||||
|
float rowmax = FLT_MIN;
|
||||||
|
#pragma GCC unroll
|
||||||
|
for (int i = 0; i < NUM_THREADS; i++) {
|
||||||
|
const float f = thread_smem[i];
|
||||||
|
asm volatile("fmax.s %0, %1, %2" : "=f"(rowmax) : "f"(rowmax), "f"(f));
|
||||||
|
}
|
||||||
|
smem_rowmax_this[row] = rowmax;
|
||||||
|
|
||||||
|
// update previous rowmax
|
||||||
|
// i.e. mi_new = max(mi, mij)
|
||||||
|
float prev_rowmax = smem_rowmax[row];
|
||||||
|
// stage prev rowmax in scratchpad for warp-wide broadcast
|
||||||
|
thread_smem[0] = prev_rowmax;
|
||||||
|
asm volatile("fmax.s %0, %1, %2"
|
||||||
|
: "=f"(rowmax)
|
||||||
|
: "f"(rowmax), "f"(prev_rowmax));
|
||||||
|
smem_rowmax[row] = rowmax;
|
||||||
|
}
|
||||||
|
#endif // PARALLEL_ROWMAX
|
||||||
|
#endif // DUMB_ROWMAX
|
||||||
|
|
||||||
|
if constexpr (!TENSOR_CORE) {
|
||||||
|
threadblock_barrier(1, 7);
|
||||||
|
} else {
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// broadcast prev rowmax to all threads in the warp
|
||||||
|
// NOTE: memory consistency is a little sketchy here
|
||||||
|
const float rowmax_prev = warp_smem[0];
|
||||||
|
const float rowmax_this = smem_rowmax_this[row];
|
||||||
|
|
||||||
|
// exponential
|
||||||
|
//
|
||||||
|
// B_ROW / (B_ROW * B_COL / (exp_elem * threads_per_threadblock))
|
||||||
|
// const uint32_t row_stride =
|
||||||
|
// (exp_elem_per_thread * threads_per_threadblock) / B_COL;
|
||||||
|
|
||||||
|
// broadcast updated rowmax to all threads in the warp
|
||||||
|
const float rowmax_new = smem_rowmax[row];
|
||||||
|
|
||||||
|
asm volatile("flashattn_exp_p_start_%=:" ::);
|
||||||
|
|
||||||
|
#pragma GCC unroll
|
||||||
|
for (int i = 0; i < per_row_iter; i++) {
|
||||||
|
const uint32_t col_offset = NUM_THREADS * i;
|
||||||
|
const uint32_t col = col_offset + tid_in_warp;
|
||||||
|
const auto [smem_row, smem_col] =
|
||||||
|
remap_to_gemmini_dma_layout<block_row_major, B_COL>(row, col);
|
||||||
|
const uint32_t offset = B_COL * smem_row + smem_col;
|
||||||
|
|
||||||
|
float f0 = smem_S[offset];
|
||||||
|
|
||||||
|
f0 -= rowmax_new;
|
||||||
|
|
||||||
|
// 2nd-order Taylor approximation
|
||||||
|
float exp = 1.0f;
|
||||||
|
exp += exponential_taylor_term<1>(f0);
|
||||||
|
exp += exponential_taylor_term<2>(f0);
|
||||||
|
|
||||||
|
// Store S transposed to the shared memory
|
||||||
|
|
||||||
|
smem_P[offset] = exp;
|
||||||
|
}
|
||||||
|
|
||||||
|
asm volatile("flashattn_exp_p_end_%=:" ::);
|
||||||
|
|
||||||
|
if constexpr (!TENSOR_CORE) {
|
||||||
|
threadblock_barrier(1, 7);
|
||||||
|
} else {
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
}
|
||||||
|
|
||||||
|
// rowsum
|
||||||
|
//
|
||||||
|
// two-level tree reduction, similar to rowmax
|
||||||
|
|
||||||
|
asm volatile("flashattn_rowsum_start_%=:" ::);
|
||||||
|
|
||||||
|
float per_thread_sum = 0.0f;
|
||||||
|
|
||||||
|
#pragma GCC unroll
|
||||||
|
for (int i = 0; i < per_row_iter; i++) {
|
||||||
|
const uint32_t col_offset = NUM_THREADS * i;
|
||||||
|
const uint32_t col = col_offset + tid_in_warp;
|
||||||
|
const auto [smem_row, smem_col] =
|
||||||
|
remap_to_gemmini_dma_layout<block_row_major, B_COL>(row, col);
|
||||||
|
const uint32_t offset = B_COL * smem_row + smem_col;
|
||||||
|
|
||||||
|
per_thread_sum += smem_P[offset];
|
||||||
|
}
|
||||||
|
// stage per-thread sum value in smem
|
||||||
|
// FIXME: threadblock_id needs to be in here too
|
||||||
|
warp_smem = smem_scratchpad + (warp_id * NUM_THREADS);
|
||||||
|
warp_smem[tid_in_warp] = per_thread_sum;
|
||||||
|
|
||||||
|
// sync writes to warp_smem
|
||||||
|
if constexpr (!TENSOR_CORE) {
|
||||||
|
threadblock_barrier(1, 7);
|
||||||
|
} else {
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 0-th thread collects all other thread's values in the warp
|
||||||
|
if (tid_in_warp == 0) {
|
||||||
|
float rowsum = per_thread_sum;
|
||||||
|
for (int iter = 1; iter < NUM_THREADS; iter++) {
|
||||||
|
float other = warp_smem[iter];
|
||||||
|
rowsum += other;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float mi_prev = rowmax_prev;
|
||||||
|
const float mi_this = rowmax_this;
|
||||||
|
|
||||||
|
const float x = mi_prev - mi_this;
|
||||||
|
// 2nd-order Taylor approximation
|
||||||
|
float exp = 1.0f;
|
||||||
|
exp += exponential_taylor_term<1>(x);
|
||||||
|
exp += exponential_taylor_term<2>(x);
|
||||||
|
|
||||||
|
// update rowsum
|
||||||
|
const float rowsum_prev = smem_rowsum[row];
|
||||||
|
float rowsum_new = exp * rowsum_prev + rowsum;
|
||||||
|
|
||||||
|
smem_rowsum[row] = rowsum_new;
|
||||||
|
}
|
||||||
|
|
||||||
|
asm volatile("flashattn_rowsum_end_%=:" ::);
|
||||||
|
|
||||||
|
if constexpr (!TENSOR_CORE) {
|
||||||
|
threadblock_barrier(1, 7);
|
||||||
|
} else {
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
}
|
||||||
|
|
||||||
|
// compute Oi rescale factor
|
||||||
|
// FIXME: parallelize this across threads
|
||||||
|
//
|
||||||
|
asm volatile("flashattn_rescale_factor_start_%=:" ::);
|
||||||
|
|
||||||
|
#pragma GCC unroll
|
||||||
|
for (int i = 0; i < per_row_iter; i++) {
|
||||||
|
const float mi_prev = rowmax_prev;
|
||||||
|
const float mi_new = rowmax_new;
|
||||||
|
|
||||||
|
const float x = mi_prev - mi_new;
|
||||||
|
// 2nd-order Taylor approximation
|
||||||
|
float exp = 1.0f;
|
||||||
|
exp += exponential_taylor_term<1>(x);
|
||||||
|
exp += exponential_taylor_term<2>(x);
|
||||||
|
|
||||||
|
// @perf: div vs. expansion on e(-x)?
|
||||||
|
smem_O_row_scale[row] = 1.0f / exp;
|
||||||
|
}
|
||||||
|
|
||||||
|
asm volatile("flashattn_rescale_factor_end_%=:" ::);
|
||||||
|
|
||||||
|
if constexpr (!TENSOR_CORE) {
|
||||||
|
threadblock_barrier(1, 7);
|
||||||
|
} else {
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
asm volatile("thread_block_online_softmax_finish_%=:" ::);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool block_row_major = false>
|
||||||
|
__attribute__((always_inline)) inline void thread_block_O_rescale(
|
||||||
|
const float *smem_O_in, float *smem_O_out, const float *smem_O_row_scale,
|
||||||
|
const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock,
|
||||||
|
const uint32_t threadblock_id_in_cluster) {
|
||||||
|
asm volatile("thread_block_O_rescale_start_%=:" ::);
|
||||||
|
|
||||||
|
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
||||||
|
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
|
||||||
|
const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS;
|
||||||
|
const uint32_t warps_per_threadblock_per_core =
|
||||||
|
warps_in_threadblock / CORES_PER_CLUSTER;
|
||||||
|
|
||||||
|
#pragma GCC unroll 1
|
||||||
|
for (int row_offset = 0; row_offset < B_ROW;
|
||||||
|
row_offset += warps_in_threadblock) {
|
||||||
|
const uint32_t row = row_offset + warp_id;
|
||||||
|
#ifdef ROW_REMAINDER_LOGIC
|
||||||
|
if (row >= B_ROW) {
|
||||||
|
// WARNING: the number of barrier calls have to exactly match that in the
|
||||||
|
// outside of the branch to prevent stalls!! FIXME better proof this.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS;
|
||||||
|
|
||||||
|
// Oi rescale
|
||||||
|
//
|
||||||
|
#pragma GCC unroll
|
||||||
|
for (int i = 0; i < per_row_iter; i++) {
|
||||||
|
const uint32_t col_offset = NUM_THREADS * i;
|
||||||
|
const uint32_t col = col_offset + tid_in_warp;
|
||||||
|
const auto [smem_row, smem_col] =
|
||||||
|
remap_to_gemmini_dma_layout<block_row_major, HEADDIM>(row, col);
|
||||||
|
|
||||||
|
const uint32_t offset = HEADDIM * smem_row + smem_col;
|
||||||
|
const float o = smem_O_in[offset];
|
||||||
|
const float scale = smem_O_row_scale[row];
|
||||||
|
smem_O_out[offset] = (o * scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// reconverge after warp divergence
|
||||||
|
if constexpr (!TENSOR_CORE) {
|
||||||
|
threadblock_barrier(1, 7);
|
||||||
|
} else {
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
}
|
||||||
|
|
||||||
|
asm volatile("thread_block_O_rescale_finish_%=:" ::);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
File diff suppressed because it is too large
Load Diff
698
tests/regression/flash_attention/kernel.gemmini.cpp
Normal file
698
tests/regression/flash_attention/kernel.gemmini.cpp
Normal file
@@ -0,0 +1,698 @@
|
|||||||
|
#include <stdint.h>
|
||||||
|
#include <vx_intrinsics.h>
|
||||||
|
#include <vx_print.h>
|
||||||
|
#include <vx_spawn.h>
|
||||||
|
#include "common.h"
|
||||||
|
#include "sgemm_impl.hpp"
|
||||||
|
#include "include/gemmini.h"
|
||||||
|
#include "gemmini_mmio.h"
|
||||||
|
#include "flash_impl.hpp"
|
||||||
|
|
||||||
|
#define FENCE_GEMM_II
|
||||||
|
|
||||||
|
constexpr bool DEBUG = false;
|
||||||
|
|
||||||
|
static_assert(GEMMINI_DMA && !WARP_SPECIALIZED,
|
||||||
|
"GEMMINI_DMA should be set and WARP_SPECIALIZED unset");
|
||||||
|
|
||||||
|
void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||||
|
// @perf: All threads are running these compute whose result is mostly same
|
||||||
|
// across the threadblock
|
||||||
|
|
||||||
|
#ifdef RADIANCE
|
||||||
|
constexpr uint32_t cores_per_cluster = CORES_PER_CLUSTER;
|
||||||
|
#else
|
||||||
|
constexpr uint32_t cores_per_cluster = 1;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// FIXME: headdim not considered
|
||||||
|
constexpr uint32_t threads_per_threadblock_theoretical =
|
||||||
|
(B_ROW * B_COL) / (ELEM_PER_THREAD);
|
||||||
|
constexpr uint32_t hw_threads_per_cluster =
|
||||||
|
CORES_PER_CLUSTER * NUM_THREADS * NUM_WARPS;
|
||||||
|
// cap maximum threadblock size to # of HW threads in cluster, to prevent
|
||||||
|
// multiple "wave" invocations which slows down the kernel
|
||||||
|
constexpr uint32_t threads_per_threadblock =
|
||||||
|
(threads_per_threadblock_theoretical > hw_threads_per_cluster)
|
||||||
|
? hw_threads_per_cluster
|
||||||
|
: threads_per_threadblock_theoretical;
|
||||||
|
constexpr uint32_t threadblocks_per_cluster =
|
||||||
|
hw_threads_per_cluster / threads_per_threadblock;
|
||||||
|
constexpr uint32_t warps_per_threadblock_per_core =
|
||||||
|
NUM_WARPS / threadblocks_per_cluster;
|
||||||
|
|
||||||
|
const uint32_t threadblock_id = task_id / threads_per_threadblock;
|
||||||
|
const uint32_t threadblock_id_in_cluster =
|
||||||
|
threadblock_id % threadblocks_per_cluster;
|
||||||
|
const uint32_t tid_in_threadblock = task_id % threads_per_threadblock;
|
||||||
|
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
|
||||||
|
constexpr uint32_t warps_in_threadblock =
|
||||||
|
threads_per_threadblock / NUM_THREADS;
|
||||||
|
|
||||||
|
// warpgroup context
|
||||||
|
constexpr uint32_t threads_per_warpgroup =
|
||||||
|
threads_per_threadblock / (WARP_SPECIALIZED ? 2 : 1);
|
||||||
|
constexpr uint32_t warpgroups_per_cluster =
|
||||||
|
threadblocks_per_cluster * (WARP_SPECIALIZED ? 2 : 1);
|
||||||
|
const uint32_t warps_per_warpgroup_per_core =
|
||||||
|
NUM_WARPS / warpgroups_per_cluster;
|
||||||
|
const uint32_t warpgroup_id = task_id / threads_per_warpgroup;
|
||||||
|
const uint32_t warpgroup_id_in_cluster =
|
||||||
|
warpgroup_id % warpgroups_per_cluster;
|
||||||
|
const uint32_t tid_in_warpgroup = tid_in_threadblock % threads_per_warpgroup;
|
||||||
|
// // warpgroup 0: warp 0
|
||||||
|
// // warpgroup 1: warp 1~7
|
||||||
|
// const uint32_t warpgroup_id = (warp_id != 0);
|
||||||
|
|
||||||
|
const uint32_t dim_seqlen = arg->dim_seqlen;
|
||||||
|
const uint32_t dim_headdim = arg->dim_headdim;
|
||||||
|
|
||||||
|
// get global memory addresses from kernel arguments
|
||||||
|
const float *gmem_Q = reinterpret_cast<float *>(arg->addr_q);
|
||||||
|
const float *gmem_K = reinterpret_cast<float *>(arg->addr_k);
|
||||||
|
const float *gmem_V = reinterpret_cast<float *>(arg->addr_v);
|
||||||
|
float *gmem_O = reinterpret_cast<float *>(arg->addr_o);
|
||||||
|
|
||||||
|
float *gmem_tmp_d0 = reinterpret_cast<float *>(0xd0000000UL);
|
||||||
|
float *gmem_tmp_d1 = reinterpret_cast<float *>(0xd1000000UL);
|
||||||
|
float *gmem_tmp_d2 = reinterpret_cast<float *>(0xd2000000UL);
|
||||||
|
float *gmem_tmp_d3 = reinterpret_cast<float *>(0xd3000000UL);
|
||||||
|
float *gmem_tmp_d4 = reinterpret_cast<float *>(0xd4000000UL);
|
||||||
|
float *gmem_tmp_d5 = reinterpret_cast<float *>(0xd5000000UL);
|
||||||
|
float *gmem_tmp_d6 = reinterpret_cast<float *>(0xd6000000UL);
|
||||||
|
float *gmem_tmp_d7 = reinterpret_cast<float *>(0xd7000000UL);
|
||||||
|
float *gmem_tmp_e0 = reinterpret_cast<float *>(0xe0000000UL);
|
||||||
|
float *gmem_tmp_e1 = reinterpret_cast<float *>(0xe1000000UL);
|
||||||
|
float *gmem_tmp_e2 = reinterpret_cast<float *>(0xe2000000UL);
|
||||||
|
float *gmem_tmp_e3 = reinterpret_cast<float *>(0xe3000000UL);
|
||||||
|
|
||||||
|
// static shared memory allocation
|
||||||
|
// these are in float elements, not bytes
|
||||||
|
constexpr uint32_t smem_Q_size = B_ROW * HEADDIM;
|
||||||
|
constexpr uint32_t smem_K_size = B_COL * HEADDIM;
|
||||||
|
constexpr uint32_t smem_QK_size = B_ROW * B_COL;
|
||||||
|
constexpr uint32_t smem_V_size = B_COL * HEADDIM;
|
||||||
|
constexpr uint32_t smem_O_size = B_COL * HEADDIM;
|
||||||
|
static_assert(
|
||||||
|
threads_per_threadblock == NUM_WARPS * NUM_THREADS * CORES_PER_CLUSTER,
|
||||||
|
"flashattention kernel assumes 1 threadblock occupancy per cluster");
|
||||||
|
uint8_t *smem_per_threadblock = reinterpret_cast<uint8_t *>(DEV_SMEM_START_ADDR);
|
||||||
|
constexpr uint32_t smem_start = DEV_SMEM_START_ADDR;
|
||||||
|
constexpr uint32_t smem_quart0 = 0 * (SMEM_SIZE / 4);
|
||||||
|
constexpr uint32_t smem_quart1 = 1 * (SMEM_SIZE / 4);
|
||||||
|
constexpr uint32_t smem_quart2 = 2 * (SMEM_SIZE / 4);
|
||||||
|
constexpr uint32_t smem_quart3 = 3 * (SMEM_SIZE / 4);
|
||||||
|
|
||||||
|
// Q/V/S in quart0/1, K/P/O in quart2/3
|
||||||
|
constexpr uint32_t smem_Q0_offset = smem_quart0;
|
||||||
|
constexpr uint32_t smem_Q1_offset = smem_quart1;
|
||||||
|
constexpr uint32_t smem_K0_offset = smem_quart2;
|
||||||
|
constexpr uint32_t smem_K1_offset = smem_quart3;
|
||||||
|
constexpr uint32_t smem_V0_offset = smem_Q0_offset + smem_Q_size * sizeof(float);
|
||||||
|
constexpr uint32_t smem_V1_offset = smem_Q1_offset + smem_Q_size * sizeof(float);
|
||||||
|
// put S1/S0 with V0/V1 so that softmax and GEMM-II doesn't cause bank
|
||||||
|
// conflicts
|
||||||
|
constexpr uint32_t smem_S0_offset = smem_V1_offset + smem_V_size * sizeof(float);
|
||||||
|
constexpr uint32_t smem_S1_offset = smem_V0_offset + smem_V_size * sizeof(float);
|
||||||
|
constexpr uint32_t smem_P0_offset = smem_K0_offset + smem_K_size * sizeof(float);
|
||||||
|
constexpr uint32_t smem_P1_offset = smem_K1_offset + smem_K_size * sizeof(float);
|
||||||
|
// reversed!
|
||||||
|
constexpr uint32_t smem_O0_offset = smem_P1_offset + smem_QK_size * sizeof(float);
|
||||||
|
constexpr uint32_t smem_O1_offset = smem_P0_offset + smem_QK_size * sizeof(float); // unused
|
||||||
|
|
||||||
|
float *smem_Q0 = reinterpret_cast<float *>(smem_start + smem_Q0_offset);
|
||||||
|
float *smem_Q1 = reinterpret_cast<float *>(smem_start + smem_Q1_offset);
|
||||||
|
float *smem_K0 = reinterpret_cast<float *>(smem_start + smem_K0_offset);
|
||||||
|
float *smem_K1 = reinterpret_cast<float *>(smem_start + smem_K1_offset);
|
||||||
|
float *smem_V0 = reinterpret_cast<float *>(smem_start + smem_V0_offset);
|
||||||
|
float *smem_V1 = reinterpret_cast<float *>(smem_start + smem_V1_offset);
|
||||||
|
float *smem_S0 = reinterpret_cast<float *>(smem_start + smem_S0_offset);
|
||||||
|
float *smem_S1 = reinterpret_cast<float *>(smem_start + smem_S1_offset);
|
||||||
|
float *smem_P0 = reinterpret_cast<float *>(smem_start + smem_P0_offset);
|
||||||
|
float *smem_P1 = reinterpret_cast<float *>(smem_start + smem_P1_offset);
|
||||||
|
float *smem_O0 = reinterpret_cast<float *>(smem_start + smem_O0_offset);
|
||||||
|
float *smem_O1 = reinterpret_cast<float *>(smem_start + smem_O1_offset);
|
||||||
|
|
||||||
|
// allocate rowmax/rowsum storage at the end of the sharedmem address space
|
||||||
|
constexpr uint32_t smem_rowmax_size = B_ROW * ROWMAX_SETS;
|
||||||
|
constexpr uint32_t smem_rowsum_size = B_ROW;
|
||||||
|
constexpr uint32_t smem_O_row_scale_size = B_ROW;
|
||||||
|
|
||||||
|
float *smem_cursor = smem_O1 + smem_O_size;
|
||||||
|
// // FIXME: dangerous
|
||||||
|
// smem_cursor = reinterpret_cast<float *>(0xff038000);
|
||||||
|
float *smem_rowmax_0 = smem_cursor;
|
||||||
|
smem_cursor += smem_rowmax_size;
|
||||||
|
float *smem_rowmax_1 = smem_cursor;
|
||||||
|
smem_cursor += smem_rowmax_size;
|
||||||
|
float *smem_rowsum_0 = smem_cursor;
|
||||||
|
smem_cursor += smem_rowsum_size;
|
||||||
|
float *smem_rowsum_1 = smem_cursor;
|
||||||
|
smem_cursor += smem_rowsum_size;
|
||||||
|
float *smem_O_row_scale_0 = smem_cursor;
|
||||||
|
smem_cursor += smem_O_row_scale_size;
|
||||||
|
float *smem_O_row_scale_1 = smem_cursor;
|
||||||
|
smem_cursor += smem_O_row_scale_size;
|
||||||
|
|
||||||
|
// sharedmem "scratchpad" area to put temporary data, e.g. for tree reduction
|
||||||
|
// in rowsum
|
||||||
|
// NOTE: out-of bounds is not checked
|
||||||
|
constexpr uint32_t smem_scratchpad_size =
|
||||||
|
threads_per_warpgroup * 2 /*arbitrary slack*/;
|
||||||
|
float *smem_scratchpad_0 = smem_cursor;
|
||||||
|
smem_cursor += smem_scratchpad_size;
|
||||||
|
float *smem_scratchpad_1 = smem_cursor;
|
||||||
|
smem_cursor += smem_scratchpad_size;
|
||||||
|
uint32_t *smem_O_flag = reinterpret_cast<uint32_t *>(smem_cursor);
|
||||||
|
smem_cursor += 1 /* 4Byte */;
|
||||||
|
|
||||||
|
static_assert(sizeof(elem_t) == sizeof(float));
|
||||||
|
constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t);
|
||||||
|
constexpr uint32_t spad_addr_Q0 = smem_Q0_offset / spad_addr_factor;
|
||||||
|
constexpr uint32_t spad_addr_Q1 = smem_Q1_offset / spad_addr_factor;
|
||||||
|
constexpr uint32_t spad_addr_K0 = smem_K0_offset / spad_addr_factor;
|
||||||
|
constexpr uint32_t spad_addr_K1 = smem_K1_offset / spad_addr_factor;
|
||||||
|
constexpr uint32_t spad_addr_V0 = smem_V0_offset / spad_addr_factor;
|
||||||
|
constexpr uint32_t spad_addr_V1 = smem_V1_offset / spad_addr_factor;
|
||||||
|
constexpr uint32_t spad_addr_S0 = smem_S0_offset / spad_addr_factor;
|
||||||
|
constexpr uint32_t spad_addr_S1 = smem_S1_offset / spad_addr_factor;
|
||||||
|
constexpr uint32_t spad_addr_P0 = smem_P0_offset / spad_addr_factor;
|
||||||
|
constexpr uint32_t spad_addr_P1 = smem_P1_offset / spad_addr_factor;
|
||||||
|
constexpr uint32_t spad_addr_O0 = smem_O0_offset / spad_addr_factor;
|
||||||
|
constexpr uint32_t spad_addr_O1 = smem_O1_offset / spad_addr_factor;
|
||||||
|
|
||||||
|
constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary
|
||||||
|
static_assert(warps_per_threadblock_per_core == NUM_WARPS);
|
||||||
|
|
||||||
|
// initialize rowmax/rowsum values in sharedmem
|
||||||
|
thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O0,
|
||||||
|
smem_rowmax_0, smem_rowsum_0, smem_O_row_scale_0);
|
||||||
|
thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O1,
|
||||||
|
smem_rowmax_1, smem_rowsum_1, smem_O_row_scale_1);
|
||||||
|
|
||||||
|
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
||||||
|
|
||||||
|
// skip everything except DMA in the loop FSM
|
||||||
|
constexpr uint32_t skips =
|
||||||
|
loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/0, /*skip_ldd=*/1,
|
||||||
|
/*skip_ex=*/1, /*skip_stc=*/1);
|
||||||
|
constexpr uint32_t skips_only_a =
|
||||||
|
loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/1, /*skip_ldd=*/1,
|
||||||
|
/*skip_ex=*/1, /*skip_stc=*/1);
|
||||||
|
constexpr uint32_t skips_only_b =
|
||||||
|
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/0, /*skip_ldd=*/1,
|
||||||
|
/*skip_ex=*/1, /*skip_stc=*/1);
|
||||||
|
constexpr uint32_t skips_mvout_spad =
|
||||||
|
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1,
|
||||||
|
/*skip_ex=*/1, /*skip_stc=*/0);
|
||||||
|
constexpr uint32_t skips_matmul =
|
||||||
|
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1,
|
||||||
|
/*skip_ex=*/0, /*skip_stc=*/0);
|
||||||
|
constexpr uint32_t skips_matmul_preload =
|
||||||
|
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/0,
|
||||||
|
/*skip_ex=*/0, /*skip_stc=*/1);
|
||||||
|
|
||||||
|
if (tid_in_warpgroup == 0) {
|
||||||
|
gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0);
|
||||||
|
|
||||||
|
// configure DMA with GMEM address strides
|
||||||
|
// Q matrix
|
||||||
|
gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY,
|
||||||
|
false, 0);
|
||||||
|
// K matrix
|
||||||
|
gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t),
|
||||||
|
MVIN_SCALE_IDENTITY, false, 1);
|
||||||
|
// configure DMA for Q*K store
|
||||||
|
gemmini_extended_config_st(B_COL * sizeof(elem_t), 0, MVIN_SCALE_IDENTITY);
|
||||||
|
gemmini_fence();
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE about barriers: Placing barriers around thread-divergent branches may
|
||||||
|
// cause bugs, because the Vortex core doesn't check for tmask for barriers.
|
||||||
|
// The compiler might decide to duplicate vx_bar into both paths of a
|
||||||
|
// conditional branch, which will get evaluated twice because of the way
|
||||||
|
// branches are handled in SIMT; this might result in stalls especially when
|
||||||
|
// other warps behave differently on the branch condition.
|
||||||
|
// threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||||
|
|
||||||
|
static_assert(B_ROW == B_COL, "currently only supports square tiles");
|
||||||
|
|
||||||
|
// move Q and K into SMEM before the loop starts
|
||||||
|
//
|
||||||
|
asm volatile("dma_move_start_%=:" ::);
|
||||||
|
if (tid_in_warpgroup == 0) {
|
||||||
|
// make sure to read from the correct row of Q
|
||||||
|
const float *gmem_Q_tile = gmem_Q + HEADDIM * B_ROW * warpgroup_id;
|
||||||
|
const float *gmem_K_tile = gmem_K;
|
||||||
|
// configure the GMEM addresses for the DMA to read from
|
||||||
|
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q_tile),
|
||||||
|
(uint64_t)(gmem_K_tile), k_LOOP_WS_CONFIG_ADDRS_AB)
|
||||||
|
// configure address strides for the DMA
|
||||||
|
GEMMINI_CISC_CMD_R((dim_seqlen << 20) | (HEADDIM << 8) |
|
||||||
|
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
|
||||||
|
gemmini_fence();
|
||||||
|
|
||||||
|
// #define GEMMINI_DMA_CISC
|
||||||
|
#ifdef GEMMINI_DMA_CISC
|
||||||
|
// the target addresses of this should match with spad_addr_Q0 and
|
||||||
|
// spad_addr_K0 set in this kernel
|
||||||
|
GEMMINI_CISC_CMD_I(10);
|
||||||
|
#else
|
||||||
|
// do DMA
|
||||||
|
//
|
||||||
|
// among other things, this also configures CONFIG_BOUNDS so that the
|
||||||
|
// DMA knows the full matrix dimensions
|
||||||
|
sp_tiled_matmul_full_spad_ws(
|
||||||
|
spad_addr_Q0, spad_addr_K0,
|
||||||
|
/*spad_D=*/0, /*spad_C=*/spad_addr_S0/*bogus*/,
|
||||||
|
/*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM),
|
||||||
|
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||||
|
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||||
|
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
|
||||||
|
#endif
|
||||||
|
gemmini_fence();
|
||||||
|
|
||||||
|
// need to also move Q to spad_addr_Q1 for the next iteration
|
||||||
|
// FIXME: re-configure necessary?
|
||||||
|
gmem_K_tile = gmem_K + (B_COL * 1);
|
||||||
|
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q_tile),
|
||||||
|
(uint64_t)(gmem_K_tile), k_LOOP_WS_CONFIG_ADDRS_AB)
|
||||||
|
GEMMINI_CISC_CMD_R((dim_seqlen << 20) | (HEADDIM << 8) |
|
||||||
|
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
|
||||||
|
gemmini_fence();
|
||||||
|
#ifdef GEMMINI_DMA_CISC
|
||||||
|
// GEMMINI_CISC_CMD_I(11);
|
||||||
|
#else
|
||||||
|
sp_tiled_matmul_full_spad_ws(
|
||||||
|
spad_addr_Q1, spad_addr_K1/*bogus*/,
|
||||||
|
/*spad_D=*/0, /*spad_C=*/spad_addr_S0/*bogus*/,
|
||||||
|
/*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM),
|
||||||
|
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||||
|
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||||
|
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
gemmini_fence();
|
||||||
|
gemmini_fence();
|
||||||
|
gemmini_fence();
|
||||||
|
gemmini_fence();
|
||||||
|
|
||||||
|
// re-configure DMA for K and V load that will later happen in the loop
|
||||||
|
// GMEM addr stride for K
|
||||||
|
gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t),
|
||||||
|
MVIN_SCALE_IDENTITY, false, 0);
|
||||||
|
// GMEM addr stride for V
|
||||||
|
gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY,
|
||||||
|
false, 1);
|
||||||
|
gemmini_fence();
|
||||||
|
}
|
||||||
|
|
||||||
|
asm volatile("dma_move_end_%=:" ::);
|
||||||
|
|
||||||
|
// protect write to SMEM
|
||||||
|
// threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||||
|
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
||||||
|
|
||||||
|
// if constexpr (DEBUG) {
|
||||||
|
// thread_block_copy_tile<B_ROW, HEADDIM>(smem_Q0, gmem_tmp_d0, tid_in_warpgroup,
|
||||||
|
// threads_per_warpgroup, warpgroup_id_in_cluster);
|
||||||
|
// thread_block_copy_tile<HEADDIM, B_COL>(smem_K0, gmem_tmp_d1, tid_in_warpgroup,
|
||||||
|
// threads_per_warpgroup, warpgroup_id_in_cluster);
|
||||||
|
// threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||||
|
// }
|
||||||
|
|
||||||
|
constexpr uint32_t threads_per_warpgroup_simt =
|
||||||
|
threads_per_warpgroup -
|
||||||
|
CORES_PER_CLUSTER * NUM_THREADS /*warp 0, 4, 8, 12*/;
|
||||||
|
constexpr uint32_t warpgroup_id_simt = 1;
|
||||||
|
constexpr uint32_t barrier_id_simt = 1;
|
||||||
|
constexpr uint32_t barrier_count_simt = NUM_WARPS - 1;
|
||||||
|
const uint32_t tid_in_warpgroup_simt =
|
||||||
|
tid_in_warpgroup - (CORES_PER_CLUSTER * NUM_THREADS);
|
||||||
|
static_assert(barrier_id_simt == 1 && barrier_count_simt == 7);
|
||||||
|
|
||||||
|
asm volatile ("tile_loop_start_%=:" :: );
|
||||||
|
|
||||||
|
// "inner loop" along the columns of K^T
|
||||||
|
const uint32_t k_tiles = (dim_seqlen / B_COL);
|
||||||
|
for (uint32_t tile_k = 0;
|
||||||
|
tile_k < (4 /*for perf measurement*/ * k_tiles) + 2 /*pipeline latency*/;
|
||||||
|
tile_k++) {
|
||||||
|
if constexpr (DEBUG || true) {
|
||||||
|
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
||||||
|
}
|
||||||
|
|
||||||
|
// select the correct double buffer by tile iteration
|
||||||
|
// all iterations work on the same Q row tile; no ping-pong necessary
|
||||||
|
asm volatile ("dbuf_sel_start_%=:" :: );
|
||||||
|
// FIXME speedup by doing arithmetic
|
||||||
|
float *smem_Q = smem_Q0;
|
||||||
|
float *smem_K_consume = (tile_k & 1) ? smem_K1 : smem_K0;
|
||||||
|
float *smem_K_produce = (tile_k & 1) ? smem_K0 : smem_K1;
|
||||||
|
float *smem_V_consume = (tile_k & 1) ? smem_V1 : smem_V0;
|
||||||
|
float *smem_V_produce = (tile_k & 1) ? smem_V0 : smem_V1;
|
||||||
|
float *smem_S_consume = (tile_k & 1) ? smem_S1 : smem_S0;
|
||||||
|
float *smem_S_produce = (tile_k & 1) ? smem_S0 : smem_S1;
|
||||||
|
float *smem_P_consume = (tile_k & 1) ? smem_P1 : smem_P0;
|
||||||
|
float *smem_P_produce = (tile_k & 1) ? smem_P0 : smem_P1;
|
||||||
|
// O, rowmax/rowsum etc. is sequentially updated at every iteration; no
|
||||||
|
// ping-pong necessary
|
||||||
|
float *smem_O = smem_O0;
|
||||||
|
float *smem_O_row_scale = smem_O_row_scale_0;
|
||||||
|
float *smem_rowmax = smem_rowmax_0;
|
||||||
|
float *smem_rowsum = smem_rowsum_0;
|
||||||
|
float *smem_scratchpad = smem_scratchpad_0;
|
||||||
|
|
||||||
|
const auto spad_addr_Q = spad_addr_Q0;
|
||||||
|
const auto spad_addr_K_consume = (tile_k & 1) ? spad_addr_K1 : spad_addr_K0;
|
||||||
|
const auto spad_addr_K_produce = (tile_k & 1) ? spad_addr_K0 : spad_addr_K1;
|
||||||
|
const auto spad_addr_V_consume = (tile_k & 1) ? spad_addr_V1 : spad_addr_V0;
|
||||||
|
const auto spad_addr_V_produce = (tile_k & 1) ? spad_addr_V0 : spad_addr_V1;
|
||||||
|
const auto spad_addr_S_consume = (tile_k & 1) ? spad_addr_S1 : spad_addr_S0;
|
||||||
|
const auto spad_addr_S_produce = (tile_k & 1) ? spad_addr_S0 : spad_addr_S1;
|
||||||
|
const auto spad_addr_P_consume = (tile_k & 1) ? spad_addr_P1 : spad_addr_P0;
|
||||||
|
const auto spad_addr_P_produce = (tile_k & 1) ? spad_addr_P0 : spad_addr_P1;
|
||||||
|
const auto spad_addr_O = spad_addr_O0; // NOTE: there's only single O tile
|
||||||
|
asm volatile ("dbuf_sel_end_%=:" :: );
|
||||||
|
|
||||||
|
if (vx_warp_id() == 0 /* warp 0 in every core */) {
|
||||||
|
if (tile_k >= 2) // delay by 2 iters for pipelining
|
||||||
|
{
|
||||||
|
const uint32_t tile_k_ = tile_k - 2;
|
||||||
|
|
||||||
|
// GEMM II: O = O + P*V
|
||||||
|
// --------------------
|
||||||
|
// This is done *before* GEMM I in the software pipeline, working on the
|
||||||
|
// online softmax result tile from the previous iteration
|
||||||
|
|
||||||
|
asm volatile("gemm_pv_start_%=:" ::);
|
||||||
|
|
||||||
|
if (tid_in_warpgroup == 0) {
|
||||||
|
#if 0
|
||||||
|
if (tile_k_ == 0) {
|
||||||
|
gemmini_fence();
|
||||||
|
GEMMINI_CISC_CMD_I(0);
|
||||||
|
} else if (tile_k_ & 1) {
|
||||||
|
gemmini_fence();
|
||||||
|
GEMMINI_CISC_CMD_I(2);
|
||||||
|
} else {
|
||||||
|
gemmini_fence();
|
||||||
|
GEMMINI_CISC_CMD_I(1);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
// kickoff matmul
|
||||||
|
// among other things, this also configures CONFIG_BOUNDS so that the
|
||||||
|
// DMA knows the full matrix dimensions
|
||||||
|
// FIXME: perf: prevent GMEM->SMEM load for O tile
|
||||||
|
gemmini_fence();
|
||||||
|
gemmini_fence();
|
||||||
|
gemmini_fence();
|
||||||
|
gemmini_fence();
|
||||||
|
sp_tiled_matmul_full_spad_ws(
|
||||||
|
spad_addr_P_consume, spad_addr_V_consume,
|
||||||
|
/*spad_D=*/spad_addr_O, /*spad_C=*/spad_addr_O,
|
||||||
|
/*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
|
||||||
|
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||||
|
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||||
|
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// // reconverge from mmio divergence
|
||||||
|
// threadblock_barrier(warpgroup_id_in_cluster,
|
||||||
|
// warps_per_warpgroup_per_core);
|
||||||
|
|
||||||
|
asm volatile("gemm_pv_finish_%=:" ::);
|
||||||
|
}
|
||||||
|
|
||||||
|
// GEMM I: S = Q*K
|
||||||
|
//
|
||||||
|
// kick off asynchronously; fence later
|
||||||
|
asm volatile("gemm_qk_start_%=:" ::);
|
||||||
|
|
||||||
|
if (tid_in_warpgroup == 0) {
|
||||||
|
// fence to GEMM II completion
|
||||||
|
gemmini_fence();
|
||||||
|
gemmini_fence();
|
||||||
|
gemmini_fence();
|
||||||
|
gemmini_fence();
|
||||||
|
|
||||||
|
#ifdef FENCE_GEMM_II
|
||||||
|
// signal that GEMM II is finished to O rescale step
|
||||||
|
*smem_O_flag = 1;
|
||||||
|
vx_fence();
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// 0,2,.: opcode 0 (quartile 0/2, no accum)
|
||||||
|
// 1,3,.: opcode 3 (quartile 1/3, no accum)
|
||||||
|
// const uint32_t opcode = 3 * (tile_k & 1);
|
||||||
|
//GEMMINI_CISC_CMD_I(opcode);
|
||||||
|
sp_tiled_matmul_full_spad_ws(
|
||||||
|
spad_addr_Q, spad_addr_K_consume,
|
||||||
|
/*spad_D=*/0, /*spad_C=*/spad_addr_S_produce,
|
||||||
|
/*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM),
|
||||||
|
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||||
|
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||||
|
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul);
|
||||||
|
|
||||||
|
// gemmini_fence();
|
||||||
|
// gemmini_fence();
|
||||||
|
// gemmini_fence();
|
||||||
|
// gemmini_fence();
|
||||||
|
asm volatile("gemm_qk_finish_%=:" ::);
|
||||||
|
|
||||||
|
// data move for K and V
|
||||||
|
//
|
||||||
|
// Q stays in SMEM for the entire loop
|
||||||
|
asm volatile("move_k_v_start_%=:" ::);
|
||||||
|
|
||||||
|
// configure GMEM addresses for K and V tiles
|
||||||
|
// load K for the next iteration
|
||||||
|
const float *gmem_K_tile = gmem_K + (B_COL * (tile_k + 1 /*runahead*/));
|
||||||
|
// load V for the *previous* iteration; this will be consumed 2
|
||||||
|
// iterations later
|
||||||
|
const float *gmem_V_tile =
|
||||||
|
gmem_V + (HEADDIM * B_COL * (tile_k - 1 /*dragbehind*/));
|
||||||
|
|
||||||
|
#if 0
|
||||||
|
// fence mvout S to SMEM
|
||||||
|
gemmini_fence();
|
||||||
|
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile),
|
||||||
|
(uint64_t)(gmem_V_tile),
|
||||||
|
k_LOOP_WS_CONFIG_ADDRS_AB)
|
||||||
|
#endif
|
||||||
|
// configure address strides for the DMA
|
||||||
|
// FIXME: unnecessary?
|
||||||
|
GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) |
|
||||||
|
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
|
||||||
|
// gemmini_fence();
|
||||||
|
|
||||||
|
// do DMA
|
||||||
|
if (tile_k == 0) {
|
||||||
|
// we load (k-1)th tile for V; skip V for the 1st iteration,
|
||||||
|
// sp_tiled_matmul_full_spad_ws(
|
||||||
|
// spad_addr_K_produce, spad_addr_V_produce,
|
||||||
|
// /*spad_D=*/0, /*spad_C=*/0,
|
||||||
|
// /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
|
||||||
|
// /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||||
|
// /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||||
|
// /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a);
|
||||||
|
} else {
|
||||||
|
sp_tiled_matmul_full_spad_ws(
|
||||||
|
spad_addr_K_produce, spad_addr_V_produce,
|
||||||
|
/*spad_D=*/0, /*spad_C=*/0,
|
||||||
|
/*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
|
||||||
|
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||||
|
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||||
|
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fence everything before going to the next tile
|
||||||
|
gemmini_fence();
|
||||||
|
gemmini_fence();
|
||||||
|
gemmini_fence();
|
||||||
|
gemmini_fence();
|
||||||
|
}
|
||||||
|
|
||||||
|
// threadblock_barrier(warpgroup_id_in_cluster,
|
||||||
|
// warps_per_warpgroup_per_core);
|
||||||
|
|
||||||
|
asm volatile("move_k_v_finish_%=:" ::);
|
||||||
|
|
||||||
|
// NOTE: cannot put barrier here; thread 1-7 in warp 0 will skip the
|
||||||
|
// branch and call this barrier earlier than when thread 0 finishes.
|
||||||
|
// Since tmask is not considered, that will be a barrier resolve done too
|
||||||
|
// early
|
||||||
|
// threadblock_barrier(0, 1);
|
||||||
|
|
||||||
|
} else /* warp_id != 0 */ {
|
||||||
|
|
||||||
|
if (tile_k >= 1) // delay by 1 iters for pipelining
|
||||||
|
{
|
||||||
|
const uint32_t tile_k_ = tile_k - 1;
|
||||||
|
|
||||||
|
if constexpr (DEBUG) {
|
||||||
|
// verify S = Q*K before softmax
|
||||||
|
if (warpgroup_id == 0) {
|
||||||
|
if (tile_k_ == 0) {
|
||||||
|
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
||||||
|
smem_S_consume, gmem_tmp_d0, tid_in_warpgroup_simt,
|
||||||
|
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||||
|
} else if (tile_k_ == 1) {
|
||||||
|
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
||||||
|
smem_S_consume, gmem_tmp_d1, tid_in_warpgroup_simt,
|
||||||
|
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||||
|
}
|
||||||
|
|
||||||
|
threadblock_barrier(barrier_id_simt, barrier_count_simt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Online softmax
|
||||||
|
//
|
||||||
|
thread_block_online_softmax</*block_row_major=*/GEMMINI_DMA>(
|
||||||
|
smem_S_consume, smem_P_produce, tid_in_warpgroup_simt,
|
||||||
|
threads_per_warpgroup_simt, warpgroup_id_simt, smem_scratchpad,
|
||||||
|
smem_rowmax, smem_rowsum, smem_O_row_scale);
|
||||||
|
|
||||||
|
threadblock_barrier(barrier_id_simt, barrier_count_simt);
|
||||||
|
|
||||||
|
if constexpr (DEBUG) {
|
||||||
|
if (warpgroup_id == 0) {
|
||||||
|
if (tile_k_ == 0) {
|
||||||
|
thread_block_copy_rowmax(
|
||||||
|
smem_rowmax, gmem_tmp_e0, tid_in_warpgroup_simt,
|
||||||
|
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||||
|
thread_block_copy_rowmax(
|
||||||
|
smem_rowsum, gmem_tmp_e2, tid_in_warpgroup_simt,
|
||||||
|
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||||
|
} else if (tile_k_ == 1) {
|
||||||
|
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1,
|
||||||
|
tid_in_warpgroup_simt, threads_per_warpgroup_simt,
|
||||||
|
warpgroup_id_simt);
|
||||||
|
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3,
|
||||||
|
tid_in_warpgroup_simt, threads_per_warpgroup_simt,
|
||||||
|
warpgroup_id_simt);
|
||||||
|
}
|
||||||
|
|
||||||
|
threadblock_barrier(barrier_id_simt, barrier_count_simt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef FENCE_GEMM_II
|
||||||
|
// check flag to make sure GEMM II finished and read-after-write
|
||||||
|
// dependency on O tile is settled for rescale
|
||||||
|
if (tid_in_warpgroup_simt == 0) {
|
||||||
|
while ((*smem_O_flag) != 1)
|
||||||
|
;
|
||||||
|
// set it back to 0 for the next tile iteration
|
||||||
|
*smem_O_flag = 0;
|
||||||
|
vx_fence();
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if 0
|
||||||
|
if (tid_in_warpgroup == 0) {
|
||||||
|
gemmini_fence();
|
||||||
|
gemmini_fence();
|
||||||
|
gemmini_fence();
|
||||||
|
gemmini_fence();
|
||||||
|
}
|
||||||
|
|
||||||
|
// reconverge from mmio divergence
|
||||||
|
threadblock_barrier(warpgroup_id_in_cluster,
|
||||||
|
warps_per_warpgroup_per_core);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
if constexpr (DEBUG) {
|
||||||
|
if (warpgroup_id == 0) {
|
||||||
|
gemmini_fence();
|
||||||
|
gemmini_fence();
|
||||||
|
|
||||||
|
// O after PV
|
||||||
|
if (tile_k_ == 1 /*wait until GEMM II finshes */) {
|
||||||
|
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
|
||||||
|
smem_O, gmem_tmp_d6, tid_in_warpgroup_simt, threads_per_warpgroup_simt,
|
||||||
|
warpgroup_id_simt);
|
||||||
|
} else if (tile_k_ == 2) {
|
||||||
|
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
|
||||||
|
smem_O, gmem_tmp_d7, tid_in_warpgroup_simt, threads_per_warpgroup_simt,
|
||||||
|
warpgroup_id_simt);
|
||||||
|
}
|
||||||
|
|
||||||
|
threadblock_barrier(barrier_id_simt, barrier_count_simt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Oi rescale
|
||||||
|
thread_block_O_rescale</*block_row_major=*/GEMMINI_DMA>(
|
||||||
|
smem_O, smem_O /*in-place*/, smem_O_row_scale,
|
||||||
|
tid_in_warpgroup_simt, threads_per_warpgroup_simt,
|
||||||
|
warpgroup_id_simt);
|
||||||
|
|
||||||
|
// rescale-to-PV-GEMM barrier
|
||||||
|
threadblock_barrier(barrier_id_simt, barrier_count_simt);
|
||||||
|
|
||||||
|
if constexpr (DEBUG) {
|
||||||
|
if (warpgroup_id == 0) {
|
||||||
|
// O before PV
|
||||||
|
if (tile_k_ == 0) {
|
||||||
|
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
||||||
|
smem_P_produce, gmem_tmp_d2, tid_in_warpgroup_simt,
|
||||||
|
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||||
|
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
|
||||||
|
smem_O, gmem_tmp_d4, tid_in_warpgroup_simt,
|
||||||
|
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||||
|
} else if (tile_k_ == 1) {
|
||||||
|
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
||||||
|
smem_P_produce, gmem_tmp_d3, tid_in_warpgroup_simt,
|
||||||
|
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||||
|
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
|
||||||
|
smem_O, gmem_tmp_d5, tid_in_warpgroup_simt,
|
||||||
|
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||||
|
}
|
||||||
|
|
||||||
|
threadblock_barrier(barrier_id_simt, barrier_count_simt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#if 0
|
||||||
|
// fence GEMM I after Oi rescale
|
||||||
|
if (tid_in_warpgroup == 0) {
|
||||||
|
gemmini_fence();
|
||||||
|
gemmini_fence();
|
||||||
|
gemmini_fence();
|
||||||
|
gemmini_fence();
|
||||||
|
}
|
||||||
|
|
||||||
|
// reconverge from mmio divergence
|
||||||
|
threadblock_barrier(warpgroup_id_in_cluster,
|
||||||
|
warps_per_warpgroup_per_core);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// intra-warpgroup barrier
|
||||||
|
threadblock_barrier(barrier_id_simt, barrier_count_simt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
asm volatile ("tile_loop_finish_%=:" :: );
|
||||||
|
}
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
kernel_arg_t *arg = (kernel_arg_t *)KERNEL_ARG_DEV_MEM_ADDR;
|
||||||
|
|
||||||
|
const uint32_t hw_threads_per_cluster =
|
||||||
|
CORES_PER_CLUSTER * vx_num_threads() * vx_num_warps();
|
||||||
|
// fix to 1 threadblock per cluster
|
||||||
|
const uint32_t grid_size = hw_threads_per_cluster;
|
||||||
|
|
||||||
|
#ifdef RADIANCE
|
||||||
|
vx_spawn_tasks_cluster(grid_size, (vx_spawn_tasks_cb)kernel_body, arg);
|
||||||
|
#else
|
||||||
|
// NOTE: This kernel assumes contiguous thread scheduling for efficient shared
|
||||||
|
// memory allocation, and therefore does not work with original vx_spawn_tasks
|
||||||
|
vx_spawn_tasks_contiguous(grid_size, (vx_spawn_tasks_cb)kernel_body, arg);
|
||||||
|
#endif
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
@@ -141,8 +141,10 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (tile_k == 0) {
|
if (tile_k == 0) {
|
||||||
|
asm volatile("cisc_start_%=:" ::);
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
GEMMINI_CISC_CMD_I(0);
|
GEMMINI_CISC_CMD_I(0);
|
||||||
|
asm volatile("cisc_end_%=:" ::);
|
||||||
} else if (tile_k & 1) {
|
} else if (tile_k & 1) {
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
GEMMINI_CISC_CMD_I(2);
|
GEMMINI_CISC_CMD_I(2);
|
||||||
@@ -218,4 +220,4 @@ int main() {
|
|||||||
vx_spawn_tasks_contiguous(grid_size, (vx_spawn_tasks_cb)kernel_body, arg);
|
vx_spawn_tasks_contiguous(grid_size, (vx_spawn_tasks_cb)kernel_body, arg);
|
||||||
#endif
|
#endif
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
#include "include/gemmini.h"
|
#include "include/gemmini.h"
|
||||||
#include "gemmini_mmio.h"
|
#include "gemmini_mmio.h"
|
||||||
|
|
||||||
constexpr bool DEBUG = true;
|
constexpr bool DEBUG = false;
|
||||||
|
|
||||||
template <uint32_t tile_dim_row, uint32_t tile_dim_col>
|
template <uint32_t tile_dim_row, uint32_t tile_dim_col>
|
||||||
inline void thread_block_copy_tile(const float *src, float *dest,
|
inline void thread_block_copy_tile(const float *src, float *dest,
|
||||||
@@ -90,19 +90,29 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
thread_block_gemm<float_type, threads_per_threadblock,
|
thread_block_gemm<float_type, threads_per_threadblock,
|
||||||
/*write_to_gmem=*/true,
|
/*write_to_gmem=*/true,
|
||||||
/*smem_a_offset=*/0,
|
/*smem_a_offset=*/0,
|
||||||
/*smem_a_dbuf_offset=*/0,
|
#ifdef GEMMINI_DMA
|
||||||
/*smem_b_offset=*/2 * BM * BK * sizeof(float),
|
/*smem_a_dbuf_offset=*/1 * 128 * 128 * sizeof(float_type),
|
||||||
/*smem_b_dbuf_offset=*/2 * BM * BK * sizeof(float)>(
|
/*smem_b_offset=*/2 * 128 * 128 * sizeof(float_type),
|
||||||
(const float_type *)arg->addr_a, (const float_type *)arg->addr_b,
|
/*smem_b_dbuf_offset=*/3 * 128 * 128 * sizeof(float_type)
|
||||||
(float *)arg->addr_c, arg->dim_m, arg->dim_n, arg->dim_k,
|
// FIXME: above offsets are hardcoded to agree with CISC
|
||||||
tid_in_threadblock, threadblocks_per_cluster, threadblock_id_in_cluster,
|
// spadQuartile
|
||||||
sharedmem_per_threadblock);
|
#else
|
||||||
|
/*smem_a_dbuf_offset=*/1 * BM * BK * sizeof(float_type),
|
||||||
|
/*smem_b_offset=*/2 * BM * BK * sizeof(float_type),
|
||||||
|
/*smem_b_dbuf_offset=*/(2 * BM * BK + BK * BN) * sizeof(float_type)
|
||||||
|
#endif
|
||||||
|
>((const float_type *)arg->addr_a,
|
||||||
|
(const float_type *)arg->addr_b, (float *)arg->addr_c,
|
||||||
|
arg->dim_m, arg->dim_n, arg->dim_k, tid_in_threadblock,
|
||||||
|
threadblocks_per_cluster, threadblock_id_in_cluster,
|
||||||
|
sharedmem_per_threadblock);
|
||||||
|
|
||||||
float *gmem_tmp_d0 = reinterpret_cast<float *>(0xd0000000UL);
|
float *gmem_tmp_d0 = reinterpret_cast<float *>(0xd0000000UL);
|
||||||
float *gmem_tmp_d1 = reinterpret_cast<float *>(0xd1000000UL);
|
float *gmem_tmp_d1 = reinterpret_cast<float *>(0xd1000000UL);
|
||||||
|
|
||||||
const float *smem_A = reinterpret_cast<float *>(sharedmem_per_threadblock);
|
const float *smem_A = reinterpret_cast<float *>(sharedmem_per_threadblock);
|
||||||
const float *smem_B = smem_A + 2 * BM * BK;
|
const float *smem_B = reinterpret_cast<float *>(
|
||||||
|
sharedmem_per_threadblock + 2 * BM * BK * sizeof(float_type));
|
||||||
|
|
||||||
if constexpr (DEBUG) {
|
if constexpr (DEBUG) {
|
||||||
threadblock_barrier(threadblock_id_in_cluster,
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ using float_type = float16_t;
|
|||||||
// (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER
|
// (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER
|
||||||
// * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields
|
// * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields
|
||||||
// BM <= BK*TM*TN
|
// BM <= BK*TM*TN
|
||||||
#define BM 128
|
#define BM ((NUM_CORES == 8) ? 128 : 64)
|
||||||
#define BN 64
|
#define BN 64
|
||||||
#if (FP_SIZE == 32)
|
#if (FP_SIZE == 32)
|
||||||
#define BK 64
|
#define BK 64
|
||||||
@@ -62,8 +62,8 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER ==
|
|||||||
#define BK_LOOP 1
|
#define BK_LOOP 1
|
||||||
// Whether to transpose smem A tile at GMEM->SMEM (produce), or SMEM->RF
|
// Whether to transpose smem A tile at GMEM->SMEM (produce), or SMEM->RF
|
||||||
// (consume). This is because the tensor core expects the A tile to be stored
|
// (consume). This is because the tensor core expects the A tile to be stored
|
||||||
// in column-major order in SMEM, whereas it will be ultimately stored in
|
// in column-major order in SMEM, so a transpose is necessary if A was stored
|
||||||
// row-major in the RF.
|
// row-major in GMEM.
|
||||||
//
|
//
|
||||||
// For correctness, only one of either should be 1. E.g., PRODUCE 1 CONSUME 0
|
// For correctness, only one of either should be 1. E.g., PRODUCE 1 CONSUME 0
|
||||||
// generates the NN kernel where both A and B are stored row-major in GMEM.
|
// generates the NN kernel where both A and B are stored row-major in GMEM.
|
||||||
@@ -72,8 +72,8 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER ==
|
|||||||
#define TRANSPOSE_AT_PRODUCE 0
|
#define TRANSPOSE_AT_PRODUCE 0
|
||||||
#define TRANSPOSE_AT_CONSUME 0
|
#define TRANSPOSE_AT_CONSUME 0
|
||||||
|
|
||||||
#define GEMMINI_DMA 0
|
#define GEMMINI_DMA 1
|
||||||
#define GEMMINI_DMA_MN_MAJOR 1
|
#define GEMMINI_DMA_FAST 1
|
||||||
#if SMEM_SIZE == 0x4000
|
#if SMEM_SIZE == 0x4000
|
||||||
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
|
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
|
||||||
#define SMEM_ADDR_Q1 ((float * const) 0xff001000)
|
#define SMEM_ADDR_Q1 ((float * const) 0xff001000)
|
||||||
@@ -101,6 +101,7 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER ==
|
|||||||
enum class MemLayout {
|
enum class MemLayout {
|
||||||
MN_major,
|
MN_major,
|
||||||
K_major,
|
K_major,
|
||||||
|
block_row_major, // Gemmini DMA
|
||||||
};
|
};
|
||||||
|
|
||||||
inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) {
|
inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) {
|
||||||
@@ -206,7 +207,7 @@ template <bool use_dma, uint32_t dim_col>
|
|||||||
inline constexpr std::pair<uint32_t, uint32_t>
|
inline constexpr std::pair<uint32_t, uint32_t>
|
||||||
remap_to_gemmini_dma_layout(const uint32_t logical_row,
|
remap_to_gemmini_dma_layout(const uint32_t logical_row,
|
||||||
const uint32_t logical_col) {
|
const uint32_t logical_col) {
|
||||||
static_assert(DIM == 8,
|
static_assert(!use_dma || DIM == 8,
|
||||||
"GEMMINI_DMA layout remapping code only written for DIM == 8");
|
"GEMMINI_DMA layout remapping code only written for DIM == 8");
|
||||||
|
|
||||||
if constexpr (use_dma) {
|
if constexpr (use_dma) {
|
||||||
@@ -253,13 +254,11 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
|||||||
constexpr int packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
|
constexpr int packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
|
||||||
const int local_k_adjusted = local_k / packed_factor;
|
const int local_k_adjusted = local_k / packed_factor;
|
||||||
|
|
||||||
static_assert(!GEMMINI_DMA || (layout == MemLayout::K_major) ||
|
|
||||||
GEMMINI_DMA_MN_MAJOR,
|
|
||||||
"GEMMINI_DMA only supported for K-major A tile");
|
|
||||||
static_assert((layout != MemLayout::K_major) || (FP_SIZE == 32),
|
static_assert((layout != MemLayout::K_major) || (FP_SIZE == 32),
|
||||||
"fp16 is not really tested for K-major A layout");
|
"fp16 is not really tested for K-major A layout");
|
||||||
|
|
||||||
if constexpr (layout == MemLayout::K_major) {
|
if constexpr (layout == MemLayout::K_major ||
|
||||||
|
layout == MemLayout::block_row_major) {
|
||||||
constexpr int smem_A_cols = leading_dim;
|
constexpr int smem_A_cols = leading_dim;
|
||||||
|
|
||||||
// f8-f15 stores a single row of A
|
// f8-f15 stores a single row of A
|
||||||
@@ -269,8 +268,9 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
|||||||
// if using Gemmini DMA, remap logical row/col to Gemmini's 2-level
|
// if using Gemmini DMA, remap logical row/col to Gemmini's 2-level
|
||||||
// block-row-major layout
|
// block-row-major layout
|
||||||
const auto [smem_row, smem_col] =
|
const auto [smem_row, smem_col] =
|
||||||
remap_to_gemmini_dma_layout<GEMMINI_DMA, smem_A_cols>(smem_logical_row,
|
remap_to_gemmini_dma_layout<layout == MemLayout::block_row_major,
|
||||||
smem_logical_col);
|
smem_A_cols>(smem_logical_row,
|
||||||
|
smem_logical_col);
|
||||||
|
|
||||||
const volatile uint8_t *smem_addr;
|
const volatile uint8_t *smem_addr;
|
||||||
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
||||||
@@ -356,8 +356,9 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k,
|
|||||||
const int thread_in_warp) {
|
const int thread_in_warp) {
|
||||||
asm volatile ("wmma_load_b_start_%=:" :: );
|
asm volatile ("wmma_load_b_start_%=:" :: );
|
||||||
|
|
||||||
static_assert(layout == MemLayout::MN_major,
|
static_assert(
|
||||||
"only N-major layout for the B tile is supported");
|
layout == MemLayout::MN_major || layout == MemLayout::block_row_major,
|
||||||
|
"only N-major or block-row-major layout are supported for the B tile");
|
||||||
|
|
||||||
const int tid = thread_in_warp;
|
const int tid = thread_in_warp;
|
||||||
const int tg = tid / 4;
|
const int tg = tid / 4;
|
||||||
@@ -379,8 +380,9 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k,
|
|||||||
// if using Gemmini DMA, remap logical row/col to Gemmini's 2-level
|
// if using Gemmini DMA, remap logical row/col to Gemmini's 2-level
|
||||||
// block-row-major layout
|
// block-row-major layout
|
||||||
const auto [smem_row, smem_col] =
|
const auto [smem_row, smem_col] =
|
||||||
remap_to_gemmini_dma_layout<GEMMINI_DMA, smem_B_cols>(smem_logical_row,
|
remap_to_gemmini_dma_layout<layout == MemLayout::block_row_major,
|
||||||
smem_logical_col);
|
smem_B_cols>(smem_logical_row,
|
||||||
|
smem_logical_col);
|
||||||
|
|
||||||
const volatile uint8_t *smem_addr;
|
const volatile uint8_t *smem_addr;
|
||||||
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
||||||
@@ -388,10 +390,10 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k,
|
|||||||
smem_B)[smem_B_cols * smem_row + smem_col]);
|
smem_B)[smem_B_cols * smem_row + smem_col]);
|
||||||
// f8-f15 stores a single column of B
|
// f8-f15 stores a single column of B
|
||||||
// threads read from different columns; no bank conflicts
|
// threads read from different columns; no bank conflicts
|
||||||
if constexpr (GEMMINI_DMA) {
|
if constexpr (layout == MemLayout::block_row_major) {
|
||||||
// for GEMMINI_DMA, moving rows for the next 7 elements in the same column
|
// for the block-row-major layout, moving rows for the next 7 elements in
|
||||||
// is the same as moving DIM elements forward in the memory because of the
|
// the same column is the same as moving DIM elements forward in the memory
|
||||||
// block-row-major layout
|
// because of the block-row-major layout
|
||||||
asm volatile("flw f8, %0(%1)" :: "i"(DIM * 0 * sizeof(float)), "r"(smem_addr));
|
asm volatile("flw f8, %0(%1)" :: "i"(DIM * 0 * sizeof(float)), "r"(smem_addr));
|
||||||
asm volatile("flw f9, %0(%1)" :: "i"(DIM * 1 * sizeof(float)), "r"(smem_addr));
|
asm volatile("flw f9, %0(%1)" :: "i"(DIM * 1 * sizeof(float)), "r"(smem_addr));
|
||||||
asm volatile("flw f10, %0(%1)" :: "i"(DIM * 2 * sizeof(float)), "r"(smem_addr));
|
asm volatile("flw f10, %0(%1)" :: "i"(DIM * 2 * sizeof(float)), "r"(smem_addr));
|
||||||
@@ -533,7 +535,9 @@ wmma_store(const int thread_in_warp, const int warp_col, const int warp_row,
|
|||||||
asm volatile ("wmma_store_finish_%=:" :: );
|
asm volatile ("wmma_store_finish_%=:" :: );
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count) {
|
__attribute__((convergent)) inline void
|
||||||
|
threadblock_barrier(const uint32_t barrier_id, const uint32_t count) {
|
||||||
|
asm volatile("" ::: "memory");
|
||||||
vx_fence();
|
vx_fence();
|
||||||
vx_barrier(barrier_id, count);
|
vx_barrier(barrier_id, count);
|
||||||
}
|
}
|
||||||
@@ -818,6 +822,10 @@ __attribute__((always_inline)) inline void thread_block_gemm_single_tile(
|
|||||||
if (tid_in_threadblock == 0) {
|
if (tid_in_threadblock == 0) {
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// reconverge after mmio
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (write_to_mem) {
|
if constexpr (write_to_mem) {
|
||||||
@@ -845,9 +853,9 @@ template <
|
|||||||
uint32_t smem_a_dbuf_offset = 0, // byte offset of A
|
uint32_t smem_a_dbuf_offset = 0, // byte offset of A
|
||||||
// double-buffer tile in shared
|
// double-buffer tile in shared
|
||||||
// memory
|
// memory
|
||||||
uint32_t smem_b_offset = sizeof(float) * BM * BK, // byte offset of B tile
|
uint32_t smem_b_offset = sizeof(T) * BM * BK, // byte offset of B tile
|
||||||
// in shared memory
|
// in shared memory
|
||||||
uint32_t smem_b_dbuf_offset = sizeof(float) * BM *
|
uint32_t smem_b_dbuf_offset = sizeof(T) * BM *
|
||||||
BK // byte offset of B double-buffer
|
BK // byte offset of B double-buffer
|
||||||
// tile in shared memory
|
// tile in shared memory
|
||||||
>
|
>
|
||||||
@@ -903,6 +911,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
for (uint32_t block_m = block_m_start; block_m < block_m_end; block_m++) {
|
for (uint32_t block_m = block_m_start; block_m < block_m_end; block_m++) {
|
||||||
#pragma GCC unroll 1
|
#pragma GCC unroll 1
|
||||||
for (uint32_t block_n = 0; (block_n * BN) < dim_n; block_n++) {
|
for (uint32_t block_n = 0; (block_n * BN) < dim_n; block_n++) {
|
||||||
|
asm volatile ("loop_mn_start_%=:" :: );
|
||||||
|
|
||||||
// clear out accumulators
|
// clear out accumulators
|
||||||
initialize_accum_regs<0>();
|
initialize_accum_regs<0>();
|
||||||
initialize_accum_regs<1>();
|
initialize_accum_regs<1>();
|
||||||
@@ -911,14 +921,13 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
// pipeline initiation
|
// pipeline initiation
|
||||||
if (tid_in_threadblock == 0) {
|
if (tid_in_threadblock == 0) {
|
||||||
// configure dma gmem address to load from
|
// configure dma gmem address to load from
|
||||||
// FIXME: block_k is wrong
|
|
||||||
ROCC_INSTRUCTION_RS1_RS2(
|
ROCC_INSTRUCTION_RS1_RS2(
|
||||||
XCUSTOM_ACC,
|
XCUSTOM_ACC,
|
||||||
(uint64_t)(A + block_m * BM * dim_k + /*block_k:*/0 * BK),
|
(uint64_t)(A + block_m * BM * dim_k + /*block_k:*/0 * BK),
|
||||||
(uint64_t)(B + /*block_k:*/0 * BK * dim_n + block_n * BN),
|
(uint64_t)(B + /*block_k:*/0 * BK * dim_n + block_n * BN),
|
||||||
k_LOOP_WS_CONFIG_ADDRS_AB)
|
k_LOOP_WS_CONFIG_ADDRS_AB)
|
||||||
// GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB
|
// GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB
|
||||||
GEMMINI_CISC_CMD_R((dim_n << 16) | (dim_k << 8) | 8);
|
GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8);
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
|
|
||||||
GEMMINI_CISC_CMD_I(10);
|
GEMMINI_CISC_CMD_I(10);
|
||||||
@@ -949,6 +958,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
|
|
||||||
#pragma GCC unroll 1
|
#pragma GCC unroll 1
|
||||||
for (uint32_t block_k = 0; (block_k * BK) < dim_k; block_k++) {
|
for (uint32_t block_k = 0; (block_k * BK) < dim_k; block_k++) {
|
||||||
|
asm volatile("loop_k_start_%=:" ::);
|
||||||
|
|
||||||
// producer code: GMEM->SMEM memory movement
|
// producer code: GMEM->SMEM memory movement
|
||||||
// ---------------------------------------------------------------------
|
// ---------------------------------------------------------------------
|
||||||
@@ -958,20 +968,19 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
#if (GEMMINI_DMA == 1)
|
#if (GEMMINI_DMA == 1)
|
||||||
if ((tid_in_threadblock == 0) && ((block_k * BK) != (dim_k - BK))) {
|
if ((tid_in_threadblock == 0) && ((block_k * BK) != (dim_k - BK))) {
|
||||||
// configure dma gmem address to load from
|
// configure dma gmem address to load from
|
||||||
// FIXME: block_k is wrong
|
|
||||||
ROCC_INSTRUCTION_RS1_RS2(
|
ROCC_INSTRUCTION_RS1_RS2(
|
||||||
XCUSTOM_ACC,
|
XCUSTOM_ACC,
|
||||||
(uint64_t)(A + block_m * BM * dim_k + (block_k + 1/*runahead*/) * BK),
|
(uint64_t)(A + block_m * BM * dim_k + (block_k + 1/*runahead*/) * BK),
|
||||||
(uint64_t)(B + (block_k + 1/*runahead*/) * BK * dim_n + block_n * BN),
|
(uint64_t)(B + (block_k + 1/*runahead*/) * BK * dim_n + block_n * BN),
|
||||||
k_LOOP_WS_CONFIG_ADDRS_AB)
|
k_LOOP_WS_CONFIG_ADDRS_AB)
|
||||||
// GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB
|
// GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB
|
||||||
GEMMINI_CISC_CMD_R((dim_n << 16) | (dim_k << 8) | 8);
|
GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8);
|
||||||
// gemmini_fence();
|
gemmini_fence();
|
||||||
|
|
||||||
// block_k is even: opcode 11 (write to local_a_buf)
|
// block_k is even: opcode 11 (write to local_a_buf)
|
||||||
// block_k is odd: opcode 10 (write to local_a)
|
// block_k is odd: opcode 10 (write to local_a)
|
||||||
const uint32_t opcode = 11 - (block_k & 1);
|
const uint32_t opcode = 11 - (block_k & 1);
|
||||||
GEMMINI_CISC_CMD_R(opcode);
|
GEMMINI_CISC_CMD_I(opcode);
|
||||||
// // TODO: branch is probably slow
|
// // TODO: branch is probably slow
|
||||||
// if (block_k & 1) {
|
// if (block_k & 1) {
|
||||||
// GEMMINI_CISC_CMD_I(12);
|
// GEMMINI_CISC_CMD_I(12);
|
||||||
@@ -1017,6 +1026,10 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips)
|
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips)
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// reconverge after mmio divergence
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
#else
|
#else
|
||||||
// move A
|
// move A
|
||||||
if constexpr (!TRANSPOSE_AT_PRODUCE) {
|
if constexpr (!TRANSPOSE_AT_PRODUCE) {
|
||||||
@@ -1038,10 +1051,10 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
warps_per_threadblock_per_core);
|
warps_per_threadblock_per_core);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if 0
|
|
||||||
// consumer code: SMEM->RF and compute
|
// consumer code: SMEM->RF and compute
|
||||||
// ----------------------------------------------------------------------
|
// ----------------------------------------------------------------------
|
||||||
// @perf: this loop spills to stack a lot because of all the flws in
|
// @perf: this loop spills to stack a lot because of all the flws in
|
||||||
|
asm volatile("dbuf_sel_start_%=:" ::);
|
||||||
const T *local_a_consume;
|
const T *local_a_consume;
|
||||||
const T *local_b_consume;
|
const T *local_b_consume;
|
||||||
if constexpr (GEMMINI_DMA) {
|
if constexpr (GEMMINI_DMA) {
|
||||||
@@ -1056,17 +1069,29 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
// local_b_consume = reinterpret_cast<T *>(
|
// local_b_consume = reinterpret_cast<T *>(
|
||||||
// (mask_odd & reinterpret_cast<uintmax_t>(local_b_buf)) |
|
// (mask_odd & reinterpret_cast<uintmax_t>(local_b_buf)) |
|
||||||
// (mask_even & reinterpret_cast<uintmax_t>(local_b)));
|
// (mask_even & reinterpret_cast<uintmax_t>(local_b)));
|
||||||
local_a_consume = local_a + (block_k & 1) * (BM * BK);
|
local_a_consume = local_a + (block_k & 1) *
|
||||||
local_b_consume = local_b + (block_k & 1) * (BK * BN);
|
(smem_a_dbuf_offset - smem_a_offset) /
|
||||||
|
sizeof(T);
|
||||||
|
local_b_consume = local_b + (block_k & 1) *
|
||||||
|
(smem_b_dbuf_offset - smem_b_offset) /
|
||||||
|
sizeof(T);
|
||||||
} else {
|
} else {
|
||||||
// no double-buffering without DMA
|
// no double-buffering without DMA
|
||||||
local_a_consume = local_a;
|
local_a_consume = local_a;
|
||||||
local_b_consume = local_b;
|
local_b_consume = local_b;
|
||||||
}
|
}
|
||||||
|
asm volatile("dbuf_sel_end_%=:" ::);
|
||||||
|
|
||||||
constexpr MemLayout layout_a =
|
constexpr MemLayout layout_a =
|
||||||
TRANSPOSE_AT_CONSUME ? MemLayout::K_major : MemLayout::MN_major;
|
GEMMINI_DMA ? (GEMMINI_DMA_FAST ? MemLayout::MN_major
|
||||||
thread_block_gemm_single_tile<T, layout_a, MemLayout::MN_major,
|
: MemLayout::block_row_major)
|
||||||
|
: (TRANSPOSE_AT_CONSUME ? MemLayout::K_major
|
||||||
|
: MemLayout::MN_major);
|
||||||
|
constexpr MemLayout layout_b =
|
||||||
|
GEMMINI_DMA ? (GEMMINI_DMA_FAST ? MemLayout::MN_major
|
||||||
|
: MemLayout::block_row_major)
|
||||||
|
: MemLayout::MN_major;
|
||||||
|
thread_block_gemm_single_tile<T, layout_a, layout_b,
|
||||||
BM, BN, BK, 0, 0,
|
BM, BN, BK, 0, 0,
|
||||||
/*load_accum=*/false,
|
/*load_accum=*/false,
|
||||||
/*write_to_mem=*/false>(
|
/*write_to_mem=*/false>(
|
||||||
@@ -1087,7 +1112,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
|
|
||||||
threadblock_barrier(threadblock_id_in_cluster,
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
warps_per_threadblock_per_core);
|
warps_per_threadblock_per_core);
|
||||||
#endif
|
|
||||||
|
asm volatile("loop_k_end_%=:" ::);
|
||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (write_to_gmem) {
|
if constexpr (write_to_gmem) {
|
||||||
@@ -1102,6 +1128,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
asm volatile("loop_mn_end_%=:" ::);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user