From 42913c00c410541bccada1321202b42a5d024232 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sun, 8 Sep 2024 14:28:27 -0700 Subject: [PATCH] sgemm_impl: Use 12-bit cmd interface, allow DIM=16 --- tests/regression/sgemm_tcore/sgemm_impl.hpp | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index 0c6274a2..d2e88ace 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -207,7 +207,7 @@ template inline constexpr std::pair remap_to_gemmini_dma_layout(const uint32_t logical_row, const uint32_t logical_col) { - static_assert(DIM == 8, + static_assert(GEMMINI_DMA_FLEXIBLE_LAYOUT || DIM == 8, "GEMMINI_DMA layout remapping code only written for DIM == 8"); if constexpr (use_dma) { @@ -905,6 +905,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, for (uint32_t block_m = block_m_start; block_m < block_m_end; block_m++) { #pragma GCC unroll 1 for (uint32_t block_n = 0; (block_n * BN) < dim_n; block_n++) { + asm volatile ("loop_mn_start_%=:" :: ); + // clear out accumulators initialize_accum_regs<0>(); initialize_accum_regs<1>(); @@ -920,7 +922,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, (uint64_t)(B + /*block_k:*/0 * BK * dim_n + block_n * BN), k_LOOP_WS_CONFIG_ADDRS_AB) // GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB - GEMMINI_CISC_CMD_R((dim_n << 16) | (dim_k << 8) | 8); + GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8); gemmini_fence(); GEMMINI_CISC_CMD_I(10); @@ -951,6 +953,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, #pragma GCC unroll 1 for (uint32_t block_k = 0; (block_k * BK) < dim_k; block_k++) { + asm volatile("loop_k_start_%=:" ::); // producer code: GMEM->SMEM memory movement // --------------------------------------------------------------------- @@ -967,8 +970,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, (uint64_t)(B + (block_k + 1/*runahead*/) * BK * dim_n + block_n * BN), k_LOOP_WS_CONFIG_ADDRS_AB) // GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB - GEMMINI_CISC_CMD_R((dim_n << 16) | (dim_k << 8) | 8); - // gemmini_fence(); + GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8); + gemmini_fence(); // block_k is even: opcode 11 (write to local_a_buf) // block_k is odd: opcode 10 (write to local_a) @@ -1043,6 +1046,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, // consumer code: SMEM->RF and compute // ---------------------------------------------------------------------- // @perf: this loop spills to stack a lot because of all the flws in + asm volatile("dbuf_sel_start_%=:" ::); const T *local_a_consume; const T *local_b_consume; if constexpr (GEMMINI_DMA) { @@ -1064,6 +1068,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, local_a_consume = local_a; local_b_consume = local_b; } + asm volatile("dbuf_sel_end_%=:" ::); constexpr MemLayout layout_a = GEMMINI_DMA ? MemLayout::block_row_major @@ -1092,6 +1097,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core); + + asm volatile("loop_k_end_%=:" ::); } if constexpr (write_to_gmem) { @@ -1106,6 +1113,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, } } } + asm volatile("loop_mn_end_%=:" ::); } }