sgemm_tcore: Improve agen for !transpose_as smem load

This commit is contained in:
Hansung Kim
2024-06-10 22:08:37 -07:00
parent dc7bd6b248
commit e3c4a4d2f5

View File

@@ -173,18 +173,28 @@ inline void vx_wmma_load_a(volatile float *smem_A, const int local_k,
constexpr int smem_AS_cols = BM;
if constexpr (!TRANSPOSE_AS) {
int A_offset = (WM * warp_row + TCM * wm_iter + row) * smem_A_cols;
// int A_offset = (WM * warp_row + TCM * wm_iter + row) * smem_A_cols;
// @perf: bank conflicts
// f8-f15 stores a single row of A
asm volatile("flw f0, %0" ::"m"(smem_A[A_offset + (local_k + 0)]));
asm volatile("flw f1, %0" ::"m"(smem_A[A_offset + (local_k + 1)]));
asm volatile("flw f2, %0" ::"m"(smem_A[A_offset + (local_k + 2)]));
asm volatile("flw f3, %0" ::"m"(smem_A[A_offset + (local_k + 3)]));
asm volatile("flw f4, %0" ::"m"(smem_A[A_offset + (local_k + 4)]));
asm volatile("flw f5, %0" ::"m"(smem_A[A_offset + (local_k + 5)]));
asm volatile("flw f6, %0" ::"m"(smem_A[A_offset + (local_k + 6)]));
asm volatile("flw f7, %0" ::"m"(smem_A[A_offset + (local_k + 7)]));
volatile float *smem_addr;
smem_addr = &smem_A[(WM * warp_row + TCM * wm_iter + row) * smem_A_cols + local_k];
asm volatile("flw f0, %0(%1)" ::"i"(0 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f1, %0(%1)" ::"i"(1 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f2, %0(%1)" ::"i"(2 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f3, %0(%1)" ::"i"(3 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f4, %0(%1)" ::"i"(4 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f5, %0(%1)" ::"i"(5 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f6, %0(%1)" ::"i"(6 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f7, %0(%1)" ::"i"(7 * sizeof(float)), "r"(smem_addr));
// asm volatile("flw f0, %0" ::"m"(smem_A[A_offset + (local_k + 0)]));
// asm volatile("flw f1, %0" ::"m"(smem_A[A_offset + (local_k + 1)]));
// asm volatile("flw f2, %0" ::"m"(smem_A[A_offset + (local_k + 2)]));
// asm volatile("flw f3, %0" ::"m"(smem_A[A_offset + (local_k + 3)]));
// asm volatile("flw f4, %0" ::"m"(smem_A[A_offset + (local_k + 4)]));
// asm volatile("flw f5, %0" ::"m"(smem_A[A_offset + (local_k + 5)]));
// asm volatile("flw f6, %0" ::"m"(smem_A[A_offset + (local_k + 6)]));
// asm volatile("flw f7, %0" ::"m"(smem_A[A_offset + (local_k + 7)]));
} else {
// transposed A
// f8-f15 stores a single row of A
@@ -610,7 +620,10 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
k_index++;
// producer code: GMEM->SMEM memory movement
// ----------------------------------------------------------------------
// ---------------------------------------------------------------------
//
// this is either done using DMA or SIMT cores depending on GEMMINI_DMA
#if (GEMMINI_DMA == 1)
if (tid_in_threadblock == 0) {
// configure dma gmem address to load from