sgemm_tcore: Blocksize 64; Fix kernel launch on larger dim

& fix addrgen assembly too large offset error
This commit is contained in:
Hansung Kim
2024-06-11 22:27:12 -07:00
parent 03d1df8f53
commit 32e31c51a4
3 changed files with 80 additions and 57 deletions

View File

@@ -7,6 +7,36 @@
#include "include/gemmini.h"
#include "gemmini_mmio.h"
#define GEMMINI_DMA 1
#if SMEM_SIZE == 0x4000
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
#define SMEM_ADDR_Q1 ((float * const) 0xff001000)
#define SMEM_ADDR_Q2 ((float * const) 0xff002000)
#define SMEM_ADDR_Q3 ((float * const) 0xff003000)
#define SPAD_ADDR_Q0 0x0
#define SPAD_ADDR_Q1 0x80
#define SPAD_ADDR_Q2 0x100
#define SPAD_ADDR_Q3 0x180
#define BOUND_INST 0x400040004ULL
#elif SMEM_SIZE == 0x10000
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
#define SMEM_ADDR_Q1 ((float * const) 0xff004000)
#define SMEM_ADDR_Q2 ((float * const) 0xff008000)
#define SMEM_ADDR_Q3 ((float * const) 0xff00c000)
#define SPAD_ADDR_Q0 0x0
#define SPAD_ADDR_Q1 0x200
#define SPAD_ADDR_Q2 0x400
#define SPAD_ADDR_Q3 0x600
#define BOUND_INST 0x800080008ULL
#else
#error Unsupported smem size
#endif
// FIXME: NUM_THREADS and NUM_WARPS hardcoded
#if ((BM * BN / ELEM_PER_THREAD) > (CORES_PER_CLUSTER * 8 * 8))
#error "threadblock size too big for cluster"
#endif
inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
const uint32_t k, const float *A, const float *B,
volatile float *local_a, volatile float *local_b,
@@ -204,14 +234,16 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
asm volatile ("fsw ft0, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
asm volatile ("fsw ft1, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
asm volatile ("fsw ft2, %0(%1)" :: "i"(BN * row_stride_b * 2 * sizeof(float)), "r"(local_b_tmp));
asm volatile ("fsw ft3, %0(%1)" :: "i"(BN * row_stride_b * 3 * sizeof(float)), "r"(local_b_tmp));
local_b_tmp += BN * row_stride_b * 4;
local_b_tmp += BN * row_stride_b * 2;
asm volatile ("fsw ft2, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
asm volatile ("fsw ft3, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
local_b_tmp += BN * row_stride_b * 2;
asm volatile ("fsw ft4, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
asm volatile ("fsw ft5, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
asm volatile ("fsw ft6, %0(%1)" :: "i"(BN * row_stride_b * 2 * sizeof(float)), "r"(local_b_tmp));
asm volatile ("fsw ft7, %0(%1)" :: "i"(BN * row_stride_b * 3 * sizeof(float)), "r"(local_b_tmp));
local_b_tmp += BN * row_stride_b * 4;
local_b_tmp += BN * row_stride_b * 2;
asm volatile ("fsw ft6, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
asm volatile ("fsw ft7, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
local_b_tmp += BN * row_stride_b * 2;
}
}
@@ -221,8 +253,7 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
const uint32_t threadblock_dim_y,
/*const uint32_t threadblock_id_x,
const uint32_t threadblock_id_y,*/
const uint32_t num_threadblocks,
const uint32_t threadblock_id,
const uint32_t threadblocks_per_cluster,
const uint32_t threadblock_id_in_cluster,
float *sharedmem_per_threadblock) {
const float *A = (const float *)arg->addr_a;
@@ -276,8 +307,8 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
#endif
// divide rows (M) by the number of threadblocks
const uint32_t dim_m_range = (dim_m / num_threadblocks);
const uint32_t dim_m_start = dim_m_range * threadblock_id;
const uint32_t dim_m_range = (dim_m / threadblocks_per_cluster);
const uint32_t dim_m_start = dim_m_range * threadblock_id_in_cluster;
const uint32_t block_m_start = dim_m_start / BM;
const uint32_t block_m_end = (dim_m_start + dim_m_range) / BM;
@@ -303,9 +334,10 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
GEMMINI_CISC_CMD_R((dim_n << 16) | (dim_k << 8) | 8);
gemmini_fence();
// GEMMINI_CISC_CMD_I(12);
// gemmini_fence();
GEMMINI_CISC_CMD_I(12);
gemmini_fence();
#if 0
// sp_tiled_matmul_full_spad_ws includes CONFIG_BOUNDS
// FIXME: block_k is 0 for two times
sp_tiled_matmul_full_spad_ws(
@@ -321,6 +353,7 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips)
gemmini_fence();
#endif
}
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
@@ -340,23 +373,22 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
// FIXME: block_k is wrong
ROCC_INSTRUCTION_RS1_RS2(
XCUSTOM_ACC,
(uint64_t)(A + block_m * BM * dim_k + block_k * BK),
(uint64_t)(B + block_k * BK * dim_n + block_n * BN),
(uint64_t)(A + block_m * BM * dim_k + (block_k + 1/*runahead*/) * BK),
(uint64_t)(B + (block_k + 1/*runahead*/) * BK * dim_n + block_n * BN),
k_LOOP_WS_CONFIG_ADDRS_AB)
// GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB
GEMMINI_CISC_CMD_R((dim_n << 16) | (dim_k << 8) | 8);
// gemmini_fence();
// TODO: this is probably slow
// if (block_k & 1) {
// GEMMINI_CISC_CMD_I(12);
// } else { // block_k == 0 is here
// GEMMINI_CISC_CMD_I(13);
// }
// TODO: branch is probably slow
if (block_k & 1) {
GEMMINI_CISC_CMD_I(12);
} else { // block_k == 0 is here
GEMMINI_CISC_CMD_I(13);
}
// configure loop iteration bounds
// FIXME: shouldn't be necessary
// #define BOUND_INST 0x400040004ULL
// ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, 0, BOUND_INST,
// k_LOOP_WS_CONFIG_BOUNDS) ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC,
// SPAD_ADDR_Q0, SPAD_ADDR_Q1, k_LOOP_WS_CONFIG_SPAD_AB)
@@ -483,12 +515,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
float *sharedmem_per_threadblock =
(float *)DEV_SMEM_START_ADDR + (2 * BM * BK) * threadblock_id_in_cluster;
const int warp_id = vx_warp_id();
thread_block_gemm(arg, tid_in_threadblock, threads_per_threadblock,
threadblock_dim_y,
/*threadblock_id_x, threadblock_id_y,*/
num_threadblocks,
threadblock_id,
threadblocks_per_cluster,
// threadblock_id,
threadblock_id_in_cluster,
sharedmem_per_threadblock);
}

View File

@@ -11,6 +11,11 @@
#undef ELEM_PER_THREAD
#define ELEM_PER_THREAD (WMITER * WNITER * ((TCM * TCN) / NUM_LANES) / (DOUBLE_BUFFER ? 2 : 1))
// FIXME: NUM_THREADS and NUM_WARPS hardcoded
#if ((BM * BN / ELEM_PER_THREAD) > (CORES_PER_CLUSTER * 8 * 8))
#error "threadblock size too big for cluster"
#endif
inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
const uint32_t k, const float *A, const float *B,
volatile float *local_a, volatile float *local_b,
@@ -85,11 +90,12 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
asm volatile ("fsw ft1, %0(%1)" :: "i"(BK * row_stride_a * 1 * sizeof(float)), "r"(local_a_tmp));
asm volatile ("fsw ft2, %0(%1)" :: "i"(BK * row_stride_a * 2 * sizeof(float)), "r"(local_a_tmp));
asm volatile ("fsw ft3, %0(%1)" :: "i"(BK * row_stride_a * 3 * sizeof(float)), "r"(local_a_tmp));
asm volatile ("fsw ft4, %0(%1)" :: "i"(BK * row_stride_a * 4 * sizeof(float)), "r"(local_a_tmp));
asm volatile ("fsw ft5, %0(%1)" :: "i"(BK * row_stride_a * 5 * sizeof(float)), "r"(local_a_tmp));
asm volatile ("fsw ft6, %0(%1)" :: "i"(BK * row_stride_a * 6 * sizeof(float)), "r"(local_a_tmp));
asm volatile ("fsw ft7, %0(%1)" :: "i"(BK * row_stride_a * 7 * sizeof(float)), "r"(local_a_tmp));
local_a_tmp += BK * row_stride_a * 8;
local_a_tmp += BK * row_stride_a * 4;
asm volatile ("fsw ft4, %0(%1)" :: "i"(BK * row_stride_a * 0 * sizeof(float)), "r"(local_a_tmp));
asm volatile ("fsw ft5, %0(%1)" :: "i"(BK * row_stride_a * 1 * sizeof(float)), "r"(local_a_tmp));
asm volatile ("fsw ft6, %0(%1)" :: "i"(BK * row_stride_a * 2 * sizeof(float)), "r"(local_a_tmp));
asm volatile ("fsw ft7, %0(%1)" :: "i"(BK * row_stride_a * 3 * sizeof(float)), "r"(local_a_tmp));
local_a_tmp += BK * row_stride_a * 4;
}
} else {
if constexpr (!GMEM_COALESCED_A) {
@@ -245,13 +251,16 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
asm volatile ("fsw ft0, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
asm volatile ("fsw ft1, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
asm volatile ("fsw ft2, %0(%1)" :: "i"(BN * row_stride_b * 2 * sizeof(float)), "r"(local_b_tmp));
asm volatile ("fsw ft3, %0(%1)" :: "i"(BN * row_stride_b * 3 * sizeof(float)), "r"(local_b_tmp));
asm volatile ("fsw ft4, %0(%1)" :: "i"(BN * row_stride_b * 4 * sizeof(float)), "r"(local_b_tmp));
asm volatile ("fsw ft5, %0(%1)" :: "i"(BN * row_stride_b * 5 * sizeof(float)), "r"(local_b_tmp));
asm volatile ("fsw ft6, %0(%1)" :: "i"(BN * row_stride_b * 6 * sizeof(float)), "r"(local_b_tmp));
asm volatile ("fsw ft7, %0(%1)" :: "i"(BN * row_stride_b * 7 * sizeof(float)), "r"(local_b_tmp));
local_b_tmp += BN * row_stride_b * 8;
local_b_tmp += BN * row_stride_b * 2;
asm volatile ("fsw ft2, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
asm volatile ("fsw ft3, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
local_b_tmp += BN * row_stride_b * 2;
asm volatile ("fsw ft4, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
asm volatile ("fsw ft5, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
local_b_tmp += BN * row_stride_b * 2;
asm volatile ("fsw ft6, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
asm volatile ("fsw ft7, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
local_b_tmp += BN * row_stride_b * 2;
}
}

View File

@@ -20,9 +20,9 @@
// (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER
// * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields
// BM <= BK*TM*TN
#define BM 32
#define BN 32
#define BK 32
#define BM 64
#define BN 64
#define BK 64
#define WM 16
#define WN 8
#define TCM 8
@@ -42,29 +42,12 @@
// For correctness, only one of either should be 1. To model the case where
// the entire A matrix is already stored transposed in GMEM ("TN" kernel), set
// both to 0.
#define TRANSPOSE_AT_PRODUCE 0
#define TRANSPOSE_AT_PRODUCE 1
#define TRANSPOSE_AT_CONSUME 0
// GMEM_COALESCED sets bank conflict-free accesses for
// 1: GMEM loads of A matrix
// 0: SMEM stores of A matrix
#define GMEM_COALESCED_A 1
#define GEMMINI_DMA 0
#if SMEM_SIZE != 0x4000
#error Currently only supports 16K spad
#endif
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
#define SMEM_ADDR_Q1 ((float * const) 0xff001000)
#define SMEM_ADDR_Q2 ((float * const) 0xff002000)
#define SMEM_ADDR_Q3 ((float * const) 0xff003000)
#define SPAD_ADDR_Q0 0x0
#define SPAD_ADDR_Q1 0x80
#define SPAD_ADDR_Q2 0x100
#define SPAD_ADDR_Q3 0x180
// FIXME: NUM_THREADS and NUM_WARPS hardcoded
#if ((BM * BN / ELEM_PER_THREAD) > (CORES_PER_CLUSTER * 8 * 8))
#error "threadblock size too big for cluster"
#endif
inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) {
const int tg = tid / 4;