sgemm_tcore: Move bank-conflicts to SMEM stores from GMEM loads
This commit is contained in:
@@ -14,6 +14,10 @@
|
||||
// scenario
|
||||
#define BK_LOOP 1
|
||||
#define TRANSPOSE_AS 1
|
||||
// GMEM_COALESCED sets bank conflict-free accesses for
|
||||
// 1: GMEM loads of A matrix
|
||||
// 0: SMEM stores of A matrix
|
||||
#define GMEM_COALESCED_A 1
|
||||
|
||||
#define DOUBLE_BUFFER 1
|
||||
|
||||
@@ -31,7 +35,7 @@
|
||||
// BM <= BK*TM*TN
|
||||
#define BM 32
|
||||
#define BN 32
|
||||
#define BK 8
|
||||
#define BK 32
|
||||
#define WM 16
|
||||
#define WN 8
|
||||
#define TCM 8
|
||||
@@ -326,6 +330,7 @@ inline void write_results(const int thread_in_warp, const int warp_col,
|
||||
inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count) {
|
||||
vx_fence();
|
||||
vx_barrier(barrier_id, count);
|
||||
// vx_barrier(0, count);
|
||||
}
|
||||
|
||||
inline void
|
||||
@@ -373,6 +378,7 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k,
|
||||
local_a_tmp += BK * row_stride_a;
|
||||
}
|
||||
} else {
|
||||
#if !GMEM_COALESCED_A
|
||||
constexpr uint32_t row_stride_as = threads_in_warpgroup / BM_d;
|
||||
const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_col;
|
||||
const float *global_a = A + dim_k * global_a_row + (k + local_as_row);
|
||||
@@ -388,7 +394,6 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k,
|
||||
(BK % (row_stride_as * 8)) == 0,
|
||||
"manual loop unrolling condition not met; BK should be power-of-two");
|
||||
|
||||
|
||||
#pragma GCC unroll 2
|
||||
for (uint32_t local_row_offset = 0; local_row_offset < BK;
|
||||
local_row_offset += row_stride_as * 8) {
|
||||
@@ -429,6 +434,59 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k,
|
||||
asm volatile ("fsw ft7, %0(%1)" :: "i"(BM * row_stride_as * 7 * sizeof(float)), "r"(local_a_tmp));
|
||||
local_a_tmp += BM * row_stride_as * 8;
|
||||
}
|
||||
#else
|
||||
constexpr uint32_t row_stride_a = threads_in_warpgroup / BK;
|
||||
const uint32_t global_a_row = BM_d * threadblock_id_y + local_a_row;
|
||||
const float *global_a = A + dim_k * global_a_row + (k + local_a_col);
|
||||
// NOTE that SMEM writes are transposed
|
||||
volatile float *local_a_tmp = local_a + BM_d * local_a_col + local_a_row;
|
||||
|
||||
static_assert(
|
||||
row_stride_a * 8 <= BM_d,
|
||||
"manual loop unrolling condition not met; consider increasing BM");
|
||||
static_assert(
|
||||
(BM_d % (row_stride_a * 8)) == 0,
|
||||
"manual loop unrolling condition not met; BM should be power-of-two");
|
||||
|
||||
#pragma GCC unroll 2
|
||||
for (uint32_t local_row_offset = 0; local_row_offset < BM_d;
|
||||
local_row_offset += row_stride_a * 8) {
|
||||
// const uint32_t global_a_offset =
|
||||
// dim_k * (global_a_row + local_row_offset) + (k + local_a_col);
|
||||
// NOTE that SMEM writes are transposed
|
||||
// local_a[BM_d * (local_a_col) + local_a_row + local_row_offset] =
|
||||
// A[global_a_offset];
|
||||
|
||||
// *local_a_tmp = *global_a;
|
||||
asm volatile ("flw ft0, (%0)" :: "r"(global_a));
|
||||
global_a += dim_k * row_stride_a;
|
||||
asm volatile ("flw ft1, (%0)" :: "r"(global_a));
|
||||
global_a += dim_k * row_stride_a;
|
||||
asm volatile ("flw ft2, (%0)" :: "r"(global_a));
|
||||
global_a += dim_k * row_stride_a;
|
||||
asm volatile ("flw ft3, (%0)" :: "r"(global_a));
|
||||
global_a += dim_k * row_stride_a;
|
||||
asm volatile ("flw ft4, (%0)" :: "r"(global_a));
|
||||
global_a += dim_k * row_stride_a;
|
||||
asm volatile ("flw ft5, (%0)" :: "r"(global_a));
|
||||
global_a += dim_k * row_stride_a;
|
||||
asm volatile ("flw ft6, (%0)" :: "r"(global_a));
|
||||
global_a += dim_k * row_stride_a;
|
||||
asm volatile ("flw ft7, (%0)" :: "r"(global_a));
|
||||
global_a += dim_k * row_stride_a;
|
||||
|
||||
// stride along columns
|
||||
asm volatile ("fsw ft0, %0(%1)" :: "i"(row_stride_a * 0 * sizeof(float)), "r"(local_a_tmp));
|
||||
asm volatile ("fsw ft1, %0(%1)" :: "i"(row_stride_a * 1 * sizeof(float)), "r"(local_a_tmp));
|
||||
asm volatile ("fsw ft2, %0(%1)" :: "i"(row_stride_a * 2 * sizeof(float)), "r"(local_a_tmp));
|
||||
asm volatile ("fsw ft3, %0(%1)" :: "i"(row_stride_a * 3 * sizeof(float)), "r"(local_a_tmp));
|
||||
asm volatile ("fsw ft4, %0(%1)" :: "i"(row_stride_a * 4 * sizeof(float)), "r"(local_a_tmp));
|
||||
asm volatile ("fsw ft5, %0(%1)" :: "i"(row_stride_a * 5 * sizeof(float)), "r"(local_a_tmp));
|
||||
asm volatile ("fsw ft6, %0(%1)" :: "i"(row_stride_a * 6 * sizeof(float)), "r"(local_a_tmp));
|
||||
asm volatile ("fsw ft7, %0(%1)" :: "i"(row_stride_a * 7 * sizeof(float)), "r"(local_a_tmp));
|
||||
local_a_tmp += row_stride_a * 8;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
constexpr uint32_t row_stride_b = threads_in_warpgroup / BN_d;
|
||||
@@ -585,6 +643,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
|
||||
} else {
|
||||
// NOTE: this *should* be signed integer to trigger arithmetic right-shift
|
||||
int32_t k_index = 0;
|
||||
#pragma GCC unroll 1
|
||||
for (uint32_t k = 0; k < dim_k; k += BK) {
|
||||
|
||||
Reference in New Issue
Block a user