sgemm_impl: Add skeleton wgmma routine for single_tile
This commit is contained in:
@@ -233,6 +233,18 @@ inline void vx_wmma(const int dest_reg) {
|
||||
}
|
||||
}
|
||||
|
||||
inline void vx_wgmma() {
|
||||
// .insn r opcode6, func3, func7, rd, rs1, rs2
|
||||
// https://www.rowleydownload.co.uk/arm/documentation/gnu/as/RISC_002dV_002dFormats.html#RISC_002dV_002dFormats
|
||||
asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3));
|
||||
}
|
||||
|
||||
inline void vx_wgmma_wait() {
|
||||
// .insn r opcode6, func3, func7, rd, rs1, rs2
|
||||
// func3 == 1 encodes wait
|
||||
asm volatile (".insn r %0, 1, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3));
|
||||
}
|
||||
|
||||
// Remap logical row/col coordinate of a matrix element to a memory index that
|
||||
// follows the 2-level block-row-major layout that Gemmini DMA uses
|
||||
template <bool use_dma, uint32_t dim_col>
|
||||
@@ -804,9 +816,19 @@ __attribute__((always_inline)) inline void thread_block_gemm_single_tile(
|
||||
const uint32_t warps_per_threadblock_per_core =
|
||||
NUM_WARPS / threadblocks_per_cluster;
|
||||
|
||||
// TODO: it would be useful if this bit is split out into a function, so that
|
||||
// preloading accumulation tile can be used for full GEMMs at the start of
|
||||
// the K-loop.
|
||||
if constexpr (TENSOR_HOPPER) {
|
||||
#pragma GCC unroll 1
|
||||
for (int i = 0; i < BK_LOOP; i++) {
|
||||
#pragma GCC unroll 4
|
||||
for (uint32_t local_k = 0; local_k < tile_dim_k; local_k += TCK) {
|
||||
// FIXME: use local_a and local_b here
|
||||
vx_wgmma();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// TODO: it would be useful if this bit is split out into a function, so
|
||||
// that preloading accumulation tile can be used for full GEMMs at the start
|
||||
// of the K-loop.
|
||||
if constexpr (load_accum) {
|
||||
#pragma GCC unroll
|
||||
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
||||
@@ -837,8 +859,8 @@ __attribute__((always_inline)) inline void thread_block_gemm_single_tile(
|
||||
wmma_load_a<T, layout_a, tile_dim_m, tile_dim_n, tile_dim_k>(
|
||||
local_a, local_k, warp_row, wm_iter, tid_in_warp);
|
||||
} else {
|
||||
wmma_load_a<T, layout_a, leading_dim_a>(local_a, local_k, warp_row,
|
||||
wm_iter, tid_in_warp);
|
||||
wmma_load_a<T, layout_a, leading_dim_a>(
|
||||
local_a, local_k, warp_row, wm_iter, tid_in_warp);
|
||||
}
|
||||
// perform mma
|
||||
vx_wmma(wm_iter);
|
||||
@@ -846,6 +868,7 @@ __attribute__((always_inline)) inline void thread_block_gemm_single_tile(
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (GEMMINI_DMA) {
|
||||
// Call gemmini fence at the end of the loop to overlap dma & wmma.
|
||||
|
||||
Reference in New Issue
Block a user