Merge branch 'new-cisc' into kernels-asplos-ae
This commit is contained in:
@@ -4,6 +4,8 @@
|
|||||||
#error INCLUDE GEMMINI.H FIRST
|
#error INCLUDE GEMMINI.H FIRST
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/* shared memory constants and helpers */
|
||||||
|
/* =================================== */
|
||||||
#define SMEM_BASE 0xff000000
|
#define SMEM_BASE 0xff000000
|
||||||
// 16KB
|
// 16KB
|
||||||
// #define SMEM_SIZE 0x4000
|
// #define SMEM_SIZE 0x4000
|
||||||
@@ -17,46 +19,16 @@
|
|||||||
#define SMEM_MASK (SMEM_SIZE - 1)
|
#define SMEM_MASK (SMEM_SIZE - 1)
|
||||||
#define SMEM_ADDR_END (SMEM_BASE + SMEM_SIZE)
|
#define SMEM_ADDR_END (SMEM_BASE + SMEM_SIZE)
|
||||||
|
|
||||||
static size_t gemmini_tile_idx[NUM_THREADS * NUM_WARPS * NUM_CORES * NUM_CLUSTERS] = {0};
|
|
||||||
|
|
||||||
#define HW_TID() ({uint32_t gtid; asm volatile ("csrr %0, mhartid" : "=r" (gtid)); gtid;})
|
|
||||||
#define use_gemmini(i) {gemmini_tile_idx[HW_TID()] = (i);}
|
|
||||||
#define GEMMINI_TILE_IDX() (gemmini_tile_idx[HW_TID()])
|
|
||||||
#define GEMMINI_CTRL (SMEM_BASE + SMEM_SIZE + 0x3000 + 0x100 * GEMMINI_TILE_IDX())
|
|
||||||
#define GEMMINI_CISC_IMM(x, i) ((x) + 32 * (i))
|
|
||||||
|
|
||||||
#define SPAD_BASE 0x0
|
#define SPAD_BASE 0x0
|
||||||
#define SPAD_ROW_SIZE (DIM * sizeof(elem_t))
|
#define SPAD_ROW_SIZE (DIM * sizeof(elem_t))
|
||||||
#define SPAD_NUM_ROWS (SMEM_SIZE / SPAD_ROW_SIZE)
|
#define SPAD_NUM_ROWS (SMEM_SIZE / SPAD_ROW_SIZE)
|
||||||
#define SPAD_MASK (SPAD_NUM_ROWS - 1)
|
#define SPAD_MASK (SPAD_NUM_ROWS - 1)
|
||||||
|
|
||||||
#define PRINT_BUF ((char *) (SMEM_ADDR_END))
|
#define PRINT_BUF ((char *) (SMEM_ADDR_END))
|
||||||
#define GEMMINI_RS1_ADDR (GEMMINI_CTRL + 0x10)
|
#define HW_TID() ({uint32_t gtid; asm volatile ("csrr %0, mhartid" : "=r" (gtid)); gtid;})
|
||||||
#define GEMMINI_RS2_ADDR (GEMMINI_CTRL + 0x18)
|
|
||||||
#define GEMMINI_INST_ADDR (GEMMINI_CTRL + 0x0)
|
|
||||||
#define GEMMINI_BUSY_ADDR (GEMMINI_CTRL + 0x20)
|
|
||||||
|
|
||||||
#define SMEM_TO_SPAD(smem_addr) (SPAD_BASE + ((smem_addr) & SMEM_MASK) / SPAD_ROW_SIZE)
|
#define SMEM_TO_SPAD(smem_addr) (SPAD_BASE + ((smem_addr) & SMEM_MASK) / SPAD_ROW_SIZE)
|
||||||
#define SPAD_TO_SMEM(spad_addr) (SMEM_BASE + ((spad_addr) & SPAD_MASK) * SPAD_ROW_SIZE)
|
#define SPAD_TO_SMEM(spad_addr) (SMEM_BASE + ((spad_addr) & SPAD_MASK) * SPAD_ROW_SIZE)
|
||||||
|
|
||||||
|
|
||||||
// CISC instructions:
|
|
||||||
// 0, 1, 2: tile-sized matmuls
|
|
||||||
// 0: k = 0, no accumulation
|
|
||||||
// 1: k % 2 = 0, buffer regions 0
|
|
||||||
// 2: k % 2 = 1, buffer regions 1
|
|
||||||
// 8, 9, 10, 11: memory ops
|
|
||||||
// 8: tile-sized move-in stride
|
|
||||||
// 9: tile-sized move-out
|
|
||||||
// 10: tile-sized move-in, buffer regions 0
|
|
||||||
// 11: tile-sized move-in, buffer regions 1
|
|
||||||
// bits [4:0] is the opcode
|
|
||||||
// bits [7:5] is the target gemmini id, zero-indexed
|
|
||||||
// #define GEMMINI_CISC_CMD_I(x) asm("csrwi 0xacc, %0" :: "i" (x))
|
|
||||||
#define GEMMINI_CISC_CMD_I(x) asm("csrw 0xacc, %0" :: "r" (x))
|
|
||||||
#define GEMMINI_CISC_CMD_R(x) asm("csrw 0xacc, %0" :: "r" (x))
|
|
||||||
#define GEMMINI_STATUS() ({uint32_t status; asm volatile ("csrr %0, 0xacc" : "=r" (status)); status;})
|
|
||||||
|
|
||||||
// convert normal matrix i,j into tiled smem offset
|
// convert normal matrix i,j into tiled smem offset
|
||||||
// top_in_tiles = i / DIM
|
// top_in_tiles = i / DIM
|
||||||
// left_in_tiles = j / DIM
|
// left_in_tiles = j / DIM
|
||||||
@@ -65,11 +37,18 @@ static size_t gemmini_tile_idx[NUM_THREADS * NUM_WARPS * NUM_CORES * NUM_CLUSTER
|
|||||||
#define SMEM_MAT_OFFSET(i, j, J) \
|
#define SMEM_MAT_OFFSET(i, j, J) \
|
||||||
(((i) / DIM * (J) / DIM + (j) / DIM) * DIM * DIM + ((i) % DIM) * DIM + ((j) % DIM))
|
(((i) / DIM * (J) / DIM + (j) / DIM) * DIM * DIM + ((i) % DIM) * DIM + ((j) % DIM))
|
||||||
|
|
||||||
// #define fence() { for (int i = 0; i < 10; i++) *((volatile uint32_t *) (0xFFFF0000)) = 0xdeadbeef; }
|
/* gemmini mmio interface */
|
||||||
#undef gemmini_fence
|
/* ====================== */
|
||||||
//#define gemmini_fence() { while (GEMMINI_STATUS()); }
|
static size_t gemmini_tile_idx[NUM_THREADS * NUM_WARPS * NUM_CORES * NUM_CLUSTERS] = {0};
|
||||||
#define gemmini_fence() { while (*((volatile uint32_t *) GEMMINI_BUSY_ADDR)) asm volatile ("nop"); }
|
#define use_gemmini(i) {gemmini_tile_idx[HW_TID()] = (i);}
|
||||||
|
#define GEMMINI_TILE_IDX() (gemmini_tile_idx[HW_TID()])
|
||||||
|
#define GEMMINI_CISC_IMM(x, i) ((x) + 32 * (i))
|
||||||
|
#define GEMMINI_CTRL (SMEM_BASE + SMEM_SIZE + 0x3000 + 0x100 * GEMMINI_TILE_IDX())
|
||||||
|
#define GEMMINI_RS1_ADDR (GEMMINI_CTRL + 0x10)
|
||||||
|
#define GEMMINI_RS2_ADDR (GEMMINI_CTRL + 0x18)
|
||||||
|
#define GEMMINI_INST_ADDR (GEMMINI_CTRL + 0x0)
|
||||||
|
#define GEMMINI_BUSY_ADDR (GEMMINI_CTRL + 0x20)
|
||||||
|
#define GEMMINI_OCCUPANCY_ADDR (GEMMINI_CTRL + 0x28)
|
||||||
#undef ROCC_INSTRUCTION_RS1_RS2
|
#undef ROCC_INSTRUCTION_RS1_RS2
|
||||||
#define ROCC_INSTRUCTION_RS1_RS2(x, rs1, rs2, funct) { \
|
#define ROCC_INSTRUCTION_RS1_RS2(x, rs1, rs2, funct) { \
|
||||||
*((volatile uint64_t *) GEMMINI_RS1_ADDR) = (rs1); \
|
*((volatile uint64_t *) GEMMINI_RS1_ADDR) = (rs1); \
|
||||||
@@ -77,6 +56,8 @@ static size_t gemmini_tile_idx[NUM_THREADS * NUM_WARPS * NUM_CORES * NUM_CLUSTER
|
|||||||
*((volatile uint32_t*) GEMMINI_INST_ADDR) = (0x7B) | (0 << 7) | (3 << 12) | (1 << 15) | (2 << 20) | ((funct) << 25); \
|
*((volatile uint32_t*) GEMMINI_INST_ADDR) = (0x7B) | (0 << 7) | (3 << 12) | (1 << 15) | (2 << 20) | ((funct) << 25); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* additional intrinsics */
|
||||||
|
/* ===================== */
|
||||||
#define loop_matmul_skips(skip_lda, skip_ldb, skip_ldd, skip_ex, skip_stc) \
|
#define loop_matmul_skips(skip_lda, skip_ldb, skip_ldd, skip_ex, skip_stc) \
|
||||||
(((skip_lda) | ((skip_ldb) << 1) | ((skip_ldd) << 2) | ((skip_ex) << 3) | ((skip_stc) << 4)) << 3)
|
(((skip_lda) | ((skip_ldb) << 1) | ((skip_ldd) << 2) | ((skip_ex) << 3) | ((skip_stc) << 4)) << 3)
|
||||||
|
|
||||||
@@ -85,16 +66,34 @@ static size_t gemmini_tile_idx[NUM_THREADS * NUM_WARPS * NUM_CORES * NUM_CLUSTER
|
|||||||
gemmini_loop_ws_spad(I, J, K, pad_I, pad_J, pad_K, A_sp_addr_start, (B_sp_addr_start) + (K) * (J) * DIM, NULL, \
|
gemmini_loop_ws_spad(I, J, K, pad_I, pad_J, pad_K, A_sp_addr_start, (B_sp_addr_start) + (K) * (J) * DIM, NULL, \
|
||||||
C_dst_sp_addr_start, a_transpose, b_transpose, full_C, low_D, acc, act, 0, 0, false, skips)
|
C_dst_sp_addr_start, a_transpose, b_transpose, full_C, low_D, acc, act, 0, 0, false, skips)
|
||||||
|
|
||||||
|
#define gemmini_status() ({uint32_t status; asm volatile ("csrr %0, 0xacc" : "=r" (status)); status;})
|
||||||
|
|
||||||
|
#undef gemmini_fence
|
||||||
|
//#define gemmini_fence() { while (gemmini_status()); }
|
||||||
|
#define gemmini_fence() { while (*((volatile uint32_t *) GEMMINI_BUSY_ADDR)) asm volatile ("nop"); }
|
||||||
|
|
||||||
|
#define virgo_fence(n) { while (*((volatile uint32_t *) GEMMINI_OCCUPANCY_ADDR) > n) asm volatile ("nop"); }
|
||||||
|
|
||||||
|
/* cisc instructions */
|
||||||
|
/* ================= */
|
||||||
|
|
||||||
|
// bits [4:0] is the opcode
|
||||||
|
// bits [7:5] is the target gemmini id, zero-indexed
|
||||||
|
// #define GEMMINI_CISC_CMD_I(x) asm("csrwi 0xacc, %0" :: "i" (x))
|
||||||
|
#define GEMMINI_CISC_CMD_I(x) asm("csrw 0xacc, %0" :: "r" (x)) // use registers even for immediate calls for now
|
||||||
|
#define GEMMINI_CISC_CMD_R(x) asm("csrw 0xacc, %0" :: "r" (x))
|
||||||
|
|
||||||
#define GEMMINI_CISC_COMPUTE_HEXADECILES 0
|
#define GEMMINI_CISC_COMPUTE_HEXADECILES 0
|
||||||
#define GEMMINI_CISC_COMPUTE_AND_STORE_TO_SPAD 1
|
#define GEMMINI_CISC_COMPUTE_AND_STORE_TO_SPAD 1
|
||||||
|
#define GEMMINI_CISC_MANUAL 2
|
||||||
#define GEMMINI_CISC_SET_AB_STRIDE 8
|
#define GEMMINI_CISC_SET_AB_STRIDE 8
|
||||||
#define GEMMINI_CISC_STORE_TO_SPAD 9
|
#define GEMMINI_CISC_STORE_TO_SPAD 9
|
||||||
#define GEMMINI_CISC_LOAD_TO_HEXADECILES 10
|
#define GEMMINI_CISC_LOAD_TO_HEXADECILES 10
|
||||||
#define GEMMINI_CISC_SET_DC_STRIDE 11
|
#define GEMMINI_CISC_SET_DC_STRIDE 11
|
||||||
#define GEMMINI_CISC_STORE_TO_GMEM 12
|
#define GEMMINI_CISC_STORE_TO_GMEM 12
|
||||||
|
|
||||||
// cisc instruction wrappers
|
/* high level virgo routines */
|
||||||
|
/* ========================= */
|
||||||
inline void gemmini_tile_load_ab(const elem_t * const a_addr, const elem_t * const b_addr,
|
inline void gemmini_tile_load_ab(const elem_t * const a_addr, const elem_t * const b_addr,
|
||||||
const uint32_t a_hexadecile, const uint32_t b_hexadecile,
|
const uint32_t a_hexadecile, const uint32_t b_hexadecile,
|
||||||
const uint32_t tile_idx_i, const uint32_t tile_idx_j, const uint32_t tile_idx_k,
|
const uint32_t tile_idx_i, const uint32_t tile_idx_j, const uint32_t tile_idx_k,
|
||||||
@@ -142,6 +141,11 @@ inline void gemmini_tile_store_c_gmem(elem_t * const c_addr,
|
|||||||
inline void gemmini_tile_store_c_spad(const uint32_t c_hexadecile) {
|
inline void gemmini_tile_store_c_spad(const uint32_t c_hexadecile) {
|
||||||
GEMMINI_CISC_CMD_R(((uint32_t) (c_hexadecile << 8)) | GEMMINI_CISC_STORE_TO_SPAD);
|
GEMMINI_CISC_CMD_R(((uint32_t) (c_hexadecile << 8)) | GEMMINI_CISC_STORE_TO_SPAD);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void gemmini_manual_job() {
|
||||||
|
GEMMINI_CISC_CMD_I(GEMMINI_CISC_MANUAL);
|
||||||
|
}
|
||||||
|
|
||||||
/* inline static void sp_tiled_matmul_full_spad_ws(const uint32_t A_sp_addr_start, const uint32_t B_sp_addr_start,
|
/* inline static void sp_tiled_matmul_full_spad_ws(const uint32_t A_sp_addr_start, const uint32_t B_sp_addr_start,
|
||||||
const uint32_t D_sp_addr_start, const uint32_t C_dst_sp_addr_start,
|
const uint32_t D_sp_addr_start, const uint32_t C_dst_sp_addr_start,
|
||||||
size_t I, size_t J, size_t K, size_t pad_I, size_t pad_J, size_t pad_K,
|
size_t I, size_t J, size_t K, size_t pad_I, size_t pad_J, size_t pad_K,
|
||||||
|
|||||||
BIN
tests/regression/sgemm_gemmini_dma/args/1024
Normal file
BIN
tests/regression/sgemm_gemmini_dma/args/1024
Normal file
Binary file not shown.
BIN
tests/regression/sgemm_gemmini_dma/args/128
Normal file
BIN
tests/regression/sgemm_gemmini_dma/args/128
Normal file
Binary file not shown.
BIN
tests/regression/sgemm_gemmini_dma/args/256
Normal file
BIN
tests/regression/sgemm_gemmini_dma/args/256
Normal file
Binary file not shown.
BIN
tests/regression/sgemm_gemmini_dma/args/512
Normal file
BIN
tests/regression/sgemm_gemmini_dma/args/512
Normal file
Binary file not shown.
11
tests/regression/sgemm_gemmini_dma/compile_ampere.sh
Executable file
11
tests/regression/sgemm_gemmini_dma/compile_ampere.sh
Executable file
@@ -0,0 +1,11 @@
|
|||||||
|
rm kernel.radiance.elf
|
||||||
|
rm -rf binaries
|
||||||
|
mkdir binaries
|
||||||
|
for a in args/*; do
|
||||||
|
cp -f $a args.bin
|
||||||
|
aa=$(basename "$a")
|
||||||
|
cp -f input.a/"$aa" input.a.bin
|
||||||
|
cp -f input.b/"$aa" input.b.bin
|
||||||
|
make > /dev/null
|
||||||
|
mv kernel.radiance.elf binaries/gemmini_fp16dma"$aa".elf
|
||||||
|
done
|
||||||
11
tests/regression/sgemm_gemmini_dma/compile_hopper.sh
Executable file
11
tests/regression/sgemm_gemmini_dma/compile_hopper.sh
Executable file
@@ -0,0 +1,11 @@
|
|||||||
|
rm kernel.radiance.elf
|
||||||
|
rm -rf binaries
|
||||||
|
mkdir binaries
|
||||||
|
for a in args/*; do
|
||||||
|
cp -f $a args.bin
|
||||||
|
aa=$(basename "$a")
|
||||||
|
cp -f input.a/"$aa" input.a.bin
|
||||||
|
cp -f input.b/"$aa" input.b.bin
|
||||||
|
make > /dev/null
|
||||||
|
mv kernel.radiance.elf binaries/gemmini_hopper_dma"$aa".elf
|
||||||
|
done
|
||||||
@@ -15,20 +15,20 @@ def truncated_matrix_multiplication(matrix_a, matrix_b, size):
|
|||||||
result = np.matmul(truncated_a, truncated_b)
|
result = np.matmul(truncated_a, truncated_b)
|
||||||
return result.astype(np.float16)
|
return result.astype(np.float16)
|
||||||
|
|
||||||
# Generate the 512x512 matrices
|
|
||||||
size = 512
|
|
||||||
matrix_a = generate_fp16_matrix(size)
|
|
||||||
matrix_b = generate_fp16_matrix(size)
|
|
||||||
|
|
||||||
# Save the operand matrices to binary files
|
|
||||||
save_matrix_to_bin("input.a.bin", matrix_a)
|
|
||||||
save_matrix_to_bin("input.b.bin", matrix_b)
|
|
||||||
|
|
||||||
# Generate and save the reference matrices for 128x128, 256x256, and 512x512 sizes
|
# Generate and save the reference matrices for 128x128, 256x256, and 512x512 sizes
|
||||||
sizes = [128, 256, 512]
|
sizes = [128, 256, 512, 1024]
|
||||||
for s in sizes:
|
for s in sizes:
|
||||||
|
np.random.seed(0)
|
||||||
|
matrix_a = generate_fp16_matrix(s)
|
||||||
|
matrix_b = generate_fp16_matrix(s)
|
||||||
|
|
||||||
|
# Save the operand matrices to binary files
|
||||||
|
save_matrix_to_bin("input.a.bin", matrix_a)
|
||||||
|
save_matrix_to_bin(f"input.a/{s}", matrix_a)
|
||||||
|
save_matrix_to_bin("input.b.bin", matrix_b)
|
||||||
|
save_matrix_to_bin(f"input.b/{s}", matrix_b)
|
||||||
|
|
||||||
ref_matrix = truncated_matrix_multiplication(matrix_a, matrix_b, s)
|
ref_matrix = truncated_matrix_multiplication(matrix_a, matrix_b, s)
|
||||||
print(ref_matrix)
|
|
||||||
save_matrix_to_bin(f"ref{s}.bin", ref_matrix)
|
save_matrix_to_bin(f"ref{s}.bin", ref_matrix)
|
||||||
|
|
||||||
print("All files generated successfully.")
|
print("All files generated successfully.")
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
for (uint32_t tile_j = 0; tile_j < num_tiles_n; tile_j += 1) {
|
for (uint32_t tile_j = 0; tile_j < num_tiles_n; tile_j += 1) {
|
||||||
for (uint32_t tile_k = 0; tile_k < num_tiles_k; tile_k += 1) {
|
for (uint32_t tile_k = 0; tile_k < num_tiles_k; tile_k += 1) {
|
||||||
uint32_t a_hexadecile = (tile_k & 1) << 2;
|
uint32_t a_hexadecile = (tile_k & 1) << 2;
|
||||||
uint32_t b_hexadecile = a_hexadecile + 8;
|
uint32_t b_hexadecile = a_hexadecile + 11;
|
||||||
gemmini_tile_load_ab(A, B,
|
gemmini_tile_load_ab(A, B,
|
||||||
a_hexadecile, b_hexadecile, tile_i, tile_j, tile_k,
|
a_hexadecile, b_hexadecile, tile_i, tile_j, tile_k,
|
||||||
dim_m, dim_n, dim_k, TILE_M, TILE_N, TILE_K);
|
dim_m, dim_n, dim_k, TILE_M, TILE_N, TILE_K);
|
||||||
|
|||||||
Reference in New Issue
Block a user