sgemm_impl: Refactor dmem_load into one unified logic

Replace the confusing logic that had slightly different use of BM/BN/BK
for A and B, into one logic that accepts matrix memory layout as a
proper argument & does compile-time logic to determine the right
dimensions.

TODO: !GMEM_COALESCED_A is not updated yet
This commit is contained in:
Hansung Kim
2024-08-18 20:21:23 -07:00
parent b44b202a21
commit 04643fa64d

View File

@@ -14,7 +14,6 @@ using float16_t = uint16_t;
#if (FP_SIZE == 32) #if (FP_SIZE == 32)
using float_type = float; using float_type = float;
#elif (FP_SIZE == 16) #elif (FP_SIZE == 16)
using float_type = float16_t; using float_type = float16_t;
#endif #endif
@@ -70,7 +69,7 @@ using float_type = float16_t;
// generates the NN kernel where both A and B are stored row-major in GMEM. // generates the NN kernel where both A and B are stored row-major in GMEM.
// To model the case where the A matrix is already stored column-major in GMEM, // To model the case where the A matrix is already stored column-major in GMEM,
// set both to 0. // set both to 0.
#define TRANSPOSE_AT_PRODUCE 0 #define TRANSPOSE_AT_PRODUCE 1
#define TRANSPOSE_AT_CONSUME 0 #define TRANSPOSE_AT_CONSUME 0
// GMEM_COALESCED: When TRANSPOSE_AT_PRODUCE == 1 (i.e. transpose at // GMEM_COALESCED: When TRANSPOSE_AT_PRODUCE == 1 (i.e. transpose at
// GMEM->SMEM), determines whether we do bank-conflict-free accesses for // GMEM->SMEM), determines whether we do bank-conflict-free accesses for
@@ -393,6 +392,182 @@ inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count)
vx_barrier(barrier_id, count); vx_barrier(barrier_id, count);
} }
enum class MemLayout {
MN_major,
K_major,
};
// Move a single matrix tile from global memory (GMEM) to shared memory (SMEM).
// `dim_col`: column dimension of the global matrix.
template <typename T,
MemLayout gmem_layout, // memory layout of the GMEM tile
MemLayout smem_layout, // memory layout of the GMEM tile
uint32_t tile_dim_mn, // row dimension of the SMEM tile
uint32_t tile_dim_k, // column dimension of the SMEM tile
bool gmem_contiguous = true
>
__attribute__((always_inline)) inline void
global_dmem_load_new(const uint32_t dim_col, const uint32_t mn_index,
const uint32_t k, const T *global_addr,
volatile T *local_addr,
const uint32_t tid_in_threadblock) {
asm volatile("global_dmem_load_start_new_%=:" ::);
// In fp16 mode, bit-pack two fp16 elements into each fp32 element, and do
// data movement at the fp32 granularity. Assuming that the matrix is stored
// row-major in GMEM, the packed fp16 pairs belong to the same row,
// neighboring columns; therefore, it essentially becomes equivalent to
// moving a fp32 matrix whose column dimensions are compressed by a factor of
// two.
constexpr uint32_t packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
constexpr uint32_t tile_dim_k_packed = tile_dim_k / packed_factor;
constexpr uint32_t gmem_dim_row =
(gmem_layout == MemLayout::K_major) ? tile_dim_mn : tile_dim_k_packed;
constexpr uint32_t gmem_dim_col =
(gmem_layout == MemLayout::K_major) ? tile_dim_k_packed : tile_dim_mn;
constexpr uint32_t smem_dim_row =
(smem_layout == MemLayout::K_major) ? tile_dim_mn : tile_dim_k_packed;
constexpr uint32_t smem_dim_col =
(smem_layout == MemLayout::K_major) ? tile_dim_k_packed : tile_dim_mn;
const uint32_t dim_col_ =
(gmem_layout == MemLayout::K_major) ? dim_col / packed_factor : dim_col;
// FIXME: unsure about this
const uint32_t k_ = k / packed_factor;
// threads in the threadblock always do contiguous accesses in the gmem
const uint32_t local_row_gmem = tid_in_threadblock / gmem_dim_col;
const uint32_t local_col_gmem = tid_in_threadblock % gmem_dim_col;
constexpr bool transposed_write = (gmem_layout != smem_layout);
// if transposed, threads write to smem in reversed col/row
const uint32_t local_row_smem =
transposed_write ? local_col_gmem : local_row_gmem;
const uint32_t local_col_smem =
transposed_write ? local_row_gmem : local_col_gmem;
// FIXME: don't hardcode this here
constexpr uint32_t threads_per_threadblock = (BM * BN) / ELEM_PER_THREAD;
static_assert(gmem_contiguous == true,
"currently only supports contiguous accesses in GMEM");
const uint32_t global_row_mn_major = k_ + local_row_gmem;
const uint32_t global_col_mn_major = smem_dim_col * mn_index + local_col_gmem;
const uint32_t global_row_k_major = gmem_dim_row * mn_index + local_row_gmem;
const uint32_t global_col_k_major = k_ + local_col_gmem;
const uint32_t global_row = (gmem_layout == MemLayout::K_major)
? global_row_k_major
: global_row_mn_major;
const uint32_t global_col = (gmem_layout == MemLayout::K_major)
? global_col_k_major
: global_col_mn_major;
const float *global = reinterpret_cast<const float *>(global_addr) +
dim_col_ * global_row + global_col;
volatile float *local = reinterpret_cast<volatile float *>(local_addr) +
smem_dim_col * local_row_smem + local_col_smem;
constexpr uint32_t row_stride = threads_per_threadblock / gmem_dim_col;
static_assert(row_stride * 8 <= gmem_dim_row,
"manual loop unrolling condition not met; tile row dimension "
"is too shallow");
static_assert((gmem_dim_row % (row_stride * 8)) == 0,
"manual loop unrolling condition not met; tile row dimension "
"should be power-of-two");
#pragma GCC unroll 1
// loop-unrolled flw/fsw to increase reuse distance and IPC
for (uint32_t load_offset = 0; load_offset < gmem_dim_row;
load_offset += row_stride * 8) {
// equivalent code:
//
// *local = *global;
// global += dim_col * row_stride;
// local += BN * row_stride;
// read same-column elements into fp registers
asm volatile("flw ft0, (%0)" ::"r"(global));
global += dim_col_ * row_stride;
asm volatile("flw ft1, (%0)" ::"r"(global));
global += dim_col_ * row_stride;
asm volatile("flw ft2, (%0)" ::"r"(global));
global += dim_col_ * row_stride;
asm volatile("flw ft3, (%0)" ::"r"(global));
global += dim_col_ * row_stride;
asm volatile("flw ft4, (%0)" ::"r"(global));
global += dim_col_ * row_stride;
asm volatile("flw ft5, (%0)" ::"r"(global));
global += dim_col_ * row_stride;
asm volatile("flw ft6, (%0)" ::"r"(global));
global += dim_col_ * row_stride;
asm volatile("flw ft7, (%0)" ::"r"(global));
global += dim_col_ * row_stride;
// do we need to do transposed write?
if constexpr (!transposed_write) {
static_assert(gmem_layout == MemLayout::MN_major);
// if not, do the same along-the-column accesses for registers as we did
// for gmem
asm volatile("fsw ft0, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
sizeof(float)),
"r"(local));
asm volatile("fsw ft1, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 *
sizeof(float)),
"r"(local));
local += smem_dim_col * row_stride * 2;
asm volatile("fsw ft2, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
sizeof(float)),
"r"(local));
asm volatile("fsw ft3, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 *
sizeof(float)),
"r"(local));
local += smem_dim_col * row_stride * 2;
asm volatile("fsw ft4, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
sizeof(float)),
"r"(local));
asm volatile("fsw ft5, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 *
sizeof(float)),
"r"(local));
local += smem_dim_col * row_stride * 2;
asm volatile("fsw ft6, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
sizeof(float)),
"r"(local));
asm volatile("fsw ft7, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 *
sizeof(float)),
"r"(local));
local += smem_dim_col * row_stride * 2;
} else {
static_assert(gmem_layout == MemLayout::K_major);
static_assert(smem_layout == MemLayout::MN_major);
// if yes, write the registers along the row, doing a transpose
// @perf: this will incur bank conflicts in smem
asm volatile("fsw ft0, %0(%1)" ::"i"(row_stride * 0 * sizeof(float)),
"r"(local));
asm volatile("fsw ft1, %0(%1)" ::"i"(row_stride * 1 * sizeof(float)),
"r"(local));
asm volatile("fsw ft2, %0(%1)" ::"i"(row_stride * 2 * sizeof(float)),
"r"(local));
asm volatile("fsw ft3, %0(%1)" ::"i"(row_stride * 3 * sizeof(float)),
"r"(local));
asm volatile("fsw ft4, %0(%1)" ::"i"(row_stride * 4 * sizeof(float)),
"r"(local));
asm volatile("fsw ft5, %0(%1)" ::"i"(row_stride * 5 * sizeof(float)),
"r"(local));
asm volatile("fsw ft6, %0(%1)" ::"i"(row_stride * 6 * sizeof(float)),
"r"(local));
asm volatile("fsw ft7, %0(%1)" ::"i"(row_stride * 7 * sizeof(float)),
"r"(local));
local += row_stride * 8;
}
}
asm volatile("global_dmem_load_finish_new_%=:" ::);
}
// TODO: reduce args by passing leading A/B dimensions // TODO: reduce args by passing leading A/B dimensions
template <typename T> template <typename T>
__attribute__((always_inline)) __attribute__((always_inline))
@@ -413,16 +588,14 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
constexpr uint32_t packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1); constexpr uint32_t packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
constexpr uint32_t BK_adjusted = BK / packed_factor; constexpr uint32_t BK_adjusted = BK / packed_factor;
const uint32_t dim_k_adjusted = dim_k / packed_factor; const uint32_t dim_k_adjusted = dim_k / packed_factor;
constexpr uint32_t BN_adjusted = BN / packed_factor;
const uint32_t dim_n_adjusted = dim_n / packed_factor;
const uint32_t k_adjusted = k / packed_factor; const uint32_t k_adjusted = k / packed_factor;
const uint32_t local_a_row = tid_in_threadblock / BK_adjusted; const uint32_t local_a_row = tid_in_threadblock / BK_adjusted;
const uint32_t local_a_col = tid_in_threadblock % BK_adjusted; const uint32_t local_a_col = tid_in_threadblock % BK_adjusted;
const uint32_t local_as_row = tid_in_threadblock / BM; const uint32_t local_as_row = tid_in_threadblock / BM;
const uint32_t local_as_col = tid_in_threadblock % BM; const uint32_t local_as_col = tid_in_threadblock % BM;
const uint32_t local_b_row = tid_in_threadblock / BN_adjusted; const uint32_t local_b_row = tid_in_threadblock / BN;
const uint32_t local_b_col = tid_in_threadblock % BN_adjusted; const uint32_t local_b_col = tid_in_threadblock % BN;
// FIXME: need fix for fp16? // FIXME: need fix for fp16?
constexpr uint32_t threads_per_threadblock = (BM * BN) / ELEM_PER_THREAD; constexpr uint32_t threads_per_threadblock = (BM * BN) / ELEM_PER_THREAD;
@@ -436,27 +609,8 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
// move A // move A
if constexpr (!TRANSPOSE_AT_PRODUCE) { if constexpr (!TRANSPOSE_AT_PRODUCE) {
// A is stored M-major in GMEM; global_dmem_load_new<T, MemLayout::MN_major, MemLayout::MN_major, BM, BK>(
// no transpose at GMEM->SMEM movement dim_m, threadblock_id_y, k, A, local_a, tid_in_threadblock);
const uint32_t block_m = threadblock_id_y;
const uint32_t global_a_row = k_adjusted + local_as_row;
const uint32_t global_a_col = BM * block_m + local_as_col;
// number of rows a full TB can read at a time
constexpr uint32_t row_stride_as = threads_per_threadblock / BM;
const float *global_a = reinterpret_cast<const float *>(A) +
dim_m * global_a_row + global_a_col;
volatile float *local_a_tmp = reinterpret_cast<volatile float *>(local_a) +
BM * local_as_row + local_as_col;
#pragma GCC unroll 1
for (uint32_t local_row_offset = 0; local_row_offset < BK_adjusted;
local_row_offset += row_stride_as) {
// TODO: the code GCC generates for below seems fine atm, but unroll to
// assembly to be absolutely sure
*local_a_tmp = *global_a;
global_a += dim_m * row_stride_as;
local_a_tmp += BM * row_stride_as;
}
} else { } else {
if constexpr (!GMEM_COALESCED_A) { if constexpr (!GMEM_COALESCED_A) {
// !GMEM_COALESCED_A: threads do uncoalesced read from neighboring row in // !GMEM_COALESCED_A: threads do uncoalesced read from neighboring row in
@@ -478,13 +632,12 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
#pragma GCC unroll 1 #pragma GCC unroll 1
for (uint32_t local_row_offset = 0; local_row_offset < BK_adjusted; for (uint32_t local_row_offset = 0; local_row_offset < BK_adjusted;
local_row_offset += row_stride_as * 8) { local_row_offset += row_stride_as * 8) {
// @perf: bank conflicts here
// const uint32_t global_a_offset = // const uint32_t global_a_offset =
// dim_k_adjusted * (global_a_row) + (k + local_as_row + local_row_offset); // dim_k_adjusted * (global_a_row) + (k + local_as_row + local_row_offset);
// local_a[BM * (local_as_row + local_row_offset) + local_as_col] = // local_a[BM * (local_as_row + local_row_offset) + local_as_col] =
// A[global_a_offset]; // A[global_a_offset];
// *local_a_tmp = *global_a; // @perf: bank conflicts
asm volatile ("flw ft0, (%0)" :: "r"(global_a)); asm volatile ("flw ft0, (%0)" :: "r"(global_a));
global_a += row_stride_as; global_a += row_stride_as;
asm volatile ("flw ft1, (%0)" :: "r"(global_a)); asm volatile ("flw ft1, (%0)" :: "r"(global_a));
@@ -517,119 +670,15 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
local_a_tmp += BM * row_stride_as * 8; local_a_tmp += BM * row_stride_as * 8;
} }
} else { } else {
constexpr uint32_t row_stride_a = threads_per_threadblock / BK_adjusted; global_dmem_load_new<T, MemLayout::K_major, MemLayout::MN_major, BM,
const uint32_t global_a_row = BM * threadblock_id_y + local_a_row; BK>(dim_k, threadblock_id_y, k, A, local_a,
const float *global_a = reinterpret_cast<const float *>(A) + tid_in_threadblock);
dim_k_adjusted * global_a_row +
(k_adjusted + local_a_col);
// NOTE that SMEM writes are transposed
volatile float *local_a_tmp =
reinterpret_cast<volatile float *>(local_a) + BM * local_a_col +
local_a_row;
static_assert(
row_stride_a * 8 <= BM,
"manual loop unrolling condition not met; consider increasing BM");
static_assert(
(BM % (row_stride_a * 8)) == 0,
"manual loop unrolling condition not met; BM should be power-of-two");
#pragma GCC unroll 1
for (uint32_t local_row_offset = 0; local_row_offset < BM;
local_row_offset += row_stride_a * 8) {
// const uint32_t global_a_offset =
// dim_k_adjusted * (global_a_row + local_row_offset) + (k + local_a_col);
// NOTE that SMEM writes are transposed
// local_a[BM * (local_a_col) + local_a_row + local_row_offset] =
// A[global_a_offset];
asm volatile ("flw ft0, (%0)" :: "r"(global_a));
global_a += dim_k_adjusted * row_stride_a;
asm volatile ("flw ft1, (%0)" :: "r"(global_a));
global_a += dim_k_adjusted * row_stride_a;
asm volatile ("flw ft2, (%0)" :: "r"(global_a));
global_a += dim_k_adjusted * row_stride_a;
asm volatile ("flw ft3, (%0)" :: "r"(global_a));
global_a += dim_k_adjusted * row_stride_a;
asm volatile ("flw ft4, (%0)" :: "r"(global_a));
global_a += dim_k_adjusted * row_stride_a;
asm volatile ("flw ft5, (%0)" :: "r"(global_a));
global_a += dim_k_adjusted * row_stride_a;
asm volatile ("flw ft6, (%0)" :: "r"(global_a));
global_a += dim_k_adjusted * row_stride_a;
asm volatile ("flw ft7, (%0)" :: "r"(global_a));
global_a += dim_k_adjusted * row_stride_a;
// stride along columns
asm volatile ("fsw ft0, %0(%1)" :: "i"(row_stride_a * 0 * sizeof(float)), "r"(local_a_tmp));
asm volatile ("fsw ft1, %0(%1)" :: "i"(row_stride_a * 1 * sizeof(float)), "r"(local_a_tmp));
asm volatile ("fsw ft2, %0(%1)" :: "i"(row_stride_a * 2 * sizeof(float)), "r"(local_a_tmp));
asm volatile ("fsw ft3, %0(%1)" :: "i"(row_stride_a * 3 * sizeof(float)), "r"(local_a_tmp));
asm volatile ("fsw ft4, %0(%1)" :: "i"(row_stride_a * 4 * sizeof(float)), "r"(local_a_tmp));
asm volatile ("fsw ft5, %0(%1)" :: "i"(row_stride_a * 5 * sizeof(float)), "r"(local_a_tmp));
asm volatile ("fsw ft6, %0(%1)" :: "i"(row_stride_a * 6 * sizeof(float)), "r"(local_a_tmp));
asm volatile ("fsw ft7, %0(%1)" :: "i"(row_stride_a * 7 * sizeof(float)), "r"(local_a_tmp));
local_a_tmp += row_stride_a * 8;
}
} }
} // end move A } // end move A
// move B // move B
constexpr uint32_t row_stride_b = threads_per_threadblock / BN_adjusted; global_dmem_load_new<T, MemLayout::MN_major, MemLayout::MN_major, BN, BK>(
const uint32_t global_b_col = BN_adjusted * threadblock_id_x + local_b_col; dim_n, threadblock_id_x, k, B, local_b, tid_in_threadblock);
// NOTE: not k_adjusted here; k is along the row dimension which is not
// compressed for fp16
const float *global_b = reinterpret_cast<const float *>(B) +
dim_n_adjusted * (k + local_b_row) + global_b_col;
volatile float *local_b_tmp = reinterpret_cast<volatile float *>(local_b) +
BN_adjusted * local_b_row + local_b_col;
static_assert(
row_stride_b * 8 <= BK_adjusted,
"manual loop unrolling condition not met; consider increasing BK");
static_assert(
(BK_adjusted % (row_stride_b * 8)) == 0,
"manual loop unrolling condition not met; BK should be power-of-two");
#pragma GCC unroll 1
for (uint32_t load_offset = 0; load_offset < BK;
load_offset += row_stride_b * 8) {
// equivalent code:
//
// *local_b_tmp = *global_b;
// global_b += dim_n * row_stride_b;
// local_b_tmp += BN * row_stride_b;
asm volatile ("flw ft0, (%0)" :: "r"(global_b));
global_b += dim_n_adjusted * row_stride_b;
asm volatile ("flw ft1, (%0)" :: "r"(global_b));
global_b += dim_n_adjusted * row_stride_b;
asm volatile ("flw ft2, (%0)" :: "r"(global_b));
global_b += dim_n_adjusted * row_stride_b;
asm volatile ("flw ft3, (%0)" :: "r"(global_b));
global_b += dim_n_adjusted * row_stride_b;
asm volatile ("flw ft4, (%0)" :: "r"(global_b));
global_b += dim_n_adjusted * row_stride_b;
asm volatile ("flw ft5, (%0)" :: "r"(global_b));
global_b += dim_n_adjusted * row_stride_b;
asm volatile ("flw ft6, (%0)" :: "r"(global_b));
global_b += dim_n_adjusted * row_stride_b;
asm volatile ("flw ft7, (%0)" :: "r"(global_b));
global_b += dim_n_adjusted * row_stride_b;
asm volatile ("fsw ft0, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
asm volatile ("fsw ft1, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
local_b_tmp += BN_adjusted * row_stride_b * 2;
asm volatile ("fsw ft2, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
asm volatile ("fsw ft3, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
local_b_tmp += BN_adjusted * row_stride_b * 2;
asm volatile ("fsw ft4, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
asm volatile ("fsw ft5, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
local_b_tmp += BN_adjusted * row_stride_b * 2;
asm volatile ("fsw ft6, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
asm volatile ("fsw ft7, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
local_b_tmp += BN_adjusted * row_stride_b * 2;
}
asm volatile ("global_dmem_load_finish_%=:" :: ); asm volatile ("global_dmem_load_finish_%=:" :: );
} }