sgemm_tcore: Improve write_result addr gen

This commit is contained in:
Hansung Kim
2024-06-10 19:34:00 -07:00
parent a22762db94
commit 3b2f5a31de

View File

@@ -292,20 +292,16 @@ inline void write_results(const int thread_in_warp, const int warp_col,
// @perf: this likely causes a lot of gmem bank conflicts
if (wm_iter == 0) {
volatile float *gmem_addr;
volatile float *gmem_addr_tmp;
gmem_addr = &global_offset_C[dim_n * (local_row + 0) + (local_col + 0)];
asm volatile ("fsw f16, %0" :: "m"(*(gmem_addr + 0)));
asm volatile ("fsw f17, %0" :: "m"(*(gmem_addr + 1)));
gmem_addr_tmp = gmem_addr + (2 * dim_n);
asm volatile ("fsw f18, %0" :: "m"(*(gmem_addr_tmp + 0)));
asm volatile ("fsw f19, %0" :: "m"(*(gmem_addr_tmp + 1)));
gmem_addr += 4;
asm volatile ("fsw f20, %0" :: "m"(*(gmem_addr + 0)));
asm volatile ("fsw f21, %0" :: "m"(*(gmem_addr + 1)));
gmem_addr_tmp = gmem_addr + (2 * dim_n);
asm volatile ("fsw f22, %0" :: "m"(*(gmem_addr_tmp + 0)));
asm volatile ("fsw f23, %0" :: "m"(*(gmem_addr_tmp + 1)));
volatile float *gmem_addr = &global_offset_C[dim_n * (local_row + 0) + (local_col + 0)];
volatile float *gmem_addr_tmp = gmem_addr + (2 * dim_n);
asm volatile ("fsw f16, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr));
asm volatile ("fsw f17, %0(%1)" :: "i"(1 * sizeof(float)), "r"(gmem_addr));
asm volatile ("fsw f18, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr_tmp));
asm volatile ("fsw f19, %0(%1)" :: "i"(1 * sizeof(float)), "r"(gmem_addr_tmp));
asm volatile ("fsw f20, %0(%1)" :: "i"(4 * sizeof(float)), "r"(gmem_addr));
asm volatile ("fsw f21, %0(%1)" :: "i"(5 * sizeof(float)), "r"(gmem_addr));
asm volatile ("fsw f22, %0(%1)" :: "i"(4 * sizeof(float)), "r"(gmem_addr_tmp));
asm volatile ("fsw f23, %0(%1)" :: "i"(5 * sizeof(float)), "r"(gmem_addr_tmp));
// asm volatile ("fsw f16, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]));
// asm volatile ("fsw f17, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 1)]));
// asm volatile ("fsw f18, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 0)]));
@@ -315,20 +311,16 @@ inline void write_results(const int thread_in_warp, const int warp_col,
// asm volatile ("fsw f22, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 4)]));
// asm volatile ("fsw f23, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)]));
} else {
volatile float *gmem_addr;
volatile float *gmem_addr_tmp;
gmem_addr = &global_offset_C[dim_n * (local_row + 0) + (local_col + 0)];
gmem_addr_tmp = gmem_addr + (2 * dim_n);
asm volatile ("fsw f24, %0" :: "m"(*(gmem_addr + 0)));
asm volatile ("fsw f25, %0" :: "m"(*(gmem_addr + 1)));
asm volatile ("fsw f26, %0" :: "m"(*(gmem_addr_tmp + 0)));
asm volatile ("fsw f27, %0" :: "m"(*(gmem_addr_tmp + 1)));
gmem_addr += 4;
gmem_addr_tmp = gmem_addr + (2 * dim_n);
asm volatile ("fsw f28, %0" :: "m"(*(gmem_addr + 0)));
asm volatile ("fsw f29, %0" :: "m"(*(gmem_addr + 1)));
asm volatile ("fsw f30, %0" :: "m"(*(gmem_addr_tmp + 0)));
asm volatile ("fsw f31, %0" :: "m"(*(gmem_addr_tmp + 1)));
volatile float *gmem_addr = &global_offset_C[dim_n * (local_row + 0) + (local_col + 0)];
volatile float *gmem_addr_tmp = gmem_addr + (2 * dim_n);
asm volatile ("fsw f24, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr));
asm volatile ("fsw f25, %0(%1)" :: "i"(1 * sizeof(float)), "r"(gmem_addr));
asm volatile ("fsw f26, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr_tmp));
asm volatile ("fsw f27, %0(%1)" :: "i"(1 * sizeof(float)), "r"(gmem_addr_tmp));
asm volatile ("fsw f28, %0(%1)" :: "i"(4 * sizeof(float)), "r"(gmem_addr));
asm volatile ("fsw f29, %0(%1)" :: "i"(5 * sizeof(float)), "r"(gmem_addr));
asm volatile ("fsw f30, %0(%1)" :: "i"(4 * sizeof(float)), "r"(gmem_addr_tmp));
asm volatile ("fsw f31, %0(%1)" :: "i"(5 * sizeof(float)), "r"(gmem_addr_tmp));
}
}