diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index db6800fc..f3cfa1a3 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -320,9 +320,10 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k, asm volatile ("wmma_load_b_finish_%=:" :: ); } -inline void initialize_C(const int dest_reg) { - // initialize C to zeros - if (dest_reg == 0) { +// Initialize the accumulator registers to zero before starting FMA operations +// with the tensor cores. +template inline void initialize_accum_regs() { + if constexpr (accum_reg_set == 0) { asm volatile("fmv.w.x f16, x0"); asm volatile("fmv.w.x f17, x0"); asm volatile("fmv.w.x f18, x0"); @@ -650,13 +651,6 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, const uint32_t threadblocks_per_cluster, const uint32_t threadblock_id_in_cluster, uint8_t *sharedmem_per_threadblock) { - const uint32_t local_a_row = tid_in_threadblock / BK; - const uint32_t local_a_col = tid_in_threadblock % BK; - const uint32_t local_as_row = tid_in_threadblock / BM; - const uint32_t local_as_col = tid_in_threadblock % BM; - const uint32_t local_b_row = tid_in_threadblock / BN; - const uint32_t local_b_col = tid_in_threadblock % BN; - // no double-buffering const uint32_t threads_per_warpgroup = threads_per_threadblock; const uint32_t warp_id_in_warpgroup = tid_in_threadblock / NUM_THREADS; @@ -703,9 +697,9 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, for (uint32_t block_m = block_m_start; block_m < block_m_end; block_m++) { #pragma GCC unroll 1 for (uint32_t block_n = 0; (block_n * BN) < dim_n; block_n++) { - // clear out C - initialize_C(0); - initialize_C(1); + // clear out accumulators + initialize_accum_regs<0>(); + initialize_accum_regs<1>(); if constexpr (GEMMINI_DMA) { // pipeline initiation