From deb6e5eba27703444d3e897442ad61cdb7c0f8a2 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Thu, 6 Jun 2024 12:43:08 -0700 Subject: [PATCH] sgemm_tcore: Move bank-conflicts to SMEM stores from GMEM loads --- tests/regression/sgemm_tcore/kernel.cpp | 63 ++++++++++++++++++++++++- 1 file changed, 61 insertions(+), 2 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 899afd8a..fd2e73da 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -14,6 +14,10 @@ // scenario #define BK_LOOP 1 #define TRANSPOSE_AS 1 +// GMEM_COALESCED sets bank conflict-free accesses for +// 1: GMEM loads of A matrix +// 0: SMEM stores of A matrix +#define GMEM_COALESCED_A 1 #define DOUBLE_BUFFER 1 @@ -31,7 +35,7 @@ // BM <= BK*TM*TN #define BM 32 #define BN 32 -#define BK 8 +#define BK 32 #define WM 16 #define WN 8 #define TCM 8 @@ -326,6 +330,7 @@ inline void write_results(const int thread_in_warp, const int warp_col, inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count) { vx_fence(); vx_barrier(barrier_id, count); + // vx_barrier(0, count); } inline void @@ -373,6 +378,7 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, local_a_tmp += BK * row_stride_a; } } else { +#if !GMEM_COALESCED_A constexpr uint32_t row_stride_as = threads_in_warpgroup / BM_d; const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_col; const float *global_a = A + dim_k * global_a_row + (k + local_as_row); @@ -388,7 +394,6 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, (BK % (row_stride_as * 8)) == 0, "manual loop unrolling condition not met; BK should be power-of-two"); - #pragma GCC unroll 2 for (uint32_t local_row_offset = 0; local_row_offset < BK; local_row_offset += row_stride_as * 8) { @@ -429,6 +434,59 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, asm volatile ("fsw ft7, %0(%1)" :: "i"(BM * row_stride_as * 7 * sizeof(float)), "r"(local_a_tmp)); local_a_tmp += BM * row_stride_as * 8; } +#else + constexpr uint32_t row_stride_a = threads_in_warpgroup / BK; + const uint32_t global_a_row = BM_d * threadblock_id_y + local_a_row; + const float *global_a = A + dim_k * global_a_row + (k + local_a_col); + // NOTE that SMEM writes are transposed + volatile float *local_a_tmp = local_a + BM_d * local_a_col + local_a_row; + + static_assert( + row_stride_a * 8 <= BM_d, + "manual loop unrolling condition not met; consider increasing BM"); + static_assert( + (BM_d % (row_stride_a * 8)) == 0, + "manual loop unrolling condition not met; BM should be power-of-two"); + +#pragma GCC unroll 2 + for (uint32_t local_row_offset = 0; local_row_offset < BM_d; + local_row_offset += row_stride_a * 8) { + // const uint32_t global_a_offset = + // dim_k * (global_a_row + local_row_offset) + (k + local_a_col); + // NOTE that SMEM writes are transposed + // local_a[BM_d * (local_a_col) + local_a_row + local_row_offset] = + // A[global_a_offset]; + + // *local_a_tmp = *global_a; + asm volatile ("flw ft0, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft1, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft2, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft3, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft4, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft5, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft6, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft7, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + + // stride along columns + asm volatile ("fsw ft0, %0(%1)" :: "i"(row_stride_a * 0 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft1, %0(%1)" :: "i"(row_stride_a * 1 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft2, %0(%1)" :: "i"(row_stride_a * 2 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft3, %0(%1)" :: "i"(row_stride_a * 3 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft4, %0(%1)" :: "i"(row_stride_a * 4 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft5, %0(%1)" :: "i"(row_stride_a * 5 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft6, %0(%1)" :: "i"(row_stride_a * 6 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft7, %0(%1)" :: "i"(row_stride_a * 7 * sizeof(float)), "r"(local_a_tmp)); + local_a_tmp += row_stride_a * 8; + } +#endif } constexpr uint32_t row_stride_b = threads_in_warpgroup / BN_d; @@ -585,6 +643,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); } else { + // NOTE: this *should* be signed integer to trigger arithmetic right-shift int32_t k_index = 0; #pragma GCC unroll 1 for (uint32_t k = 0; k < dim_k; k += BK) {