diff --git a/tests/regression/sgemm_gemmini_dma/kernel.cpp b/tests/regression/sgemm_gemmini_dma/kernel.cpp index 89e3600c..6391e500 100644 --- a/tests/regression/sgemm_gemmini_dma/kernel.cpp +++ b/tests/regression/sgemm_gemmini_dma/kernel.cpp @@ -6,19 +6,29 @@ #include "include/gemmini.h" #include "gemmini_mmio.h" -#define TILE_M 64 -#define TILE_N 64 -#define TILE_K 64 -#define SMEM_ADDR_Q0 ((float * const) 0xff000000) -#define SMEM_ADDR_Q1 ((float * const) 0xff004000) -#define SMEM_ADDR_Q2 ((float * const) 0xff008000) -#define SMEM_ADDR_Q3 ((float * const) 0xff00c000) -#define SPAD_ADDR_Q0 0x0 -#define SPAD_ADDR_Q1 0x200 -#define SPAD_ADDR_Q2 0x400 -#define SPAD_ADDR_Q3 0x600 +// fp16 16x16 +#define TILE_M 128 +#define TILE_N 128 +#define TILE_K 128 #define BOUND_INST 0x800080008ULL +#define NUM_THREADS_IN_CLUSTER 512 +// fp32 8x8 +// #define TILE_M 64 +// #define TILE_N 64 +// #define TILE_K 64 +// #define SMEM_ADDR_Q0 ((float * const) 0xff000000) +// #define SMEM_ADDR_Q1 ((float * const) 0xff004000) +// #define SMEM_ADDR_Q2 ((float * const) 0xff008000) +// #define SMEM_ADDR_Q3 ((float * const) 0xff00c000) +// #define SPAD_ADDR_Q0 0x0 +// #define SPAD_ADDR_Q1 0x200 +// #define SPAD_ADDR_Q2 0x400 +// #define SPAD_ADDR_Q3 0x600 +// #define BOUND_INST 0x800080008ULL +// #define NUM_THREADS_IN_CLUSTER 256 + +// fp32 4x4 // #define TILE_M 32 // #define TILE_N 32 // #define TILE_K 32 @@ -31,9 +41,9 @@ // #define SPAD_ADDR_Q2 0x100 // #define SPAD_ADDR_Q3 0x180 // #define BOUND_INST 0x400040004ULL +// #define NUM_THREADS_IN_CLUSTER 256 #define NUM_CLUSTERS 1 -#define NUM_THREADS_IN_CLUSTER 256 \ // (NUM_CORES * NUM_WARPS * NUM_THREADS) #define rd_cycles_force(x) asm volatile ("csrr %0, mcycle" : "=r" (x)) @@ -42,7 +52,7 @@ #define PRINTF(...) sprintf(PRINT_BUF, __VA_ARGS__) // #define PRINTF(...) vx_printf(__VA_ARGS__) #define SWISH(beta, x) ((x) / (1 + exp(-(beta) * (x)))) -// #define POWER +#define POWER typedef uint16_t smem_elem_t; // typedef float smem_elem_t; @@ -83,13 +93,11 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg, const uint32_t num_tile_rows_per_tb = num_tiles_m / NUM_CLUSTERS; - constexpr scale_t MVIN_SCALE_IDENTITY_HEX = 0x3c00; - if (HW_TID() == 0) { - gemmini_extended3_config_ld(dim_k * sizeof(elem_t), MVIN_SCALE_IDENTITY_HEX, false, 0); - gemmini_extended3_config_ld(dim_n * sizeof(elem_t), MVIN_SCALE_IDENTITY_HEX, false, 1); + gemmini_extended3_config_ld(dim_k * sizeof(elem_t), MVIN_SCALE_IDENTITY, false, 0); + gemmini_extended3_config_ld(dim_n * sizeof(elem_t), MVIN_SCALE_IDENTITY, false, 1); // gemmini_extended3_config_ld(repeating_bias ? 0 : (stride_D * sizeof_D), D_scale_factor, low_D, 2); - gemmini_extended_config_st(dim_n * sizeof(elem_t), 0, MVIN_SCALE_IDENTITY_HEX); + gemmini_extended_config_st(dim_n * sizeof(elem_t), 0, MVIN_SCALE_IDENTITY); // gemmini_extended_config_st(stride_C * sizeof_C, act & 3, scale); } @@ -187,4 +195,4 @@ int main() { vx_spawn_tasks_contiguous(grid_size, (vx_spawn_tasks_cb)kernel_body, arg); #endif return 0; -} +} \ No newline at end of file