sgemm_impl: Use 12-bit cmd interface, allow DIM=16

This commit is contained in:
Hansung Kim
2024-09-08 14:28:27 -07:00
parent adcd0a9d49
commit 42913c00c4

View File

@@ -207,7 +207,7 @@ template <bool use_dma, uint32_t dim_col>
inline constexpr std::pair<uint32_t, uint32_t>
remap_to_gemmini_dma_layout(const uint32_t logical_row,
const uint32_t logical_col) {
static_assert(DIM == 8,
static_assert(GEMMINI_DMA_FLEXIBLE_LAYOUT || DIM == 8,
"GEMMINI_DMA layout remapping code only written for DIM == 8");
if constexpr (use_dma) {
@@ -905,6 +905,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
for (uint32_t block_m = block_m_start; block_m < block_m_end; block_m++) {
#pragma GCC unroll 1
for (uint32_t block_n = 0; (block_n * BN) < dim_n; block_n++) {
asm volatile ("loop_mn_start_%=:" :: );
// clear out accumulators
initialize_accum_regs<0>();
initialize_accum_regs<1>();
@@ -920,7 +922,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
(uint64_t)(B + /*block_k:*/0 * BK * dim_n + block_n * BN),
k_LOOP_WS_CONFIG_ADDRS_AB)
// GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB
GEMMINI_CISC_CMD_R((dim_n << 16) | (dim_k << 8) | 8);
GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8);
gemmini_fence();
GEMMINI_CISC_CMD_I(10);
@@ -951,6 +953,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
#pragma GCC unroll 1
for (uint32_t block_k = 0; (block_k * BK) < dim_k; block_k++) {
asm volatile("loop_k_start_%=:" ::);
// producer code: GMEM->SMEM memory movement
// ---------------------------------------------------------------------
@@ -967,8 +970,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
(uint64_t)(B + (block_k + 1/*runahead*/) * BK * dim_n + block_n * BN),
k_LOOP_WS_CONFIG_ADDRS_AB)
// GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB
GEMMINI_CISC_CMD_R((dim_n << 16) | (dim_k << 8) | 8);
// gemmini_fence();
GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8);
gemmini_fence();
// block_k is even: opcode 11 (write to local_a_buf)
// block_k is odd: opcode 10 (write to local_a)
@@ -1043,6 +1046,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
// consumer code: SMEM->RF and compute
// ----------------------------------------------------------------------
// @perf: this loop spills to stack a lot because of all the flws in
asm volatile("dbuf_sel_start_%=:" ::);
const T *local_a_consume;
const T *local_b_consume;
if constexpr (GEMMINI_DMA) {
@@ -1064,6 +1068,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
local_a_consume = local_a;
local_b_consume = local_b;
}
asm volatile("dbuf_sel_end_%=:" ::);
constexpr MemLayout layout_a =
GEMMINI_DMA ? MemLayout::block_row_major
@@ -1092,6 +1097,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
asm volatile("loop_k_end_%=:" ::);
}
if constexpr (write_to_gmem) {
@@ -1106,6 +1113,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
}
}
}
asm volatile("loop_mn_end_%=:" ::);
}
}