sgemm_tcore: Fix addr gen for GMEM->SMEM for M-major A

This fixes correctness for TRANSPOSE_AT_PRODUCE/COLUMN=0/0, provided the
matrices are already stored in the correct layout in GMEM.
This commit is contained in:
Hansung Kim
2024-08-14 15:28:52 -07:00
parent 409424b032
commit 0534e5d1f6
2 changed files with 45 additions and 36 deletions

View File

@@ -40,13 +40,16 @@
// using float_type = float; // using float_type = float;
using float_type = float16_t; using float_type = float16_t;
// TODO: reduce args by passing leading A/B dimensions
template <typename T> template <typename T>
inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const uint32_t dim_k,
const uint32_t k, const T *A, const T *B, const uint32_t k, const T *A, const T *B,
volatile T *local_a, volatile T *local_b, volatile T *local_a, volatile T *local_b,
const uint32_t tid_in_threadblock, const uint32_t tid_in_threadblock,
const uint32_t threadblock_id_x, const uint32_t threadblock_id_x,
const uint32_t threadblock_id_y) { const uint32_t threadblock_id_y) {
asm volatile ("global_dmem_load_start_%=:" :: );
// In fp16 mode, bit-pack two fp16 elements into each fp32 element, and do // In fp16 mode, bit-pack two fp16 elements into each fp32 element, and do
// data movement at the fp32 granularity. Assuming that the matrix is stored // data movement at the fp32 granularity. Assuming that the matrix is stored
// row-major in GMEM, the packed fp16 pairs belong to the same row, // row-major in GMEM, the packed fp16 pairs belong to the same row,
@@ -79,28 +82,28 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
// move A // move A
if constexpr (!TRANSPOSE_AT_PRODUCE) { if constexpr (!TRANSPOSE_AT_PRODUCE) {
// No transpose at GMEM->SMEM movement // A is stored M-major in GMEM;
// FIXME: !TRANSPOSE_AS code is old // no transpose at GMEM->SMEM movement
const uint32_t block_m = threadblock_id_y;
const uint32_t global_a_row = BM * threadblock_id_y + local_a_row; const uint32_t global_a_row = k_adjusted + local_as_row;
const uint32_t global_a_col = BM * block_m + local_as_col;
// number of rows a full TB can read at a time // number of rows a full TB can read at a time
// this is equivalent to threadblock_dim_y (assuming threadblock_dim_x == // this is equivalent to threadblock_dim_y (assuming threadblock_dim_x ==
// BK) // BK)
constexpr uint32_t row_stride_a = threads_in_threadblock / BK_adjusted; constexpr uint32_t row_stride_as = threads_in_threadblock / BM;
const float *global_a = reinterpret_cast<const float *>(A) + const float *global_a = reinterpret_cast<const float *>(A) +
dim_k_adjusted * global_a_row + dim_m * global_a_row + global_a_col;
(k_adjusted + local_a_col);
volatile float *local_a_tmp = reinterpret_cast<volatile float *>(local_a) + volatile float *local_a_tmp = reinterpret_cast<volatile float *>(local_a) +
BK_adjusted * local_a_row + local_a_col; BM * local_as_row + local_as_col;
#pragma GCC unroll 1 #pragma GCC unroll 1
for (uint32_t local_row_offset = 0; local_row_offset < BM; for (uint32_t local_row_offset = 0; local_row_offset < BK_adjusted;
local_row_offset += row_stride_a) { local_row_offset += row_stride_as) {
// TODO: the code GCC generates for below seems fine atm, but unroll to
// assembly to be absolutely sure
*local_a_tmp = *global_a; *local_a_tmp = *global_a;
global_a += dim_m * row_stride_as;
// move to the next "row-chunk", when threadblock is smaller than BM*BK local_a_tmp += BM * row_stride_as;
global_a += dim_k_adjusted * row_stride_a;
local_a_tmp += BK_adjusted * row_stride_a;
} }
} else { } else {
if constexpr (!GMEM_COALESCED_A) { if constexpr (!GMEM_COALESCED_A) {
@@ -126,9 +129,6 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
// @perf: bank conflicts here // @perf: bank conflicts here
// const uint32_t global_a_offset = // const uint32_t global_a_offset =
// dim_k_adjusted * (global_a_row) + (k + local_as_row + local_row_offset); // dim_k_adjusted * (global_a_row) + (k + local_as_row + local_row_offset);
// FIXME experimenting with global coalescing
// const uint32_t global_a_offset =
// dim_k_adjusted * (global_a_row + local_row_offset) + (k + local_as_col);
// local_a[BM * (local_as_row + local_row_offset) + local_as_col] = // local_a[BM * (local_as_row + local_row_offset) + local_as_col] =
// A[global_a_offset]; // A[global_a_offset];
@@ -278,6 +278,8 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
asm volatile ("fsw ft7, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp)); asm volatile ("fsw ft7, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
local_b_tmp += BN_adjusted * row_stride_b * 2; local_b_tmp += BN_adjusted * row_stride_b * 2;
} }
asm volatile ("global_dmem_load_finish_%=:" :: );
} }
template <typename T> template <typename T>
@@ -464,8 +466,8 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
#endif #endif
} }
#else #else
global_dmem_load<T>(dim_n, dim_k, block_k * BK, A, B, local_a, local_b, global_dmem_load<T>(dim_m, dim_n, dim_k, block_k * BK, A, B, local_a,
tid_in_threadblock, block_n, block_m); local_b, tid_in_threadblock, block_n, block_m);
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
#endif #endif

