sgemm_impl: Refactor dmem_load into one unified logic
Replace the confusing logic that had slightly different use of BM/BN/BK for A and B, into one logic that accepts matrix memory layout as a proper argument & does compile-time logic to determine the right dimensions. TODO: !GMEM_COALESCED_A is not updated yet
This commit is contained in:
@@ -14,7 +14,6 @@ using float16_t = uint16_t;
|
||||
#if (FP_SIZE == 32)
|
||||
using float_type = float;
|
||||
#elif (FP_SIZE == 16)
|
||||
|
||||
using float_type = float16_t;
|
||||
#endif
|
||||
|
||||
@@ -70,7 +69,7 @@ using float_type = float16_t;
|
||||
// generates the NN kernel where both A and B are stored row-major in GMEM.
|
||||
// To model the case where the A matrix is already stored column-major in GMEM,
|
||||
// set both to 0.
|
||||
#define TRANSPOSE_AT_PRODUCE 0
|
||||
#define TRANSPOSE_AT_PRODUCE 1
|
||||
#define TRANSPOSE_AT_CONSUME 0
|
||||
// GMEM_COALESCED: When TRANSPOSE_AT_PRODUCE == 1 (i.e. transpose at
|
||||
// GMEM->SMEM), determines whether we do bank-conflict-free accesses for
|
||||
@@ -393,6 +392,182 @@ inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count)
|
||||
vx_barrier(barrier_id, count);
|
||||
}
|
||||
|
||||
enum class MemLayout {
|
||||
MN_major,
|
||||
K_major,
|
||||
};
|
||||
|
||||
// Move a single matrix tile from global memory (GMEM) to shared memory (SMEM).
|
||||
// `dim_col`: column dimension of the global matrix.
|
||||
template <typename T,
|
||||
MemLayout gmem_layout, // memory layout of the GMEM tile
|
||||
MemLayout smem_layout, // memory layout of the GMEM tile
|
||||
uint32_t tile_dim_mn, // row dimension of the SMEM tile
|
||||
uint32_t tile_dim_k, // column dimension of the SMEM tile
|
||||
bool gmem_contiguous = true
|
||||
>
|
||||
__attribute__((always_inline)) inline void
|
||||
global_dmem_load_new(const uint32_t dim_col, const uint32_t mn_index,
|
||||
const uint32_t k, const T *global_addr,
|
||||
volatile T *local_addr,
|
||||
const uint32_t tid_in_threadblock) {
|
||||
asm volatile("global_dmem_load_start_new_%=:" ::);
|
||||
|
||||
// In fp16 mode, bit-pack two fp16 elements into each fp32 element, and do
|
||||
// data movement at the fp32 granularity. Assuming that the matrix is stored
|
||||
// row-major in GMEM, the packed fp16 pairs belong to the same row,
|
||||
// neighboring columns; therefore, it essentially becomes equivalent to
|
||||
// moving a fp32 matrix whose column dimensions are compressed by a factor of
|
||||
// two.
|
||||
constexpr uint32_t packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
|
||||
|
||||
constexpr uint32_t tile_dim_k_packed = tile_dim_k / packed_factor;
|
||||
constexpr uint32_t gmem_dim_row =
|
||||
(gmem_layout == MemLayout::K_major) ? tile_dim_mn : tile_dim_k_packed;
|
||||
constexpr uint32_t gmem_dim_col =
|
||||
(gmem_layout == MemLayout::K_major) ? tile_dim_k_packed : tile_dim_mn;
|
||||
constexpr uint32_t smem_dim_row =
|
||||
(smem_layout == MemLayout::K_major) ? tile_dim_mn : tile_dim_k_packed;
|
||||
constexpr uint32_t smem_dim_col =
|
||||
(smem_layout == MemLayout::K_major) ? tile_dim_k_packed : tile_dim_mn;
|
||||
|
||||
const uint32_t dim_col_ =
|
||||
(gmem_layout == MemLayout::K_major) ? dim_col / packed_factor : dim_col;
|
||||
// FIXME: unsure about this
|
||||
const uint32_t k_ = k / packed_factor;
|
||||
|
||||
// threads in the threadblock always do contiguous accesses in the gmem
|
||||
const uint32_t local_row_gmem = tid_in_threadblock / gmem_dim_col;
|
||||
const uint32_t local_col_gmem = tid_in_threadblock % gmem_dim_col;
|
||||
|
||||
constexpr bool transposed_write = (gmem_layout != smem_layout);
|
||||
// if transposed, threads write to smem in reversed col/row
|
||||
const uint32_t local_row_smem =
|
||||
transposed_write ? local_col_gmem : local_row_gmem;
|
||||
const uint32_t local_col_smem =
|
||||
transposed_write ? local_row_gmem : local_col_gmem;
|
||||
|
||||
// FIXME: don't hardcode this here
|
||||
constexpr uint32_t threads_per_threadblock = (BM * BN) / ELEM_PER_THREAD;
|
||||
|
||||
static_assert(gmem_contiguous == true,
|
||||
"currently only supports contiguous accesses in GMEM");
|
||||
|
||||
const uint32_t global_row_mn_major = k_ + local_row_gmem;
|
||||
const uint32_t global_col_mn_major = smem_dim_col * mn_index + local_col_gmem;
|
||||
const uint32_t global_row_k_major = gmem_dim_row * mn_index + local_row_gmem;
|
||||
const uint32_t global_col_k_major = k_ + local_col_gmem;
|
||||
const uint32_t global_row = (gmem_layout == MemLayout::K_major)
|
||||
? global_row_k_major
|
||||
: global_row_mn_major;
|
||||
const uint32_t global_col = (gmem_layout == MemLayout::K_major)
|
||||
? global_col_k_major
|
||||
: global_col_mn_major;
|
||||
|
||||
const float *global = reinterpret_cast<const float *>(global_addr) +
|
||||
dim_col_ * global_row + global_col;
|
||||
volatile float *local = reinterpret_cast<volatile float *>(local_addr) +
|
||||
smem_dim_col * local_row_smem + local_col_smem;
|
||||
|
||||
constexpr uint32_t row_stride = threads_per_threadblock / gmem_dim_col;
|
||||
static_assert(row_stride * 8 <= gmem_dim_row,
|
||||
"manual loop unrolling condition not met; tile row dimension "
|
||||
"is too shallow");
|
||||
static_assert((gmem_dim_row % (row_stride * 8)) == 0,
|
||||
"manual loop unrolling condition not met; tile row dimension "
|
||||
"should be power-of-two");
|
||||
|
||||
#pragma GCC unroll 1
|
||||
// loop-unrolled flw/fsw to increase reuse distance and IPC
|
||||
for (uint32_t load_offset = 0; load_offset < gmem_dim_row;
|
||||
load_offset += row_stride * 8) {
|
||||
// equivalent code:
|
||||
//
|
||||
// *local = *global;
|
||||
// global += dim_col * row_stride;
|
||||
// local += BN * row_stride;
|
||||
|
||||
// read same-column elements into fp registers
|
||||
asm volatile("flw ft0, (%0)" ::"r"(global));
|
||||
global += dim_col_ * row_stride;
|
||||
asm volatile("flw ft1, (%0)" ::"r"(global));
|
||||
global += dim_col_ * row_stride;
|
||||
asm volatile("flw ft2, (%0)" ::"r"(global));
|
||||
global += dim_col_ * row_stride;
|
||||
asm volatile("flw ft3, (%0)" ::"r"(global));
|
||||
global += dim_col_ * row_stride;
|
||||
asm volatile("flw ft4, (%0)" ::"r"(global));
|
||||
global += dim_col_ * row_stride;
|
||||
asm volatile("flw ft5, (%0)" ::"r"(global));
|
||||
global += dim_col_ * row_stride;
|
||||
asm volatile("flw ft6, (%0)" ::"r"(global));
|
||||
global += dim_col_ * row_stride;
|
||||
asm volatile("flw ft7, (%0)" ::"r"(global));
|
||||
global += dim_col_ * row_stride;
|
||||
|
||||
// do we need to do transposed write?
|
||||
if constexpr (!transposed_write) {
|
||||
static_assert(gmem_layout == MemLayout::MN_major);
|
||||
|
||||
// if not, do the same along-the-column accesses for registers as we did
|
||||
// for gmem
|
||||
asm volatile("fsw ft0, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
||||
sizeof(float)),
|
||||
"r"(local));
|
||||
asm volatile("fsw ft1, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 *
|
||||
sizeof(float)),
|
||||
"r"(local));
|
||||
local += smem_dim_col * row_stride * 2;
|
||||
asm volatile("fsw ft2, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
||||
sizeof(float)),
|
||||
"r"(local));
|
||||
asm volatile("fsw ft3, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 *
|
||||
sizeof(float)),
|
||||
"r"(local));
|
||||
local += smem_dim_col * row_stride * 2;
|
||||
asm volatile("fsw ft4, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
||||
sizeof(float)),
|
||||
"r"(local));
|
||||
asm volatile("fsw ft5, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 *
|
||||
sizeof(float)),
|
||||
"r"(local));
|
||||
local += smem_dim_col * row_stride * 2;
|
||||
asm volatile("fsw ft6, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
||||
sizeof(float)),
|
||||
"r"(local));
|
||||
asm volatile("fsw ft7, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 *
|
||||
sizeof(float)),
|
||||
"r"(local));
|
||||
local += smem_dim_col * row_stride * 2;
|
||||
} else {
|
||||
static_assert(gmem_layout == MemLayout::K_major);
|
||||
static_assert(smem_layout == MemLayout::MN_major);
|
||||
|
||||
// if yes, write the registers along the row, doing a transpose
|
||||
// @perf: this will incur bank conflicts in smem
|
||||
asm volatile("fsw ft0, %0(%1)" ::"i"(row_stride * 0 * sizeof(float)),
|
||||
"r"(local));
|
||||
asm volatile("fsw ft1, %0(%1)" ::"i"(row_stride * 1 * sizeof(float)),
|
||||
"r"(local));
|
||||
asm volatile("fsw ft2, %0(%1)" ::"i"(row_stride * 2 * sizeof(float)),
|
||||
"r"(local));
|
||||
asm volatile("fsw ft3, %0(%1)" ::"i"(row_stride * 3 * sizeof(float)),
|
||||
"r"(local));
|
||||
asm volatile("fsw ft4, %0(%1)" ::"i"(row_stride * 4 * sizeof(float)),
|
||||
"r"(local));
|
||||
asm volatile("fsw ft5, %0(%1)" ::"i"(row_stride * 5 * sizeof(float)),
|
||||
"r"(local));
|
||||
asm volatile("fsw ft6, %0(%1)" ::"i"(row_stride * 6 * sizeof(float)),
|
||||
"r"(local));
|
||||
asm volatile("fsw ft7, %0(%1)" ::"i"(row_stride * 7 * sizeof(float)),
|
||||
"r"(local));
|
||||
local += row_stride * 8;
|
||||
}
|
||||
}
|
||||
|
||||
asm volatile("global_dmem_load_finish_new_%=:" ::);
|
||||
}
|
||||
|
||||
// TODO: reduce args by passing leading A/B dimensions
|
||||
template <typename T>
|
||||
__attribute__((always_inline))
|
||||
@@ -413,16 +588,14 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
|
||||
constexpr uint32_t packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
|
||||
constexpr uint32_t BK_adjusted = BK / packed_factor;
|
||||
const uint32_t dim_k_adjusted = dim_k / packed_factor;
|
||||
constexpr uint32_t BN_adjusted = BN / packed_factor;
|
||||
const uint32_t dim_n_adjusted = dim_n / packed_factor;
|
||||
const uint32_t k_adjusted = k / packed_factor;
|
||||
|
||||
const uint32_t local_a_row = tid_in_threadblock / BK_adjusted;
|
||||
const uint32_t local_a_col = tid_in_threadblock % BK_adjusted;
|
||||
const uint32_t local_as_row = tid_in_threadblock / BM;
|
||||
const uint32_t local_as_col = tid_in_threadblock % BM;
|
||||
const uint32_t local_b_row = tid_in_threadblock / BN_adjusted;
|
||||
const uint32_t local_b_col = tid_in_threadblock % BN_adjusted;
|
||||
const uint32_t local_b_row = tid_in_threadblock / BN;
|
||||
const uint32_t local_b_col = tid_in_threadblock % BN;
|
||||
|
||||
// FIXME: need fix for fp16?
|
||||
constexpr uint32_t threads_per_threadblock = (BM * BN) / ELEM_PER_THREAD;
|
||||
@@ -436,27 +609,8 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
|
||||
|
||||
// move A
|
||||
if constexpr (!TRANSPOSE_AT_PRODUCE) {
|
||||
// A is stored M-major in GMEM;
|
||||
// no transpose at GMEM->SMEM movement
|
||||
const uint32_t block_m = threadblock_id_y;
|
||||
const uint32_t global_a_row = k_adjusted + local_as_row;
|
||||
const uint32_t global_a_col = BM * block_m + local_as_col;
|
||||
// number of rows a full TB can read at a time
|
||||
constexpr uint32_t row_stride_as = threads_per_threadblock / BM;
|
||||
const float *global_a = reinterpret_cast<const float *>(A) +
|
||||
dim_m * global_a_row + global_a_col;
|
||||
volatile float *local_a_tmp = reinterpret_cast<volatile float *>(local_a) +
|
||||
BM * local_as_row + local_as_col;
|
||||
|
||||
#pragma GCC unroll 1
|
||||
for (uint32_t local_row_offset = 0; local_row_offset < BK_adjusted;
|
||||
local_row_offset += row_stride_as) {
|
||||
// TODO: the code GCC generates for below seems fine atm, but unroll to
|
||||
// assembly to be absolutely sure
|
||||
*local_a_tmp = *global_a;
|
||||
global_a += dim_m * row_stride_as;
|
||||
local_a_tmp += BM * row_stride_as;
|
||||
}
|
||||
global_dmem_load_new<T, MemLayout::MN_major, MemLayout::MN_major, BM, BK>(
|
||||
dim_m, threadblock_id_y, k, A, local_a, tid_in_threadblock);
|
||||
} else {
|
||||
if constexpr (!GMEM_COALESCED_A) {
|
||||
// !GMEM_COALESCED_A: threads do uncoalesced read from neighboring row in
|
||||
@@ -478,13 +632,12 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
|
||||
#pragma GCC unroll 1
|
||||
for (uint32_t local_row_offset = 0; local_row_offset < BK_adjusted;
|
||||
local_row_offset += row_stride_as * 8) {
|
||||
// @perf: bank conflicts here
|
||||
// const uint32_t global_a_offset =
|
||||
// dim_k_adjusted * (global_a_row) + (k + local_as_row + local_row_offset);
|
||||
// local_a[BM * (local_as_row + local_row_offset) + local_as_col] =
|
||||
// A[global_a_offset];
|
||||
|
||||
// *local_a_tmp = *global_a;
|
||||
// @perf: bank conflicts
|
||||
asm volatile ("flw ft0, (%0)" :: "r"(global_a));
|
||||
global_a += row_stride_as;
|
||||
asm volatile ("flw ft1, (%0)" :: "r"(global_a));
|
||||
@@ -517,119 +670,15 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
|
||||
local_a_tmp += BM * row_stride_as * 8;
|
||||
}
|
||||
} else {
|
||||
constexpr uint32_t row_stride_a = threads_per_threadblock / BK_adjusted;
|
||||
const uint32_t global_a_row = BM * threadblock_id_y + local_a_row;
|
||||
const float *global_a = reinterpret_cast<const float *>(A) +
|
||||
dim_k_adjusted * global_a_row +
|
||||
(k_adjusted + local_a_col);
|
||||
// NOTE that SMEM writes are transposed
|
||||
volatile float *local_a_tmp =
|
||||
reinterpret_cast<volatile float *>(local_a) + BM * local_a_col +
|
||||
local_a_row;
|
||||
|
||||
static_assert(
|
||||
row_stride_a * 8 <= BM,
|
||||
"manual loop unrolling condition not met; consider increasing BM");
|
||||
static_assert(
|
||||
(BM % (row_stride_a * 8)) == 0,
|
||||
"manual loop unrolling condition not met; BM should be power-of-two");
|
||||
|
||||
#pragma GCC unroll 1
|
||||
for (uint32_t local_row_offset = 0; local_row_offset < BM;
|
||||
local_row_offset += row_stride_a * 8) {
|
||||
// const uint32_t global_a_offset =
|
||||
// dim_k_adjusted * (global_a_row + local_row_offset) + (k + local_a_col);
|
||||
// NOTE that SMEM writes are transposed
|
||||
// local_a[BM * (local_a_col) + local_a_row + local_row_offset] =
|
||||
// A[global_a_offset];
|
||||
|
||||
asm volatile ("flw ft0, (%0)" :: "r"(global_a));
|
||||
global_a += dim_k_adjusted * row_stride_a;
|
||||
asm volatile ("flw ft1, (%0)" :: "r"(global_a));
|
||||
global_a += dim_k_adjusted * row_stride_a;
|
||||
asm volatile ("flw ft2, (%0)" :: "r"(global_a));
|
||||
global_a += dim_k_adjusted * row_stride_a;
|
||||
asm volatile ("flw ft3, (%0)" :: "r"(global_a));
|
||||
global_a += dim_k_adjusted * row_stride_a;
|
||||
asm volatile ("flw ft4, (%0)" :: "r"(global_a));
|
||||
global_a += dim_k_adjusted * row_stride_a;
|
||||
asm volatile ("flw ft5, (%0)" :: "r"(global_a));
|
||||
global_a += dim_k_adjusted * row_stride_a;
|
||||
asm volatile ("flw ft6, (%0)" :: "r"(global_a));
|
||||
global_a += dim_k_adjusted * row_stride_a;
|
||||
asm volatile ("flw ft7, (%0)" :: "r"(global_a));
|
||||
global_a += dim_k_adjusted * row_stride_a;
|
||||
|
||||
// stride along columns
|
||||
asm volatile ("fsw ft0, %0(%1)" :: "i"(row_stride_a * 0 * sizeof(float)), "r"(local_a_tmp));
|
||||
asm volatile ("fsw ft1, %0(%1)" :: "i"(row_stride_a * 1 * sizeof(float)), "r"(local_a_tmp));
|
||||
asm volatile ("fsw ft2, %0(%1)" :: "i"(row_stride_a * 2 * sizeof(float)), "r"(local_a_tmp));
|
||||
asm volatile ("fsw ft3, %0(%1)" :: "i"(row_stride_a * 3 * sizeof(float)), "r"(local_a_tmp));
|
||||
asm volatile ("fsw ft4, %0(%1)" :: "i"(row_stride_a * 4 * sizeof(float)), "r"(local_a_tmp));
|
||||
asm volatile ("fsw ft5, %0(%1)" :: "i"(row_stride_a * 5 * sizeof(float)), "r"(local_a_tmp));
|
||||
asm volatile ("fsw ft6, %0(%1)" :: "i"(row_stride_a * 6 * sizeof(float)), "r"(local_a_tmp));
|
||||
asm volatile ("fsw ft7, %0(%1)" :: "i"(row_stride_a * 7 * sizeof(float)), "r"(local_a_tmp));
|
||||
local_a_tmp += row_stride_a * 8;
|
||||
}
|
||||
global_dmem_load_new<T, MemLayout::K_major, MemLayout::MN_major, BM,
|
||||
BK>(dim_k, threadblock_id_y, k, A, local_a,
|
||||
tid_in_threadblock);
|
||||
}
|
||||
} // end move A
|
||||
|
||||
// move B
|
||||
constexpr uint32_t row_stride_b = threads_per_threadblock / BN_adjusted;
|
||||
const uint32_t global_b_col = BN_adjusted * threadblock_id_x + local_b_col;
|
||||
// NOTE: not k_adjusted here; k is along the row dimension which is not
|
||||
// compressed for fp16
|
||||
const float *global_b = reinterpret_cast<const float *>(B) +
|
||||
dim_n_adjusted * (k + local_b_row) + global_b_col;
|
||||
volatile float *local_b_tmp = reinterpret_cast<volatile float *>(local_b) +
|
||||
BN_adjusted * local_b_row + local_b_col;
|
||||
|
||||
static_assert(
|
||||
row_stride_b * 8 <= BK_adjusted,
|
||||
"manual loop unrolling condition not met; consider increasing BK");
|
||||
static_assert(
|
||||
(BK_adjusted % (row_stride_b * 8)) == 0,
|
||||
"manual loop unrolling condition not met; BK should be power-of-two");
|
||||
|
||||
#pragma GCC unroll 1
|
||||
for (uint32_t load_offset = 0; load_offset < BK;
|
||||
load_offset += row_stride_b * 8) {
|
||||
// equivalent code:
|
||||
//
|
||||
// *local_b_tmp = *global_b;
|
||||
// global_b += dim_n * row_stride_b;
|
||||
// local_b_tmp += BN * row_stride_b;
|
||||
|
||||
asm volatile ("flw ft0, (%0)" :: "r"(global_b));
|
||||
global_b += dim_n_adjusted * row_stride_b;
|
||||
asm volatile ("flw ft1, (%0)" :: "r"(global_b));
|
||||
global_b += dim_n_adjusted * row_stride_b;
|
||||
asm volatile ("flw ft2, (%0)" :: "r"(global_b));
|
||||
global_b += dim_n_adjusted * row_stride_b;
|
||||
asm volatile ("flw ft3, (%0)" :: "r"(global_b));
|
||||
global_b += dim_n_adjusted * row_stride_b;
|
||||
asm volatile ("flw ft4, (%0)" :: "r"(global_b));
|
||||
global_b += dim_n_adjusted * row_stride_b;
|
||||
asm volatile ("flw ft5, (%0)" :: "r"(global_b));
|
||||
global_b += dim_n_adjusted * row_stride_b;
|
||||
asm volatile ("flw ft6, (%0)" :: "r"(global_b));
|
||||
global_b += dim_n_adjusted * row_stride_b;
|
||||
asm volatile ("flw ft7, (%0)" :: "r"(global_b));
|
||||
global_b += dim_n_adjusted * row_stride_b;
|
||||
|
||||
asm volatile ("fsw ft0, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
|
||||
asm volatile ("fsw ft1, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
|
||||
local_b_tmp += BN_adjusted * row_stride_b * 2;
|
||||
asm volatile ("fsw ft2, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
|
||||
asm volatile ("fsw ft3, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
|
||||
local_b_tmp += BN_adjusted * row_stride_b * 2;
|
||||
asm volatile ("fsw ft4, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
|
||||
asm volatile ("fsw ft5, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
|
||||
local_b_tmp += BN_adjusted * row_stride_b * 2;
|
||||
asm volatile ("fsw ft6, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
|
||||
asm volatile ("fsw ft7, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
|
||||
local_b_tmp += BN_adjusted * row_stride_b * 2;
|
||||
}
|
||||
global_dmem_load_new<T, MemLayout::MN_major, MemLayout::MN_major, BN, BK>(
|
||||
dim_n, threadblock_id_x, k, B, local_b, tid_in_threadblock);
|
||||
|
||||
asm volatile ("global_dmem_load_finish_%=:" :: );
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user