From 635da96154ac883ce45a39edc32c0e47300f7800 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Tue, 11 Jun 2024 22:49:59 -0700 Subject: [PATCH] sgemm_tcore: Constify smem pointer for wmma_load --- tests/regression/sgemm_tcore/kernel.warpspecial.cpp | 4 ++-- tests/regression/sgemm_tcore/util.hpp | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.warpspecial.cpp b/tests/regression/sgemm_tcore/kernel.warpspecial.cpp index 7da03a99..5cb8c7fb 100644 --- a/tests/regression/sgemm_tcore/kernel.warpspecial.cpp +++ b/tests/regression/sgemm_tcore/kernel.warpspecial.cpp @@ -381,8 +381,8 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, int32_t k_index = 0; #pragma GCC unroll 1 for (uint32_t k = 0; k < (dim_k); k += BK) { - volatile float *local_a_consume; - volatile float *local_b_consume; + const volatile float *local_a_consume; + const volatile float *local_b_consume; if constexpr (DOUBLE_BUFFER) { // local_a_consume = (k_index % 2) ? local_a_buf : local_a; // local_b_consume = (k_index % 2) ? local_b_buf : local_b; diff --git a/tests/regression/sgemm_tcore/util.hpp b/tests/regression/sgemm_tcore/util.hpp index b4634b2b..5d54cd4b 100644 --- a/tests/regression/sgemm_tcore/util.hpp +++ b/tests/regression/sgemm_tcore/util.hpp @@ -147,7 +147,7 @@ inline void vx_wmma(const int dest_reg) { } // `local_k` is assumed to be multiple of TCK -inline void vx_wmma_load_a(volatile float *smem_A, const int local_k, +inline void vx_wmma_load_a(volatile const float *smem_A, const int local_k, const int warp_row, const int wm_iter, const int thread_in_warp) { const int tid = thread_in_warp; const int tg = tid / 4; @@ -167,7 +167,7 @@ inline void vx_wmma_load_a(volatile float *smem_A, const int local_k, // @perf: bank conflicts // f8-f15 stores a single row of A - volatile float *smem_addr; + const volatile float *smem_addr; smem_addr = &smem_A[(WM * warp_row + TCM * wm_iter + row) * smem_A_cols + local_k]; asm volatile("flw f0, %0(%1)" ::"i"(0 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f1, %0(%1)" ::"i"(1 * sizeof(float)), "r"(smem_addr)); @@ -188,7 +188,7 @@ inline void vx_wmma_load_a(volatile float *smem_A, const int local_k, } else { // read smem A tile as-is; bank-conflict-free AS load // f8-f15 stores a single row of A - volatile float *smem_addr; + const volatile float *smem_addr; smem_addr = &smem_A[((local_k + 0) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]; asm volatile("flw f0, %0(%1)" :: "i"(smem_AS_cols * 0 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f1, %0(%1)" :: "i"(smem_AS_cols * 1 * sizeof(float)), "r"(smem_addr)); @@ -211,7 +211,7 @@ inline void vx_wmma_load_a(volatile float *smem_A, const int local_k, } // `local_k` is assumed to be multiple of TCK -inline void vx_wmma_load_b(volatile float *smem_B, const int local_k, +inline void vx_wmma_load_b(const volatile float *smem_B, const int local_k, const int warp_col, const int wn_iter, const int thread_in_warp) { const int tid = thread_in_warp; @@ -225,7 +225,7 @@ inline void vx_wmma_load_b(volatile float *smem_B, const int local_k, constexpr int smem_B_cols = BN; // f8-f15 stores a single column of B - volatile float *smem_addr; + const volatile float *smem_addr; smem_addr = &smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]; asm volatile("flw f8, %0(%1)" :: "i"(smem_B_cols * 0 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f9, %0(%1)" :: "i"(smem_B_cols * 1 * sizeof(float)), "r"(smem_addr));