sgemm_impl: Rename dmem load function

This commit is contained in:
Hansung Kim
2024-08-18 22:25:01 -07:00
parent 46b5047775
commit 1b133e7b5c

View File

@@ -395,13 +395,12 @@ template <typename T,
MemLayout gmem_layout, // memory layout of the GMEM tile
MemLayout smem_layout, // memory layout of the GMEM tile
uint32_t tile_dim_mn, // row dimension of the SMEM tile
uint32_t tile_dim_k // column dimension of the SMEM tile
uint32_t tile_dim_k // column dimension of the SMEM tile
>
__attribute__((always_inline)) inline void
global_dmem_load_new(const uint32_t dim_col, const uint32_t mn_index,
const uint32_t k, const T *global_addr,
volatile T *local_addr,
const uint32_t tid_in_threadblock) {
load_tile_to_smem(const uint32_t dim_col, const uint32_t mn_index,
const uint32_t k, const T *global_addr,
volatile T *local_addr, const uint32_t tid_in_threadblock) {
asm volatile("global_dmem_load_start_new_%=:" ::);
// In fp16 mode, bit-pack two fp16 elements into each fp32 element, and do
@@ -805,19 +804,17 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
#else
// move A
if constexpr (!TRANSPOSE_AT_PRODUCE) {
global_dmem_load_new<T, MemLayout::MN_major, MemLayout::MN_major, BM,
BK>(dim_m, block_m, block_k * BK, A, local_a,
tid_in_threadblock);
load_tile_to_smem<T, MemLayout::MN_major, MemLayout::MN_major, BM,
BK>(dim_m, block_m, block_k * BK, A, local_a,
tid_in_threadblock);
} else {
global_dmem_load_new<T, MemLayout::K_major, MemLayout::MN_major, BM,
BK>(dim_k, block_m, block_k * BK, A, local_a,
tid_in_threadblock);
load_tile_to_smem<T, MemLayout::K_major, MemLayout::MN_major, BM, BK>(
dim_k, block_m, block_k * BK, A, local_a, tid_in_threadblock);
}
// move B
global_dmem_load_new<T, MemLayout::MN_major, MemLayout::MN_major, BN,
BK>(dim_n, block_n, block_k * BK, B, local_b,
tid_in_threadblock);
load_tile_to_smem<T, MemLayout::MN_major, MemLayout::MN_major, BN, BK>(
dim_n, block_n, block_k * BK, B, local_b, tid_in_threadblock);
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);