sgemm_impl: Remove GMEM_COALESCED_A option

Uncoalesced GMEM accesses is verified to yield slow performance and the
relevant code is not used anymore; remove the cruft
This commit is contained in:
Hansung Kim
2024-08-18 22:21:17 -07:00
parent 04643fa64d
commit 46b5047775

View File

@@ -71,14 +71,6 @@ using float_type = float16_t;
// set both to 0.
#define TRANSPOSE_AT_PRODUCE 1
#define TRANSPOSE_AT_CONSUME 0
// GMEM_COALESCED: When TRANSPOSE_AT_PRODUCE == 1 (i.e. transpose at
// GMEM->SMEM), determines whether we do bank-conflict-free accesses for
// 1: GMEM loads of A matrix, or
// 0: SMEM stores of A matrix.
//
// Usually, GMEM_COALESCED==1 yields better performance since the memory
// behavior of GMEM is more sensitive to bank conflicts.
#define GMEM_COALESCED_A 1
#define GEMMINI_DMA 0
#if SMEM_SIZE == 0x4000
@@ -403,8 +395,7 @@ template <typename T,
MemLayout gmem_layout, // memory layout of the GMEM tile
MemLayout smem_layout, // memory layout of the GMEM tile
uint32_t tile_dim_mn, // row dimension of the SMEM tile
uint32_t tile_dim_k, // column dimension of the SMEM tile
bool gmem_contiguous = true
uint32_t tile_dim_k // column dimension of the SMEM tile
>
__attribute__((always_inline)) inline void
global_dmem_load_new(const uint32_t dim_col, const uint32_t mn_index,
@@ -450,9 +441,6 @@ global_dmem_load_new(const uint32_t dim_col, const uint32_t mn_index,
// FIXME: don't hardcode this here
constexpr uint32_t threads_per_threadblock = (BM * BN) / ELEM_PER_THREAD;
static_assert(gmem_contiguous == true,
"currently only supports contiguous accesses in GMEM");
const uint32_t global_row_mn_major = k_ + local_row_gmem;
const uint32_t global_col_mn_major = smem_dim_col * mn_index + local_col_gmem;
const uint32_t global_row_k_major = gmem_dim_row * mn_index + local_row_gmem;
@@ -505,12 +493,9 @@ global_dmem_load_new(const uint32_t dim_col, const uint32_t mn_index,
asm volatile("flw ft7, (%0)" ::"r"(global));
global += dim_col_ * row_stride;
// do we need to do transposed write?
// need to branch because address offset constant in the inline assembly
// cannot be larger than a certain limit
if constexpr (!transposed_write) {
static_assert(gmem_layout == MemLayout::MN_major);
// if not, do the same along-the-column accesses for registers as we did
// for gmem
asm volatile("fsw ft0, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
sizeof(float)),
"r"(local));
@@ -540,11 +525,11 @@ global_dmem_load_new(const uint32_t dim_col, const uint32_t mn_index,
"r"(local));
local += smem_dim_col * row_stride * 2;
} else {
// currently, tensor core hardware only supports MN-major SMEM tile
// layout for correct results
static_assert(gmem_layout == MemLayout::K_major);
static_assert(smem_layout == MemLayout::MN_major);
// if yes, write the registers along the row, doing a transpose
// @perf: this will incur bank conflicts in smem
asm volatile("fsw ft0, %0(%1)" ::"i"(row_stride * 0 * sizeof(float)),
"r"(local));
asm volatile("fsw ft1, %0(%1)" ::"i"(row_stride * 1 * sizeof(float)),
@@ -568,121 +553,6 @@ global_dmem_load_new(const uint32_t dim_col, const uint32_t mn_index,
asm volatile("global_dmem_load_finish_new_%=:" ::);
}
// TODO: reduce args by passing leading A/B dimensions
template <typename T>
__attribute__((always_inline))
inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const uint32_t dim_k,
const uint32_t k, const T *A, const T *B,
volatile T *local_a, volatile T *local_b,
const uint32_t tid_in_threadblock,
const uint32_t threadblock_id_x,
const uint32_t threadblock_id_y) {
asm volatile ("global_dmem_load_start_%=:" :: );
// In fp16 mode, bit-pack two fp16 elements into each fp32 element, and do
// data movement at the fp32 granularity. Assuming that the matrix is stored
// row-major in GMEM, the packed fp16 pairs belong to the same row,
// neighboring columns; therefore, it essentially becomes equivalent to
// moving a fp32 matrix whose column dimensions (dim_k/BK/k) are compressed
// by a factor of two.
constexpr uint32_t packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
constexpr uint32_t BK_adjusted = BK / packed_factor;
const uint32_t dim_k_adjusted = dim_k / packed_factor;
const uint32_t k_adjusted = k / packed_factor;
const uint32_t local_a_row = tid_in_threadblock / BK_adjusted;
const uint32_t local_a_col = tid_in_threadblock % BK_adjusted;
const uint32_t local_as_row = tid_in_threadblock / BM;
const uint32_t local_as_col = tid_in_threadblock % BM;
const uint32_t local_b_row = tid_in_threadblock / BN;
const uint32_t local_b_col = tid_in_threadblock % BN;
// FIXME: need fix for fp16?
constexpr uint32_t threads_per_threadblock = (BM * BN) / ELEM_PER_THREAD;
// Data move from GMEM to SMEM
//
// Make sure global offset values for A and B are contiguous between
// neighboring threads to ensure GMEM coalescing.
//
// TODO: Sharedmem swizzling is important here
// move A
if constexpr (!TRANSPOSE_AT_PRODUCE) {
global_dmem_load_new<T, MemLayout::MN_major, MemLayout::MN_major, BM, BK>(
dim_m, threadblock_id_y, k, A, local_a, tid_in_threadblock);
} else {
if constexpr (!GMEM_COALESCED_A) {
// !GMEM_COALESCED_A: threads do uncoalesced read from neighboring row in
// GMEM, writes to neighboring cols in SMEM
constexpr uint32_t row_stride_as = threads_per_threadblock / BM;
const uint32_t global_a_row = BM * threadblock_id_y + local_as_col;
const float *global_a =
reinterpret_cast<float *>(A) + dim_k_adjusted * global_a_row + (k_adjusted + local_as_row);
volatile float *local_a_tmp =
reinterpret_cast<float *>(local_a) + BM * local_as_row + local_as_col;
static_assert(
row_stride_as * 8 <= BK_adjusted,
"manual loop unrolling condition not met; consider increasing BK");
static_assert(
(BK_adjusted % (row_stride_as * 8)) == 0,
"manual loop unrolling condition not met; BK should be power-of-two");
#pragma GCC unroll 1
for (uint32_t local_row_offset = 0; local_row_offset < BK_adjusted;
local_row_offset += row_stride_as * 8) {
// const uint32_t global_a_offset =
// dim_k_adjusted * (global_a_row) + (k + local_as_row + local_row_offset);
// local_a[BM * (local_as_row + local_row_offset) + local_as_col] =
// A[global_a_offset];
// @perf: bank conflicts
asm volatile ("flw ft0, (%0)" :: "r"(global_a));
global_a += row_stride_as;
asm volatile ("flw ft1, (%0)" :: "r"(global_a));
global_a += row_stride_as;
asm volatile ("flw ft2, (%0)" :: "r"(global_a));
global_a += row_stride_as;
asm volatile ("flw ft3, (%0)" :: "r"(global_a));
global_a += row_stride_as;
asm volatile ("flw ft4, (%0)" :: "r"(global_a));
global_a += row_stride_as;
asm volatile ("flw ft5, (%0)" :: "r"(global_a));
global_a += row_stride_as;
asm volatile ("flw ft6, (%0)" :: "r"(global_a));
global_a += row_stride_as;
asm volatile ("flw ft7, (%0)" :: "r"(global_a));
global_a += row_stride_as;
// NOTE: stride is fixed to word size , i.e. sizeof(float) = 4,
// regardless of fp16 or fp32. Since Vortex core does not support fp16,
// load things at word granularity and reinterpret bits inside the
// tensor core.
asm volatile ("fsw ft0, %0(%1)" :: "i"(BM * row_stride_as * 0 * sizeof(float)), "r"(local_a_tmp));
asm volatile ("fsw ft1, %0(%1)" :: "i"(BM * row_stride_as * 1 * sizeof(float)), "r"(local_a_tmp));
asm volatile ("fsw ft2, %0(%1)" :: "i"(BM * row_stride_as * 2 * sizeof(float)), "r"(local_a_tmp));
asm volatile ("fsw ft3, %0(%1)" :: "i"(BM * row_stride_as * 3 * sizeof(float)), "r"(local_a_tmp));
asm volatile ("fsw ft4, %0(%1)" :: "i"(BM * row_stride_as * 4 * sizeof(float)), "r"(local_a_tmp));
asm volatile ("fsw ft5, %0(%1)" :: "i"(BM * row_stride_as * 5 * sizeof(float)), "r"(local_a_tmp));
asm volatile ("fsw ft6, %0(%1)" :: "i"(BM * row_stride_as * 6 * sizeof(float)), "r"(local_a_tmp));
asm volatile ("fsw ft7, %0(%1)" :: "i"(BM * row_stride_as * 7 * sizeof(float)), "r"(local_a_tmp));
local_a_tmp += BM * row_stride_as * 8;
}
} else {
global_dmem_load_new<T, MemLayout::K_major, MemLayout::MN_major, BM,
BK>(dim_k, threadblock_id_y, k, A, local_a,
tid_in_threadblock);
}
} // end move A
// move B
global_dmem_load_new<T, MemLayout::MN_major, MemLayout::MN_major, BN, BK>(
dim_n, threadblock_id_x, k, B, local_b, tid_in_threadblock);
asm volatile ("global_dmem_load_finish_%=:" :: );
}
// Do a single tile*tile matrix multiplication using the matrix data stored in
// SMEM. Useful in fused kernels where GEMMs are done at a per-tile scope.
template <typename T,
@@ -933,8 +803,21 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
#endif
}
#else
global_dmem_load<T>(dim_m, dim_n, dim_k, block_k * BK, A, B, local_a,
local_b, tid_in_threadblock, block_n, block_m);
// move A
if constexpr (!TRANSPOSE_AT_PRODUCE) {
global_dmem_load_new<T, MemLayout::MN_major, MemLayout::MN_major, BM,
BK>(dim_m, block_m, block_k * BK, A, local_a,
tid_in_threadblock);
} else {
global_dmem_load_new<T, MemLayout::K_major, MemLayout::MN_major, BM,
BK>(dim_k, block_m, block_k * BK, A, local_a,
tid_in_threadblock);
}
// move B
global_dmem_load_new<T, MemLayout::MN_major, MemLayout::MN_major, BN,
BK>(dim_n, block_n, block_k * BK, B, local_b,
tid_in_threadblock);
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);