sgemm_impl: Rename initialize_C
This commit is contained in:
@@ -320,9 +320,10 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k,
|
||||
asm volatile ("wmma_load_b_finish_%=:" :: );
|
||||
}
|
||||
|
||||
inline void initialize_C(const int dest_reg) {
|
||||
// initialize C to zeros
|
||||
if (dest_reg == 0) {
|
||||
// Initialize the accumulator registers to zero before starting FMA operations
|
||||
// with the tensor cores.
|
||||
template <int accum_reg_set> inline void initialize_accum_regs() {
|
||||
if constexpr (accum_reg_set == 0) {
|
||||
asm volatile("fmv.w.x f16, x0");
|
||||
asm volatile("fmv.w.x f17, x0");
|
||||
asm volatile("fmv.w.x f18, x0");
|
||||
@@ -650,13 +651,6 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
||||
const uint32_t threadblocks_per_cluster,
|
||||
const uint32_t threadblock_id_in_cluster,
|
||||
uint8_t *sharedmem_per_threadblock) {
|
||||
const uint32_t local_a_row = tid_in_threadblock / BK;
|
||||
const uint32_t local_a_col = tid_in_threadblock % BK;
|
||||
const uint32_t local_as_row = tid_in_threadblock / BM;
|
||||
const uint32_t local_as_col = tid_in_threadblock % BM;
|
||||
const uint32_t local_b_row = tid_in_threadblock / BN;
|
||||
const uint32_t local_b_col = tid_in_threadblock % BN;
|
||||
|
||||
// no double-buffering
|
||||
const uint32_t threads_per_warpgroup = threads_per_threadblock;
|
||||
const uint32_t warp_id_in_warpgroup = tid_in_threadblock / NUM_THREADS;
|
||||
@@ -703,9 +697,9 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
||||
for (uint32_t block_m = block_m_start; block_m < block_m_end; block_m++) {
|
||||
#pragma GCC unroll 1
|
||||
for (uint32_t block_n = 0; (block_n * BN) < dim_n; block_n++) {
|
||||
// clear out C
|
||||
initialize_C(0);
|
||||
initialize_C(1);
|
||||
// clear out accumulators
|
||||
initialize_accum_regs<0>();
|
||||
initialize_accum_regs<1>();
|
||||
|
||||
if constexpr (GEMMINI_DMA) {
|
||||
// pipeline initiation
|
||||
|
||||
Reference in New Issue
Block a user