updated no dma gemmini kernel

This commit is contained in:
Richard Yan
2025-01-29 16:59:44 -08:00
parent 8c45b8b4b7
commit ec41200845
7 changed files with 338 additions and 55 deletions

View File

@@ -0,0 +1 @@
../sgemm_gemmini_dma/args

View File

@@ -5,6 +5,8 @@
#define KERNEL_ARG_DEV_MEM_ADDR 0x9fff0000
#define DEV_SMEM_START_ADDR 0xff000000
#define MARK_BEG() asm volatile ("slti x0, x1, -1047")
#define MARK_END() asm volatile ("slti x0, x1, -499")
typedef struct {
uint32_t dim_m;

View File

@@ -0,0 +1,11 @@
rm kernel.radiance.elf
rm -rf binaries
mkdir binaries
for a in args/*; do
cp -f $a args.bin
aa=$(basename "$a")
cp -f input.a/"$aa" input.a.bin
cp -f input.b/"$aa" input.b.bin
make > /dev/null
mv kernel.radiance.elf binaries/gemmini_fp16nodma"$aa".elf
done

View File

@@ -0,0 +1,11 @@
rm kernel.radiance.elf
rm -rf binaries
mkdir binaries
for a in args/*; do
cp -f $a args.bin
aa=$(basename "$a")
cp -f input.a/"$aa" input.a.bin
cp -f input.b/"$aa" input.b.bin
make > /dev/null
mv kernel.radiance.elf binaries/gemmini_hopper_nodma"$aa".elf
done

View File

@@ -0,0 +1 @@
../sgemm_gemmini_dma/input.a

View File

@@ -0,0 +1 @@
../sgemm_gemmini_dma/input.b

View File

@@ -8,6 +8,7 @@
#define NUM_CLUSTERS 1
// #define FP32
#define HOPPER
#ifdef FP32
// fp32
@@ -39,7 +40,16 @@ typedef float mem_elem_t;
#define TILE_MN 8192
#define TILE_MK 16384
#define TILE_NK 8192
#define NUM_WARPS_ 8
#define NUM_THREADS_ 8
#ifdef HOPPER
#define NUM_THREADS_IN_CLUSTER 256
#define CORES_PER_CLUSTER_ 4
#else
#define NUM_THREADS_IN_CLUSTER 512
#define CORES_PER_CLUSTER_ 8
#endif
#define SMEM_ADDR_Q0 ((mem_elem_t * const) 0xff000000)
#define SMEM_ADDR_Q1 ((mem_elem_t * const) 0xff008000)
@@ -111,6 +121,7 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
uint32_t swish_dur = 0;
#endif
rd_cycles_force(marker0);
MARK_BEG();
const uint32_t dim_m = arg->dim_m;
const uint32_t dim_n = arg->dim_n;
@@ -122,20 +133,22 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
constexpr uint32_t a_elems_per_thread = TILE_MK / num_threads_in_cluster;
constexpr uint32_t b_elems_per_thread = TILE_NK / num_threads_in_cluster;
constexpr uint32_t c_elems_per_thread = TILE_MN / num_threads_in_cluster;
constexpr uint32_t e_mult = sizeof(uint32_t) / sizeof(smem_elem_t);
const uint32_t hw_tid = tid_in_threadblock % num_threads_in_cluster;
// the dram coordinates are (i1 + i0, j1 + j0). i0 and j0 are both spatially mapped only.
const uint32_t j0 = HW_TID() % DIM;
const uint32_t i0 = (HW_TID() / DIM) % DIM;
const uint32_t j0 = HW_TID() * e_mult % DIM;
const uint32_t i0 = (HW_TID() * e_mult / DIM) % DIM;
// j1 is both spatially and temporally mapped. j1 increases every iteration.
const uint32_t j1_idx = (HW_TID() / DIM / DIM) * DIM; // A: % TILE_K, B: % TILE_N, C: % TILE_N
const uint32_t j1_idx = (HW_TID() * e_mult / DIM / DIM) * DIM; // A: % TILE_K, B: % TILE_N, C: % TILE_N
// every iteratioon, j1 increases by j1_stride
constexpr uint32_t j1_stride = (num_threads_in_cluster / DIM / DIM) * DIM; // mod TILE_W after stride
constexpr uint32_t j1_stride = (num_threads_in_cluster * e_mult / DIM / DIM) * DIM; // mod TILE_W after stride
// i1 is only temporally mapped. i1 increments every one or more iterations
constexpr uint32_t i1_stride = DIM; // step per increment (increment doesnt happen every iteration)
constexpr uint32_t i1_iters = (DIM * DIM * (TILE_K / DIM)) / num_threads_in_cluster; // num of iters before striding
constexpr uint32_t i1_iters_a = (DIM * DIM * (TILE_K / DIM)) / num_threads_in_cluster / e_mult; // num of iters before striding
constexpr uint32_t i1_iters_b = (DIM * DIM * (TILE_N / DIM)) / num_threads_in_cluster / e_mult; // num of iters before striding
const uint32_t num_tile_rows_per_tb = num_tiles_m / NUM_CLUSTERS;
@@ -158,33 +171,50 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
mem_elem_t * const smem_acc_tile_start = SMEM_ADDR_Q2 + hw_tid;
#endif
#ifndef FP32
#ifdef HOPPER
constexpr uint32_t every_iter = j1_stride;
const uint32_t every_4iters_a = i1_stride * dim_k;
const uint32_t runtime_const_a = i0 * dim_k + j1_idx + j0;
const uint32_t every_2iters_b = i1_stride * dim_n;
const uint32_t runtime_const_b = i0 * dim_n + j1_idx + j0;
#else
constexpr uint32_t every_iter = j1_stride;
const uint32_t every_2iters_a = i1_stride * dim_k;
const uint32_t runtime_const_a = i0 * dim_k + j1_idx + j0;
const uint32_t every_iter_b = i1_stride * dim_n;
const uint32_t runtime_const_b = i0 * dim_n + j1_idx + j0;
#endif
#else
constexpr uint32_t every_iter = j1_stride;
const uint32_t every_2iters_a = i1_stride * dim_k;
const uint32_t runtime_const_a = i0 * dim_k + j1_idx + j0;
const uint32_t every_2iters_b = i1_stride * dim_n;
const uint32_t runtime_const_b = i0 * dim_n + j1_idx + j0;
#endif
for (int tile_k = 0; tile_k < num_tiles_k; tile_k += 1) {
// TODO: double buffer
rd_cycles(marker1);
#ifdef HARDCODE
// #if (TILE_MK / NUM_THREADS / NUM_WARPS / CORES_PER_CLUSTER) != 8
// #if (TILE_MK / NUM_THREADS / NUM_WARPS_ / CORES_PER_CLUSTER) != 8
// #error CANNOT UNROLL
// #endif
constexpr uint32_t every_iter = j1_stride;
const uint32_t every_2iters_a = i1_stride * (dim_k * sizeof(smem_elem_t) / 4);
const uint32_t runtime_const_a = i0 * (dim_k * sizeof(smem_elem_t) / 4) + j1_idx + j0;
const uint32_t every_2iters_b = i1_stride * dim_n;
const uint32_t runtime_const_b = i0 * dim_n + j1_idx + j0;
const mem_elem_t * const dram_a_tile_start = (const mem_elem_t * const) (A + tile_i * TILE_M * dim_k + tile_k * TILE_K + runtime_const_a);
const mem_elem_t * const dram_b_tile_start = (const mem_elem_t * const) (B + tile_k * TILE_K * dim_n + tile_j * TILE_N + runtime_const_b);
#ifdef DBUF
mem_elem_t * const smem_a_tile_start = (mem_elem_t * const) (((tile_k & 1) ? SMEM_ADDR_Q1 : SMEM_ADDR_Q0) + HW_TID());
mem_elem_t * const smem_b_tile_start = (mem_elem_t * const) (((tile_k & 1) ? SMEM_ADDR_Q3 : SMEM_ADDR_Q2) + HW_TID());
mem_elem_t * const smem_a_tile_start = (mem_elem_t * const) (((tile_k & 1) ? SMEM_ADDR_Q1 : SMEM_ADDR_Q0) + HW_TID() * e_mult);
mem_elem_t * const smem_b_tile_start = (mem_elem_t * const) (((tile_k & 1) ? SMEM_ADDR_Q3 : SMEM_ADDR_Q2) + HW_TID() * e_mult);
#else
mem_elem_t * const smem_a_tile_start = (mem_elem_t * const) (SMEM_ADDR_Q0 + HW_TID());
mem_elem_t * const smem_b_tile_start = (mem_elem_t * const) (SMEM_ADDR_Q3 + HW_TID());
mem_elem_t * const smem_a_tile_start = (mem_elem_t * const) (SMEM_ADDR_Q0 + HW_TID() * e_mult);
mem_elem_t * const smem_b_tile_start = (mem_elem_t * const) (SMEM_ADDR_Q3 + HW_TID() * e_mult);
#endif
{
#ifndef REGBLOCK
/*
smem_a_tile_start[0 * num_threads_in_cluster + hw_tid] = \
dram_a_tile_start[every_iter * 0 + every_2iters_a * 0];
smem_a_tile_start[1 * num_threads_in_cluster + hw_tid] = \
@@ -218,6 +248,174 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
dram_b_tile_start[every_iter * 0 + every_2iters_b * 3];
smem_b_tile_start[7 * num_threads_in_cluster + hw_tid] = \
dram_b_tile_start[every_iter * 1 + every_2iters_b * 3];
*/
#else
#ifndef FP32
#ifdef HOPPER
mem_elem_t v0 = dram_a_tile_start[every_iter * 0 + every_4iters_a * 0];
mem_elem_t v1 = dram_a_tile_start[every_iter * 1 + every_4iters_a * 0];
mem_elem_t v2 = dram_a_tile_start[every_iter * 2 + every_4iters_a * 0];
mem_elem_t v3 = dram_a_tile_start[every_iter * 3 + every_4iters_a * 0];
smem_a_tile_start[0 * num_threads_in_cluster] = v0;
smem_a_tile_start[1 * num_threads_in_cluster] = v1;
smem_a_tile_start[2 * num_threads_in_cluster] = v2;
smem_a_tile_start[3 * num_threads_in_cluster] = v3;
v0 = dram_a_tile_start[every_iter * 0 + every_4iters_a * 1];
v1 = dram_a_tile_start[every_iter * 1 + every_4iters_a * 1];
v2 = dram_a_tile_start[every_iter * 2 + every_4iters_a * 1];
v3 = dram_a_tile_start[every_iter * 3 + every_4iters_a * 1];
smem_a_tile_start[4 * num_threads_in_cluster] = v0;
smem_a_tile_start[5 * num_threads_in_cluster] = v1;
smem_a_tile_start[6 * num_threads_in_cluster] = v2;
smem_a_tile_start[7 * num_threads_in_cluster] = v3;
v0 = dram_a_tile_start[every_iter * 0 + every_4iters_a * 2];
v1 = dram_a_tile_start[every_iter * 1 + every_4iters_a * 2];
v2 = dram_a_tile_start[every_iter * 2 + every_4iters_a * 2];
v3 = dram_a_tile_start[every_iter * 3 + every_4iters_a * 2];
smem_a_tile_start[8 * num_threads_in_cluster] = v0;
smem_a_tile_start[9 * num_threads_in_cluster] = v1;
smem_a_tile_start[10 * num_threads_in_cluster] = v2;
smem_a_tile_start[11 * num_threads_in_cluster] = v3;
v0 = dram_a_tile_start[every_iter * 0 + every_4iters_a * 3];
v1 = dram_a_tile_start[every_iter * 1 + every_4iters_a * 3];
v2 = dram_a_tile_start[every_iter * 2 + every_4iters_a * 3];
v3 = dram_a_tile_start[every_iter * 3 + every_4iters_a * 3];
smem_a_tile_start[12 * num_threads_in_cluster] = v0;
smem_a_tile_start[13 * num_threads_in_cluster] = v1;
smem_a_tile_start[14 * num_threads_in_cluster] = v2;
smem_a_tile_start[15 * num_threads_in_cluster] = v3;
v0 = dram_a_tile_start[every_iter * 0 + every_4iters_a * 4];
v1 = dram_a_tile_start[every_iter * 1 + every_4iters_a * 4];
v2 = dram_a_tile_start[every_iter * 2 + every_4iters_a * 4];
v3 = dram_a_tile_start[every_iter * 3 + every_4iters_a * 4];
smem_a_tile_start[16 * num_threads_in_cluster] = v0;
smem_a_tile_start[17 * num_threads_in_cluster] = v1;
smem_a_tile_start[18 * num_threads_in_cluster] = v2;
smem_a_tile_start[19 * num_threads_in_cluster] = v3;
v0 = dram_a_tile_start[every_iter * 0 + every_4iters_a * 5];
v1 = dram_a_tile_start[every_iter * 1 + every_4iters_a * 5];
v2 = dram_a_tile_start[every_iter * 2 + every_4iters_a * 5];
v3 = dram_a_tile_start[every_iter * 3 + every_4iters_a * 5];
smem_a_tile_start[20 * num_threads_in_cluster] = v0;
smem_a_tile_start[21 * num_threads_in_cluster] = v1;
smem_a_tile_start[22 * num_threads_in_cluster] = v2;
smem_a_tile_start[23 * num_threads_in_cluster] = v3;
v0 = dram_a_tile_start[every_iter * 0 + every_4iters_a * 6];
v1 = dram_a_tile_start[every_iter * 1 + every_4iters_a * 6];
v2 = dram_a_tile_start[every_iter * 2 + every_4iters_a * 6];
v3 = dram_a_tile_start[every_iter * 3 + every_4iters_a * 6];
smem_a_tile_start[24 * num_threads_in_cluster] = v0;
smem_a_tile_start[25 * num_threads_in_cluster] = v1;
smem_a_tile_start[26 * num_threads_in_cluster] = v2;
smem_a_tile_start[27 * num_threads_in_cluster] = v3;
v0 = dram_a_tile_start[every_iter * 0 + every_4iters_a * 7];
v1 = dram_a_tile_start[every_iter * 1 + every_4iters_a * 7];
v2 = dram_a_tile_start[every_iter * 2 + every_4iters_a * 7];
v3 = dram_a_tile_start[every_iter * 3 + every_4iters_a * 7];
smem_a_tile_start[28 * num_threads_in_cluster] = v0;
smem_a_tile_start[29 * num_threads_in_cluster] = v1;
smem_a_tile_start[30 * num_threads_in_cluster] = v2;
smem_a_tile_start[31 * num_threads_in_cluster] = v3;
// --------------------
v0 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 0];
v1 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 0];
v2 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 1];
v3 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 1];
smem_b_tile_start[0 * num_threads_in_cluster] = v0;
smem_b_tile_start[1 * num_threads_in_cluster] = v1;
smem_b_tile_start[2 * num_threads_in_cluster] = v2;
smem_b_tile_start[3 * num_threads_in_cluster] = v3;
v0 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 2];
v1 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 2];
v2 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 3];
v3 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 3];
smem_b_tile_start[4 * num_threads_in_cluster] = v0;
smem_b_tile_start[5 * num_threads_in_cluster] = v1;
smem_b_tile_start[6 * num_threads_in_cluster] = v2;
smem_b_tile_start[7 * num_threads_in_cluster] = v3;
v0 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 4];
v1 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 4];
v2 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 5];
v3 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 5];
smem_b_tile_start[8 * num_threads_in_cluster] = v0;
smem_b_tile_start[9 * num_threads_in_cluster] = v1;
smem_b_tile_start[10 * num_threads_in_cluster] = v2;
smem_b_tile_start[11 * num_threads_in_cluster] = v3;
v0 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 6];
v1 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 6];
v2 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 7];
v3 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 7];
smem_b_tile_start[12 * num_threads_in_cluster] = v0;
smem_b_tile_start[13 * num_threads_in_cluster] = v1;
smem_b_tile_start[14 * num_threads_in_cluster] = v2;
smem_b_tile_start[15 * num_threads_in_cluster] = v3;
#else
mem_elem_t v0 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 0];
mem_elem_t v1 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 0];
mem_elem_t v2 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 1];
mem_elem_t v3 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 1];
smem_a_tile_start[0 * num_threads_in_cluster] = v0;
smem_a_tile_start[1 * num_threads_in_cluster] = v1;
smem_a_tile_start[2 * num_threads_in_cluster] = v2;
smem_a_tile_start[3 * num_threads_in_cluster] = v3;
v0 = dram_b_tile_start[every_iter * 0 + every_iter_b * 0];
v1 = dram_b_tile_start[every_iter * 0 + every_iter_b * 1];
v2 = dram_b_tile_start[every_iter * 0 + every_iter_b * 2];
v3 = dram_b_tile_start[every_iter * 0 + every_iter_b * 3];
smem_b_tile_start[0 * num_threads_in_cluster] = v0;
smem_b_tile_start[1 * num_threads_in_cluster] = v1;
smem_b_tile_start[2 * num_threads_in_cluster] = v2;
smem_b_tile_start[3 * num_threads_in_cluster] = v3;
v0 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 2];
v1 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 2];
v2 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 3];
v3 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 3];
smem_a_tile_start[4 * num_threads_in_cluster] = v0;
smem_a_tile_start[5 * num_threads_in_cluster] = v1;
smem_a_tile_start[6 * num_threads_in_cluster] = v2;
smem_a_tile_start[7 * num_threads_in_cluster] = v3;
v0 = dram_b_tile_start[every_iter * 0 + every_iter_b * 4];
v1 = dram_b_tile_start[every_iter * 0 + every_iter_b * 5];
v2 = dram_b_tile_start[every_iter * 0 + every_iter_b * 6];
v3 = dram_b_tile_start[every_iter * 0 + every_iter_b * 7];
smem_b_tile_start[4 * num_threads_in_cluster] = v0;
smem_b_tile_start[5 * num_threads_in_cluster] = v1;
smem_b_tile_start[6 * num_threads_in_cluster] = v2;
smem_b_tile_start[7 * num_threads_in_cluster] = v3;
v0 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 4];
v1 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 4];
v2 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 5];
v3 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 5];
smem_a_tile_start[8 * num_threads_in_cluster] = v0;
smem_a_tile_start[9 * num_threads_in_cluster] = v1;
smem_a_tile_start[10 * num_threads_in_cluster] = v2;
smem_a_tile_start[11 * num_threads_in_cluster] = v3;
v0 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 6];
v1 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 6];
v2 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 7];
v3 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 7];
smem_a_tile_start[12 * num_threads_in_cluster] = v0;
smem_a_tile_start[13 * num_threads_in_cluster] = v1;
smem_a_tile_start[14 * num_threads_in_cluster] = v2;
smem_a_tile_start[15 * num_threads_in_cluster] = v3;
#endif
#else
mem_elem_t v0 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 0];
mem_elem_t v1 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 0];
@@ -264,14 +462,14 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
smem_a_tile_start[10 * num_threads_in_cluster] = v2;
smem_a_tile_start[11 * num_threads_in_cluster] = v3;
// v0 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 4];
// v1 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 4];
// v2 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 5];
// v3 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 5];
// smem_b_tile_start[8 * num_threads_in_cluster] = v0;
// smem_b_tile_start[9 * num_threads_in_cluster] = v1;
// smem_b_tile_start[10 * num_threads_in_cluster] = v2;
// smem_b_tile_start[11 * num_threads_in_cluster] = v3;
v0 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 4];
v1 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 4];
v2 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 5];
v3 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 5];
smem_b_tile_start[8 * num_threads_in_cluster] = v0;
smem_b_tile_start[9 * num_threads_in_cluster] = v1;
smem_b_tile_start[10 * num_threads_in_cluster] = v2;
smem_b_tile_start[11 * num_threads_in_cluster] = v3;
v0 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 6];
v1 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 6];
@@ -282,14 +480,15 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
smem_a_tile_start[14 * num_threads_in_cluster] = v2;
smem_a_tile_start[15 * num_threads_in_cluster] = v3;
// v0 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 6];
// v1 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 6];
// v2 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 7];
// v3 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 7];
// smem_b_tile_start[12 * num_threads_in_cluster] = v0;
// smem_b_tile_start[13 * num_threads_in_cluster] = v1;
// smem_b_tile_start[14 * num_threads_in_cluster] = v2;
// smem_b_tile_start[15 * num_threads_in_cluster] = v3;
v0 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 6];
v1 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 6];
v2 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 7];
v3 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 7];
smem_b_tile_start[12 * num_threads_in_cluster] = v0;
smem_b_tile_start[13 * num_threads_in_cluster] = v1;
smem_b_tile_start[14 * num_threads_in_cluster] = v2;
smem_b_tile_start[15 * num_threads_in_cluster] = v3;
#endif
#endif
}
#else
@@ -353,7 +552,7 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
rd_cycles(marker2);
// cluster wide barrier to wait for A and B loads to complete
threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS);
threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS_);
rd_cycles(marker3);
if (HW_TID() == 0) {
#ifdef DBUF
@@ -396,7 +595,7 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
}
rd_cycles(marker4);
#ifndef DBUF
threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS);
threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS_);
#endif
rd_cycles(marker5);
@@ -411,7 +610,7 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
smem_acc_tile_start[thread_i * s] = smem_c_tile_start[hw_tid + s * thread_i];
}
} else {
#if (TILE_NK / NUM_THREADS / NUM_WARPS / CORES_PER_CLUSTER) != 8
#if (TILE_NK / NUM_THREADS_ / NUM_WARPS_ / CORES_PER_CLUSTER_) != 8
#error CANNOT UNROLL
#endif
for (int thread_i = 0; thread_i < c_elems_per_thread; thread_i += 8) {
@@ -458,7 +657,7 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
}
#ifdef OFFLOAD_ACCUMULATE
threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS);
threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS_);
rd_cycles(marker6);
// mvout to scratchpad for activation
if (HW_TID() == 0) {
@@ -474,21 +673,76 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
#endif
gemmini_fence();
}
threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS);
threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS_);
#endif
rd_cycles(marker7);
// move out to dram
#ifdef HARDCODE
// #if (TILE_MN / NUM_THREADS / NUM_WARPS / CORES_PER_CLUSTER) != 8
// #if (TILE_MN / NUM_THREADS / NUM_WARPS_ / CORES_PER_CLUSTER) != 8
// #error CANNOT UNROLL
// #endif
constexpr uint32_t every_iter = j1_stride;
const uint32_t every_2iters = i1_stride * dim_n;
const uint32_t runtime_const = i0 * dim_n + j1_idx + j0;
mem_elem_t * const dram_c_tile_start = (mem_elem_t * const) (C + tile_i * TILE_M * dim_n + tile_j * TILE_N + runtime_const);
mem_elem_t * const dram_c_tile_start = (mem_elem_t * const) (C + tile_i * TILE_M * dim_n + tile_j * TILE_N + runtime_const_b);
#ifdef REGBLOCK
#ifndef FP32
#ifdef HOPPER
mem_elem_t v0 = smem_acc_tile_start[0 * num_threads_in_cluster];
mem_elem_t v1 = smem_acc_tile_start[1 * num_threads_in_cluster];
mem_elem_t v2 = smem_acc_tile_start[2 * num_threads_in_cluster];
mem_elem_t v3 = smem_acc_tile_start[3 * num_threads_in_cluster];
dram_c_tile_start[every_iter * 0 + every_2iters_b * 0] = v0;
dram_c_tile_start[every_iter * 1 + every_2iters_b * 0] = v1;
dram_c_tile_start[every_iter * 0 + every_2iters_b * 1] = v2;
dram_c_tile_start[every_iter * 1 + every_2iters_b * 1] = v3;
v0 = smem_acc_tile_start[4 * num_threads_in_cluster];
v1 = smem_acc_tile_start[5 * num_threads_in_cluster];
v2 = smem_acc_tile_start[6 * num_threads_in_cluster];
v3 = smem_acc_tile_start[7 * num_threads_in_cluster];
dram_c_tile_start[every_iter * 0 + every_2iters_b * 2] = v0;
dram_c_tile_start[every_iter * 1 + every_2iters_b * 2] = v1;
dram_c_tile_start[every_iter * 0 + every_2iters_b * 3] = v2;
dram_c_tile_start[every_iter * 1 + every_2iters_b * 3] = v3;
v0 = smem_acc_tile_start[8 * num_threads_in_cluster];
v1 = smem_acc_tile_start[9 * num_threads_in_cluster];
v2 = smem_acc_tile_start[10 * num_threads_in_cluster];
v3 = smem_acc_tile_start[11 * num_threads_in_cluster];
dram_c_tile_start[every_iter * 0 + every_2iters_b * 4] = v0;
dram_c_tile_start[every_iter * 1 + every_2iters_b * 4] = v1;
dram_c_tile_start[every_iter * 0 + every_2iters_b * 5] = v2;
dram_c_tile_start[every_iter * 1 + every_2iters_b * 5] = v3;
v0 = smem_acc_tile_start[12 * num_threads_in_cluster];
v1 = smem_acc_tile_start[13 * num_threads_in_cluster];
v2 = smem_acc_tile_start[14 * num_threads_in_cluster];
v3 = smem_acc_tile_start[15 * num_threads_in_cluster];
dram_c_tile_start[every_iter * 0 + every_2iters_b * 6] = v0;
dram_c_tile_start[every_iter * 1 + every_2iters_b * 6] = v1;
dram_c_tile_start[every_iter * 0 + every_2iters_b * 7] = v2;
dram_c_tile_start[every_iter * 1 + every_2iters_b * 7] = v3;
#else // not HOPPER
mem_elem_t v0 = smem_acc_tile_start[0 * num_threads_in_cluster];
mem_elem_t v1 = smem_acc_tile_start[1 * num_threads_in_cluster];
mem_elem_t v2 = smem_acc_tile_start[2 * num_threads_in_cluster];
mem_elem_t v3 = smem_acc_tile_start[3 * num_threads_in_cluster];
dram_c_tile_start[every_iter * 0 + every_iter_b * 0] = v0;
dram_c_tile_start[every_iter * 0 + every_iter_b * 1] = v1;
dram_c_tile_start[every_iter * 0 + every_iter_b * 2] = v2;
dram_c_tile_start[every_iter * 0 + every_iter_b * 3] = v3;
v0 = smem_acc_tile_start[4 * num_threads_in_cluster];
v1 = smem_acc_tile_start[5 * num_threads_in_cluster];
v2 = smem_acc_tile_start[6 * num_threads_in_cluster];
v3 = smem_acc_tile_start[7 * num_threads_in_cluster];
dram_c_tile_start[every_iter * 0 + every_iter_b * 4] = v0;
dram_c_tile_start[every_iter * 0 + every_iter_b * 5] = v1;
dram_c_tile_start[every_iter * 0 + every_iter_b * 6] = v2;
dram_c_tile_start[every_iter * 0 + every_iter_b * 7] = v3;
#endif // HOPPER
#else // FP32
mem_elem_t v0 = smem_acc_tile_start[0 * num_threads_in_cluster];
mem_elem_t v1 = smem_acc_tile_start[1 * num_threads_in_cluster];
mem_elem_t v2 = smem_acc_tile_start[2 * num_threads_in_cluster];
@@ -503,10 +757,10 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
rd_cycles_force(swish_end);
swish_dur += swish_end - swish_start;
#endif
dram_c_tile_start[every_iter * 0 + every_2iters * 0] = v0;
dram_c_tile_start[every_iter * 1 + every_2iters * 0] = v1;
dram_c_tile_start[every_iter * 0 + every_2iters * 1] = v2;
dram_c_tile_start[every_iter * 1 + every_2iters * 1] = v3;
dram_c_tile_start[every_iter * 0 + every_2iters_b * 0] = v0;
dram_c_tile_start[every_iter * 1 + every_2iters_b * 0] = v1;
dram_c_tile_start[every_iter * 0 + every_2iters_b * 1] = v2;
dram_c_tile_start[every_iter * 1 + every_2iters_b * 1] = v3;
v0 = smem_acc_tile_start[4 * num_threads_in_cluster];
v1 = smem_acc_tile_start[5 * num_threads_in_cluster];
@@ -521,10 +775,10 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
rd_cycles_force(swish_end);
swish_dur += swish_end - swish_start;
#endif
dram_c_tile_start[every_iter * 0 + every_2iters * 2] = v0;
dram_c_tile_start[every_iter * 1 + every_2iters * 2] = v1;
dram_c_tile_start[every_iter * 0 + every_2iters * 3] = v2;
dram_c_tile_start[every_iter * 1 + every_2iters * 3] = v3;
dram_c_tile_start[every_iter * 0 + every_2iters_b * 2] = v0;
dram_c_tile_start[every_iter * 1 + every_2iters_b * 2] = v1;
dram_c_tile_start[every_iter * 0 + every_2iters_b * 3] = v2;
dram_c_tile_start[every_iter * 1 + every_2iters_b * 3] = v3;
// v0 = smem_acc_tile_start[8 * num_threads_in_cluster];
// v1 = smem_acc_tile_start[9 * num_threads_in_cluster];
@@ -543,6 +797,7 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
// dram_c_tile_start[every_iter * 1 + every_2iters * 6] = v1;
// dram_c_tile_start[every_iter * 0 + every_2iters * 7] = v2;
// dram_c_tile_start[every_iter * 1 + every_2iters * 7] = v3;
#endif // FP16/FP32
#else
dram_c_tile_start[every_iter * 0 + every_2iters * 0] = \
smem_acc_tile_start[0 * num_threads_in_cluster];
@@ -577,7 +832,8 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
}
// last thread block complete
if (threadblock_id == NUM_CLUSTERS - 1) {
threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS);
threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS_);
MARK_END();
rd_cycles_force(marker9);
#ifdef POWER
if (HW_TID() == 0) {
@@ -610,7 +866,7 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
#endif
PRINTF("dram mvout cycles: %d\n", marker8 - marker7);
}
threadblock_barrier(/*barrier_id=*/1, /*count=*/NUM_WARPS);
threadblock_barrier(/*barrier_id=*/1, /*count=*/NUM_WARPS_);
}
#endif
if (HW_TID() == 0) {
@@ -623,7 +879,7 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
}
#endif
}
threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS);
threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS_);
vx_tmc(0);
}
@@ -640,7 +896,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
int main() {
kernel_arg_t *arg = (kernel_arg_t *)KERNEL_ARG_DEV_MEM_ADDR;
const uint32_t num_threads_in_cluster = vx_num_threads() * vx_num_warps() * CORES_PER_CLUSTER;
const uint32_t num_threads_in_cluster = NUM_THREADS_IN_CLUSTER; // vx_num_threads() * vx_num_warps() * CORES_PER_CLUSTER;
const uint32_t grid_size = num_threads_in_cluster * NUM_CLUSTERS;
#ifdef RADIANCE
vx_spawn_tasks_cluster(grid_size, (vx_spawn_tasks_cb)kernel_body, arg);