diff --git a/tests/regression/sgemm_gemmini/kernel.cpp b/tests/regression/sgemm_gemmini/kernel.cpp index e1a33df6..c609eec1 100644 --- a/tests/regression/sgemm_gemmini/kernel.cpp +++ b/tests/regression/sgemm_gemmini/kernel.cpp @@ -6,25 +6,25 @@ #include "include/gemmini.h" #include "gemmini_mmio.h" -#define TILE_M 32 -#define TILE_N 32 -#define TILE_K 32 -#define TILE_MN 1024 -#define TILE_MK 1024 -#define TILE_NK 1024 +#define TILE_M 64 +#define TILE_N 64 +#define TILE_K 64 +#define TILE_MN 4096 +#define TILE_MK 4096 +#define TILE_NK 4096 #define NUM_CLUSTERS 1 -#define NUM_THREADS_IN_CLUSTER 128 +#define NUM_THREADS_IN_CLUSTER 256 #define SMEM_ADDR_Q0 ((float * const) 0xff000000) -#define SMEM_ADDR_Q1 ((float * const) 0xff001000) -#define SMEM_ADDR_Q2 ((float * const) 0xff002000) -#define SMEM_ADDR_Q3 ((float * const) 0xff003000) +#define SMEM_ADDR_Q1 ((float * const) 0xff004000) +#define SMEM_ADDR_Q2 ((float * const) 0xff008000) +#define SMEM_ADDR_Q3 ((float * const) 0xff00c000) #define SPAD_ADDR_Q0 0x0 -#define SPAD_ADDR_Q1 0x80 -#define SPAD_ADDR_Q2 0x100 -#define SPAD_ADDR_Q3 0x180 -#define SPAD_ADDR_Q4 0x200 +#define SPAD_ADDR_Q1 0x200 +#define SPAD_ADDR_Q2 0x400 +#define SPAD_ADDR_Q3 0x600 +#define SPAD_ADDR_Q4 0x800 #define HARDCODE #define REGBLOCK @@ -61,7 +61,6 @@ inline void threadblock_barrier(unsigned int barrier_id, unsigned int count) { void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg, const uint32_t threadblock_id, const uint32_t tid_in_threadblock) { - __asm__("matmul_start:"); const float * const A = (const float * const) arg->addr_a; const float * const B = (const float * const) arg->addr_b; float * const C = (float * const) arg->addr_c; @@ -71,7 +70,9 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg, // gemmini_extended_config_ex(dataflow, act & 3, 0, 1, a_transpose, b_transpose); // gemmini_extended_config_st(stride_C * sizeof_C, act & 3, scale); + #ifndef POWER PRINTF("start\n"); + #endif } vx_fence(); @@ -121,9 +122,7 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg, for (uint32_t tile_i = num_tile_rows_per_tb * threadblock_id; tile_i < num_tile_rows_per_tb * (threadblock_id + 1); tile_i += 1) { - __asm__("i_loop:"); for (int tile_j = 0; tile_j < num_tiles_n; tile_j += 1) { - __asm__("j_loop:"); float * const smem_c_tile_start = SMEM_ADDR_Q1; #ifdef OFFLOAD_ACCUMULATE float * const smem_acc_tile_start = SMEM_ADDR_Q0 + HW_TID(); @@ -131,15 +130,14 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg, float * const smem_acc_tile_start = SMEM_ADDR_Q2 + hw_tid; #endif - __asm__("k_loop:"); for (int tile_k = 0; tile_k < num_tiles_k; tile_k += 1) { // TODO: double buffer rd_cycles(marker1); #ifdef HARDCODE - #if (TILE_MK / NUM_THREADS / NUM_WARPS / CORES_PER_CLUSTER) != 8 - #error CANNOT UNROLL - #endif + // #if (TILE_MK / NUM_THREADS / NUM_WARPS / CORES_PER_CLUSTER) != 8 + // #error CANNOT UNROLL + // #endif constexpr uint32_t every_iter = j1_stride; const uint32_t every_2iters_a = i1_stride * dim_k; @@ -228,6 +226,42 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg, smem_b_tile_start[5 * num_threads_in_cluster] = v1; smem_b_tile_start[6 * num_threads_in_cluster] = v2; smem_b_tile_start[7 * num_threads_in_cluster] = v3; + + v0 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 4]; + v1 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 4]; + v2 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 5]; + v3 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 5]; + smem_a_tile_start[8 * num_threads_in_cluster] = v0; + smem_a_tile_start[9 * num_threads_in_cluster] = v1; + smem_a_tile_start[10 * num_threads_in_cluster] = v2; + smem_a_tile_start[11 * num_threads_in_cluster] = v3; + + v0 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 4]; + v1 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 4]; + v2 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 5]; + v3 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 5]; + smem_b_tile_start[8 * num_threads_in_cluster] = v0; + smem_b_tile_start[9 * num_threads_in_cluster] = v1; + smem_b_tile_start[10 * num_threads_in_cluster] = v2; + smem_b_tile_start[11 * num_threads_in_cluster] = v3; + + v0 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 6]; + v1 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 6]; + v2 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 7]; + v3 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 7]; + smem_a_tile_start[12 * num_threads_in_cluster] = v0; + smem_a_tile_start[13 * num_threads_in_cluster] = v1; + smem_a_tile_start[14 * num_threads_in_cluster] = v2; + smem_a_tile_start[15 * num_threads_in_cluster] = v3; + + v0 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 6]; + v1 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 6]; + v2 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 7]; + v3 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 7]; + smem_b_tile_start[12 * num_threads_in_cluster] = v0; + smem_b_tile_start[13 * num_threads_in_cluster] = v1; + smem_b_tile_start[14 * num_threads_in_cluster] = v2; + smem_b_tile_start[15 * num_threads_in_cluster] = v3; #endif } #else @@ -398,34 +432,29 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg, #ifdef OFFLOAD_ACCUMULATE threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS); rd_cycles(marker6); - __asm__("mvout_spad_ser:"); // mvout to scratchpad for activation if (HW_TID() == 0) { - __asm__("mvout_spad:"); // #ifdef DBUF // gemmini_fence(); // #endif #ifdef CISC GEMMINI_CISC_CMD_I(9); #else - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, 0, (4ULL << 32) | (4ULL << 16) | 4ULL, k_LOOP_WS_CONFIG_BOUNDS) + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, 0, (((uint64_t) TILE_M / DIM) << 32) | + (((uint64_t) TILE_K / DIM) << 16) | ((uint64_t) TILE_N / DIM), k_LOOP_WS_CONFIG_BOUNDS) ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, 0, 0x278U, k_LOOP_WS) #endif - __asm__("mvout_spad_fence:"); gemmini_fence(); } - __asm__("mvout_spad_bar:"); threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS); - __asm__("end_mvout_spad:"); #endif rd_cycles(marker7); // move out to dram - __asm__("mvout_dram:"); #ifdef HARDCODE - #if (TILE_MN / NUM_THREADS / NUM_WARPS / CORES_PER_CLUSTER) != 8 - #error CANNOT UNROLL - #endif + // #if (TILE_MN / NUM_THREADS / NUM_WARPS / CORES_PER_CLUSTER) != 8 + // #error CANNOT UNROLL + // #endif constexpr uint32_t every_iter = j1_stride; const uint32_t every_2iters = i1_stride * dim_n; const uint32_t runtime_const = i0 * dim_n + j1_idx + j0; @@ -468,6 +497,24 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg, dram_c_tile_start[every_iter * 1 + every_2iters * 2] = v1; dram_c_tile_start[every_iter * 0 + every_2iters * 3] = v2; dram_c_tile_start[every_iter * 1 + every_2iters * 3] = v3; + + v0 = smem_acc_tile_start[8 * num_threads_in_cluster]; + v1 = smem_acc_tile_start[9 * num_threads_in_cluster]; + v2 = smem_acc_tile_start[10 * num_threads_in_cluster]; + v3 = smem_acc_tile_start[11 * num_threads_in_cluster]; + dram_c_tile_start[every_iter * 0 + every_2iters * 4] = v0; + dram_c_tile_start[every_iter * 1 + every_2iters * 4] = v1; + dram_c_tile_start[every_iter * 0 + every_2iters * 5] = v2; + dram_c_tile_start[every_iter * 1 + every_2iters * 5] = v3; + + v0 = smem_acc_tile_start[12 * num_threads_in_cluster]; + v1 = smem_acc_tile_start[13 * num_threads_in_cluster]; + v2 = smem_acc_tile_start[14 * num_threads_in_cluster]; + v3 = smem_acc_tile_start[15 * num_threads_in_cluster]; + dram_c_tile_start[every_iter * 0 + every_2iters * 6] = v0; + dram_c_tile_start[every_iter * 1 + every_2iters * 6] = v1; + dram_c_tile_start[every_iter * 0 + every_2iters * 7] = v2; + dram_c_tile_start[every_iter * 1 + every_2iters * 7] = v3; #else dram_c_tile_start[every_iter * 0 + every_2iters * 0] = \ smem_acc_tile_start[0 * num_threads_in_cluster]; @@ -496,7 +543,6 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg, *(SMEM_ADDR_Q2 + SMEM_MAT_OFFSET(elem_offset / TILE_N, elem_offset % TILE_N, TILE_N)); } #endif - __asm__("end_mvout_dram:"); // rd_cycles_force(marker8); } @@ -507,7 +553,7 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg, rd_cycles_force(marker9); #ifdef POWER if (HW_TID() == 0) { - PRINTF("\nstart %d end %d\n", marker0, marker9); + PRINTF("%d\n", marker9 - marker0); } #else if (HW_TID() == 0) { diff --git a/tests/regression/sgemm_gemmini/sgemm_gemmini b/tests/regression/sgemm_gemmini/sgemm_gemmini index 67ade61b..2204a038 100755 Binary files a/tests/regression/sgemm_gemmini/sgemm_gemmini and b/tests/regression/sgemm_gemmini/sgemm_gemmini differ