View File

@@ -158,6 +158,8 @@ template <typename T>
inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k, inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k,
const int warp_row, const int wm_iter, const int warp_row, const int wm_iter,
const int thread_in_warp) { const int thread_in_warp) {
asm volatile ("vx_wmma_load_a_start_%=:" :: );
const int tid = thread_in_warp; const int tid = thread_in_warp;
const int tg = tid / 4; const int tg = tid / 4;
@@ -174,17 +176,13 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k,
// by a factor of two. // by a factor of two.
constexpr int packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1); constexpr int packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
constexpr int BK_adjusted = BK / packed_factor; constexpr int BK_adjusted = BK / packed_factor;
constexpr int BM_adjusted = BM / packed_factor;
const int local_k_adjusted = local_k / packed_factor; const int local_k_adjusted = local_k / packed_factor;
constexpr int smem_A_rows = BM;
constexpr int smem_A_cols = BK_adjusted;
constexpr int smem_AS_rows = BK_adjusted;
constexpr int smem_AS_cols = BM;
// constexpr int smem_AS_rows = BK;
// constexpr int smem_AS_cols = BM_adjusted;
if constexpr (TRANSPOSE_AT_CONSUME) { if constexpr (TRANSPOSE_AT_CONSUME) {
// A is stored K-major in smem
constexpr int smem_A_rows = BM;
constexpr int smem_A_cols = BK_adjusted;
// int A_offset = (WM * warp_row + TCM * wm_iter + row) * smem_A_cols; // int A_offset = (WM * warp_row + TCM * wm_iter + row) * smem_A_cols;
// @perf: bank conflicts // @perf: bank conflicts
@@ -205,15 +203,16 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k,
asm volatile("flw f6, %0(%1)" ::"i"(6 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f6, %0(%1)" ::"i"(6 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f7, %0(%1)" ::"i"(7 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f7, %0(%1)" ::"i"(7 * sizeof(float)), "r"(smem_addr));
} else { } else {
// read smem A tile as-is; bank-conflict-free AS load // A is stored M-major in smem
// smem A tile is stored column-major constexpr int smem_AS_rows = BK_adjusted;
// f8-f15 stores a single row of A constexpr int smem_AS_cols = BM;
const volatile uint8_t *smem_addr; const volatile uint8_t *smem_addr;
smem_addr = reinterpret_cast<const volatile uint8_t *>( smem_addr = reinterpret_cast<const volatile uint8_t *>(
&reinterpret_cast<const volatile float *>( &reinterpret_cast<const volatile float *>(
smem_A)[((local_k_adjusted + 0) * smem_AS_cols) + smem_A)[((local_k_adjusted + 0) * smem_AS_cols) +
(WM * warp_row + TCM * wm_iter) + row]); (WM * warp_row + TCM * wm_iter) + row]);
// step to the next row // f8-f15 stores a single row of A
// threads read from different columns; no bank conflicts // threads read from different columns; no bank conflicts
asm volatile("flw f0, %0(%1)" :: "i"(smem_AS_cols * 0 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f0, %0(%1)" :: "i"(smem_AS_cols * 0 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f1, %0(%1)" :: "i"(smem_AS_cols * 1 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f1, %0(%1)" :: "i"(smem_AS_cols * 1 * sizeof(float)), "r"(smem_addr));
@@ -224,6 +223,8 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k,
asm volatile("flw f6, %0(%1)" :: "i"(smem_AS_cols * 6 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f6, %0(%1)" :: "i"(smem_AS_cols * 6 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f7, %0(%1)" :: "i"(smem_AS_cols * 7 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f7, %0(%1)" :: "i"(smem_AS_cols * 7 * sizeof(float)), "r"(smem_addr));
} }
asm volatile ("vx_wmma_load_a_finish_%=:" :: );
} }
// `local_k` is assumed to be multiple of TCK // `local_k` is assumed to be multiple of TCK
@@ -231,6 +232,8 @@ template <typename T>
inline void vx_wmma_load_b(const volatile T *smem_B, const int local_k, inline void vx_wmma_load_b(const volatile T *smem_B, const int local_k,
const int warp_col, const int wn_iter, const int warp_col, const int wn_iter,
const int thread_in_warp) { const int thread_in_warp) {
asm volatile ("vx_wmma_load_b_start_%=:" :: );
const int tid = thread_in_warp; const int tid = thread_in_warp;
const int tg = tid / 4; const int tg = tid / 4;
@@ -244,18 +247,16 @@ inline void vx_wmma_load_b(const volatile T *smem_B, const int local_k,
constexpr int BN_adjusted = BN / packed_factor; constexpr int BN_adjusted = BN / packed_factor;
const int local_k_adjusted = local_k / packed_factor; const int local_k_adjusted = local_k / packed_factor;
// constexpr int smem_B_rows = BK; // B is stored N-major in smem
// constexpr int smem_B_cols = BN_adjusted;
constexpr int smem_B_rows = BK_adjusted; constexpr int smem_B_rows = BK_adjusted;
constexpr int smem_B_cols = BN; constexpr int smem_B_cols = BN;
// f8-f15 stores a single column of B
const volatile uint8_t *smem_addr; const volatile uint8_t *smem_addr;
smem_addr = reinterpret_cast<const volatile uint8_t *>( smem_addr = reinterpret_cast<const volatile uint8_t *>(
&reinterpret_cast<const volatile float *>( &reinterpret_cast<const volatile float *>(
smem_B)[((local_k_adjusted + 0) * smem_B_cols) + smem_B)[((local_k_adjusted + 0) * smem_B_cols) +
(WN * warp_col + TCN * wn_iter) + col]); (WN * warp_col + TCN * wn_iter) + col]);
// step to the next row // f8-f15 stores a single column of B
// threads read from different columns; no bank conflicts // threads read from different columns; no bank conflicts
asm volatile("flw f8, %0(%1)" :: "i"(smem_B_cols * 0 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f8, %0(%1)" :: "i"(smem_B_cols * 0 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f9, %0(%1)" :: "i"(smem_B_cols * 1 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f9, %0(%1)" :: "i"(smem_B_cols * 1 * sizeof(float)), "r"(smem_addr));
@@ -265,6 +266,8 @@ inline void vx_wmma_load_b(const volatile T *smem_B, const int local_k,
asm volatile("flw f13, %0(%1)" :: "i"(smem_B_cols * 5 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f13, %0(%1)" :: "i"(smem_B_cols * 5 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f14, %0(%1)" :: "i"(smem_B_cols * 6 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f14, %0(%1)" :: "i"(smem_B_cols * 6 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f15, %0(%1)" :: "i"(smem_B_cols * 7 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f15, %0(%1)" :: "i"(smem_B_cols * 7 * sizeof(float)), "r"(smem_addr));
asm volatile ("vx_wmma_load_b_finish_%=:" :: );
} }
inline void initialize_C(const int dest_reg) { inline void initialize_C(const int dest_reg) {
@@ -295,6 +298,8 @@ inline void write_results(const int thread_in_warp, const int warp_col,
const int wm_iter, const int dim_n, const int wm_iter, const int dim_n,
float *C, const int threadblock_id_x, float *C, const int threadblock_id_x,
const int threadblock_id_y) { const int threadblock_id_y) {
asm volatile ("write_results_start_%=:" :: );
int tid = thread_in_warp; int tid = thread_in_warp;
// these are [0, TCM/TCN) // these are [0, TCM/TCN)
@@ -342,6 +347,8 @@ inline void write_results(const int thread_in_warp, const int warp_col,
asm volatile ("fsw f30, %0(%1)" :: "i"(4 * sizeof(float)), "r"(gmem_addr_tmp)); asm volatile ("fsw f30, %0(%1)" :: "i"(4 * sizeof(float)), "r"(gmem_addr_tmp));
asm volatile ("fsw f31, %0(%1)" :: "i"(5 * sizeof(float)), "r"(gmem_addr_tmp)); asm volatile ("fsw f31, %0(%1)" :: "i"(5 * sizeof(float)), "r"(gmem_addr_tmp));
} }
asm volatile ("write_results_finish_%=:" :: );
} }
inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count) { inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count) {