sgemm_impl: Specify leading dimension to wmma load
This is necessary for when loading a subtile from a full tile in SMEM into RF, but that subtile is split by non-major dimension.
This commit is contained in:
@@ -200,7 +200,14 @@ inline void vx_wmma(const int dest_reg) {
|
||||
}
|
||||
|
||||
// `local_k` is assumed to be multiple of TCK
|
||||
template <typename T, MemLayout layout>
|
||||
template <typename T, MemLayout layout,
|
||||
uint32_t leading_dim // stride in sizeof(T) between consecutive
|
||||
// "rows" in the memory. What a row is
|
||||
// corresponds to whatever `layout` specifies.
|
||||
// E.g., if layout == MN_major, leading_dim
|
||||
// becomes the stride between the 1st M-dim
|
||||
// vector and the 2nd M-dim vector.
|
||||
>
|
||||
inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
||||
const int warp_row, const int wm_iter,
|
||||
const int thread_in_warp) {
|
||||
@@ -221,12 +228,10 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
||||
// moving a fp32 matrix whose column dimensions (dim_k/BK/k) are compressed
|
||||
// by a factor of two.
|
||||
constexpr int packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
|
||||
constexpr int BK_adjusted = BK / packed_factor;
|
||||
const int local_k_adjusted = local_k / packed_factor;
|
||||
|
||||
if constexpr (layout == MemLayout::K_major) {
|
||||
constexpr int smem_A_rows = BM;
|
||||
constexpr int smem_A_cols = BK_adjusted;
|
||||
constexpr int smem_A_cols = leading_dim;
|
||||
|
||||
// int A_offset = (WM * warp_row + TCM * wm_iter + row) * smem_A_cols;
|
||||
|
||||
@@ -247,8 +252,7 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
||||
asm volatile("flw f6, %0(%1)" ::"i"(6 * sizeof(float)), "r"(smem_addr));
|
||||
asm volatile("flw f7, %0(%1)" ::"i"(7 * sizeof(float)), "r"(smem_addr));
|
||||
} else if constexpr (layout == MemLayout::MN_major) {
|
||||
constexpr int smem_AS_rows = BK_adjusted;
|
||||
constexpr int smem_AS_cols = BM;
|
||||
constexpr int smem_AS_cols = leading_dim;
|
||||
|
||||
const volatile uint8_t *smem_addr;
|
||||
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
||||
@@ -274,11 +278,35 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
||||
asm volatile ("wmma_load_a_finish_%=:" :: );
|
||||
}
|
||||
|
||||
// Convenience wrapper for wmma_load_a if tile layout is packed, i.e.
|
||||
// leading_dim == col.
|
||||
template <typename T, MemLayout layout, uint32_t tile_dim_m,
|
||||
uint32_t tile_dim_n, uint32_t tile_dim_k>
|
||||
inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
||||
const int warp_row, const int wm_iter,
|
||||
const int thread_in_warp) {
|
||||
// In fp16 mode, bit-pack two fp16 elements into each fp32 element, and do
|
||||
// data movement at the fp32 granularity. Assuming that the matrix is stored
|
||||
// row-major in GMEM, the packed fp16 pairs belong to the same row,
|
||||
// neighboring columns; therefore, it essentially becomes equivalent to
|
||||
// moving a fp32 matrix whose column dimensions (dim_k/BK/k) are compressed
|
||||
// by a factor of two.
|
||||
constexpr int packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
|
||||
constexpr int tile_dim_k_adjusted = tile_dim_k / packed_factor;
|
||||
constexpr int leading_dim = (layout == MemLayout::K_major)
|
||||
? tile_dim_k_adjusted
|
||||
: tile_dim_m;
|
||||
|
||||
wmma_load_a<T, layout, leading_dim>(smem_A, local_k, warp_row, wm_iter,
|
||||
thread_in_warp);
|
||||
}
|
||||
|
||||
// `local_k` is assumed to be multiple of TCK
|
||||
template <typename T, MemLayout layout>
|
||||
template <typename T, MemLayout layout, uint32_t tile_dim_m,
|
||||
uint32_t tile_dim_n, uint32_t tile_dim_k>
|
||||
inline void wmma_load_b(const volatile T *smem_B, const int local_k,
|
||||
const int warp_col, const int wn_iter,
|
||||
const int thread_in_warp) {
|
||||
const int warp_col, const int wn_iter,
|
||||
const int thread_in_warp) {
|
||||
asm volatile ("wmma_load_b_start_%=:" :: );
|
||||
|
||||
static_assert(layout == MemLayout::MN_major,
|
||||
@@ -293,12 +321,12 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k,
|
||||
|
||||
// see comment in wmma_load_a
|
||||
constexpr int packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
|
||||
constexpr int BK_adjusted = BK / packed_factor;
|
||||
constexpr int tile_dim_k_adjusted = tile_dim_k / packed_factor;
|
||||
const int local_k_adjusted = local_k / packed_factor;
|
||||
|
||||
// B is stored N-major in smem
|
||||
constexpr int smem_B_rows = BK_adjusted;
|
||||
constexpr int smem_B_cols = BN;
|
||||
constexpr int smem_B_rows = tile_dim_k_adjusted;
|
||||
constexpr int smem_B_cols = tile_dim_n;
|
||||
|
||||
const volatile uint8_t *smem_addr;
|
||||
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
||||
@@ -445,6 +473,11 @@ inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count)
|
||||
// Move a single matrix tile from global memory (GMEM) to shared memory (SMEM).
|
||||
// `dim_major`: major dimension of the matrix in GMEM, e.g. if K-major, K; or
|
||||
// MN-major, M/N.
|
||||
//
|
||||
// Note that there's not a single way to specify a layout of the matrix.
|
||||
// Identifying a matrix to be K-major and specifying the mn_index of a tile,
|
||||
// is equivalent to identifying it as MN-major and specifying the k_index
|
||||
// (provided `dim_major` is set accordingly).
|
||||
template <typename T,
|
||||
MemLayout gmem_layout, // memory layout of the GMEM tile
|
||||
MemLayout smem_layout, // memory layout of the GMEM tile
|
||||
@@ -606,11 +639,13 @@ load_tile_to_smem(const uint32_t dim_major, const uint32_t mn_index,
|
||||
// Do a single tile*tile matrix multiplication using the matrix data stored in
|
||||
// SMEM. Useful in fused kernels where GEMMs are done at a per-tile scope.
|
||||
template <typename T,
|
||||
MemLayout layout_a, // memory layout of `local_a`
|
||||
MemLayout layout_b, // memory layout of `local_b`
|
||||
uint32_t tile_dim_m,
|
||||
uint32_t tile_dim_n,
|
||||
uint32_t tile_dim_k,
|
||||
MemLayout layout_a, // memory layout of `local_a`
|
||||
MemLayout layout_b, // memory layout of `local_b`
|
||||
uint32_t tile_dim_m, uint32_t tile_dim_n, uint32_t tile_dim_k,
|
||||
uint32_t leading_dim_a, // if zero, assumes packed layout, i.e. row
|
||||
// stride == col.
|
||||
uint32_t leading_dim_b, // if zero, assumes packed layout, i.e. row
|
||||
// stride == col.
|
||||
bool load_accum = false, // if true, load the accumulation registers
|
||||
// with `local_c`. used for the (C + A*B)
|
||||
// operation
|
||||
@@ -640,7 +675,6 @@ __attribute__((always_inline)) inline void thread_block_gemm_single_tile(
|
||||
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
||||
#pragma GCC unroll
|
||||
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
||||
// FIXME: template parameter-ize BM
|
||||
wmma_load_accum(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter,
|
||||
tile_dim_n, local_c);
|
||||
}
|
||||
@@ -654,13 +688,21 @@ __attribute__((always_inline)) inline void thread_block_gemm_single_tile(
|
||||
#pragma GCC unroll 2
|
||||
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
||||
// SMEM -> RF
|
||||
wmma_load_b<T, layout_b>(local_b, local_k, warp_col, wn_iter,
|
||||
tid_in_warp);
|
||||
static_assert(leading_dim_b == 0,
|
||||
"leading_dim for wmma_load_b is not implemented yet");
|
||||
wmma_load_b<T, layout_b, tile_dim_m, tile_dim_n,
|
||||
tile_dim_k /*leading_dim_b is TODO */>(
|
||||
local_b, local_k, warp_col, wn_iter, tid_in_warp);
|
||||
#pragma GCC unroll 2
|
||||
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
||||
// SMEM -> RF
|
||||
wmma_load_a<T, layout_a>(local_a, local_k, warp_row, wm_iter,
|
||||
tid_in_warp);
|
||||
if constexpr (leading_dim_a == 0) {
|
||||
wmma_load_a<T, layout_a, tile_dim_m, tile_dim_n, tile_dim_k>(
|
||||
local_a, local_k, warp_row, wm_iter, tid_in_warp);
|
||||
} else {
|
||||
wmma_load_a<T, layout_a, leading_dim_a>(local_a, local_k, warp_row,
|
||||
wm_iter, tid_in_warp);
|
||||
}
|
||||
// perform mma
|
||||
vx_wmma(wm_iter);
|
||||
}
|
||||
@@ -925,7 +967,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
||||
constexpr MemLayout layout_a =
|
||||
TRANSPOSE_AT_CONSUME ? MemLayout::K_major : MemLayout::MN_major;
|
||||
thread_block_gemm_single_tile<T, layout_a, MemLayout::MN_major,
|
||||
BM, BN, BK,
|
||||
BM, BN, BK, 0, 0,
|
||||
/*load_accum=*/false,
|
||||
/*write_to_mem=*/false>(
|
||||
local_a_consume, local_b_consume,
|
||||
|
||||
Reference in New Issue
Block a user