sgemm_impl: load_tile: accept k_index for consistency + fix gmem addr gen
This commit is contained in:
@@ -453,7 +453,7 @@ template <typename T,
|
||||
>
|
||||
__attribute__((always_inline)) inline void
|
||||
load_tile_to_smem(const uint32_t dim_major, const uint32_t mn_index,
|
||||
const uint32_t k, const T *global_addr,
|
||||
const uint32_t k_index, const T *global_addr,
|
||||
volatile T *local_addr, const uint32_t tid_in_threadblock) {
|
||||
asm volatile("global_dmem_load_start_new_%=:" ::);
|
||||
|
||||
@@ -469,15 +469,11 @@ load_tile_to_smem(const uint32_t dim_major, const uint32_t mn_index,
|
||||
(gmem_layout == MemLayout::K_major) ? tile_dim_mn : tile_dim_k_packed;
|
||||
constexpr uint32_t gmem_dim_col =
|
||||
(gmem_layout == MemLayout::K_major) ? tile_dim_k_packed : tile_dim_mn;
|
||||
constexpr uint32_t smem_dim_row =
|
||||
(smem_layout == MemLayout::K_major) ? tile_dim_mn : tile_dim_k_packed;
|
||||
constexpr uint32_t smem_dim_col =
|
||||
(smem_layout == MemLayout::K_major) ? tile_dim_k_packed : tile_dim_mn;
|
||||
|
||||
const uint32_t dim_major_ =
|
||||
(gmem_layout == MemLayout::K_major) ? dim_major / packed_factor : dim_major;
|
||||
// FIXME: unsure about this
|
||||
const uint32_t k_ = k / packed_factor;
|
||||
|
||||
// threads in the threadblock always do contiguous accesses in the gmem
|
||||
const uint32_t local_row_gmem = tid_in_threadblock / gmem_dim_col;
|
||||
@@ -493,10 +489,10 @@ load_tile_to_smem(const uint32_t dim_major, const uint32_t mn_index,
|
||||
// FIXME: don't hardcode this here
|
||||
constexpr uint32_t threads_per_threadblock = (BM * BN) / ELEM_PER_THREAD;
|
||||
|
||||
const uint32_t global_row_mn_major = k_ + local_row_gmem;
|
||||
const uint32_t global_col_mn_major = smem_dim_col * mn_index + local_col_gmem;
|
||||
const uint32_t global_row_mn_major = tile_dim_k_packed * k_index + local_row_gmem;
|
||||
const uint32_t global_col_mn_major = gmem_dim_col * mn_index + local_col_gmem;
|
||||
const uint32_t global_row_k_major = gmem_dim_row * mn_index + local_row_gmem;
|
||||
const uint32_t global_col_k_major = k_ + local_col_gmem;
|
||||
const uint32_t global_col_k_major = tile_dim_k_packed * k_index + local_col_gmem;
|
||||
const uint32_t global_row = (gmem_layout == MemLayout::K_major)
|
||||
? global_row_k_major
|
||||
: global_row_mn_major;
|
||||
@@ -879,16 +875,16 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
||||
// move A
|
||||
if constexpr (!TRANSPOSE_AT_PRODUCE) {
|
||||
load_tile_to_smem<T, MemLayout::MN_major, MemLayout::MN_major, BM,
|
||||
BK>(dim_m, block_m, block_k * BK, A, local_a,
|
||||
BK>(dim_m, block_m, block_k, A, local_a,
|
||||
tid_in_threadblock);
|
||||
} else {
|
||||
load_tile_to_smem<T, MemLayout::K_major, MemLayout::MN_major, BM, BK>(
|
||||
dim_k, block_m, block_k * BK, A, local_a, tid_in_threadblock);
|
||||
dim_k, block_m, block_k, A, local_a, tid_in_threadblock);
|
||||
}
|
||||
|
||||
// move B
|
||||
load_tile_to_smem<T, MemLayout::MN_major, MemLayout::MN_major, BN, BK>(
|
||||
dim_n, block_n, block_k * BK, B, local_b, tid_in_threadblock);
|
||||
dim_n, block_n, block_k, B, local_b, tid_in_threadblock);
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
|
||||
Reference in New Issue
Block a user