make non dma gemmini use 64x64 tile size

This commit is contained in:
Richard Yan
2024-06-19 17:45:01 -07:00
parent 095ccfd79a
commit c06cc40e59
2 changed files with 79 additions and 33 deletions

View File

@@ -6,25 +6,25 @@
#include "include/gemmini.h" #include "include/gemmini.h"
#include "gemmini_mmio.h" #include "gemmini_mmio.h"
#define TILE_M 32 #define TILE_M 64
#define TILE_N 32 #define TILE_N 64
#define TILE_K 32 #define TILE_K 64
#define TILE_MN 1024 #define TILE_MN 4096
#define TILE_MK 1024 #define TILE_MK 4096
#define TILE_NK 1024 #define TILE_NK 4096
#define NUM_CLUSTERS 1 #define NUM_CLUSTERS 1
#define NUM_THREADS_IN_CLUSTER 128 #define NUM_THREADS_IN_CLUSTER 256
#define SMEM_ADDR_Q0 ((float * const) 0xff000000) #define SMEM_ADDR_Q0 ((float * const) 0xff000000)
#define SMEM_ADDR_Q1 ((float * const) 0xff001000) #define SMEM_ADDR_Q1 ((float * const) 0xff004000)
#define SMEM_ADDR_Q2 ((float * const) 0xff002000) #define SMEM_ADDR_Q2 ((float * const) 0xff008000)
#define SMEM_ADDR_Q3 ((float * const) 0xff003000) #define SMEM_ADDR_Q3 ((float * const) 0xff00c000)
#define SPAD_ADDR_Q0 0x0 #define SPAD_ADDR_Q0 0x0
#define SPAD_ADDR_Q1 0x80 #define SPAD_ADDR_Q1 0x200
#define SPAD_ADDR_Q2 0x100 #define SPAD_ADDR_Q2 0x400
#define SPAD_ADDR_Q3 0x180 #define SPAD_ADDR_Q3 0x600
#define SPAD_ADDR_Q4 0x200 #define SPAD_ADDR_Q4 0x800
#define HARDCODE #define HARDCODE
#define REGBLOCK #define REGBLOCK
@@ -61,7 +61,6 @@ inline void threadblock_barrier(unsigned int barrier_id, unsigned int count) {
void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg, void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
const uint32_t threadblock_id, const uint32_t threadblock_id,
const uint32_t tid_in_threadblock) { const uint32_t tid_in_threadblock) {
__asm__("matmul_start:");
const float * const A = (const float * const) arg->addr_a; const float * const A = (const float * const) arg->addr_a;
const float * const B = (const float * const) arg->addr_b; const float * const B = (const float * const) arg->addr_b;
float * const C = (float * const) arg->addr_c; float * const C = (float * const) arg->addr_c;
@@ -71,7 +70,9 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
// gemmini_extended_config_ex(dataflow, act & 3, 0, 1, a_transpose, b_transpose); // gemmini_extended_config_ex(dataflow, act & 3, 0, 1, a_transpose, b_transpose);
// gemmini_extended_config_st(stride_C * sizeof_C, act & 3, scale); // gemmini_extended_config_st(stride_C * sizeof_C, act & 3, scale);
#ifndef POWER
PRINTF("start\n"); PRINTF("start\n");
#endif
} }
vx_fence(); vx_fence();
@@ -121,9 +122,7 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
for (uint32_t tile_i = num_tile_rows_per_tb * threadblock_id; for (uint32_t tile_i = num_tile_rows_per_tb * threadblock_id;
tile_i < num_tile_rows_per_tb * (threadblock_id + 1); tile_i < num_tile_rows_per_tb * (threadblock_id + 1);
tile_i += 1) { tile_i += 1) {
__asm__("i_loop:");
for (int tile_j = 0; tile_j < num_tiles_n; tile_j += 1) { for (int tile_j = 0; tile_j < num_tiles_n; tile_j += 1) {
__asm__("j_loop:");
float * const smem_c_tile_start = SMEM_ADDR_Q1; float * const smem_c_tile_start = SMEM_ADDR_Q1;
#ifdef OFFLOAD_ACCUMULATE #ifdef OFFLOAD_ACCUMULATE
float * const smem_acc_tile_start = SMEM_ADDR_Q0 + HW_TID(); float * const smem_acc_tile_start = SMEM_ADDR_Q0 + HW_TID();
@@ -131,15 +130,14 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
float * const smem_acc_tile_start = SMEM_ADDR_Q2 + hw_tid; float * const smem_acc_tile_start = SMEM_ADDR_Q2 + hw_tid;
#endif #endif
__asm__("k_loop:");
for (int tile_k = 0; tile_k < num_tiles_k; tile_k += 1) { for (int tile_k = 0; tile_k < num_tiles_k; tile_k += 1) {
// TODO: double buffer // TODO: double buffer
rd_cycles(marker1); rd_cycles(marker1);
#ifdef HARDCODE #ifdef HARDCODE
#if (TILE_MK / NUM_THREADS / NUM_WARPS / CORES_PER_CLUSTER) != 8 // #if (TILE_MK / NUM_THREADS / NUM_WARPS / CORES_PER_CLUSTER) != 8
#error CANNOT UNROLL // #error CANNOT UNROLL
#endif // #endif
constexpr uint32_t every_iter = j1_stride; constexpr uint32_t every_iter = j1_stride;
const uint32_t every_2iters_a = i1_stride * dim_k; const uint32_t every_2iters_a = i1_stride * dim_k;
@@ -228,6 +226,42 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
smem_b_tile_start[5 * num_threads_in_cluster] = v1; smem_b_tile_start[5 * num_threads_in_cluster] = v1;
smem_b_tile_start[6 * num_threads_in_cluster] = v2; smem_b_tile_start[6 * num_threads_in_cluster] = v2;
smem_b_tile_start[7 * num_threads_in_cluster] = v3; smem_b_tile_start[7 * num_threads_in_cluster] = v3;
v0 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 4];
v1 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 4];
v2 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 5];
v3 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 5];
smem_a_tile_start[8 * num_threads_in_cluster] = v0;
smem_a_tile_start[9 * num_threads_in_cluster] = v1;
smem_a_tile_start[10 * num_threads_in_cluster] = v2;
smem_a_tile_start[11 * num_threads_in_cluster] = v3;
v0 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 4];
v1 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 4];
v2 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 5];
v3 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 5];
smem_b_tile_start[8 * num_threads_in_cluster] = v0;
smem_b_tile_start[9 * num_threads_in_cluster] = v1;
smem_b_tile_start[10 * num_threads_in_cluster] = v2;
smem_b_tile_start[11 * num_threads_in_cluster] = v3;
v0 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 6];
v1 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 6];
v2 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 7];
v3 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 7];
smem_a_tile_start[12 * num_threads_in_cluster] = v0;
smem_a_tile_start[13 * num_threads_in_cluster] = v1;
smem_a_tile_start[14 * num_threads_in_cluster] = v2;
smem_a_tile_start[15 * num_threads_in_cluster] = v3;
v0 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 6];
v1 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 6];
v2 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 7];
v3 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 7];
smem_b_tile_start[12 * num_threads_in_cluster] = v0;
smem_b_tile_start[13 * num_threads_in_cluster] = v1;
smem_b_tile_start[14 * num_threads_in_cluster] = v2;
smem_b_tile_start[15 * num_threads_in_cluster] = v3;
#endif #endif
} }
#else #else
@@ -398,34 +432,29 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
#ifdef OFFLOAD_ACCUMULATE #ifdef OFFLOAD_ACCUMULATE
threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS); threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS);
rd_cycles(marker6); rd_cycles(marker6);
__asm__("mvout_spad_ser:");
// mvout to scratchpad for activation // mvout to scratchpad for activation
if (HW_TID() == 0) { if (HW_TID() == 0) {
__asm__("mvout_spad:");
// #ifdef DBUF // #ifdef DBUF
// gemmini_fence(); // gemmini_fence();
// #endif // #endif
#ifdef CISC #ifdef CISC
GEMMINI_CISC_CMD_I(9); GEMMINI_CISC_CMD_I(9);
#else #else
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, 0, (4ULL << 32) | (4ULL << 16) | 4ULL, k_LOOP_WS_CONFIG_BOUNDS) ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, 0, (((uint64_t) TILE_M / DIM) << 32) |
(((uint64_t) TILE_K / DIM) << 16) | ((uint64_t) TILE_N / DIM), k_LOOP_WS_CONFIG_BOUNDS)
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, 0, 0x278U, k_LOOP_WS) ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, 0, 0x278U, k_LOOP_WS)
#endif #endif
__asm__("mvout_spad_fence:");
gemmini_fence(); gemmini_fence();
} }
__asm__("mvout_spad_bar:");
threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS); threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS);
__asm__("end_mvout_spad:");
#endif #endif
rd_cycles(marker7); rd_cycles(marker7);
// move out to dram // move out to dram
__asm__("mvout_dram:");
#ifdef HARDCODE #ifdef HARDCODE
#if (TILE_MN / NUM_THREADS / NUM_WARPS / CORES_PER_CLUSTER) != 8 // #if (TILE_MN / NUM_THREADS / NUM_WARPS / CORES_PER_CLUSTER) != 8
#error CANNOT UNROLL // #error CANNOT UNROLL
#endif // #endif
constexpr uint32_t every_iter = j1_stride; constexpr uint32_t every_iter = j1_stride;
const uint32_t every_2iters = i1_stride * dim_n; const uint32_t every_2iters = i1_stride * dim_n;
const uint32_t runtime_const = i0 * dim_n + j1_idx + j0; const uint32_t runtime_const = i0 * dim_n + j1_idx + j0;
@@ -468,6 +497,24 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
dram_c_tile_start[every_iter * 1 + every_2iters * 2] = v1; dram_c_tile_start[every_iter * 1 + every_2iters * 2] = v1;
dram_c_tile_start[every_iter * 0 + every_2iters * 3] = v2; dram_c_tile_start[every_iter * 0 + every_2iters * 3] = v2;
dram_c_tile_start[every_iter * 1 + every_2iters * 3] = v3; dram_c_tile_start[every_iter * 1 + every_2iters * 3] = v3;
v0 = smem_acc_tile_start[8 * num_threads_in_cluster];
v1 = smem_acc_tile_start[9 * num_threads_in_cluster];
v2 = smem_acc_tile_start[10 * num_threads_in_cluster];
v3 = smem_acc_tile_start[11 * num_threads_in_cluster];
dram_c_tile_start[every_iter * 0 + every_2iters * 4] = v0;
dram_c_tile_start[every_iter * 1 + every_2iters * 4] = v1;
dram_c_tile_start[every_iter * 0 + every_2iters * 5] = v2;
dram_c_tile_start[every_iter * 1 + every_2iters * 5] = v3;
v0 = smem_acc_tile_start[12 * num_threads_in_cluster];
v1 = smem_acc_tile_start[13 * num_threads_in_cluster];
v2 = smem_acc_tile_start[14 * num_threads_in_cluster];
v3 = smem_acc_tile_start[15 * num_threads_in_cluster];
dram_c_tile_start[every_iter * 0 + every_2iters * 6] = v0;
dram_c_tile_start[every_iter * 1 + every_2iters * 6] = v1;
dram_c_tile_start[every_iter * 0 + every_2iters * 7] = v2;
dram_c_tile_start[every_iter * 1 + every_2iters * 7] = v3;
#else #else
dram_c_tile_start[every_iter * 0 + every_2iters * 0] = \ dram_c_tile_start[every_iter * 0 + every_2iters * 0] = \
smem_acc_tile_start[0 * num_threads_in_cluster]; smem_acc_tile_start[0 * num_threads_in_cluster];
@@ -496,7 +543,6 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
*(SMEM_ADDR_Q2 + SMEM_MAT_OFFSET(elem_offset / TILE_N, elem_offset % TILE_N, TILE_N)); *(SMEM_ADDR_Q2 + SMEM_MAT_OFFSET(elem_offset / TILE_N, elem_offset % TILE_N, TILE_N));
} }
#endif #endif
__asm__("end_mvout_dram:");
// rd_cycles_force(marker8); // rd_cycles_force(marker8);
} }
@@ -507,7 +553,7 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
rd_cycles_force(marker9); rd_cycles_force(marker9);
#ifdef POWER #ifdef POWER
if (HW_TID() == 0) { if (HW_TID() == 0) {
PRINTF("\nstart %d end %d\n", marker0, marker9); PRINTF("%d\n", marker9 - marker0);
} }
#else #else
if (HW_TID() == 0) { if (HW_TID() == 0) {