Merge branch 'new-cisc' into kernels-asplos-ae

This commit is contained in:
Hansung Kim
2025-01-28 21:18:12 -08:00
9 changed files with 75 additions and 49 deletions

View File

@@ -4,6 +4,8 @@
#error INCLUDE GEMMINI.H FIRST
#endif
/* shared memory constants and helpers */
/* =================================== */
#define SMEM_BASE 0xff000000
// 16KB
// #define SMEM_SIZE 0x4000
@@ -17,46 +19,16 @@
#define SMEM_MASK (SMEM_SIZE - 1)
#define SMEM_ADDR_END (SMEM_BASE + SMEM_SIZE)
static size_t gemmini_tile_idx[NUM_THREADS * NUM_WARPS * NUM_CORES * NUM_CLUSTERS] = {0};
#define HW_TID() ({uint32_t gtid; asm volatile ("csrr %0, mhartid" : "=r" (gtid)); gtid;})
#define use_gemmini(i) {gemmini_tile_idx[HW_TID()] = (i);}
#define GEMMINI_TILE_IDX() (gemmini_tile_idx[HW_TID()])
#define GEMMINI_CTRL (SMEM_BASE + SMEM_SIZE + 0x3000 + 0x100 * GEMMINI_TILE_IDX())
#define GEMMINI_CISC_IMM(x, i) ((x) + 32 * (i))
#define SPAD_BASE 0x0
#define SPAD_ROW_SIZE (DIM * sizeof(elem_t))
#define SPAD_NUM_ROWS (SMEM_SIZE / SPAD_ROW_SIZE)
#define SPAD_MASK (SPAD_NUM_ROWS - 1)
#define PRINT_BUF ((char *) (SMEM_ADDR_END))
#define GEMMINI_RS1_ADDR (GEMMINI_CTRL + 0x10)
#define GEMMINI_RS2_ADDR (GEMMINI_CTRL + 0x18)
#define GEMMINI_INST_ADDR (GEMMINI_CTRL + 0x0)
#define GEMMINI_BUSY_ADDR (GEMMINI_CTRL + 0x20)
#define HW_TID() ({uint32_t gtid; asm volatile ("csrr %0, mhartid" : "=r" (gtid)); gtid;})
#define SMEM_TO_SPAD(smem_addr) (SPAD_BASE + ((smem_addr) & SMEM_MASK) / SPAD_ROW_SIZE)
#define SPAD_TO_SMEM(spad_addr) (SMEM_BASE + ((spad_addr) & SPAD_MASK) * SPAD_ROW_SIZE)
// CISC instructions:
// 0, 1, 2: tile-sized matmuls
// 0: k = 0, no accumulation
// 1: k % 2 = 0, buffer regions 0
// 2: k % 2 = 1, buffer regions 1
// 8, 9, 10, 11: memory ops
// 8: tile-sized move-in stride
// 9: tile-sized move-out
// 10: tile-sized move-in, buffer regions 0
// 11: tile-sized move-in, buffer regions 1
// bits [4:0] is the opcode
// bits [7:5] is the target gemmini id, zero-indexed
// #define GEMMINI_CISC_CMD_I(x) asm("csrwi 0xacc, %0" :: "i" (x))
#define GEMMINI_CISC_CMD_I(x) asm("csrw 0xacc, %0" :: "r" (x))
#define GEMMINI_CISC_CMD_R(x) asm("csrw 0xacc, %0" :: "r" (x))
#define GEMMINI_STATUS() ({uint32_t status; asm volatile ("csrr %0, 0xacc" : "=r" (status)); status;})
// convert normal matrix i,j into tiled smem offset
// top_in_tiles = i / DIM
// left_in_tiles = j / DIM
@@ -65,11 +37,18 @@ static size_t gemmini_tile_idx[NUM_THREADS * NUM_WARPS * NUM_CORES * NUM_CLUSTER
#define SMEM_MAT_OFFSET(i, j, J) \
(((i) / DIM * (J) / DIM + (j) / DIM) * DIM * DIM + ((i) % DIM) * DIM + ((j) % DIM))
// #define fence() { for (int i = 0; i < 10; i++) *((volatile uint32_t *) (0xFFFF0000)) = 0xdeadbeef; }
#undef gemmini_fence
//#define gemmini_fence() { while (GEMMINI_STATUS()); }
#define gemmini_fence() { while (*((volatile uint32_t *) GEMMINI_BUSY_ADDR)) asm volatile ("nop"); }
/* gemmini mmio interface */
/* ====================== */
static size_t gemmini_tile_idx[NUM_THREADS * NUM_WARPS * NUM_CORES * NUM_CLUSTERS] = {0};
#define use_gemmini(i) {gemmini_tile_idx[HW_TID()] = (i);}
#define GEMMINI_TILE_IDX() (gemmini_tile_idx[HW_TID()])
#define GEMMINI_CISC_IMM(x, i) ((x) + 32 * (i))
#define GEMMINI_CTRL (SMEM_BASE + SMEM_SIZE + 0x3000 + 0x100 * GEMMINI_TILE_IDX())
#define GEMMINI_RS1_ADDR (GEMMINI_CTRL + 0x10)
#define GEMMINI_RS2_ADDR (GEMMINI_CTRL + 0x18)
#define GEMMINI_INST_ADDR (GEMMINI_CTRL + 0x0)
#define GEMMINI_BUSY_ADDR (GEMMINI_CTRL + 0x20)
#define GEMMINI_OCCUPANCY_ADDR (GEMMINI_CTRL + 0x28)
#undef ROCC_INSTRUCTION_RS1_RS2
#define ROCC_INSTRUCTION_RS1_RS2(x, rs1, rs2, funct) { \
*((volatile uint64_t *) GEMMINI_RS1_ADDR) = (rs1); \
@@ -77,6 +56,8 @@ static size_t gemmini_tile_idx[NUM_THREADS * NUM_WARPS * NUM_CORES * NUM_CLUSTER
*((volatile uint32_t*) GEMMINI_INST_ADDR) = (0x7B) | (0 << 7) | (3 << 12) | (1 << 15) | (2 << 20) | ((funct) << 25); \
}
/* additional intrinsics */
/* ===================== */
#define loop_matmul_skips(skip_lda, skip_ldb, skip_ldd, skip_ex, skip_stc) \
(((skip_lda) | ((skip_ldb) << 1) | ((skip_ldd) << 2) | ((skip_ex) << 3) | ((skip_stc) << 4)) << 3)
@@ -85,16 +66,34 @@ static size_t gemmini_tile_idx[NUM_THREADS * NUM_WARPS * NUM_CORES * NUM_CLUSTER
gemmini_loop_ws_spad(I, J, K, pad_I, pad_J, pad_K, A_sp_addr_start, (B_sp_addr_start) + (K) * (J) * DIM, NULL, \
C_dst_sp_addr_start, a_transpose, b_transpose, full_C, low_D, acc, act, 0, 0, false, skips)
#define gemmini_status() ({uint32_t status; asm volatile ("csrr %0, 0xacc" : "=r" (status)); status;})
#undef gemmini_fence
//#define gemmini_fence() { while (gemmini_status()); }
#define gemmini_fence() { while (*((volatile uint32_t *) GEMMINI_BUSY_ADDR)) asm volatile ("nop"); }
#define virgo_fence(n) { while (*((volatile uint32_t *) GEMMINI_OCCUPANCY_ADDR) > n) asm volatile ("nop"); }
/* cisc instructions */
/* ================= */
// bits [4:0] is the opcode
// bits [7:5] is the target gemmini id, zero-indexed
// #define GEMMINI_CISC_CMD_I(x) asm("csrwi 0xacc, %0" :: "i" (x))
#define GEMMINI_CISC_CMD_I(x) asm("csrw 0xacc, %0" :: "r" (x)) // use registers even for immediate calls for now
#define GEMMINI_CISC_CMD_R(x) asm("csrw 0xacc, %0" :: "r" (x))
#define GEMMINI_CISC_COMPUTE_HEXADECILES 0
#define GEMMINI_CISC_COMPUTE_AND_STORE_TO_SPAD 1
#define GEMMINI_CISC_MANUAL 2
#define GEMMINI_CISC_SET_AB_STRIDE 8
#define GEMMINI_CISC_STORE_TO_SPAD 9
#define GEMMINI_CISC_LOAD_TO_HEXADECILES 10
#define GEMMINI_CISC_SET_DC_STRIDE 11
#define GEMMINI_CISC_STORE_TO_GMEM 12
// cisc instruction wrappers
/* high level virgo routines */
/* ========================= */
inline void gemmini_tile_load_ab(const elem_t * const a_addr, const elem_t * const b_addr,
const uint32_t a_hexadecile, const uint32_t b_hexadecile,
const uint32_t tile_idx_i, const uint32_t tile_idx_j, const uint32_t tile_idx_k,
@@ -142,6 +141,11 @@ inline void gemmini_tile_store_c_gmem(elem_t * const c_addr,
inline void gemmini_tile_store_c_spad(const uint32_t c_hexadecile) {
GEMMINI_CISC_CMD_R(((uint32_t) (c_hexadecile << 8)) | GEMMINI_CISC_STORE_TO_SPAD);
}
inline void gemmini_manual_job() {
GEMMINI_CISC_CMD_I(GEMMINI_CISC_MANUAL);
}
/* inline static void sp_tiled_matmul_full_spad_ws(const uint32_t A_sp_addr_start, const uint32_t B_sp_addr_start,
const uint32_t D_sp_addr_start, const uint32_t C_dst_sp_addr_start,
size_t I, size_t J, size_t K, size_t pad_I, size_t pad_J, size_t pad_K,

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,11 @@
rm kernel.radiance.elf
rm -rf binaries
mkdir binaries
for a in args/*; do
cp -f $a args.bin
aa=$(basename "$a")
cp -f input.a/"$aa" input.a.bin
cp -f input.b/"$aa" input.b.bin
make > /dev/null
mv kernel.radiance.elf binaries/gemmini_fp16dma"$aa".elf
done

View File

@@ -0,0 +1,11 @@
rm kernel.radiance.elf
rm -rf binaries
mkdir binaries
for a in args/*; do
cp -f $a args.bin
aa=$(basename "$a")
cp -f input.a/"$aa" input.a.bin
cp -f input.b/"$aa" input.b.bin
make > /dev/null
mv kernel.radiance.elf binaries/gemmini_hopper_dma"$aa".elf
done

View File

@@ -15,20 +15,20 @@ def truncated_matrix_multiplication(matrix_a, matrix_b, size):
result = np.matmul(truncated_a, truncated_b)
return result.astype(np.float16)
# Generate the 512x512 matrices
size = 512
matrix_a = generate_fp16_matrix(size)
matrix_b = generate_fp16_matrix(size)
# Save the operand matrices to binary files
save_matrix_to_bin("input.a.bin", matrix_a)
save_matrix_to_bin("input.b.bin", matrix_b)
# Generate and save the reference matrices for 128x128, 256x256, and 512x512 sizes
sizes = [128, 256, 512]
sizes = [128, 256, 512, 1024]
for s in sizes:
np.random.seed(0)
matrix_a = generate_fp16_matrix(s)
matrix_b = generate_fp16_matrix(s)
# Save the operand matrices to binary files
save_matrix_to_bin("input.a.bin", matrix_a)
save_matrix_to_bin(f"input.a/{s}", matrix_a)
save_matrix_to_bin("input.b.bin", matrix_b)
save_matrix_to_bin(f"input.b/{s}", matrix_b)
ref_matrix = truncated_matrix_multiplication(matrix_a, matrix_b, s)
print(ref_matrix)
save_matrix_to_bin(f"ref{s}.bin", ref_matrix)
print("All files generated successfully.")

View File

@@ -107,7 +107,7 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
for (uint32_t tile_j = 0; tile_j < num_tiles_n; tile_j += 1) {
for (uint32_t tile_k = 0; tile_k < num_tiles_k; tile_k += 1) {
uint32_t a_hexadecile = (tile_k & 1) << 2;
uint32_t b_hexadecile = a_hexadecile + 8;
uint32_t b_hexadecile = a_hexadecile + 11;
gemmini_tile_load_ab(A, B,
a_hexadecile, b_hexadecile, tile_i, tile_j, tile_k,
dim_m, dim_n, dim_k, TILE_M, TILE_N, TILE_K);