sgemm_tcore: Fix addr gen for GMEM->SMEM for M-major A
This fixes correctness for TRANSPOSE_AT_PRODUCE/COLUMN=0/0, provided the matrices are already stored in the correct layout in GMEM.
This commit is contained in:
@@ -40,13 +40,16 @@
|
||||
// using float_type = float;
|
||||
using float_type = float16_t;
|
||||
|
||||
// TODO: reduce args by passing leading A/B dimensions
|
||||
template <typename T>
|
||||
inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
|
||||
inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const uint32_t dim_k,
|
||||
const uint32_t k, const T *A, const T *B,
|
||||
volatile T *local_a, volatile T *local_b,
|
||||
const uint32_t tid_in_threadblock,
|
||||
const uint32_t threadblock_id_x,
|
||||
const uint32_t threadblock_id_y) {
|
||||
asm volatile ("global_dmem_load_start_%=:" :: );
|
||||
|
||||
// In fp16 mode, bit-pack two fp16 elements into each fp32 element, and do
|
||||
// data movement at the fp32 granularity. Assuming that the matrix is stored
|
||||
// row-major in GMEM, the packed fp16 pairs belong to the same row,
|
||||
@@ -79,28 +82,28 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
|
||||
|
||||
// move A
|
||||
if constexpr (!TRANSPOSE_AT_PRODUCE) {
|
||||
// No transpose at GMEM->SMEM movement
|
||||
// FIXME: !TRANSPOSE_AS code is old
|
||||
|
||||
const uint32_t global_a_row = BM * threadblock_id_y + local_a_row;
|
||||
// A is stored M-major in GMEM;
|
||||
// no transpose at GMEM->SMEM movement
|
||||
const uint32_t block_m = threadblock_id_y;
|
||||
const uint32_t global_a_row = k_adjusted + local_as_row;
|
||||
const uint32_t global_a_col = BM * block_m + local_as_col;
|
||||
// number of rows a full TB can read at a time
|
||||
// this is equivalent to threadblock_dim_y (assuming threadblock_dim_x ==
|
||||
// BK)
|
||||
constexpr uint32_t row_stride_a = threads_in_threadblock / BK_adjusted;
|
||||
constexpr uint32_t row_stride_as = threads_in_threadblock / BM;
|
||||
const float *global_a = reinterpret_cast<const float *>(A) +
|
||||
dim_k_adjusted * global_a_row +
|
||||
(k_adjusted + local_a_col);
|
||||
dim_m * global_a_row + global_a_col;
|
||||
volatile float *local_a_tmp = reinterpret_cast<volatile float *>(local_a) +
|
||||
BK_adjusted * local_a_row + local_a_col;
|
||||
BM * local_as_row + local_as_col;
|
||||
|
||||
#pragma GCC unroll 1
|
||||
for (uint32_t local_row_offset = 0; local_row_offset < BM;
|
||||
local_row_offset += row_stride_a) {
|
||||
for (uint32_t local_row_offset = 0; local_row_offset < BK_adjusted;
|
||||
local_row_offset += row_stride_as) {
|
||||
// TODO: the code GCC generates for below seems fine atm, but unroll to
|
||||
// assembly to be absolutely sure
|
||||
*local_a_tmp = *global_a;
|
||||
|
||||
// move to the next "row-chunk", when threadblock is smaller than BM*BK
|
||||
global_a += dim_k_adjusted * row_stride_a;
|
||||
local_a_tmp += BK_adjusted * row_stride_a;
|
||||
global_a += dim_m * row_stride_as;
|
||||
local_a_tmp += BM * row_stride_as;
|
||||
}
|
||||
} else {
|
||||
if constexpr (!GMEM_COALESCED_A) {
|
||||
@@ -126,9 +129,6 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
|
||||
// @perf: bank conflicts here
|
||||
// const uint32_t global_a_offset =
|
||||
// dim_k_adjusted * (global_a_row) + (k + local_as_row + local_row_offset);
|
||||
// FIXME experimenting with global coalescing
|
||||
// const uint32_t global_a_offset =
|
||||
// dim_k_adjusted * (global_a_row + local_row_offset) + (k + local_as_col);
|
||||
// local_a[BM * (local_as_row + local_row_offset) + local_as_col] =
|
||||
// A[global_a_offset];
|
||||
|
||||
@@ -278,6 +278,8 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
|
||||
asm volatile ("fsw ft7, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
|
||||
local_b_tmp += BN_adjusted * row_stride_b * 2;
|
||||
}
|
||||
|
||||
asm volatile ("global_dmem_load_finish_%=:" :: );
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -464,8 +466,8 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
#endif
|
||||
}
|
||||
#else
|
||||
global_dmem_load<T>(dim_n, dim_k, block_k * BK, A, B, local_a, local_b,
|
||||
tid_in_threadblock, block_n, block_m);
|
||||
global_dmem_load<T>(dim_m, dim_n, dim_k, block_k * BK, A, B, local_a,
|
||||
local_b, tid_in_threadblock, block_n, block_m);
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user