sgemm_tcore: Constify smem pointer for wmma_load
This commit is contained in:
@@ -381,8 +381,8 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
int32_t k_index = 0;
|
||||
#pragma GCC unroll 1
|
||||
for (uint32_t k = 0; k < (dim_k); k += BK) {
|
||||
volatile float *local_a_consume;
|
||||
volatile float *local_b_consume;
|
||||
const volatile float *local_a_consume;
|
||||
const volatile float *local_b_consume;
|
||||
if constexpr (DOUBLE_BUFFER) {
|
||||
// local_a_consume = (k_index % 2) ? local_a_buf : local_a;
|
||||
// local_b_consume = (k_index % 2) ? local_b_buf : local_b;
|
||||
|
||||
@@ -147,7 +147,7 @@ inline void vx_wmma(const int dest_reg) {
|
||||
}
|
||||
|
||||
// `local_k` is assumed to be multiple of TCK
|
||||
inline void vx_wmma_load_a(volatile float *smem_A, const int local_k,
|
||||
inline void vx_wmma_load_a(volatile const float *smem_A, const int local_k,
|
||||
const int warp_row, const int wm_iter, const int thread_in_warp) {
|
||||
const int tid = thread_in_warp;
|
||||
const int tg = tid / 4;
|
||||
@@ -167,7 +167,7 @@ inline void vx_wmma_load_a(volatile float *smem_A, const int local_k,
|
||||
|
||||
// @perf: bank conflicts
|
||||
// f8-f15 stores a single row of A
|
||||
volatile float *smem_addr;
|
||||
const volatile float *smem_addr;
|
||||
smem_addr = &smem_A[(WM * warp_row + TCM * wm_iter + row) * smem_A_cols + local_k];
|
||||
asm volatile("flw f0, %0(%1)" ::"i"(0 * sizeof(float)), "r"(smem_addr));
|
||||
asm volatile("flw f1, %0(%1)" ::"i"(1 * sizeof(float)), "r"(smem_addr));
|
||||
@@ -188,7 +188,7 @@ inline void vx_wmma_load_a(volatile float *smem_A, const int local_k,
|
||||
} else {
|
||||
// read smem A tile as-is; bank-conflict-free AS load
|
||||
// f8-f15 stores a single row of A
|
||||
volatile float *smem_addr;
|
||||
const volatile float *smem_addr;
|
||||
smem_addr = &smem_A[((local_k + 0) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row];
|
||||
asm volatile("flw f0, %0(%1)" :: "i"(smem_AS_cols * 0 * sizeof(float)), "r"(smem_addr));
|
||||
asm volatile("flw f1, %0(%1)" :: "i"(smem_AS_cols * 1 * sizeof(float)), "r"(smem_addr));
|
||||
@@ -211,7 +211,7 @@ inline void vx_wmma_load_a(volatile float *smem_A, const int local_k,
|
||||
}
|
||||
|
||||
// `local_k` is assumed to be multiple of TCK
|
||||
inline void vx_wmma_load_b(volatile float *smem_B, const int local_k,
|
||||
inline void vx_wmma_load_b(const volatile float *smem_B, const int local_k,
|
||||
const int warp_col, const int wn_iter,
|
||||
const int thread_in_warp) {
|
||||
const int tid = thread_in_warp;
|
||||
@@ -225,7 +225,7 @@ inline void vx_wmma_load_b(volatile float *smem_B, const int local_k,
|
||||
constexpr int smem_B_cols = BN;
|
||||
|
||||
// f8-f15 stores a single column of B
|
||||
volatile float *smem_addr;
|
||||
const volatile float *smem_addr;
|
||||
smem_addr = &smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col];
|
||||
asm volatile("flw f8, %0(%1)" :: "i"(smem_B_cols * 0 * sizeof(float)), "r"(smem_addr));
|
||||
asm volatile("flw f9, %0(%1)" :: "i"(smem_B_cols * 1 * sizeof(float)), "r"(smem_addr));
|
||||
|
||||
Reference in New Issue
Block a user