sgemm: Specify A/B tile SMEM address via template args
& split single-time GEMM into a separate function.
This commit is contained in:
@@ -235,7 +235,6 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k,
|
||||
|
||||
// int A_offset = (WM * warp_row + TCM * wm_iter + row) * smem_A_cols;
|
||||
|
||||
// @perf: bank conflicts
|
||||
// f8-f15 stores a single row of A
|
||||
const volatile uint8_t *smem_addr;
|
||||
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
||||
@@ -243,7 +242,7 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k,
|
||||
smem_A)[(WM * warp_row + TCM * wm_iter + row) * smem_A_cols +
|
||||
local_k /* FIXME: adjust for fp16? */]);
|
||||
// step to the next column
|
||||
// threads read from different rows; bank conflicts
|
||||
// @perf: bank conflicts; threads read from different rows
|
||||
asm volatile("flw f0, %0(%1)" ::"i"(0 * sizeof(float)), "r"(smem_addr));
|
||||
asm volatile("flw f1, %0(%1)" ::"i"(1 * sizeof(float)), "r"(smem_addr));
|
||||
asm volatile("flw f2, %0(%1)" ::"i"(2 * sizeof(float)), "r"(smem_addr));
|
||||
@@ -408,6 +407,7 @@ inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count)
|
||||
|
||||
// TODO: reduce args by passing leading A/B dimensions
|
||||
template <typename T>
|
||||
__attribute__((always_inline))
|
||||
inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const uint32_t dim_k,
|
||||
const uint32_t k, const T *A, const T *B,
|
||||
volatile T *local_a, volatile T *local_b,
|
||||
@@ -646,10 +646,66 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
|
||||
asm volatile ("global_dmem_load_finish_%=:" :: );
|
||||
}
|
||||
|
||||
template <typename T, bool write_to_gmem = true>
|
||||
// Do a single tile*tile matrix multiplication using the matrix data stored in
|
||||
// SMEM. Useful in fused kernels where GEMMs are done at a per-tile scope.
|
||||
template <typename T>
|
||||
__attribute__((always_inline)) inline void
|
||||
thread_block_gemm_single_tile(const T *local_a, const T *local_b,
|
||||
const uint32_t tid_in_threadblock,
|
||||
const uint32_t threads_per_threadblock) {
|
||||
// no double-buffering
|
||||
// FIXME: duplicated from thread_block_gemm
|
||||
const uint32_t threads_per_warpgroup = threads_per_threadblock;
|
||||
const uint32_t warp_id_in_warpgroup = tid_in_threadblock / NUM_THREADS;
|
||||
const uint32_t warp_row = warp_id_in_warpgroup / (BN / WN);
|
||||
const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN);
|
||||
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
||||
|
||||
#pragma GCC unroll 1
|
||||
for (int i = 0; i < BK_LOOP; i++) {
|
||||
#pragma GCC unroll 4
|
||||
for (uint32_t local_k = 0; local_k < BK; local_k += TCK) {
|
||||
#pragma GCC unroll 2
|
||||
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
||||
// SMEM -> RF
|
||||
vx_wmma_load_b<T>(local_b, local_k, warp_col, wn_iter, tid_in_warp);
|
||||
#pragma GCC unroll 2
|
||||
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
||||
// SMEM -> RF
|
||||
vx_wmma_load_a<T>(local_a, local_k, warp_row, wm_iter, tid_in_warp);
|
||||
// perform mma
|
||||
vx_wmma(wm_iter);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (GEMMINI_DMA) {
|
||||
// Call gemmini fence at the end of the loop to overlap dma & wmma.
|
||||
// Usually, by this time, dma has finished the copy so that this
|
||||
// becomes a no-op.
|
||||
if (tid_in_threadblock == 0) {
|
||||
gemmini_fence();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool write_to_gmem = true,
|
||||
// by default, A/B tiles are placed at the start of the smem
|
||||
uint32_t smem_a_offset = 0, // byte offset of A tile in shared
|
||||
// memory
|
||||
uint32_t smem_a_dbuf_offset = 0, // byte offset of A
|
||||
// double-buffer tile in shared
|
||||
// memory
|
||||
uint32_t smem_b_offset = sizeof(float) * BM *
|
||||
BK, // byte offset of B tile
|
||||
// in shared memory
|
||||
uint32_t smem_b_dbuf_offset = sizeof(float) * BM *
|
||||
BK // byte offset of B double-buffer
|
||||
// tile in shared memory
|
||||
>
|
||||
inline void thread_block_gemm(const T *A, const T *B, float *C,
|
||||
const uint32_t dim_m,
|
||||
const uint32_t dim_n,
|
||||
const uint32_t dim_m, const uint32_t dim_n,
|
||||
const uint32_t dim_k,
|
||||
const uint32_t tid_in_threadblock,
|
||||
const uint32_t threads_per_threadblock,
|
||||
@@ -672,13 +728,14 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
||||
const uint32_t warps_per_threadblock_per_core =
|
||||
NUM_WARPS / threads_per_threadblock;
|
||||
|
||||
volatile T *local_a = reinterpret_cast<T *>(sharedmem_per_threadblock);
|
||||
constexpr size_t local_a_elems = (BM * BK);
|
||||
volatile T *local_a_buf = local_a + local_a_elems;
|
||||
|
||||
volatile T *local_b = local_a_buf + local_a_elems;
|
||||
constexpr size_t local_b_elems = (BK * BN);
|
||||
volatile T *local_b_buf = local_a_buf + local_b_elems;
|
||||
volatile T *local_a =
|
||||
reinterpret_cast<T *>(sharedmem_per_threadblock + smem_a_offset);
|
||||
volatile T *local_a_buf =
|
||||
reinterpret_cast<T *>(sharedmem_per_threadblock + smem_a_dbuf_offset);
|
||||
volatile T *local_b =
|
||||
reinterpret_cast<T *>(sharedmem_per_threadblock + smem_b_offset);
|
||||
volatile T *local_b_buf =
|
||||
reinterpret_cast<T *>(sharedmem_per_threadblock + smem_b_dbuf_offset);
|
||||
|
||||
constexpr uint32_t skips =
|
||||
loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/0, /*skip_ldd=*/1,
|
||||
@@ -849,34 +906,17 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
||||
// local_b_consume = reinterpret_cast<volatile T *>(
|
||||
// (mask_odd & reinterpret_cast<uintmax_t>(local_b_buf)) |
|
||||
// (mask_even & reinterpret_cast<uintmax_t>(local_b)));
|
||||
local_a_consume = local_a + (block_k & 1) * (local_a_elems);
|
||||
local_b_consume = local_b + (block_k & 1) * (local_b_elems);
|
||||
local_a_consume = local_a + (block_k & 1) * (BM * BK);
|
||||
local_b_consume = local_b + (block_k & 1) * (BK * BN);
|
||||
} else {
|
||||
// no double-buffering without DMA
|
||||
local_a_consume = local_a;
|
||||
local_b_consume = local_b;
|
||||
}
|
||||
|
||||
#pragma GCC unroll 1
|
||||
for (int i = 0; i < BK_LOOP; i++) {
|
||||
#pragma GCC unroll 4
|
||||
for (uint32_t local_k = 0; local_k < BK; local_k += TCK) {
|
||||
#pragma GCC unroll 2
|
||||
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
||||
// SMEM -> RF
|
||||
vx_wmma_load_b<T>(local_b_consume, local_k, warp_col, wn_iter,
|
||||
tid_in_warp);
|
||||
#pragma GCC unroll 2
|
||||
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
||||
// SMEM -> RF
|
||||
vx_wmma_load_a<T>(local_a_consume, local_k, warp_row, wm_iter,
|
||||
tid_in_warp);
|
||||
// perform mma
|
||||
vx_wmma(wm_iter);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
thread_block_gemm_single_tile(local_a_consume, local_b_consume,
|
||||
tid_in_threadblock,
|
||||
threads_per_threadblock);
|
||||
|
||||
if constexpr (GEMMINI_DMA) {
|
||||
// Call gemmini fence at the end of the loop to overlap dma & wmma.
|
||||
|
||||
Reference in New Issue
Block a user