sgemm_tcore: Addr gen for local_k; add SIMT-only for reference

This commit is contained in:
Hansung Kim
2024-05-16 14:09:55 -07:00
parent df1aa62916
commit 8f64fae7a7
2 changed files with 121 additions and 49 deletions

View File

@@ -6,6 +6,9 @@
#include <vx_spawn.h>
#include "common.h"
#define USE_TENSOR_CORE 1
#define TC_SINGLE_WARP 0
#define NUM_LANES 8
// Constraints on parameters:
@@ -20,18 +23,19 @@
// (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER
// * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields
// BM <= BK*TM*TN
#define BM 8
#define BN BM
#define BK 8
#define BM 16
#define BN 16
#define BK 32
#define TCM 8
#define TCN 8
#define TCK 8
#define WM 8
#define WN 8
#define WMITER (WM / TCM)
#define WNITER (WN / TCN)
#define TM 1
// #define TN ((TCM * TCN) / NUM_LANES / TM)
#define TN 1
#define TN ((TCM * TCN) / NUM_LANES / TM)
// #define TN 1
inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) {
@@ -125,9 +129,10 @@ inline void vx_wmma() {
asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3));
}
void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_col,
int warp_row, int wn_iter, int wm_iter,
int thread_in_warp) {
// `local_k` is assumed to be multiple of TCK
void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, const int local_k,
const int warp_col, const int warp_row, const int wn_iter,
const int wm_iter, const int thread_in_warp) {
int tid = thread_in_warp;
int tg = tid / 4;
@@ -142,23 +147,24 @@ void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_col,
int A_offset = (row + WM * warp_row + TCM * wm_iter) * smem_A_cols;
asm volatile("flw f0, %0" ::"m"(smem_A[A_offset + 0]));
asm volatile("flw f1, %0" ::"m"(smem_A[A_offset + 1]));
asm volatile("flw f2, %0" ::"m"(smem_A[A_offset + 2]));
asm volatile("flw f3, %0" ::"m"(smem_A[A_offset + 3]));
asm volatile("flw f4, %0" ::"m"(smem_A[A_offset + 4]));
asm volatile("flw f5, %0" ::"m"(smem_A[A_offset + 5]));
asm volatile("flw f6, %0" ::"m"(smem_A[A_offset + 6]));
asm volatile("flw f7, %0" ::"m"(smem_A[A_offset + 7]));
// @perf: bank conflicts
asm volatile("flw f0, %0" ::"m"(smem_A[A_offset + (local_k + 0)]));
asm volatile("flw f1, %0" ::"m"(smem_A[A_offset + (local_k + 1)]));
asm volatile("flw f2, %0" ::"m"(smem_A[A_offset + (local_k + 2)]));
asm volatile("flw f3, %0" ::"m"(smem_A[A_offset + (local_k + 3)]));
asm volatile("flw f4, %0" ::"m"(smem_A[A_offset + (local_k + 4)]));
asm volatile("flw f5, %0" ::"m"(smem_A[A_offset + (local_k + 5)]));
asm volatile("flw f6, %0" ::"m"(smem_A[A_offset + (local_k + 6)]));
asm volatile("flw f7, %0" ::"m"(smem_A[A_offset + (local_k + 7)]));
asm volatile("flw f8 , %0" ::"m"(smem_B[(0 * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
asm volatile("flw f9 , %0" ::"m"(smem_B[(1 * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
asm volatile("flw f10, %0" ::"m"(smem_B[(2 * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
asm volatile("flw f11, %0" ::"m"(smem_B[(3 * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
asm volatile("flw f12, %0" ::"m"(smem_B[(4 * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
asm volatile("flw f13, %0" ::"m"(smem_B[(5 * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
asm volatile("flw f14, %0" ::"m"(smem_B[(6 * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
asm volatile("flw f15, %0" ::"m"(smem_B[(7 * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
asm volatile("flw f8 , %0" ::"m"(smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
asm volatile("flw f9 , %0" ::"m"(smem_B[((local_k + 1) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
asm volatile("flw f10, %0" ::"m"(smem_B[((local_k + 2) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
asm volatile("flw f11, %0" ::"m"(smem_B[((local_k + 3) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
asm volatile("flw f12, %0" ::"m"(smem_B[((local_k + 4) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
asm volatile("flw f13, %0" ::"m"(smem_B[((local_k + 5) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
asm volatile("flw f14, %0" ::"m"(smem_B[((local_k + 6) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
asm volatile("flw f15, %0" ::"m"(smem_B[((local_k + 7) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
}
inline void initialize_C() {
@@ -232,6 +238,14 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
const uint32_t global_a_row = BM * threadblock_id_y + local_a_row;
const uint32_t global_b_col = BN * threadblock_id_x + local_b_col;
const uint32_t local_c_row = tid_in_threadblock / (BN / TN);
const uint32_t local_c_col = tid_in_threadblock % (BN / TN);
// each thread generates TM output element
float reg_c[TM * TN] = { 0.0f };
float reg_a[TM] = { 0.0f };
float reg_b[TN] = { 0.0f };
const uint32_t warp_in_threadblock = tid_in_threadblock / NUM_LANES;
const uint32_t warp_row = warp_in_threadblock / (BN / WN);
const uint32_t warp_col = warp_in_threadblock % (BN / WN);
@@ -239,11 +253,9 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
volatile float *local_a = sharedmem_per_threadblock;
// const size_t local_a_elems = threadblock_dim_x * threadblock_dim_y;
// FIXME: this better be BM * BK, but the GMEM->SMEM load assumes all threads
// in TB participates in the load
const size_t local_a_elems = (BM * BN);
const size_t local_a_elems = (BM * BK);
volatile float *local_b = sharedmem_per_threadblock + local_a_elems;
const size_t local_b_elems = (BM * BN);
const size_t local_b_elems = (BK * BN);
volatile float *local_warp_results =
local_b + local_b_elems + (warp_in_threadblock * TCM * TCN);
@@ -281,36 +293,95 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster,
threadblock_dim_y);
// perform wmma
// vx_wmma_load(local_a, local_b, warp_x, warp_y, tid_in_warp);
// FIXME: If multiple warps try to issue to Tensor Core at the same time,
// does one stall the other?
// FIXME: this is wrong!! need separate accumulation register for
// WM/WN_ITERS
if (warp_in_threadblock == 0) {
#if USE_TENSOR_CORE
for (uint32_t local_k = 0; local_k < BK; local_k += TCK) {
// perform wmma
// vx_wmma_load(local_a, local_b, warp_x, warp_y, tid_in_warp);
// FIXME: If multiple warps try to issue to Tensor Core at the same time,
// does one stall the other?
// FIXME: this is wrong!! need separate accumulation register for
// WM/WN_ITERS
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
vx_wmma_load(local_a, local_b, warp_col, warp_row, wn_iter, wm_iter,
tid_in_warp);
vx_wmma();
#if TC_SINGLE_WARP
if (warp_in_threadblock == 0) {
#endif
vx_wmma_load(local_a, local_b, local_k, warp_col, warp_row, wn_iter,
wm_iter, tid_in_warp);
vx_wmma();
#if TC_SINGLE_WARP
}
#endif
}
}
}
#else
// Compute single tile*tile matmul
#pragma GCC unroll 4
for (uint32_t local_k = 0; local_k < BK; local_k++) {
// First, pump data from SMEM->RF
#pragma GCC unroll TM
for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) {
reg_a[res_idx_m] =
local_a[BK * (TM * local_c_row + res_idx_m) + local_k];
}
#pragma GCC unroll TN
for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) {
reg_b[res_idx_n] =
local_b[BN * local_k + (TN * local_c_col + res_idx_n)];
}
// Next, compute multiple result elements (TM*TN) by reusing data in RF
#pragma GCC unroll TM
for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) {
#pragma GCC unroll TN
for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) {
// NOTE use of local_b_row
reg_c[TN * res_idx_m + res_idx_n] +=
reg_a[res_idx_m] * reg_b[res_idx_n];
// reg_c[TN * res_idx_m + res_idx_n] +=
// local_a[BK * (TM * local_c_row + res_idx_m) + local_k] *
// local_b[BN * local_k + (TN * local_c_col + res_idx_n)];
}
}
}
#endif
threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster,
threadblock_dim_y);
}
if (warp_in_threadblock == 0) {
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
write_results(local_warp_results, tid_in_warp,
warp_col, warp_row,
wn_iter, wm_iter,
dim_m, dim_n, C, threadblock_id_x, threadblock_id_y);
#if USE_TENSOR_CORE
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
#if TC_SINGLE_WARP
if (warp_in_threadblock == 0) {
#endif
write_results(local_warp_results, tid_in_warp, warp_col, warp_row,
wn_iter, wm_iter, dim_m, dim_n, C, threadblock_id_x,
threadblock_id_y);
#if TC_SINGLE_WARP
}
#endif
}
}
#else
// Store result data from RF to GMEM
#pragma GCC unroll TM
for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) {
#pragma GCC unroll TN
for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) {
C[dim_n * (BM * threadblock_id_y + TM * local_c_row + res_idx_m) +
(BN * threadblock_id_x + TN * local_c_col + res_idx_n)] =
reg_c[TN * res_idx_m + res_idx_n];
}
}
#endif
}
void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
@@ -340,8 +411,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// "static" shared memory allocation. This would determine threadblock
// occupancy of a single cluster
// FIXME: 4* is unnecessary; being safe for overlaps
float *sharedmem_per_threadblock =
(float *)DEV_SMEM_START_ADDR + (2 * BM * BK) * threadblock_id_in_cluster;
(float *)DEV_SMEM_START_ADDR + (4 * BM * BK) * threadblock_id_in_cluster;
thread_block_gemm(arg, tid_in_threadblock, threadblock_dim_x,
threadblock_dim_y, threadblock_id_x, threadblock_id_y,
threadblock_id_in_cluster, sharedmem_per_threadblock);

View File

@@ -147,9 +147,9 @@ int main(int argc, char *argv[]) {
RT_CHECK(vx_dev_open(&device));
// FIXME: hardcoded
uint32_t dim_m = 64;
uint32_t dim_n = 64;
uint32_t dim_k = 64;
uint32_t dim_m = 32;
uint32_t dim_n = 32;
uint32_t dim_k = 32;
generate_source_matrix(dim_m, dim_n, dim_k);
generate_reference_matmul(dim_m, dim_n, dim_k);