sgemm_tcore: Fix wrong double-buf addr for wmma_load

This commit is contained in:
Hansung Kim
2024-06-15 00:51:35 -07:00
parent 9d6ff196b3
commit cfb6ae4a91

View File

@@ -280,11 +280,11 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
volatile float *local_a = sharedmem_per_threadblock;
constexpr size_t local_a_elems = (BM * BK);
volatile float *local_b = sharedmem_per_threadblock + local_a_elems;
constexpr size_t local_b_elems = (BK * BN);
volatile float *local_a_buf = local_a + local_a_elems;
volatile float *local_a_buf = local_b + local_b_elems;
volatile float *local_b_buf = local_a_buf + local_a_elems;
volatile float *local_b = local_a_buf + local_a_elems;
constexpr size_t local_b_elems = (BK * BN);
volatile float *local_b_buf = local_a_buf + local_b_elems;
constexpr uint32_t skips =
loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/0, /*skip_ldd=*/1,
@@ -453,8 +453,8 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
// local_b_consume = reinterpret_cast<volatile float *>(
// (mask_odd & reinterpret_cast<uintmax_t>(local_b_buf)) |
// (mask_even & reinterpret_cast<uintmax_t>(local_b)));
local_a_consume = local_a + (block_k & 1) * (local_a_elems + local_b_elems);
local_b_consume = local_b + (block_k & 1) * (local_a_elems + local_b_elems);
local_a_consume = local_a + (block_k & 1) * (local_a_elems);
local_b_consume = local_b + (block_k & 1) * (local_b_elems);
} else {
local_a_consume = local_a;
local_b_consume = local_b;