From 3b2f5a31de064a5fc7d7f2a08c707d34629c6535 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Mon, 10 Jun 2024 19:34:00 -0700 Subject: [PATCH] sgemm_tcore: Improve write_result addr gen --- tests/regression/sgemm_tcore/kernel.cpp | 48 +++++++++++-------------- 1 file changed, 20 insertions(+), 28 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index a6693605..cb10c14d 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -292,20 +292,16 @@ inline void write_results(const int thread_in_warp, const int warp_col, // @perf: this likely causes a lot of gmem bank conflicts if (wm_iter == 0) { - volatile float *gmem_addr; - volatile float *gmem_addr_tmp; - gmem_addr = &global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]; - asm volatile ("fsw f16, %0" :: "m"(*(gmem_addr + 0))); - asm volatile ("fsw f17, %0" :: "m"(*(gmem_addr + 1))); - gmem_addr_tmp = gmem_addr + (2 * dim_n); - asm volatile ("fsw f18, %0" :: "m"(*(gmem_addr_tmp + 0))); - asm volatile ("fsw f19, %0" :: "m"(*(gmem_addr_tmp + 1))); - gmem_addr += 4; - asm volatile ("fsw f20, %0" :: "m"(*(gmem_addr + 0))); - asm volatile ("fsw f21, %0" :: "m"(*(gmem_addr + 1))); - gmem_addr_tmp = gmem_addr + (2 * dim_n); - asm volatile ("fsw f22, %0" :: "m"(*(gmem_addr_tmp + 0))); - asm volatile ("fsw f23, %0" :: "m"(*(gmem_addr_tmp + 1))); + volatile float *gmem_addr = &global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]; + volatile float *gmem_addr_tmp = gmem_addr + (2 * dim_n); + asm volatile ("fsw f16, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr)); + asm volatile ("fsw f17, %0(%1)" :: "i"(1 * sizeof(float)), "r"(gmem_addr)); + asm volatile ("fsw f18, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr_tmp)); + asm volatile ("fsw f19, %0(%1)" :: "i"(1 * sizeof(float)), "r"(gmem_addr_tmp)); + asm volatile ("fsw f20, %0(%1)" :: "i"(4 * sizeof(float)), "r"(gmem_addr)); + asm volatile ("fsw f21, %0(%1)" :: "i"(5 * sizeof(float)), "r"(gmem_addr)); + asm volatile ("fsw f22, %0(%1)" :: "i"(4 * sizeof(float)), "r"(gmem_addr_tmp)); + asm volatile ("fsw f23, %0(%1)" :: "i"(5 * sizeof(float)), "r"(gmem_addr_tmp)); // asm volatile ("fsw f16, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 0)])); // asm volatile ("fsw f17, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 1)])); // asm volatile ("fsw f18, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 0)])); @@ -315,20 +311,16 @@ inline void write_results(const int thread_in_warp, const int warp_col, // asm volatile ("fsw f22, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 4)])); // asm volatile ("fsw f23, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)])); } else { - volatile float *gmem_addr; - volatile float *gmem_addr_tmp; - gmem_addr = &global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]; - gmem_addr_tmp = gmem_addr + (2 * dim_n); - asm volatile ("fsw f24, %0" :: "m"(*(gmem_addr + 0))); - asm volatile ("fsw f25, %0" :: "m"(*(gmem_addr + 1))); - asm volatile ("fsw f26, %0" :: "m"(*(gmem_addr_tmp + 0))); - asm volatile ("fsw f27, %0" :: "m"(*(gmem_addr_tmp + 1))); - gmem_addr += 4; - gmem_addr_tmp = gmem_addr + (2 * dim_n); - asm volatile ("fsw f28, %0" :: "m"(*(gmem_addr + 0))); - asm volatile ("fsw f29, %0" :: "m"(*(gmem_addr + 1))); - asm volatile ("fsw f30, %0" :: "m"(*(gmem_addr_tmp + 0))); - asm volatile ("fsw f31, %0" :: "m"(*(gmem_addr_tmp + 1))); + volatile float *gmem_addr = &global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]; + volatile float *gmem_addr_tmp = gmem_addr + (2 * dim_n); + asm volatile ("fsw f24, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr)); + asm volatile ("fsw f25, %0(%1)" :: "i"(1 * sizeof(float)), "r"(gmem_addr)); + asm volatile ("fsw f26, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr_tmp)); + asm volatile ("fsw f27, %0(%1)" :: "i"(1 * sizeof(float)), "r"(gmem_addr_tmp)); + asm volatile ("fsw f28, %0(%1)" :: "i"(4 * sizeof(float)), "r"(gmem_addr)); + asm volatile ("fsw f29, %0(%1)" :: "i"(5 * sizeof(float)), "r"(gmem_addr)); + asm volatile ("fsw f30, %0(%1)" :: "i"(4 * sizeof(float)), "r"(gmem_addr_tmp)); + asm volatile ("fsw f31, %0(%1)" :: "i"(5 * sizeof(float)), "r"(gmem_addr_tmp)); } }