sgemm_tcore: Fix wrong double-buf addr for wmma_load
This commit is contained in:
@@ -280,11 +280,11 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
|
||||
volatile float *local_a = sharedmem_per_threadblock;
|
||||
constexpr size_t local_a_elems = (BM * BK);
|
||||
volatile float *local_b = sharedmem_per_threadblock + local_a_elems;
|
||||
constexpr size_t local_b_elems = (BK * BN);
|
||||
volatile float *local_a_buf = local_a + local_a_elems;
|
||||
|
||||
volatile float *local_a_buf = local_b + local_b_elems;
|
||||
volatile float *local_b_buf = local_a_buf + local_a_elems;
|
||||
volatile float *local_b = local_a_buf + local_a_elems;
|
||||
constexpr size_t local_b_elems = (BK * BN);
|
||||
volatile float *local_b_buf = local_a_buf + local_b_elems;
|
||||
|
||||
constexpr uint32_t skips =
|
||||
loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/0, /*skip_ldd=*/1,
|
||||
@@ -453,8 +453,8 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
// local_b_consume = reinterpret_cast<volatile float *>(
|
||||
// (mask_odd & reinterpret_cast<uintmax_t>(local_b_buf)) |
|
||||
// (mask_even & reinterpret_cast<uintmax_t>(local_b)));
|
||||
local_a_consume = local_a + (block_k & 1) * (local_a_elems + local_b_elems);
|
||||
local_b_consume = local_b + (block_k & 1) * (local_a_elems + local_b_elems);
|
||||
local_a_consume = local_a + (block_k & 1) * (local_a_elems);
|
||||
local_b_consume = local_b + (block_k & 1) * (local_b_elems);
|
||||
} else {
|
||||
local_a_consume = local_a;
|
||||
local_b_consume = local_b;
|
||||
|
||||
Reference in New Issue
Block a user