sgemm_impl: Specify leading dimension to wmma load

This is necessary for when loading a subtile from a full tile in SMEM
into RF, but that subtile is split by non-major dimension.
This commit is contained in:
Hansung Kim
2024-09-02 00:14:35 -07:00
parent 602fe4a400
commit bdd955836d

View File

@@ -200,7 +200,14 @@ inline void vx_wmma(const int dest_reg) {
}
// `local_k` is assumed to be multiple of TCK
template <typename T, MemLayout layout>
template <typename T, MemLayout layout,
uint32_t leading_dim // stride in sizeof(T) between consecutive
// "rows" in the memory. What a row is
// corresponds to whatever `layout` specifies.
// E.g., if layout == MN_major, leading_dim
// becomes the stride between the 1st M-dim
// vector and the 2nd M-dim vector.
>
inline void wmma_load_a(volatile const T *smem_A, const int local_k,
const int warp_row, const int wm_iter,
const int thread_in_warp) {
@@ -221,12 +228,10 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
// moving a fp32 matrix whose column dimensions (dim_k/BK/k) are compressed
// by a factor of two.
constexpr int packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
constexpr int BK_adjusted = BK / packed_factor;
const int local_k_adjusted = local_k / packed_factor;
if constexpr (layout == MemLayout::K_major) {
constexpr int smem_A_rows = BM;
constexpr int smem_A_cols = BK_adjusted;
constexpr int smem_A_cols = leading_dim;
// int A_offset = (WM * warp_row + TCM * wm_iter + row) * smem_A_cols;
@@ -247,8 +252,7 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
asm volatile("flw f6, %0(%1)" ::"i"(6 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f7, %0(%1)" ::"i"(7 * sizeof(float)), "r"(smem_addr));
} else if constexpr (layout == MemLayout::MN_major) {
constexpr int smem_AS_rows = BK_adjusted;
constexpr int smem_AS_cols = BM;
constexpr int smem_AS_cols = leading_dim;
const volatile uint8_t *smem_addr;
smem_addr = reinterpret_cast<const volatile uint8_t *>(
@@ -274,11 +278,35 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
asm volatile ("wmma_load_a_finish_%=:" :: );
}
// Convenience wrapper for wmma_load_a if tile layout is packed, i.e.
// leading_dim == col.
template <typename T, MemLayout layout, uint32_t tile_dim_m,
uint32_t tile_dim_n, uint32_t tile_dim_k>
inline void wmma_load_a(volatile const T *smem_A, const int local_k,
const int warp_row, const int wm_iter,
const int thread_in_warp) {
// In fp16 mode, bit-pack two fp16 elements into each fp32 element, and do
// data movement at the fp32 granularity. Assuming that the matrix is stored
// row-major in GMEM, the packed fp16 pairs belong to the same row,
// neighboring columns; therefore, it essentially becomes equivalent to
// moving a fp32 matrix whose column dimensions (dim_k/BK/k) are compressed
// by a factor of two.
constexpr int packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
constexpr int tile_dim_k_adjusted = tile_dim_k / packed_factor;
constexpr int leading_dim = (layout == MemLayout::K_major)
? tile_dim_k_adjusted
: tile_dim_m;
wmma_load_a<T, layout, leading_dim>(smem_A, local_k, warp_row, wm_iter,
thread_in_warp);
}
// `local_k` is assumed to be multiple of TCK
template <typename T, MemLayout layout>
template <typename T, MemLayout layout, uint32_t tile_dim_m,
uint32_t tile_dim_n, uint32_t tile_dim_k>
inline void wmma_load_b(const volatile T *smem_B, const int local_k,
const int warp_col, const int wn_iter,
const int thread_in_warp) {
const int warp_col, const int wn_iter,
const int thread_in_warp) {
asm volatile ("wmma_load_b_start_%=:" :: );
static_assert(layout == MemLayout::MN_major,
@@ -293,12 +321,12 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k,
// see comment in wmma_load_a
constexpr int packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
constexpr int BK_adjusted = BK / packed_factor;
constexpr int tile_dim_k_adjusted = tile_dim_k / packed_factor;
const int local_k_adjusted = local_k / packed_factor;
// B is stored N-major in smem
constexpr int smem_B_rows = BK_adjusted;
constexpr int smem_B_cols = BN;
constexpr int smem_B_rows = tile_dim_k_adjusted;
constexpr int smem_B_cols = tile_dim_n;
const volatile uint8_t *smem_addr;
smem_addr = reinterpret_cast<const volatile uint8_t *>(
@@ -445,6 +473,11 @@ inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count)
// Move a single matrix tile from global memory (GMEM) to shared memory (SMEM).
// `dim_major`: major dimension of the matrix in GMEM, e.g. if K-major, K; or
// MN-major, M/N.
//
// Note that there's not a single way to specify a layout of the matrix.
// Identifying a matrix to be K-major and specifying the mn_index of a tile,
// is equivalent to identifying it as MN-major and specifying the k_index
// (provided `dim_major` is set accordingly).
template <typename T,
MemLayout gmem_layout, // memory layout of the GMEM tile
MemLayout smem_layout, // memory layout of the GMEM tile
@@ -606,11 +639,13 @@ load_tile_to_smem(const uint32_t dim_major, const uint32_t mn_index,
// Do a single tile*tile matrix multiplication using the matrix data stored in
// SMEM. Useful in fused kernels where GEMMs are done at a per-tile scope.
template <typename T,
MemLayout layout_a, // memory layout of `local_a`
MemLayout layout_b, // memory layout of `local_b`
uint32_t tile_dim_m,
uint32_t tile_dim_n,
uint32_t tile_dim_k,
MemLayout layout_a, // memory layout of `local_a`
MemLayout layout_b, // memory layout of `local_b`
uint32_t tile_dim_m, uint32_t tile_dim_n, uint32_t tile_dim_k,
uint32_t leading_dim_a, // if zero, assumes packed layout, i.e. row
// stride == col.
uint32_t leading_dim_b, // if zero, assumes packed layout, i.e. row
// stride == col.
bool load_accum = false, // if true, load the accumulation registers
// with `local_c`. used for the (C + A*B)
// operation
@@ -640,7 +675,6 @@ __attribute__((always_inline)) inline void thread_block_gemm_single_tile(
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
#pragma GCC unroll
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
// FIXME: template parameter-ize BM
wmma_load_accum(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter,
tile_dim_n, local_c);
}
@@ -654,13 +688,21 @@ __attribute__((always_inline)) inline void thread_block_gemm_single_tile(
#pragma GCC unroll 2
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
// SMEM -> RF
wmma_load_b<T, layout_b>(local_b, local_k, warp_col, wn_iter,
tid_in_warp);
static_assert(leading_dim_b == 0,
"leading_dim for wmma_load_b is not implemented yet");
wmma_load_b<T, layout_b, tile_dim_m, tile_dim_n,
tile_dim_k /*leading_dim_b is TODO */>(
local_b, local_k, warp_col, wn_iter, tid_in_warp);
#pragma GCC unroll 2
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
// SMEM -> RF
wmma_load_a<T, layout_a>(local_a, local_k, warp_row, wm_iter,
tid_in_warp);
if constexpr (leading_dim_a == 0) {
wmma_load_a<T, layout_a, tile_dim_m, tile_dim_n, tile_dim_k>(
local_a, local_k, warp_row, wm_iter, tid_in_warp);
} else {
wmma_load_a<T, layout_a, leading_dim_a>(local_a, local_k, warp_row,
wm_iter, tid_in_warp);
}
// perform mma
vx_wmma(wm_iter);
}
@@ -925,7 +967,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
constexpr MemLayout layout_a =
TRANSPOSE_AT_CONSUME ? MemLayout::K_major : MemLayout::MN_major;
thread_block_gemm_single_tile<T, layout_a, MemLayout::MN_major,
BM, BN, BK,
BM, BN, BK, 0, 0,
/*load_accum=*/false,
/*write_to_mem=*/false>(
local_a_consume, local_b_consume,