sgemm_impl: Accept layout template param at gemm_single_tile and wmma_load
This commit is contained in:
@@ -69,7 +69,7 @@ using float_type = float16_t;
|
||||
// generates the NN kernel where both A and B are stored row-major in GMEM.
|
||||
// To model the case where the A matrix is already stored column-major in GMEM,
|
||||
// set both to 0.
|
||||
#define TRANSPOSE_AT_PRODUCE 1
|
||||
#define TRANSPOSE_AT_PRODUCE 0
|
||||
#define TRANSPOSE_AT_CONSUME 0
|
||||
|
||||
#define GEMMINI_DMA 0
|
||||
@@ -97,6 +97,11 @@ using float_type = float16_t;
|
||||
#error Unsupported smem size
|
||||
#endif
|
||||
|
||||
enum class MemLayout {
|
||||
MN_major,
|
||||
K_major,
|
||||
};
|
||||
|
||||
inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) {
|
||||
const int tg = tid / 4;
|
||||
|
||||
@@ -195,10 +200,10 @@ inline void vx_wmma(const int dest_reg) {
|
||||
}
|
||||
|
||||
// `local_k` is assumed to be multiple of TCK
|
||||
template <typename T>
|
||||
template <typename T, MemLayout layout>
|
||||
inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
||||
const int warp_row, const int wm_iter,
|
||||
const int thread_in_warp) {
|
||||
const int warp_row, const int wm_iter,
|
||||
const int thread_in_warp) {
|
||||
asm volatile ("wmma_load_a_start_%=:" :: );
|
||||
|
||||
const int tid = thread_in_warp;
|
||||
@@ -219,8 +224,7 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
||||
constexpr int BK_adjusted = BK / packed_factor;
|
||||
const int local_k_adjusted = local_k / packed_factor;
|
||||
|
||||
if constexpr (TRANSPOSE_AT_CONSUME) {
|
||||
// A is stored K-major in smem
|
||||
if constexpr (layout == MemLayout::K_major) {
|
||||
constexpr int smem_A_rows = BM;
|
||||
constexpr int smem_A_cols = BK_adjusted;
|
||||
|
||||
@@ -242,8 +246,7 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
||||
asm volatile("flw f5, %0(%1)" ::"i"(5 * sizeof(float)), "r"(smem_addr));
|
||||
asm volatile("flw f6, %0(%1)" ::"i"(6 * sizeof(float)), "r"(smem_addr));
|
||||
asm volatile("flw f7, %0(%1)" ::"i"(7 * sizeof(float)), "r"(smem_addr));
|
||||
} else {
|
||||
// A is stored M-major in smem
|
||||
} else if (layout == MemLayout::MN_major) {
|
||||
constexpr int smem_AS_rows = BK_adjusted;
|
||||
constexpr int smem_AS_cols = BM;
|
||||
|
||||
@@ -262,18 +265,25 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
||||
asm volatile("flw f5, %0(%1)" :: "i"(smem_AS_cols * 5 * sizeof(float)), "r"(smem_addr));
|
||||
asm volatile("flw f6, %0(%1)" :: "i"(smem_AS_cols * 6 * sizeof(float)), "r"(smem_addr));
|
||||
asm volatile("flw f7, %0(%1)" :: "i"(smem_AS_cols * 7 * sizeof(float)), "r"(smem_addr));
|
||||
} else {
|
||||
static_assert(layout ==
|
||||
MemLayout::K_major /* fake cond that is always false */,
|
||||
"unsupported memory layout");
|
||||
}
|
||||
|
||||
asm volatile ("wmma_load_a_finish_%=:" :: );
|
||||
}
|
||||
|
||||
// `local_k` is assumed to be multiple of TCK
|
||||
template <typename T>
|
||||
template <typename T, MemLayout layout>
|
||||
inline void wmma_load_b(const volatile T *smem_B, const int local_k,
|
||||
const int warp_col, const int wn_iter,
|
||||
const int thread_in_warp) {
|
||||
asm volatile ("wmma_load_b_start_%=:" :: );
|
||||
|
||||
static_assert(layout == MemLayout::MN_major,
|
||||
"only N-major layout for the B tile is supported");
|
||||
|
||||
const int tid = thread_in_warp;
|
||||
const int tg = tid / 4;
|
||||
|
||||
@@ -384,11 +394,6 @@ inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count)
|
||||
vx_barrier(barrier_id, count);
|
||||
}
|
||||
|
||||
enum class MemLayout {
|
||||
MN_major,
|
||||
K_major,
|
||||
};
|
||||
|
||||
// Move a single matrix tile from global memory (GMEM) to shared memory (SMEM).
|
||||
// `dim_col`: column dimension of the global matrix.
|
||||
template <typename T,
|
||||
@@ -404,11 +409,10 @@ load_tile_to_smem(const uint32_t dim_col, const uint32_t mn_index,
|
||||
asm volatile("global_dmem_load_start_new_%=:" ::);
|
||||
|
||||
// In fp16 mode, bit-pack two fp16 elements into each fp32 element, and do
|
||||
// data movement at the fp32 granularity. Assuming that the matrix is stored
|
||||
// row-major in GMEM, the packed fp16 pairs belong to the same row,
|
||||
// neighboring columns; therefore, it essentially becomes equivalent to
|
||||
// moving a fp32 matrix whose column dimensions are compressed by a factor of
|
||||
// two.
|
||||
// data movement at the fp32 granularity. The tensor core hardware assumes
|
||||
// the fp16 elements are contiguously stored along the K-dimension;
|
||||
// therefore, this essentially becomes equivalent to a fp32 GEMM where the
|
||||
// K-dimension is shrinked by the factor of two.
|
||||
constexpr uint32_t packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
|
||||
|
||||
constexpr uint32_t tile_dim_k_packed = tile_dim_k / packed_factor;
|
||||
@@ -555,6 +559,8 @@ load_tile_to_smem(const uint32_t dim_col, const uint32_t mn_index,
|
||||
// Do a single tile*tile matrix multiplication using the matrix data stored in
|
||||
// SMEM. Useful in fused kernels where GEMMs are done at a per-tile scope.
|
||||
template <typename T,
|
||||
MemLayout layout_a, // memory layout of `local_a`
|
||||
MemLayout layout_b, // memory layout of `local_b`
|
||||
bool write_to_smem = false // if true, write result tile to SMEM at a
|
||||
// given address
|
||||
>
|
||||
@@ -577,11 +583,13 @@ thread_block_gemm_single_tile(const T *local_a, const T *local_b, T *local_c,
|
||||
#pragma GCC unroll 2
|
||||
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
||||
// SMEM -> RF
|
||||
wmma_load_b<T>(local_b, local_k, warp_col, wn_iter, tid_in_warp);
|
||||
wmma_load_b<T, layout_b>(local_b, local_k, warp_col, wn_iter,
|
||||
tid_in_warp);
|
||||
#pragma GCC unroll 2
|
||||
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
||||
// SMEM -> RF
|
||||
wmma_load_a<T>(local_a, local_k, warp_row, wm_iter, tid_in_warp);
|
||||
wmma_load_a<T, layout_a>(local_a, local_k, warp_row, wm_iter,
|
||||
tid_in_warp);
|
||||
// perform mma
|
||||
vx_wmma(wm_iter);
|
||||
}
|
||||
@@ -845,7 +853,10 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
||||
local_b_consume = local_b;
|
||||
}
|
||||
|
||||
thread_block_gemm_single_tile(
|
||||
constexpr MemLayout layout_a =
|
||||
TRANSPOSE_AT_CONSUME ? MemLayout::K_major : MemLayout::MN_major;
|
||||
thread_block_gemm_single_tile<T, layout_a, MemLayout::MN_major,
|
||||
/*write_to_smem=*/false>(
|
||||
local_a_consume, local_b_consume,
|
||||
static_cast<volatile T *>(nullptr) /*ignore*/, tid_in_threadblock,
|
||||
threads_per_threadblock);
|
||||
|
||||
Reference in New Issue
Block a user