|
|
|
|
@@ -8,6 +8,7 @@
|
|
|
|
|
|
|
|
|
|
#define NUM_CLUSTERS 1
|
|
|
|
|
// #define FP32
|
|
|
|
|
#define HOPPER
|
|
|
|
|
|
|
|
|
|
#ifdef FP32
|
|
|
|
|
// fp32
|
|
|
|
|
@@ -39,7 +40,16 @@ typedef float mem_elem_t;
|
|
|
|
|
#define TILE_MN 8192
|
|
|
|
|
#define TILE_MK 16384
|
|
|
|
|
#define TILE_NK 8192
|
|
|
|
|
#define NUM_WARPS_ 8
|
|
|
|
|
#define NUM_THREADS_ 8
|
|
|
|
|
|
|
|
|
|
#ifdef HOPPER
|
|
|
|
|
#define NUM_THREADS_IN_CLUSTER 256
|
|
|
|
|
#define CORES_PER_CLUSTER_ 4
|
|
|
|
|
#else
|
|
|
|
|
#define NUM_THREADS_IN_CLUSTER 512
|
|
|
|
|
#define CORES_PER_CLUSTER_ 8
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#define SMEM_ADDR_Q0 ((mem_elem_t * const) 0xff000000)
|
|
|
|
|
#define SMEM_ADDR_Q1 ((mem_elem_t * const) 0xff008000)
|
|
|
|
|
@@ -111,6 +121,7 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
|
|
|
|
uint32_t swish_dur = 0;
|
|
|
|
|
#endif
|
|
|
|
|
rd_cycles_force(marker0);
|
|
|
|
|
MARK_BEG();
|
|
|
|
|
|
|
|
|
|
const uint32_t dim_m = arg->dim_m;
|
|
|
|
|
const uint32_t dim_n = arg->dim_n;
|
|
|
|
|
@@ -122,20 +133,22 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
|
|
|
|
constexpr uint32_t a_elems_per_thread = TILE_MK / num_threads_in_cluster;
|
|
|
|
|
constexpr uint32_t b_elems_per_thread = TILE_NK / num_threads_in_cluster;
|
|
|
|
|
constexpr uint32_t c_elems_per_thread = TILE_MN / num_threads_in_cluster;
|
|
|
|
|
constexpr uint32_t e_mult = sizeof(uint32_t) / sizeof(smem_elem_t);
|
|
|
|
|
const uint32_t hw_tid = tid_in_threadblock % num_threads_in_cluster;
|
|
|
|
|
|
|
|
|
|
// the dram coordinates are (i1 + i0, j1 + j0). i0 and j0 are both spatially mapped only.
|
|
|
|
|
const uint32_t j0 = HW_TID() % DIM;
|
|
|
|
|
const uint32_t i0 = (HW_TID() / DIM) % DIM;
|
|
|
|
|
const uint32_t j0 = HW_TID() * e_mult % DIM;
|
|
|
|
|
const uint32_t i0 = (HW_TID() * e_mult / DIM) % DIM;
|
|
|
|
|
|
|
|
|
|
// j1 is both spatially and temporally mapped. j1 increases every iteration.
|
|
|
|
|
const uint32_t j1_idx = (HW_TID() / DIM / DIM) * DIM; // A: % TILE_K, B: % TILE_N, C: % TILE_N
|
|
|
|
|
const uint32_t j1_idx = (HW_TID() * e_mult / DIM / DIM) * DIM; // A: % TILE_K, B: % TILE_N, C: % TILE_N
|
|
|
|
|
// every iteratioon, j1 increases by j1_stride
|
|
|
|
|
constexpr uint32_t j1_stride = (num_threads_in_cluster / DIM / DIM) * DIM; // mod TILE_W after stride
|
|
|
|
|
constexpr uint32_t j1_stride = (num_threads_in_cluster * e_mult / DIM / DIM) * DIM; // mod TILE_W after stride
|
|
|
|
|
|
|
|
|
|
// i1 is only temporally mapped. i1 increments every one or more iterations
|
|
|
|
|
constexpr uint32_t i1_stride = DIM; // step per increment (increment doesnt happen every iteration)
|
|
|
|
|
constexpr uint32_t i1_iters = (DIM * DIM * (TILE_K / DIM)) / num_threads_in_cluster; // num of iters before striding
|
|
|
|
|
constexpr uint32_t i1_iters_a = (DIM * DIM * (TILE_K / DIM)) / num_threads_in_cluster / e_mult; // num of iters before striding
|
|
|
|
|
constexpr uint32_t i1_iters_b = (DIM * DIM * (TILE_N / DIM)) / num_threads_in_cluster / e_mult; // num of iters before striding
|
|
|
|
|
|
|
|
|
|
const uint32_t num_tile_rows_per_tb = num_tiles_m / NUM_CLUSTERS;
|
|
|
|
|
|
|
|
|
|
@@ -158,33 +171,50 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
|
|
|
|
mem_elem_t * const smem_acc_tile_start = SMEM_ADDR_Q2 + hw_tid;
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#ifndef FP32
|
|
|
|
|
#ifdef HOPPER
|
|
|
|
|
constexpr uint32_t every_iter = j1_stride;
|
|
|
|
|
const uint32_t every_4iters_a = i1_stride * dim_k;
|
|
|
|
|
const uint32_t runtime_const_a = i0 * dim_k + j1_idx + j0;
|
|
|
|
|
const uint32_t every_2iters_b = i1_stride * dim_n;
|
|
|
|
|
const uint32_t runtime_const_b = i0 * dim_n + j1_idx + j0;
|
|
|
|
|
#else
|
|
|
|
|
constexpr uint32_t every_iter = j1_stride;
|
|
|
|
|
const uint32_t every_2iters_a = i1_stride * dim_k;
|
|
|
|
|
const uint32_t runtime_const_a = i0 * dim_k + j1_idx + j0;
|
|
|
|
|
const uint32_t every_iter_b = i1_stride * dim_n;
|
|
|
|
|
const uint32_t runtime_const_b = i0 * dim_n + j1_idx + j0;
|
|
|
|
|
#endif
|
|
|
|
|
#else
|
|
|
|
|
constexpr uint32_t every_iter = j1_stride;
|
|
|
|
|
const uint32_t every_2iters_a = i1_stride * dim_k;
|
|
|
|
|
const uint32_t runtime_const_a = i0 * dim_k + j1_idx + j0;
|
|
|
|
|
const uint32_t every_2iters_b = i1_stride * dim_n;
|
|
|
|
|
const uint32_t runtime_const_b = i0 * dim_n + j1_idx + j0;
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
for (int tile_k = 0; tile_k < num_tiles_k; tile_k += 1) {
|
|
|
|
|
// TODO: double buffer
|
|
|
|
|
rd_cycles(marker1);
|
|
|
|
|
|
|
|
|
|
#ifdef HARDCODE
|
|
|
|
|
// #if (TILE_MK / NUM_THREADS / NUM_WARPS / CORES_PER_CLUSTER) != 8
|
|
|
|
|
// #if (TILE_MK / NUM_THREADS / NUM_WARPS_ / CORES_PER_CLUSTER) != 8
|
|
|
|
|
// #error CANNOT UNROLL
|
|
|
|
|
// #endif
|
|
|
|
|
|
|
|
|
|
constexpr uint32_t every_iter = j1_stride;
|
|
|
|
|
const uint32_t every_2iters_a = i1_stride * (dim_k * sizeof(smem_elem_t) / 4);
|
|
|
|
|
const uint32_t runtime_const_a = i0 * (dim_k * sizeof(smem_elem_t) / 4) + j1_idx + j0;
|
|
|
|
|
const uint32_t every_2iters_b = i1_stride * dim_n;
|
|
|
|
|
const uint32_t runtime_const_b = i0 * dim_n + j1_idx + j0;
|
|
|
|
|
|
|
|
|
|
const mem_elem_t * const dram_a_tile_start = (const mem_elem_t * const) (A + tile_i * TILE_M * dim_k + tile_k * TILE_K + runtime_const_a);
|
|
|
|
|
const mem_elem_t * const dram_b_tile_start = (const mem_elem_t * const) (B + tile_k * TILE_K * dim_n + tile_j * TILE_N + runtime_const_b);
|
|
|
|
|
#ifdef DBUF
|
|
|
|
|
mem_elem_t * const smem_a_tile_start = (mem_elem_t * const) (((tile_k & 1) ? SMEM_ADDR_Q1 : SMEM_ADDR_Q0) + HW_TID());
|
|
|
|
|
mem_elem_t * const smem_b_tile_start = (mem_elem_t * const) (((tile_k & 1) ? SMEM_ADDR_Q3 : SMEM_ADDR_Q2) + HW_TID());
|
|
|
|
|
mem_elem_t * const smem_a_tile_start = (mem_elem_t * const) (((tile_k & 1) ? SMEM_ADDR_Q1 : SMEM_ADDR_Q0) + HW_TID() * e_mult);
|
|
|
|
|
mem_elem_t * const smem_b_tile_start = (mem_elem_t * const) (((tile_k & 1) ? SMEM_ADDR_Q3 : SMEM_ADDR_Q2) + HW_TID() * e_mult);
|
|
|
|
|
#else
|
|
|
|
|
mem_elem_t * const smem_a_tile_start = (mem_elem_t * const) (SMEM_ADDR_Q0 + HW_TID());
|
|
|
|
|
mem_elem_t * const smem_b_tile_start = (mem_elem_t * const) (SMEM_ADDR_Q3 + HW_TID());
|
|
|
|
|
mem_elem_t * const smem_a_tile_start = (mem_elem_t * const) (SMEM_ADDR_Q0 + HW_TID() * e_mult);
|
|
|
|
|
mem_elem_t * const smem_b_tile_start = (mem_elem_t * const) (SMEM_ADDR_Q3 + HW_TID() * e_mult);
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
#ifndef REGBLOCK
|
|
|
|
|
/*
|
|
|
|
|
smem_a_tile_start[0 * num_threads_in_cluster + hw_tid] = \
|
|
|
|
|
dram_a_tile_start[every_iter * 0 + every_2iters_a * 0];
|
|
|
|
|
smem_a_tile_start[1 * num_threads_in_cluster + hw_tid] = \
|
|
|
|
|
@@ -218,7 +248,175 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
|
|
|
|
dram_b_tile_start[every_iter * 0 + every_2iters_b * 3];
|
|
|
|
|
smem_b_tile_start[7 * num_threads_in_cluster + hw_tid] = \
|
|
|
|
|
dram_b_tile_start[every_iter * 1 + every_2iters_b * 3];
|
|
|
|
|
*/
|
|
|
|
|
#else
|
|
|
|
|
#ifndef FP32
|
|
|
|
|
#ifdef HOPPER
|
|
|
|
|
mem_elem_t v0 = dram_a_tile_start[every_iter * 0 + every_4iters_a * 0];
|
|
|
|
|
mem_elem_t v1 = dram_a_tile_start[every_iter * 1 + every_4iters_a * 0];
|
|
|
|
|
mem_elem_t v2 = dram_a_tile_start[every_iter * 2 + every_4iters_a * 0];
|
|
|
|
|
mem_elem_t v3 = dram_a_tile_start[every_iter * 3 + every_4iters_a * 0];
|
|
|
|
|
smem_a_tile_start[0 * num_threads_in_cluster] = v0;
|
|
|
|
|
smem_a_tile_start[1 * num_threads_in_cluster] = v1;
|
|
|
|
|
smem_a_tile_start[2 * num_threads_in_cluster] = v2;
|
|
|
|
|
smem_a_tile_start[3 * num_threads_in_cluster] = v3;
|
|
|
|
|
|
|
|
|
|
v0 = dram_a_tile_start[every_iter * 0 + every_4iters_a * 1];
|
|
|
|
|
v1 = dram_a_tile_start[every_iter * 1 + every_4iters_a * 1];
|
|
|
|
|
v2 = dram_a_tile_start[every_iter * 2 + every_4iters_a * 1];
|
|
|
|
|
v3 = dram_a_tile_start[every_iter * 3 + every_4iters_a * 1];
|
|
|
|
|
smem_a_tile_start[4 * num_threads_in_cluster] = v0;
|
|
|
|
|
smem_a_tile_start[5 * num_threads_in_cluster] = v1;
|
|
|
|
|
smem_a_tile_start[6 * num_threads_in_cluster] = v2;
|
|
|
|
|
smem_a_tile_start[7 * num_threads_in_cluster] = v3;
|
|
|
|
|
|
|
|
|
|
v0 = dram_a_tile_start[every_iter * 0 + every_4iters_a * 2];
|
|
|
|
|
v1 = dram_a_tile_start[every_iter * 1 + every_4iters_a * 2];
|
|
|
|
|
v2 = dram_a_tile_start[every_iter * 2 + every_4iters_a * 2];
|
|
|
|
|
v3 = dram_a_tile_start[every_iter * 3 + every_4iters_a * 2];
|
|
|
|
|
smem_a_tile_start[8 * num_threads_in_cluster] = v0;
|
|
|
|
|
smem_a_tile_start[9 * num_threads_in_cluster] = v1;
|
|
|
|
|
smem_a_tile_start[10 * num_threads_in_cluster] = v2;
|
|
|
|
|
smem_a_tile_start[11 * num_threads_in_cluster] = v3;
|
|
|
|
|
|
|
|
|
|
v0 = dram_a_tile_start[every_iter * 0 + every_4iters_a * 3];
|
|
|
|
|
v1 = dram_a_tile_start[every_iter * 1 + every_4iters_a * 3];
|
|
|
|
|
v2 = dram_a_tile_start[every_iter * 2 + every_4iters_a * 3];
|
|
|
|
|
v3 = dram_a_tile_start[every_iter * 3 + every_4iters_a * 3];
|
|
|
|
|
smem_a_tile_start[12 * num_threads_in_cluster] = v0;
|
|
|
|
|
smem_a_tile_start[13 * num_threads_in_cluster] = v1;
|
|
|
|
|
smem_a_tile_start[14 * num_threads_in_cluster] = v2;
|
|
|
|
|
smem_a_tile_start[15 * num_threads_in_cluster] = v3;
|
|
|
|
|
|
|
|
|
|
v0 = dram_a_tile_start[every_iter * 0 + every_4iters_a * 4];
|
|
|
|
|
v1 = dram_a_tile_start[every_iter * 1 + every_4iters_a * 4];
|
|
|
|
|
v2 = dram_a_tile_start[every_iter * 2 + every_4iters_a * 4];
|
|
|
|
|
v3 = dram_a_tile_start[every_iter * 3 + every_4iters_a * 4];
|
|
|
|
|
smem_a_tile_start[16 * num_threads_in_cluster] = v0;
|
|
|
|
|
smem_a_tile_start[17 * num_threads_in_cluster] = v1;
|
|
|
|
|
smem_a_tile_start[18 * num_threads_in_cluster] = v2;
|
|
|
|
|
smem_a_tile_start[19 * num_threads_in_cluster] = v3;
|
|
|
|
|
|
|
|
|
|
v0 = dram_a_tile_start[every_iter * 0 + every_4iters_a * 5];
|
|
|
|
|
v1 = dram_a_tile_start[every_iter * 1 + every_4iters_a * 5];
|
|
|
|
|
v2 = dram_a_tile_start[every_iter * 2 + every_4iters_a * 5];
|
|
|
|
|
v3 = dram_a_tile_start[every_iter * 3 + every_4iters_a * 5];
|
|
|
|
|
smem_a_tile_start[20 * num_threads_in_cluster] = v0;
|
|
|
|
|
smem_a_tile_start[21 * num_threads_in_cluster] = v1;
|
|
|
|
|
smem_a_tile_start[22 * num_threads_in_cluster] = v2;
|
|
|
|
|
smem_a_tile_start[23 * num_threads_in_cluster] = v3;
|
|
|
|
|
|
|
|
|
|
v0 = dram_a_tile_start[every_iter * 0 + every_4iters_a * 6];
|
|
|
|
|
v1 = dram_a_tile_start[every_iter * 1 + every_4iters_a * 6];
|
|
|
|
|
v2 = dram_a_tile_start[every_iter * 2 + every_4iters_a * 6];
|
|
|
|
|
v3 = dram_a_tile_start[every_iter * 3 + every_4iters_a * 6];
|
|
|
|
|
smem_a_tile_start[24 * num_threads_in_cluster] = v0;
|
|
|
|
|
smem_a_tile_start[25 * num_threads_in_cluster] = v1;
|
|
|
|
|
smem_a_tile_start[26 * num_threads_in_cluster] = v2;
|
|
|
|
|
smem_a_tile_start[27 * num_threads_in_cluster] = v3;
|
|
|
|
|
|
|
|
|
|
v0 = dram_a_tile_start[every_iter * 0 + every_4iters_a * 7];
|
|
|
|
|
v1 = dram_a_tile_start[every_iter * 1 + every_4iters_a * 7];
|
|
|
|
|
v2 = dram_a_tile_start[every_iter * 2 + every_4iters_a * 7];
|
|
|
|
|
v3 = dram_a_tile_start[every_iter * 3 + every_4iters_a * 7];
|
|
|
|
|
smem_a_tile_start[28 * num_threads_in_cluster] = v0;
|
|
|
|
|
smem_a_tile_start[29 * num_threads_in_cluster] = v1;
|
|
|
|
|
smem_a_tile_start[30 * num_threads_in_cluster] = v2;
|
|
|
|
|
smem_a_tile_start[31 * num_threads_in_cluster] = v3;
|
|
|
|
|
|
|
|
|
|
// --------------------
|
|
|
|
|
|
|
|
|
|
v0 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 0];
|
|
|
|
|
v1 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 0];
|
|
|
|
|
v2 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 1];
|
|
|
|
|
v3 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 1];
|
|
|
|
|
smem_b_tile_start[0 * num_threads_in_cluster] = v0;
|
|
|
|
|
smem_b_tile_start[1 * num_threads_in_cluster] = v1;
|
|
|
|
|
smem_b_tile_start[2 * num_threads_in_cluster] = v2;
|
|
|
|
|
smem_b_tile_start[3 * num_threads_in_cluster] = v3;
|
|
|
|
|
|
|
|
|
|
v0 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 2];
|
|
|
|
|
v1 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 2];
|
|
|
|
|
v2 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 3];
|
|
|
|
|
v3 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 3];
|
|
|
|
|
smem_b_tile_start[4 * num_threads_in_cluster] = v0;
|
|
|
|
|
smem_b_tile_start[5 * num_threads_in_cluster] = v1;
|
|
|
|
|
smem_b_tile_start[6 * num_threads_in_cluster] = v2;
|
|
|
|
|
smem_b_tile_start[7 * num_threads_in_cluster] = v3;
|
|
|
|
|
|
|
|
|
|
v0 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 4];
|
|
|
|
|
v1 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 4];
|
|
|
|
|
v2 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 5];
|
|
|
|
|
v3 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 5];
|
|
|
|
|
smem_b_tile_start[8 * num_threads_in_cluster] = v0;
|
|
|
|
|
smem_b_tile_start[9 * num_threads_in_cluster] = v1;
|
|
|
|
|
smem_b_tile_start[10 * num_threads_in_cluster] = v2;
|
|
|
|
|
smem_b_tile_start[11 * num_threads_in_cluster] = v3;
|
|
|
|
|
|
|
|
|
|
v0 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 6];
|
|
|
|
|
v1 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 6];
|
|
|
|
|
v2 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 7];
|
|
|
|
|
v3 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 7];
|
|
|
|
|
smem_b_tile_start[12 * num_threads_in_cluster] = v0;
|
|
|
|
|
smem_b_tile_start[13 * num_threads_in_cluster] = v1;
|
|
|
|
|
smem_b_tile_start[14 * num_threads_in_cluster] = v2;
|
|
|
|
|
smem_b_tile_start[15 * num_threads_in_cluster] = v3;
|
|
|
|
|
#else
|
|
|
|
|
mem_elem_t v0 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 0];
|
|
|
|
|
mem_elem_t v1 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 0];
|
|
|
|
|
mem_elem_t v2 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 1];
|
|
|
|
|
mem_elem_t v3 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 1];
|
|
|
|
|
smem_a_tile_start[0 * num_threads_in_cluster] = v0;
|
|
|
|
|
smem_a_tile_start[1 * num_threads_in_cluster] = v1;
|
|
|
|
|
smem_a_tile_start[2 * num_threads_in_cluster] = v2;
|
|
|
|
|
smem_a_tile_start[3 * num_threads_in_cluster] = v3;
|
|
|
|
|
|
|
|
|
|
v0 = dram_b_tile_start[every_iter * 0 + every_iter_b * 0];
|
|
|
|
|
v1 = dram_b_tile_start[every_iter * 0 + every_iter_b * 1];
|
|
|
|
|
v2 = dram_b_tile_start[every_iter * 0 + every_iter_b * 2];
|
|
|
|
|
v3 = dram_b_tile_start[every_iter * 0 + every_iter_b * 3];
|
|
|
|
|
smem_b_tile_start[0 * num_threads_in_cluster] = v0;
|
|
|
|
|
smem_b_tile_start[1 * num_threads_in_cluster] = v1;
|
|
|
|
|
smem_b_tile_start[2 * num_threads_in_cluster] = v2;
|
|
|
|
|
smem_b_tile_start[3 * num_threads_in_cluster] = v3;
|
|
|
|
|
|
|
|
|
|
v0 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 2];
|
|
|
|
|
v1 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 2];
|
|
|
|
|
v2 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 3];
|
|
|
|
|
v3 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 3];
|
|
|
|
|
smem_a_tile_start[4 * num_threads_in_cluster] = v0;
|
|
|
|
|
smem_a_tile_start[5 * num_threads_in_cluster] = v1;
|
|
|
|
|
smem_a_tile_start[6 * num_threads_in_cluster] = v2;
|
|
|
|
|
smem_a_tile_start[7 * num_threads_in_cluster] = v3;
|
|
|
|
|
|
|
|
|
|
v0 = dram_b_tile_start[every_iter * 0 + every_iter_b * 4];
|
|
|
|
|
v1 = dram_b_tile_start[every_iter * 0 + every_iter_b * 5];
|
|
|
|
|
v2 = dram_b_tile_start[every_iter * 0 + every_iter_b * 6];
|
|
|
|
|
v3 = dram_b_tile_start[every_iter * 0 + every_iter_b * 7];
|
|
|
|
|
smem_b_tile_start[4 * num_threads_in_cluster] = v0;
|
|
|
|
|
smem_b_tile_start[5 * num_threads_in_cluster] = v1;
|
|
|
|
|
smem_b_tile_start[6 * num_threads_in_cluster] = v2;
|
|
|
|
|
smem_b_tile_start[7 * num_threads_in_cluster] = v3;
|
|
|
|
|
|
|
|
|
|
v0 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 4];
|
|
|
|
|
v1 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 4];
|
|
|
|
|
v2 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 5];
|
|
|
|
|
v3 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 5];
|
|
|
|
|
smem_a_tile_start[8 * num_threads_in_cluster] = v0;
|
|
|
|
|
smem_a_tile_start[9 * num_threads_in_cluster] = v1;
|
|
|
|
|
smem_a_tile_start[10 * num_threads_in_cluster] = v2;
|
|
|
|
|
smem_a_tile_start[11 * num_threads_in_cluster] = v3;
|
|
|
|
|
|
|
|
|
|
v0 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 6];
|
|
|
|
|
v1 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 6];
|
|
|
|
|
v2 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 7];
|
|
|
|
|
v3 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 7];
|
|
|
|
|
smem_a_tile_start[12 * num_threads_in_cluster] = v0;
|
|
|
|
|
smem_a_tile_start[13 * num_threads_in_cluster] = v1;
|
|
|
|
|
smem_a_tile_start[14 * num_threads_in_cluster] = v2;
|
|
|
|
|
smem_a_tile_start[15 * num_threads_in_cluster] = v3;
|
|
|
|
|
#endif
|
|
|
|
|
#else
|
|
|
|
|
mem_elem_t v0 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 0];
|
|
|
|
|
mem_elem_t v1 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 0];
|
|
|
|
|
mem_elem_t v2 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 1];
|
|
|
|
|
@@ -264,14 +462,14 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
|
|
|
|
smem_a_tile_start[10 * num_threads_in_cluster] = v2;
|
|
|
|
|
smem_a_tile_start[11 * num_threads_in_cluster] = v3;
|
|
|
|
|
|
|
|
|
|
// v0 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 4];
|
|
|
|
|
// v1 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 4];
|
|
|
|
|
// v2 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 5];
|
|
|
|
|
// v3 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 5];
|
|
|
|
|
// smem_b_tile_start[8 * num_threads_in_cluster] = v0;
|
|
|
|
|
// smem_b_tile_start[9 * num_threads_in_cluster] = v1;
|
|
|
|
|
// smem_b_tile_start[10 * num_threads_in_cluster] = v2;
|
|
|
|
|
// smem_b_tile_start[11 * num_threads_in_cluster] = v3;
|
|
|
|
|
v0 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 4];
|
|
|
|
|
v1 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 4];
|
|
|
|
|
v2 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 5];
|
|
|
|
|
v3 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 5];
|
|
|
|
|
smem_b_tile_start[8 * num_threads_in_cluster] = v0;
|
|
|
|
|
smem_b_tile_start[9 * num_threads_in_cluster] = v1;
|
|
|
|
|
smem_b_tile_start[10 * num_threads_in_cluster] = v2;
|
|
|
|
|
smem_b_tile_start[11 * num_threads_in_cluster] = v3;
|
|
|
|
|
|
|
|
|
|
v0 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 6];
|
|
|
|
|
v1 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 6];
|
|
|
|
|
@@ -282,14 +480,15 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
|
|
|
|
smem_a_tile_start[14 * num_threads_in_cluster] = v2;
|
|
|
|
|
smem_a_tile_start[15 * num_threads_in_cluster] = v3;
|
|
|
|
|
|
|
|
|
|
// v0 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 6];
|
|
|
|
|
// v1 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 6];
|
|
|
|
|
// v2 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 7];
|
|
|
|
|
// v3 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 7];
|
|
|
|
|
// smem_b_tile_start[12 * num_threads_in_cluster] = v0;
|
|
|
|
|
// smem_b_tile_start[13 * num_threads_in_cluster] = v1;
|
|
|
|
|
// smem_b_tile_start[14 * num_threads_in_cluster] = v2;
|
|
|
|
|
// smem_b_tile_start[15 * num_threads_in_cluster] = v3;
|
|
|
|
|
v0 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 6];
|
|
|
|
|
v1 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 6];
|
|
|
|
|
v2 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 7];
|
|
|
|
|
v3 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 7];
|
|
|
|
|
smem_b_tile_start[12 * num_threads_in_cluster] = v0;
|
|
|
|
|
smem_b_tile_start[13 * num_threads_in_cluster] = v1;
|
|
|
|
|
smem_b_tile_start[14 * num_threads_in_cluster] = v2;
|
|
|
|
|
smem_b_tile_start[15 * num_threads_in_cluster] = v3;
|
|
|
|
|
#endif
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
@@ -353,7 +552,7 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
|
|
|
|
|
|
|
|
|
rd_cycles(marker2);
|
|
|
|
|
// cluster wide barrier to wait for A and B loads to complete
|
|
|
|
|
threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS);
|
|
|
|
|
threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS_);
|
|
|
|
|
rd_cycles(marker3);
|
|
|
|
|
if (HW_TID() == 0) {
|
|
|
|
|
#ifdef DBUF
|
|
|
|
|
@@ -396,7 +595,7 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
|
|
|
|
}
|
|
|
|
|
rd_cycles(marker4);
|
|
|
|
|
#ifndef DBUF
|
|
|
|
|
threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS);
|
|
|
|
|
threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS_);
|
|
|
|
|
#endif
|
|
|
|
|
rd_cycles(marker5);
|
|
|
|
|
|
|
|
|
|
@@ -411,7 +610,7 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
|
|
|
|
smem_acc_tile_start[thread_i * s] = smem_c_tile_start[hw_tid + s * thread_i];
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
#if (TILE_NK / NUM_THREADS / NUM_WARPS / CORES_PER_CLUSTER) != 8
|
|
|
|
|
#if (TILE_NK / NUM_THREADS_ / NUM_WARPS_ / CORES_PER_CLUSTER_) != 8
|
|
|
|
|
#error CANNOT UNROLL
|
|
|
|
|
#endif
|
|
|
|
|
for (int thread_i = 0; thread_i < c_elems_per_thread; thread_i += 8) {
|
|
|
|
|
@@ -458,7 +657,7 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifdef OFFLOAD_ACCUMULATE
|
|
|
|
|
threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS);
|
|
|
|
|
threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS_);
|
|
|
|
|
rd_cycles(marker6);
|
|
|
|
|
// mvout to scratchpad for activation
|
|
|
|
|
if (HW_TID() == 0) {
|
|
|
|
|
@@ -474,21 +673,76 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
|
|
|
|
#endif
|
|
|
|
|
gemmini_fence();
|
|
|
|
|
}
|
|
|
|
|
threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS);
|
|
|
|
|
threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS_);
|
|
|
|
|
#endif
|
|
|
|
|
rd_cycles(marker7);
|
|
|
|
|
|
|
|
|
|
// move out to dram
|
|
|
|
|
#ifdef HARDCODE
|
|
|
|
|
// #if (TILE_MN / NUM_THREADS / NUM_WARPS / CORES_PER_CLUSTER) != 8
|
|
|
|
|
// #if (TILE_MN / NUM_THREADS / NUM_WARPS_ / CORES_PER_CLUSTER) != 8
|
|
|
|
|
// #error CANNOT UNROLL
|
|
|
|
|
// #endif
|
|
|
|
|
constexpr uint32_t every_iter = j1_stride;
|
|
|
|
|
const uint32_t every_2iters = i1_stride * dim_n;
|
|
|
|
|
const uint32_t runtime_const = i0 * dim_n + j1_idx + j0;
|
|
|
|
|
mem_elem_t * const dram_c_tile_start = (mem_elem_t * const) (C + tile_i * TILE_M * dim_n + tile_j * TILE_N + runtime_const);
|
|
|
|
|
mem_elem_t * const dram_c_tile_start = (mem_elem_t * const) (C + tile_i * TILE_M * dim_n + tile_j * TILE_N + runtime_const_b);
|
|
|
|
|
|
|
|
|
|
#ifdef REGBLOCK
|
|
|
|
|
#ifndef FP32
|
|
|
|
|
#ifdef HOPPER
|
|
|
|
|
mem_elem_t v0 = smem_acc_tile_start[0 * num_threads_in_cluster];
|
|
|
|
|
mem_elem_t v1 = smem_acc_tile_start[1 * num_threads_in_cluster];
|
|
|
|
|
mem_elem_t v2 = smem_acc_tile_start[2 * num_threads_in_cluster];
|
|
|
|
|
mem_elem_t v3 = smem_acc_tile_start[3 * num_threads_in_cluster];
|
|
|
|
|
dram_c_tile_start[every_iter * 0 + every_2iters_b * 0] = v0;
|
|
|
|
|
dram_c_tile_start[every_iter * 1 + every_2iters_b * 0] = v1;
|
|
|
|
|
dram_c_tile_start[every_iter * 0 + every_2iters_b * 1] = v2;
|
|
|
|
|
dram_c_tile_start[every_iter * 1 + every_2iters_b * 1] = v3;
|
|
|
|
|
|
|
|
|
|
v0 = smem_acc_tile_start[4 * num_threads_in_cluster];
|
|
|
|
|
v1 = smem_acc_tile_start[5 * num_threads_in_cluster];
|
|
|
|
|
v2 = smem_acc_tile_start[6 * num_threads_in_cluster];
|
|
|
|
|
v3 = smem_acc_tile_start[7 * num_threads_in_cluster];
|
|
|
|
|
dram_c_tile_start[every_iter * 0 + every_2iters_b * 2] = v0;
|
|
|
|
|
dram_c_tile_start[every_iter * 1 + every_2iters_b * 2] = v1;
|
|
|
|
|
dram_c_tile_start[every_iter * 0 + every_2iters_b * 3] = v2;
|
|
|
|
|
dram_c_tile_start[every_iter * 1 + every_2iters_b * 3] = v3;
|
|
|
|
|
|
|
|
|
|
v0 = smem_acc_tile_start[8 * num_threads_in_cluster];
|
|
|
|
|
v1 = smem_acc_tile_start[9 * num_threads_in_cluster];
|
|
|
|
|
v2 = smem_acc_tile_start[10 * num_threads_in_cluster];
|
|
|
|
|
v3 = smem_acc_tile_start[11 * num_threads_in_cluster];
|
|
|
|
|
dram_c_tile_start[every_iter * 0 + every_2iters_b * 4] = v0;
|
|
|
|
|
dram_c_tile_start[every_iter * 1 + every_2iters_b * 4] = v1;
|
|
|
|
|
dram_c_tile_start[every_iter * 0 + every_2iters_b * 5] = v2;
|
|
|
|
|
dram_c_tile_start[every_iter * 1 + every_2iters_b * 5] = v3;
|
|
|
|
|
|
|
|
|
|
v0 = smem_acc_tile_start[12 * num_threads_in_cluster];
|
|
|
|
|
v1 = smem_acc_tile_start[13 * num_threads_in_cluster];
|
|
|
|
|
v2 = smem_acc_tile_start[14 * num_threads_in_cluster];
|
|
|
|
|
v3 = smem_acc_tile_start[15 * num_threads_in_cluster];
|
|
|
|
|
dram_c_tile_start[every_iter * 0 + every_2iters_b * 6] = v0;
|
|
|
|
|
dram_c_tile_start[every_iter * 1 + every_2iters_b * 6] = v1;
|
|
|
|
|
dram_c_tile_start[every_iter * 0 + every_2iters_b * 7] = v2;
|
|
|
|
|
dram_c_tile_start[every_iter * 1 + every_2iters_b * 7] = v3;
|
|
|
|
|
#else // not HOPPER
|
|
|
|
|
mem_elem_t v0 = smem_acc_tile_start[0 * num_threads_in_cluster];
|
|
|
|
|
mem_elem_t v1 = smem_acc_tile_start[1 * num_threads_in_cluster];
|
|
|
|
|
mem_elem_t v2 = smem_acc_tile_start[2 * num_threads_in_cluster];
|
|
|
|
|
mem_elem_t v3 = smem_acc_tile_start[3 * num_threads_in_cluster];
|
|
|
|
|
dram_c_tile_start[every_iter * 0 + every_iter_b * 0] = v0;
|
|
|
|
|
dram_c_tile_start[every_iter * 0 + every_iter_b * 1] = v1;
|
|
|
|
|
dram_c_tile_start[every_iter * 0 + every_iter_b * 2] = v2;
|
|
|
|
|
dram_c_tile_start[every_iter * 0 + every_iter_b * 3] = v3;
|
|
|
|
|
|
|
|
|
|
v0 = smem_acc_tile_start[4 * num_threads_in_cluster];
|
|
|
|
|
v1 = smem_acc_tile_start[5 * num_threads_in_cluster];
|
|
|
|
|
v2 = smem_acc_tile_start[6 * num_threads_in_cluster];
|
|
|
|
|
v3 = smem_acc_tile_start[7 * num_threads_in_cluster];
|
|
|
|
|
dram_c_tile_start[every_iter * 0 + every_iter_b * 4] = v0;
|
|
|
|
|
dram_c_tile_start[every_iter * 0 + every_iter_b * 5] = v1;
|
|
|
|
|
dram_c_tile_start[every_iter * 0 + every_iter_b * 6] = v2;
|
|
|
|
|
dram_c_tile_start[every_iter * 0 + every_iter_b * 7] = v3;
|
|
|
|
|
#endif // HOPPER
|
|
|
|
|
|
|
|
|
|
#else // FP32
|
|
|
|
|
mem_elem_t v0 = smem_acc_tile_start[0 * num_threads_in_cluster];
|
|
|
|
|
mem_elem_t v1 = smem_acc_tile_start[1 * num_threads_in_cluster];
|
|
|
|
|
mem_elem_t v2 = smem_acc_tile_start[2 * num_threads_in_cluster];
|
|
|
|
|
@@ -503,10 +757,10 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
|
|
|
|
rd_cycles_force(swish_end);
|
|
|
|
|
swish_dur += swish_end - swish_start;
|
|
|
|
|
#endif
|
|
|
|
|
dram_c_tile_start[every_iter * 0 + every_2iters * 0] = v0;
|
|
|
|
|
dram_c_tile_start[every_iter * 1 + every_2iters * 0] = v1;
|
|
|
|
|
dram_c_tile_start[every_iter * 0 + every_2iters * 1] = v2;
|
|
|
|
|
dram_c_tile_start[every_iter * 1 + every_2iters * 1] = v3;
|
|
|
|
|
dram_c_tile_start[every_iter * 0 + every_2iters_b * 0] = v0;
|
|
|
|
|
dram_c_tile_start[every_iter * 1 + every_2iters_b * 0] = v1;
|
|
|
|
|
dram_c_tile_start[every_iter * 0 + every_2iters_b * 1] = v2;
|
|
|
|
|
dram_c_tile_start[every_iter * 1 + every_2iters_b * 1] = v3;
|
|
|
|
|
|
|
|
|
|
v0 = smem_acc_tile_start[4 * num_threads_in_cluster];
|
|
|
|
|
v1 = smem_acc_tile_start[5 * num_threads_in_cluster];
|
|
|
|
|
@@ -521,10 +775,10 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
|
|
|
|
rd_cycles_force(swish_end);
|
|
|
|
|
swish_dur += swish_end - swish_start;
|
|
|
|
|
#endif
|
|
|
|
|
dram_c_tile_start[every_iter * 0 + every_2iters * 2] = v0;
|
|
|
|
|
dram_c_tile_start[every_iter * 1 + every_2iters * 2] = v1;
|
|
|
|
|
dram_c_tile_start[every_iter * 0 + every_2iters * 3] = v2;
|
|
|
|
|
dram_c_tile_start[every_iter * 1 + every_2iters * 3] = v3;
|
|
|
|
|
dram_c_tile_start[every_iter * 0 + every_2iters_b * 2] = v0;
|
|
|
|
|
dram_c_tile_start[every_iter * 1 + every_2iters_b * 2] = v1;
|
|
|
|
|
dram_c_tile_start[every_iter * 0 + every_2iters_b * 3] = v2;
|
|
|
|
|
dram_c_tile_start[every_iter * 1 + every_2iters_b * 3] = v3;
|
|
|
|
|
|
|
|
|
|
// v0 = smem_acc_tile_start[8 * num_threads_in_cluster];
|
|
|
|
|
// v1 = smem_acc_tile_start[9 * num_threads_in_cluster];
|
|
|
|
|
@@ -543,6 +797,7 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
|
|
|
|
// dram_c_tile_start[every_iter * 1 + every_2iters * 6] = v1;
|
|
|
|
|
// dram_c_tile_start[every_iter * 0 + every_2iters * 7] = v2;
|
|
|
|
|
// dram_c_tile_start[every_iter * 1 + every_2iters * 7] = v3;
|
|
|
|
|
#endif // FP16/FP32
|
|
|
|
|
#else
|
|
|
|
|
dram_c_tile_start[every_iter * 0 + every_2iters * 0] = \
|
|
|
|
|
smem_acc_tile_start[0 * num_threads_in_cluster];
|
|
|
|
|
@@ -577,7 +832,8 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
|
|
|
|
}
|
|
|
|
|
// last thread block complete
|
|
|
|
|
if (threadblock_id == NUM_CLUSTERS - 1) {
|
|
|
|
|
threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS);
|
|
|
|
|
threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS_);
|
|
|
|
|
MARK_END();
|
|
|
|
|
rd_cycles_force(marker9);
|
|
|
|
|
#ifdef POWER
|
|
|
|
|
if (HW_TID() == 0) {
|
|
|
|
|
@@ -610,7 +866,7 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
|
|
|
|
#endif
|
|
|
|
|
PRINTF("dram mvout cycles: %d\n", marker8 - marker7);
|
|
|
|
|
}
|
|
|
|
|
threadblock_barrier(/*barrier_id=*/1, /*count=*/NUM_WARPS);
|
|
|
|
|
threadblock_barrier(/*barrier_id=*/1, /*count=*/NUM_WARPS_);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
if (HW_TID() == 0) {
|
|
|
|
|
@@ -623,7 +879,7 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS);
|
|
|
|
|
threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS_);
|
|
|
|
|
vx_tmc(0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -640,7 +896,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|
|
|
|
int main() {
|
|
|
|
|
kernel_arg_t *arg = (kernel_arg_t *)KERNEL_ARG_DEV_MEM_ADDR;
|
|
|
|
|
|
|
|
|
|
const uint32_t num_threads_in_cluster = vx_num_threads() * vx_num_warps() * CORES_PER_CLUSTER;
|
|
|
|
|
const uint32_t num_threads_in_cluster = NUM_THREADS_IN_CLUSTER; // vx_num_threads() * vx_num_warps() * CORES_PER_CLUSTER;
|
|
|
|
|
const uint32_t grid_size = num_threads_in_cluster * NUM_CLUSTERS;
|
|
|
|
|
#ifdef RADIANCE
|
|
|
|
|
vx_spawn_tasks_cluster(grid_size, (vx_spawn_tasks_cb)kernel_body, arg);
|
|
|
|
|
@@ -650,4 +906,4 @@ int main() {
|
|
|
|
|
vx_spawn_tasks_contiguous(grid_size, (vx_spawn_tasks_cb)kernel_body, arg);
|
|
|
|
|
#endif
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|