sgemm_impl: Specify leading dimension to wmma load
This is necessary for when loading a subtile from a full tile in SMEM into RF, but that subtile is split by non-major dimension.
This commit is contained in:
@@ -200,7 +200,14 @@ inline void vx_wmma(const int dest_reg) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// `local_k` is assumed to be multiple of TCK
|
// `local_k` is assumed to be multiple of TCK
|
||||||
template <typename T, MemLayout layout>
|
template <typename T, MemLayout layout,
|
||||||
|
uint32_t leading_dim // stride in sizeof(T) between consecutive
|
||||||
|
// "rows" in the memory. What a row is
|
||||||
|
// corresponds to whatever `layout` specifies.
|
||||||
|
// E.g., if layout == MN_major, leading_dim
|
||||||
|
// becomes the stride between the 1st M-dim
|
||||||
|
// vector and the 2nd M-dim vector.
|
||||||
|
>
|
||||||
inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
||||||
const int warp_row, const int wm_iter,
|
const int warp_row, const int wm_iter,
|
||||||
const int thread_in_warp) {
|
const int thread_in_warp) {
|
||||||
@@ -221,12 +228,10 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
|||||||
// moving a fp32 matrix whose column dimensions (dim_k/BK/k) are compressed
|
// moving a fp32 matrix whose column dimensions (dim_k/BK/k) are compressed
|
||||||
// by a factor of two.
|
// by a factor of two.
|
||||||
constexpr int packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
|
constexpr int packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
|
||||||
constexpr int BK_adjusted = BK / packed_factor;
|
|
||||||
const int local_k_adjusted = local_k / packed_factor;
|
const int local_k_adjusted = local_k / packed_factor;
|
||||||
|
|
||||||
if constexpr (layout == MemLayout::K_major) {
|
if constexpr (layout == MemLayout::K_major) {
|
||||||
constexpr int smem_A_rows = BM;
|
constexpr int smem_A_cols = leading_dim;
|
||||||
constexpr int smem_A_cols = BK_adjusted;
|
|
||||||
|
|
||||||
// int A_offset = (WM * warp_row + TCM * wm_iter + row) * smem_A_cols;
|
// int A_offset = (WM * warp_row + TCM * wm_iter + row) * smem_A_cols;
|
||||||
|
|
||||||
@@ -247,8 +252,7 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
|||||||
asm volatile("flw f6, %0(%1)" ::"i"(6 * sizeof(float)), "r"(smem_addr));
|
asm volatile("flw f6, %0(%1)" ::"i"(6 * sizeof(float)), "r"(smem_addr));
|
||||||
asm volatile("flw f7, %0(%1)" ::"i"(7 * sizeof(float)), "r"(smem_addr));
|
asm volatile("flw f7, %0(%1)" ::"i"(7 * sizeof(float)), "r"(smem_addr));
|
||||||
} else if constexpr (layout == MemLayout::MN_major) {
|
} else if constexpr (layout == MemLayout::MN_major) {
|
||||||
constexpr int smem_AS_rows = BK_adjusted;
|
constexpr int smem_AS_cols = leading_dim;
|
||||||
constexpr int smem_AS_cols = BM;
|
|
||||||
|
|
||||||
const volatile uint8_t *smem_addr;
|
const volatile uint8_t *smem_addr;
|
||||||
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
||||||
@@ -274,11 +278,35 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
|||||||
asm volatile ("wmma_load_a_finish_%=:" :: );
|
asm volatile ("wmma_load_a_finish_%=:" :: );
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Convenience wrapper for wmma_load_a if tile layout is packed, i.e.
|
||||||
|
// leading_dim == col.
|
||||||
|
template <typename T, MemLayout layout, uint32_t tile_dim_m,
|
||||||
|
uint32_t tile_dim_n, uint32_t tile_dim_k>
|
||||||
|
inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
||||||
|
const int warp_row, const int wm_iter,
|
||||||
|
const int thread_in_warp) {
|
||||||
|
// In fp16 mode, bit-pack two fp16 elements into each fp32 element, and do
|
||||||
|
// data movement at the fp32 granularity. Assuming that the matrix is stored
|
||||||
|
// row-major in GMEM, the packed fp16 pairs belong to the same row,
|
||||||
|
// neighboring columns; therefore, it essentially becomes equivalent to
|
||||||
|
// moving a fp32 matrix whose column dimensions (dim_k/BK/k) are compressed
|
||||||
|
// by a factor of two.
|
||||||
|
constexpr int packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
|
||||||
|
constexpr int tile_dim_k_adjusted = tile_dim_k / packed_factor;
|
||||||
|
constexpr int leading_dim = (layout == MemLayout::K_major)
|
||||||
|
? tile_dim_k_adjusted
|
||||||
|
: tile_dim_m;
|
||||||
|
|
||||||
|
wmma_load_a<T, layout, leading_dim>(smem_A, local_k, warp_row, wm_iter,
|
||||||
|
thread_in_warp);
|
||||||
|
}
|
||||||
|
|
||||||
// `local_k` is assumed to be multiple of TCK
|
// `local_k` is assumed to be multiple of TCK
|
||||||
template <typename T, MemLayout layout>
|
template <typename T, MemLayout layout, uint32_t tile_dim_m,
|
||||||
|
uint32_t tile_dim_n, uint32_t tile_dim_k>
|
||||||
inline void wmma_load_b(const volatile T *smem_B, const int local_k,
|
inline void wmma_load_b(const volatile T *smem_B, const int local_k,
|
||||||
const int warp_col, const int wn_iter,
|
const int warp_col, const int wn_iter,
|
||||||
const int thread_in_warp) {
|
const int thread_in_warp) {
|
||||||
asm volatile ("wmma_load_b_start_%=:" :: );
|
asm volatile ("wmma_load_b_start_%=:" :: );
|
||||||
|
|
||||||
static_assert(layout == MemLayout::MN_major,
|
static_assert(layout == MemLayout::MN_major,
|
||||||
@@ -293,12 +321,12 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k,
|
|||||||
|
|
||||||
// see comment in wmma_load_a
|
// see comment in wmma_load_a
|
||||||
constexpr int packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
|
constexpr int packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
|
||||||
constexpr int BK_adjusted = BK / packed_factor;
|
constexpr int tile_dim_k_adjusted = tile_dim_k / packed_factor;
|
||||||
const int local_k_adjusted = local_k / packed_factor;
|
const int local_k_adjusted = local_k / packed_factor;
|
||||||
|
|
||||||
// B is stored N-major in smem
|
// B is stored N-major in smem
|
||||||
constexpr int smem_B_rows = BK_adjusted;
|
constexpr int smem_B_rows = tile_dim_k_adjusted;
|
||||||
constexpr int smem_B_cols = BN;
|
constexpr int smem_B_cols = tile_dim_n;
|
||||||
|
|
||||||
const volatile uint8_t *smem_addr;
|
const volatile uint8_t *smem_addr;
|
||||||
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
||||||
@@ -445,6 +473,11 @@ inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count)
|
|||||||
// Move a single matrix tile from global memory (GMEM) to shared memory (SMEM).
|
// Move a single matrix tile from global memory (GMEM) to shared memory (SMEM).
|
||||||
// `dim_major`: major dimension of the matrix in GMEM, e.g. if K-major, K; or
|
// `dim_major`: major dimension of the matrix in GMEM, e.g. if K-major, K; or
|
||||||
// MN-major, M/N.
|
// MN-major, M/N.
|
||||||
|
//
|
||||||
|
// Note that there's not a single way to specify a layout of the matrix.
|
||||||
|
// Identifying a matrix to be K-major and specifying the mn_index of a tile,
|
||||||
|
// is equivalent to identifying it as MN-major and specifying the k_index
|
||||||
|
// (provided `dim_major` is set accordingly).
|
||||||
template <typename T,
|
template <typename T,
|
||||||
MemLayout gmem_layout, // memory layout of the GMEM tile
|
MemLayout gmem_layout, // memory layout of the GMEM tile
|
||||||
MemLayout smem_layout, // memory layout of the GMEM tile
|
MemLayout smem_layout, // memory layout of the GMEM tile
|
||||||
@@ -606,11 +639,13 @@ load_tile_to_smem(const uint32_t dim_major, const uint32_t mn_index,
|
|||||||
// Do a single tile*tile matrix multiplication using the matrix data stored in
|
// Do a single tile*tile matrix multiplication using the matrix data stored in
|
||||||
// SMEM. Useful in fused kernels where GEMMs are done at a per-tile scope.
|
// SMEM. Useful in fused kernels where GEMMs are done at a per-tile scope.
|
||||||
template <typename T,
|
template <typename T,
|
||||||
MemLayout layout_a, // memory layout of `local_a`
|
MemLayout layout_a, // memory layout of `local_a`
|
||||||
MemLayout layout_b, // memory layout of `local_b`
|
MemLayout layout_b, // memory layout of `local_b`
|
||||||
uint32_t tile_dim_m,
|
uint32_t tile_dim_m, uint32_t tile_dim_n, uint32_t tile_dim_k,
|
||||||
uint32_t tile_dim_n,
|
uint32_t leading_dim_a, // if zero, assumes packed layout, i.e. row
|
||||||
uint32_t tile_dim_k,
|
// stride == col.
|
||||||
|
uint32_t leading_dim_b, // if zero, assumes packed layout, i.e. row
|
||||||
|
// stride == col.
|
||||||
bool load_accum = false, // if true, load the accumulation registers
|
bool load_accum = false, // if true, load the accumulation registers
|
||||||
// with `local_c`. used for the (C + A*B)
|
// with `local_c`. used for the (C + A*B)
|
||||||
// operation
|
// operation
|
||||||
@@ -640,7 +675,6 @@ __attribute__((always_inline)) inline void thread_block_gemm_single_tile(
|
|||||||
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
||||||
#pragma GCC unroll
|
#pragma GCC unroll
|
||||||
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
||||||
// FIXME: template parameter-ize BM
|
|
||||||
wmma_load_accum(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter,
|
wmma_load_accum(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter,
|
||||||
tile_dim_n, local_c);
|
tile_dim_n, local_c);
|
||||||
}
|
}
|
||||||
@@ -654,13 +688,21 @@ __attribute__((always_inline)) inline void thread_block_gemm_single_tile(
|
|||||||
#pragma GCC unroll 2
|
#pragma GCC unroll 2
|
||||||
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
||||||
// SMEM -> RF
|
// SMEM -> RF
|
||||||
wmma_load_b<T, layout_b>(local_b, local_k, warp_col, wn_iter,
|
static_assert(leading_dim_b == 0,
|
||||||
tid_in_warp);
|
"leading_dim for wmma_load_b is not implemented yet");
|
||||||
|
wmma_load_b<T, layout_b, tile_dim_m, tile_dim_n,
|
||||||
|
tile_dim_k /*leading_dim_b is TODO */>(
|
||||||
|
local_b, local_k, warp_col, wn_iter, tid_in_warp);
|
||||||
#pragma GCC unroll 2
|
#pragma GCC unroll 2
|
||||||
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
||||||
// SMEM -> RF
|
// SMEM -> RF
|
||||||
wmma_load_a<T, layout_a>(local_a, local_k, warp_row, wm_iter,
|
if constexpr (leading_dim_a == 0) {
|
||||||
tid_in_warp);
|
wmma_load_a<T, layout_a, tile_dim_m, tile_dim_n, tile_dim_k>(
|
||||||
|
local_a, local_k, warp_row, wm_iter, tid_in_warp);
|
||||||
|
} else {
|
||||||
|
wmma_load_a<T, layout_a, leading_dim_a>(local_a, local_k, warp_row,
|
||||||
|
wm_iter, tid_in_warp);
|
||||||
|
}
|
||||||
// perform mma
|
// perform mma
|
||||||
vx_wmma(wm_iter);
|
vx_wmma(wm_iter);
|
||||||
}
|
}
|
||||||
@@ -925,7 +967,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
constexpr MemLayout layout_a =
|
constexpr MemLayout layout_a =
|
||||||
TRANSPOSE_AT_CONSUME ? MemLayout::K_major : MemLayout::MN_major;
|
TRANSPOSE_AT_CONSUME ? MemLayout::K_major : MemLayout::MN_major;
|
||||||
thread_block_gemm_single_tile<T, layout_a, MemLayout::MN_major,
|
thread_block_gemm_single_tile<T, layout_a, MemLayout::MN_major,
|
||||||
BM, BN, BK,
|
BM, BN, BK, 0, 0,
|
||||||
/*load_accum=*/false,
|
/*load_accum=*/false,
|
||||||
/*write_to_mem=*/false>(
|
/*write_to_mem=*/false>(
|
||||||
local_a_consume, local_b_consume,
|
local_a_consume, local_b_consume,
|
||||||
|
|||||||
Reference in New Issue
Block a user