sgemm: Specify A/B tile SMEM address via template args

& split single-time GEMM into a separate function.
This commit is contained in:
Hansung Kim
2024-08-16 16:27:35 -07:00
parent 64b9717064
commit d0809d292a

View File

@@ -235,7 +235,6 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k,
// int A_offset = (WM * warp_row + TCM * wm_iter + row) * smem_A_cols;
// @perf: bank conflicts
// f8-f15 stores a single row of A
const volatile uint8_t *smem_addr;
smem_addr = reinterpret_cast<const volatile uint8_t *>(
@@ -243,7 +242,7 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k,
smem_A)[(WM * warp_row + TCM * wm_iter + row) * smem_A_cols +
local_k /* FIXME: adjust for fp16? */]);
// step to the next column
// threads read from different rows; bank conflicts
// @perf: bank conflicts; threads read from different rows
asm volatile("flw f0, %0(%1)" ::"i"(0 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f1, %0(%1)" ::"i"(1 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f2, %0(%1)" ::"i"(2 * sizeof(float)), "r"(smem_addr));
@@ -408,6 +407,7 @@ inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count)
// TODO: reduce args by passing leading A/B dimensions
template <typename T>
__attribute__((always_inline))
inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const uint32_t dim_k,
const uint32_t k, const T *A, const T *B,
volatile T *local_a, volatile T *local_b,
@@ -646,10 +646,66 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
asm volatile ("global_dmem_load_finish_%=:" :: );
}
template <typename T, bool write_to_gmem = true>
// Do a single tile*tile matrix multiplication using the matrix data stored in
// SMEM. Useful in fused kernels where GEMMs are done at a per-tile scope.
template <typename T>
__attribute__((always_inline)) inline void
thread_block_gemm_single_tile(const T *local_a, const T *local_b,
const uint32_t tid_in_threadblock,
const uint32_t threads_per_threadblock) {
// no double-buffering
// FIXME: duplicated from thread_block_gemm
const uint32_t threads_per_warpgroup = threads_per_threadblock;
const uint32_t warp_id_in_warpgroup = tid_in_threadblock / NUM_THREADS;
const uint32_t warp_row = warp_id_in_warpgroup / (BN / WN);
const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN);
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
#pragma GCC unroll 1
for (int i = 0; i < BK_LOOP; i++) {
#pragma GCC unroll 4
for (uint32_t local_k = 0; local_k < BK; local_k += TCK) {
#pragma GCC unroll 2
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
// SMEM -> RF
vx_wmma_load_b<T>(local_b, local_k, warp_col, wn_iter, tid_in_warp);
#pragma GCC unroll 2
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
// SMEM -> RF
vx_wmma_load_a<T>(local_a, local_k, warp_row, wm_iter, tid_in_warp);
// perform mma
vx_wmma(wm_iter);
}
}
}
}
if constexpr (GEMMINI_DMA) {
// Call gemmini fence at the end of the loop to overlap dma & wmma.
// Usually, by this time, dma has finished the copy so that this
// becomes a no-op.
if (tid_in_threadblock == 0) {
gemmini_fence();
}
}
}
template <typename T, bool write_to_gmem = true,
// by default, A/B tiles are placed at the start of the smem
uint32_t smem_a_offset = 0, // byte offset of A tile in shared
// memory
uint32_t smem_a_dbuf_offset = 0, // byte offset of A
// double-buffer tile in shared
// memory
uint32_t smem_b_offset = sizeof(float) * BM *
BK, // byte offset of B tile
// in shared memory
uint32_t smem_b_dbuf_offset = sizeof(float) * BM *
BK // byte offset of B double-buffer
// tile in shared memory
>
inline void thread_block_gemm(const T *A, const T *B, float *C,
const uint32_t dim_m,
const uint32_t dim_n,
const uint32_t dim_m, const uint32_t dim_n,
const uint32_t dim_k,
const uint32_t tid_in_threadblock,
const uint32_t threads_per_threadblock,
@@ -672,13 +728,14 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
const uint32_t warps_per_threadblock_per_core =
NUM_WARPS / threads_per_threadblock;
volatile T *local_a = reinterpret_cast<T *>(sharedmem_per_threadblock);
constexpr size_t local_a_elems = (BM * BK);
volatile T *local_a_buf = local_a + local_a_elems;
volatile T *local_b = local_a_buf + local_a_elems;
constexpr size_t local_b_elems = (BK * BN);
volatile T *local_b_buf = local_a_buf + local_b_elems;
volatile T *local_a =
reinterpret_cast<T *>(sharedmem_per_threadblock + smem_a_offset);
volatile T *local_a_buf =
reinterpret_cast<T *>(sharedmem_per_threadblock + smem_a_dbuf_offset);
volatile T *local_b =
reinterpret_cast<T *>(sharedmem_per_threadblock + smem_b_offset);
volatile T *local_b_buf =
reinterpret_cast<T *>(sharedmem_per_threadblock + smem_b_dbuf_offset);
constexpr uint32_t skips =
loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/0, /*skip_ldd=*/1,
@@ -849,34 +906,17 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
// local_b_consume = reinterpret_cast<volatile T *>(
// (mask_odd & reinterpret_cast<uintmax_t>(local_b_buf)) |
// (mask_even & reinterpret_cast<uintmax_t>(local_b)));
local_a_consume = local_a + (block_k & 1) * (local_a_elems);
local_b_consume = local_b + (block_k & 1) * (local_b_elems);
local_a_consume = local_a + (block_k & 1) * (BM * BK);
local_b_consume = local_b + (block_k & 1) * (BK * BN);
} else {
// no double-buffering without DMA
local_a_consume = local_a;
local_b_consume = local_b;
}
#pragma GCC unroll 1
for (int i = 0; i < BK_LOOP; i++) {
#pragma GCC unroll 4
for (uint32_t local_k = 0; local_k < BK; local_k += TCK) {
#pragma GCC unroll 2
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
// SMEM -> RF
vx_wmma_load_b<T>(local_b_consume, local_k, warp_col, wn_iter,
tid_in_warp);
#pragma GCC unroll 2
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
// SMEM -> RF
vx_wmma_load_a<T>(local_a_consume, local_k, warp_row, wm_iter,
tid_in_warp);
// perform mma
vx_wmma(wm_iter);
}
}
}
}
thread_block_gemm_single_tile(local_a_consume, local_b_consume,
tid_in_threadblock,
threads_per_threadblock);
if constexpr (GEMMINI_DMA) {
// Call gemmini fence at the end of the loop to overlap dma & wmma.