sgemm_impl: Rename dmem load function
This commit is contained in:
@@ -395,13 +395,12 @@ template <typename T,
|
||||
MemLayout gmem_layout, // memory layout of the GMEM tile
|
||||
MemLayout smem_layout, // memory layout of the GMEM tile
|
||||
uint32_t tile_dim_mn, // row dimension of the SMEM tile
|
||||
uint32_t tile_dim_k // column dimension of the SMEM tile
|
||||
uint32_t tile_dim_k // column dimension of the SMEM tile
|
||||
>
|
||||
__attribute__((always_inline)) inline void
|
||||
global_dmem_load_new(const uint32_t dim_col, const uint32_t mn_index,
|
||||
const uint32_t k, const T *global_addr,
|
||||
volatile T *local_addr,
|
||||
const uint32_t tid_in_threadblock) {
|
||||
load_tile_to_smem(const uint32_t dim_col, const uint32_t mn_index,
|
||||
const uint32_t k, const T *global_addr,
|
||||
volatile T *local_addr, const uint32_t tid_in_threadblock) {
|
||||
asm volatile("global_dmem_load_start_new_%=:" ::);
|
||||
|
||||
// In fp16 mode, bit-pack two fp16 elements into each fp32 element, and do
|
||||
@@ -805,19 +804,17 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
||||
#else
|
||||
// move A
|
||||
if constexpr (!TRANSPOSE_AT_PRODUCE) {
|
||||
global_dmem_load_new<T, MemLayout::MN_major, MemLayout::MN_major, BM,
|
||||
BK>(dim_m, block_m, block_k * BK, A, local_a,
|
||||
tid_in_threadblock);
|
||||
load_tile_to_smem<T, MemLayout::MN_major, MemLayout::MN_major, BM,
|
||||
BK>(dim_m, block_m, block_k * BK, A, local_a,
|
||||
tid_in_threadblock);
|
||||
} else {
|
||||
global_dmem_load_new<T, MemLayout::K_major, MemLayout::MN_major, BM,
|
||||
BK>(dim_k, block_m, block_k * BK, A, local_a,
|
||||
tid_in_threadblock);
|
||||
load_tile_to_smem<T, MemLayout::K_major, MemLayout::MN_major, BM, BK>(
|
||||
dim_k, block_m, block_k * BK, A, local_a, tid_in_threadblock);
|
||||
}
|
||||
|
||||
// move B
|
||||
global_dmem_load_new<T, MemLayout::MN_major, MemLayout::MN_major, BN,
|
||||
BK>(dim_n, block_n, block_k * BK, B, local_b,
|
||||
tid_in_threadblock);
|
||||
load_tile_to_smem<T, MemLayout::MN_major, MemLayout::MN_major, BN, BK>(
|
||||
dim_n, block_n, block_k * BK, B, local_b, tid_in_threadblock);
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
|
||||
Reference in New Issue
Block a user