sgemm_impl: Accept layout template param at gemm_single_tile and wmma_load

This commit is contained in:
Hansung Kim
2024-08-19 13:16:22 -07:00
parent 1b133e7b5c
commit 42ddb9a48e

View File

@@ -69,7 +69,7 @@ using float_type = float16_t;
// generates the NN kernel where both A and B are stored row-major in GMEM.
// To model the case where the A matrix is already stored column-major in GMEM,
// set both to 0.
#define TRANSPOSE_AT_PRODUCE 1
#define TRANSPOSE_AT_PRODUCE 0
#define TRANSPOSE_AT_CONSUME 0
#define GEMMINI_DMA 0
@@ -97,6 +97,11 @@ using float_type = float16_t;
#error Unsupported smem size
#endif
enum class MemLayout {
MN_major,
K_major,
};
inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) {
const int tg = tid / 4;
@@ -195,10 +200,10 @@ inline void vx_wmma(const int dest_reg) {
}
// `local_k` is assumed to be multiple of TCK
template <typename T>
template <typename T, MemLayout layout>
inline void wmma_load_a(volatile const T *smem_A, const int local_k,
const int warp_row, const int wm_iter,
const int thread_in_warp) {
const int warp_row, const int wm_iter,
const int thread_in_warp) {
asm volatile ("wmma_load_a_start_%=:" :: );
const int tid = thread_in_warp;
@@ -219,8 +224,7 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
constexpr int BK_adjusted = BK / packed_factor;
const int local_k_adjusted = local_k / packed_factor;
if constexpr (TRANSPOSE_AT_CONSUME) {
// A is stored K-major in smem
if constexpr (layout == MemLayout::K_major) {
constexpr int smem_A_rows = BM;
constexpr int smem_A_cols = BK_adjusted;
@@ -242,8 +246,7 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
asm volatile("flw f5, %0(%1)" ::"i"(5 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f6, %0(%1)" ::"i"(6 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f7, %0(%1)" ::"i"(7 * sizeof(float)), "r"(smem_addr));
} else {
// A is stored M-major in smem
} else if (layout == MemLayout::MN_major) {
constexpr int smem_AS_rows = BK_adjusted;
constexpr int smem_AS_cols = BM;
@@ -262,18 +265,25 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
asm volatile("flw f5, %0(%1)" :: "i"(smem_AS_cols * 5 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f6, %0(%1)" :: "i"(smem_AS_cols * 6 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f7, %0(%1)" :: "i"(smem_AS_cols * 7 * sizeof(float)), "r"(smem_addr));
} else {
static_assert(layout ==
MemLayout::K_major /* fake cond that is always false */,
"unsupported memory layout");
}
asm volatile ("wmma_load_a_finish_%=:" :: );
}
// `local_k` is assumed to be multiple of TCK
template <typename T>
template <typename T, MemLayout layout>
inline void wmma_load_b(const volatile T *smem_B, const int local_k,
const int warp_col, const int wn_iter,
const int thread_in_warp) {
asm volatile ("wmma_load_b_start_%=:" :: );
static_assert(layout == MemLayout::MN_major,
"only N-major layout for the B tile is supported");
const int tid = thread_in_warp;
const int tg = tid / 4;
@@ -384,11 +394,6 @@ inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count)
vx_barrier(barrier_id, count);
}
enum class MemLayout {
MN_major,
K_major,
};
// Move a single matrix tile from global memory (GMEM) to shared memory (SMEM).
// `dim_col`: column dimension of the global matrix.
template <typename T,
@@ -404,11 +409,10 @@ load_tile_to_smem(const uint32_t dim_col, const uint32_t mn_index,
asm volatile("global_dmem_load_start_new_%=:" ::);
// In fp16 mode, bit-pack two fp16 elements into each fp32 element, and do
// data movement at the fp32 granularity. Assuming that the matrix is stored
// row-major in GMEM, the packed fp16 pairs belong to the same row,
// neighboring columns; therefore, it essentially becomes equivalent to
// moving a fp32 matrix whose column dimensions are compressed by a factor of
// two.
// data movement at the fp32 granularity. The tensor core hardware assumes
// the fp16 elements are contiguously stored along the K-dimension;
// therefore, this essentially becomes equivalent to a fp32 GEMM where the
// K-dimension is shrinked by the factor of two.
constexpr uint32_t packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
constexpr uint32_t tile_dim_k_packed = tile_dim_k / packed_factor;
@@ -555,6 +559,8 @@ load_tile_to_smem(const uint32_t dim_col, const uint32_t mn_index,
// Do a single tile*tile matrix multiplication using the matrix data stored in
// SMEM. Useful in fused kernels where GEMMs are done at a per-tile scope.
template <typename T,
MemLayout layout_a, // memory layout of `local_a`
MemLayout layout_b, // memory layout of `local_b`
bool write_to_smem = false // if true, write result tile to SMEM at a
// given address
>
@@ -577,11 +583,13 @@ thread_block_gemm_single_tile(const T *local_a, const T *local_b, T *local_c,
#pragma GCC unroll 2
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
// SMEM -> RF
wmma_load_b<T>(local_b, local_k, warp_col, wn_iter, tid_in_warp);
wmma_load_b<T, layout_b>(local_b, local_k, warp_col, wn_iter,
tid_in_warp);
#pragma GCC unroll 2
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
// SMEM -> RF
wmma_load_a<T>(local_a, local_k, warp_row, wm_iter, tid_in_warp);
wmma_load_a<T, layout_a>(local_a, local_k, warp_row, wm_iter,
tid_in_warp);
// perform mma
vx_wmma(wm_iter);
}
@@ -845,7 +853,10 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
local_b_consume = local_b;
}
thread_block_gemm_single_tile(
constexpr MemLayout layout_a =
TRANSPOSE_AT_CONSUME ? MemLayout::K_major : MemLayout::MN_major;
thread_block_gemm_single_tile<T, layout_a, MemLayout::MN_major,
/*write_to_smem=*/false>(
local_a_consume, local_b_consume,
static_cast<volatile T *>(nullptr) /*ignore*/, tid_in_threadblock,
threads_per_threadblock);