sgemm_impl: Rename to wmma
This commit is contained in:
@@ -78,7 +78,6 @@ inline void thread_block_online_softmax(
|
||||
|
||||
volatile float *gmem_tmp0 = reinterpret_cast<volatile float *>(0xd0000000UL);
|
||||
volatile float *gmem_tmp1 = reinterpret_cast<volatile float *>(0xe0000000UL);
|
||||
volatile float *gmem_tmp2 = reinterpret_cast<volatile float *>(0xf0000000UL);
|
||||
|
||||
float *smem_rowmax_prev = smem_rowmax;
|
||||
float *smem_rowmax_new = smem_rowmax + B_ROW;
|
||||
@@ -310,35 +309,33 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
(smem_QK_size + smem_V_size + smem_O_size) *
|
||||
threadblock_id_in_cluster);
|
||||
|
||||
uint8_t *smem_S = smem_per_threadblock;
|
||||
uint8_t *smem_P = smem_S; // in-place update from S to P
|
||||
uint8_t *smem_V = smem_per_threadblock + sizeof(float) * smem_QK_size;
|
||||
uint8_t *smem_O =
|
||||
smem_per_threadblock + sizeof(float) * (smem_QK_size + smem_V_size);
|
||||
float *smem_S = reinterpret_cast<float *>(smem_per_threadblock);
|
||||
float *smem_P = smem_S; // in-place update from S to P
|
||||
float *smem_V =
|
||||
reinterpret_cast<float *>(smem_per_threadblock) + smem_QK_size;
|
||||
float *smem_O = reinterpret_cast<float *>(smem_per_threadblock) +
|
||||
smem_QK_size + smem_V_size;
|
||||
|
||||
// allocate rowmax/rowsum storage at the end of the sharedmem address space
|
||||
constexpr uint32_t smem_rowmax_size = sizeof(float) * B_ROW * 3 /* mi, mi~, minew */;
|
||||
constexpr uint32_t smem_rowsum_size = sizeof(float) * B_ROW;
|
||||
uint8_t *smem_rowmax =
|
||||
reinterpret_cast<uint8_t *>(SMEM_ADDR_END) - smem_rowmax_size;
|
||||
uint8_t *smem_rowsum = smem_rowmax - smem_rowsum_size;
|
||||
constexpr uint32_t smem_rowmax_size = B_ROW * 3 /* mi, mi~, minew */;
|
||||
constexpr uint32_t smem_rowsum_size = B_ROW;
|
||||
float *smem_rowmax =
|
||||
reinterpret_cast<float *>(SMEM_ADDR_END) - smem_rowmax_size;
|
||||
float *smem_rowsum = smem_rowmax - smem_rowsum_size;
|
||||
|
||||
// sharedmem "scratchpad" area to put temporary data, e.g. for tree reduction
|
||||
// in rowsum
|
||||
// NOTE: out-of bounds is not checked
|
||||
constexpr uint32_t smem_scratchpad_size =
|
||||
sizeof(float) * B_ROW * NUM_THREADS * 2 /*arbitrary slack*/;
|
||||
uint8_t *smem_scratchpad =
|
||||
smem_rowmax - smem_scratchpad_size;
|
||||
B_ROW * NUM_THREADS * 2 /*arbitrary slack*/;
|
||||
float *smem_scratchpad = smem_rowmax - smem_scratchpad_size;
|
||||
|
||||
const uint32_t warps_per_threadblock_per_core =
|
||||
NUM_WARPS / threads_per_threadblock;
|
||||
|
||||
// initialize rowmax/rowsum values in sharedmem
|
||||
thread_block_init_sharedmem(tid_in_threadblock, threads_per_threadblock,
|
||||
(float *)smem_O,
|
||||
(float *)smem_rowmax,
|
||||
(float *)smem_rowsum);
|
||||
smem_O, smem_rowmax, smem_rowsum);
|
||||
|
||||
#define SKIP_GEMM
|
||||
#ifndef SKIP_GEMM
|
||||
@@ -359,16 +356,18 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
#endif
|
||||
|
||||
thread_block_online_softmax(
|
||||
tile_S, (float *)smem_O, (float *)smem_P, tid_in_threadblock,
|
||||
threads_per_threadblock, threadblock_id_in_cluster,
|
||||
(float *)smem_scratchpad, (float *)smem_rowmax, (float *)smem_rowsum);
|
||||
tile_S, smem_O, smem_P, tid_in_threadblock, threads_per_threadblock,
|
||||
threadblock_id_in_cluster, smem_scratchpad, smem_rowmax, smem_rowsum);
|
||||
|
||||
// FIXME unnecessary?
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
|
||||
thread_block_gemm_single_tile(smem_P, smem_V, tid_in_threadblock,
|
||||
threads_per_threadblock);
|
||||
float *gmem_tmp2 = reinterpret_cast<float *>(0xf0000000UL);
|
||||
|
||||
thread_block_gemm_single_tile<float, /*write_to_smem=*/true>(
|
||||
smem_P, smem_V, gmem_tmp2 /*smem_O*/, tid_in_threadblock,
|
||||
threads_per_threadblock);
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
@@ -205,15 +205,15 @@ inline void vx_wmma(const int dest_reg) {
|
||||
|
||||
// `local_k` is assumed to be multiple of TCK
|
||||
template <typename T>
|
||||
inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k,
|
||||
inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
||||
const int warp_row, const int wm_iter,
|
||||
const int thread_in_warp) {
|
||||
asm volatile ("vx_wmma_load_a_start_%=:" :: );
|
||||
asm volatile ("wmma_load_a_start_%=:" :: );
|
||||
|
||||
const int tid = thread_in_warp;
|
||||
const int tg = tid / 4;
|
||||
|
||||
// @perf: this is duplicately computed in vx_wmma_load_a and vx_wmma_load_b
|
||||
// @perf: this is duplicately computed in wmma_load_a and wmma_load_b
|
||||
int row = 0;
|
||||
int col = 0;
|
||||
map_operand(tid, row, col);
|
||||
@@ -273,15 +273,15 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k,
|
||||
asm volatile("flw f7, %0(%1)" :: "i"(smem_AS_cols * 7 * sizeof(float)), "r"(smem_addr));
|
||||
}
|
||||
|
||||
asm volatile ("vx_wmma_load_a_finish_%=:" :: );
|
||||
asm volatile ("wmma_load_a_finish_%=:" :: );
|
||||
}
|
||||
|
||||
// `local_k` is assumed to be multiple of TCK
|
||||
template <typename T>
|
||||
inline void vx_wmma_load_b(const volatile T *smem_B, const int local_k,
|
||||
inline void wmma_load_b(const volatile T *smem_B, const int local_k,
|
||||
const int warp_col, const int wn_iter,
|
||||
const int thread_in_warp) {
|
||||
asm volatile ("vx_wmma_load_b_start_%=:" :: );
|
||||
asm volatile ("wmma_load_b_start_%=:" :: );
|
||||
|
||||
const int tid = thread_in_warp;
|
||||
const int tg = tid / 4;
|
||||
@@ -290,7 +290,7 @@ inline void vx_wmma_load_b(const volatile T *smem_B, const int local_k,
|
||||
int col = 0;
|
||||
map_operand(tid, row, col);
|
||||
|
||||
// see comment in vx_wmma_load_a
|
||||
// see comment in wmma_load_a
|
||||
constexpr int packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
|
||||
constexpr int BK_adjusted = BN / packed_factor;
|
||||
constexpr int BN_adjusted = BN / packed_factor;
|
||||
@@ -316,7 +316,7 @@ inline void vx_wmma_load_b(const volatile T *smem_B, const int local_k,
|
||||
asm volatile("flw f14, %0(%1)" :: "i"(smem_B_cols * 6 * sizeof(float)), "r"(smem_addr));
|
||||
asm volatile("flw f15, %0(%1)" :: "i"(smem_B_cols * 7 * sizeof(float)), "r"(smem_addr));
|
||||
|
||||
asm volatile ("vx_wmma_load_b_finish_%=:" :: );
|
||||
asm volatile ("wmma_load_b_finish_%=:" :: );
|
||||
}
|
||||
|
||||
inline void initialize_C(const int dest_reg) {
|
||||
@@ -659,11 +659,11 @@ thread_block_gemm_single_tile(const T *local_a, const T *local_b, T *local_c,
|
||||
#pragma GCC unroll 2
|
||||
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
||||
// SMEM -> RF
|
||||
vx_wmma_load_b<T>(local_b, local_k, warp_col, wn_iter, tid_in_warp);
|
||||
wmma_load_b<T>(local_b, local_k, warp_col, wn_iter, tid_in_warp);
|
||||
#pragma GCC unroll 2
|
||||
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
||||
// SMEM -> RF
|
||||
vx_wmma_load_a<T>(local_a, local_k, warp_row, wm_iter, tid_in_warp);
|
||||
wmma_load_a<T>(local_a, local_k, warp_row, wm_iter, tid_in_warp);
|
||||
// perform mma
|
||||
vx_wmma(wm_iter);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user