Files
kernels/lib/gemmini/include/gemmini.h
2025-01-29 21:22:41 -08:00

3612 lines
142 KiB
C

// See LICENSE for license details.
#ifndef SRC_MAIN_C_GEMMINI_H
#define SRC_MAIN_C_GEMMINI_H
#undef abs
#include <stdint.h>
#include <stdlib.h>
#include <stdio.h>
#include <math.h>
#include <limits.h>
#include <stdbool.h>
#include "include/gemmini_params.h"
#define GEMMINI_ASSERTIONS
// Accelerator interface
#include "rocc-software/src/xcustom.h"
// Counter Definition
#include "include/gemmini_counter.h"
#define k_CONFIG 0
#define k_MVIN2 1
#define k_MVIN 2
#define k_MVOUT 3
#define k_COMPUTE_PRELOADED 4
#define k_COMPUTE_ACCUMULATE 5
#define k_PRELOAD 6
#define k_FLUSH 7
#define k_LOOP_WS 8
#define k_LOOP_WS_CONFIG_BOUNDS 9
#define k_LOOP_WS_CONFIG_ADDRS_AB 10
#define k_LOOP_WS_CONFIG_ADDRS_DC 11
#define k_LOOP_WS_CONFIG_STRIDES_AB 12
#define k_LOOP_WS_CONFIG_STRIDES_DC 13
#define k_MVIN3 14
#define k_COUNTER 126
#define k_LOOP_CONV_WS 15
#define k_LOOP_CONV_WS_CONFIG_1 16
#define k_LOOP_CONV_WS_CONFIG_2 17
#define k_LOOP_CONV_WS_CONFIG_3 18
#define k_LOOP_CONV_WS_CONFIG_4 19
#define k_LOOP_CONV_WS_CONFIG_5 20
#define k_LOOP_CONV_WS_CONFIG_6 21
// CLKGATE_EN: 22
#define k_MVOUT_SPAD 23
#define k_LOOP_WS_CONFIG_SPAD_AB 24
#define k_LOOP_WS_CONFIG_SPAD_C 25
#define CONFIG_EX 0
#define CONFIG_LD 1
#define CONFIG_ST 2
#define CONFIG_BERT 3
#define GARBAGE_ADDR ((uint32_t)(-1))
#define OUTPUT_STATIONARY 0
#define WEIGHT_STATIONARY 1
#define NO_ACTIVATION 0
#define RELU 1
#define LAYERNORM 2
#define IGELU 3
#define SOFTMAX 4
#ifdef ELEM_T_IS_FLOAT
elem_t elem_t_bits_to_elem_t(elem_t_bits x) {
union {
elem_t_bits b;
elem_t f;
} un;
un.b = x;
return un.f;
}
elem_t_bits elem_t_to_elem_t_bits(elem_t x) {
union {
elem_t_bits b;
elem_t f;
} un;
un.f = x;
return un.b;
}
acc_t acc_t_bits_to_acc_t(acc_t_bits x) {
union {
acc_t_bits b;
acc_t f;
} un;
un.b = x;
return un.f;
}
acc_t_bits acc_t_to_acc_t_bits(acc_t x) {
union {
acc_t_bits b;
acc_t f;
} un;
un.f = x;
return un.b;
}
bool elem_t_isnan(elem_t x) {
elem_t_bits bits = elem_t_to_elem_t_bits(x);
uint64_t exp = (bits >> (ELEM_T_SIG_BITS-1)) & (((uint64_t)1 << ELEM_T_EXP_BITS) - 1);
uint64_t sig = bits & (((uint64_t)1 << ELEM_T_SIG_BITS) - 1);
bool is_nan_or_inf = exp == (((uint64_t)1 << ELEM_T_EXP_BITS) - 1);
bool is_not_inf = sig != 0;
return is_nan_or_inf && is_not_inf;
}
bool acc_t_isnan(acc_t x) {
acc_t_bits bits = acc_t_to_acc_t_bits(x);
uint64_t exp = (bits >> (ACC_T_SIG_BITS-1)) & (((uint64_t)1 << ACC_T_EXP_BITS) - 1);
uint64_t sig = bits & (((uint64_t)1 << ACC_T_SIG_BITS) - 1);
bool is_nan_or_inf = exp == (((uint64_t)1 << ACC_T_EXP_BITS) - 1);
bool is_not_inf = sig != 0;
return is_nan_or_inf && is_not_inf;
}
#endif
#ifdef HAS_MVIN_SCALE
static scale_t scale_t_bits_to_scale_t(scale_t_bits x) {
union {
scale_t_bits b;
scale_t f;
} un;
un.b = x;
return un.f;
}
static scale_t_bits scale_t_to_scale_t_bits(scale_t x) {
union {
scale_t_bits b;
scale_t f;
} un;
un.f = x;
return un.b;
}
#else
#define scale_t_to_scale_t_bits(x) 0
#endif
#ifdef HAS_MVIN_ACC_SCALE
static scale_acc_t scale_acc_t_bits_to_scale_acc_t(scale_acc_t_bits x) {
union {
scale_acc_t_bits b;
scale_acc_t f;
} un;
un.b = x;
return un.f;
}
static scale_acc_t_bits scale_acc_t_to_scale_acc_t_bits(scale_acc_t x) {
union {
scale_acc_t_bits b;
scale_acc_t f;
} un;
un.f = x;
return un.b;
}
#endif
static acc_scale_t acc_scale_t_bits_to_acc_scale_t(acc_scale_t_bits x) {
union {
acc_scale_t_bits b;
acc_scale_t f;
} un;
un.b = x;
return un.f;
}
static acc_scale_t_bits acc_scale_t_to_acc_scale_t_bits(acc_scale_t x) {
union {
acc_scale_t_bits b;
acc_scale_t f;
} un;
un.f = x;
return un.b;
}
#define ROCC_INSTRUCTION_RS1_RS2(x, rs1, rs2, funct) \
ROCC_INSTRUCTION_0_R_R(x, rs1, rs2, funct)
// mvin and mvout
#define gemmini_extended_mvin(dram_addr, spad_addr, cols, rows) \
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, dram_addr, ((uint64_t)(rows) << (ADDR_LEN + 16)) | ((uint64_t)(cols) << ADDR_LEN) | (spad_addr), k_MVIN)
#define gemmini_extended_mvin2(dram_addr, spad_addr, cols, rows) \
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, dram_addr, ((uint64_t)(rows) << (ADDR_LEN + 16)) | ((uint64_t)(cols) << ADDR_LEN) | (spad_addr), k_MVIN2)
#define gemmini_extended_mvin3(dram_addr, spad_addr, cols, rows) \
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, dram_addr, ((uint64_t)(rows) << (ADDR_LEN + 16)) | ((uint64_t)(cols) << ADDR_LEN) | (spad_addr), k_MVIN3)
#define gemmini_block_mvin(dram_addr, spad_addr, len) \
gemmini_extended_mvin(dram_addr, spad_addr, (len) * DIM, DIM)
#define gemmini_mvin(dram_addr, spad_addr) \
gemmini_extended_mvin(dram_addr, spad_addr, DIM, DIM)
#define gemmini_extended_mvout(dram_addr, spad_addr, cols, rows) \
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, dram_addr, ((uint64_t)(rows) << (ADDR_LEN + 16)) | ((uint64_t)(cols) << ADDR_LEN) | (uint64_t)(spad_addr), k_MVOUT)
#define gemmini_extended_mvout_spad(dst_addr, dst_stride, src_addr, cols, rows) \
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(dst_stride) << 32) | (uint64_t)(dst_addr), ((uint64_t)(rows) << (ADDR_LEN + 16)) | ((uint64_t)(cols) << ADDR_LEN) | (uint64_t)(src_addr), k_MVOUT_SPAD)
#define gemmini_mvout_spad(dst_addr, src_addr) \
gemmini_extended_mvout_spad(dst_addr, 1, src_addr, DIM, DIM)
#define gemmini_mvout(dram_addr, spad_addr) \
gemmini_extended_mvout(dram_addr, spad_addr, DIM, DIM)
// compute
#define gemmini_extended_compute_preloaded(A, BD, A_cols, A_rows, BD_cols, BD_rows) \
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(A_rows) << (ADDR_LEN + 16)) | ((uint64_t)(A_cols) << ADDR_LEN) | (uint64_t)(A), ((uint64_t)(BD_rows) << (ADDR_LEN + 16)) | ((uint64_t)(BD_cols) << ADDR_LEN) | (uint64_t)(BD), k_COMPUTE_PRELOADED)
#define gemmini_extended_compute_accumulated(A, BD, A_cols, A_rows, BD_cols, BD_rows) \
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(A_rows) << (ADDR_LEN + 16)) | ((uint64_t)(A_cols) << ADDR_LEN) | (uint64_t)(A), ((uint64_t)(BD_rows) << (ADDR_LEN + 16)) | ((uint64_t)(BD_cols) << ADDR_LEN) | (uint64_t)(BD), k_COMPUTE_ACCUMULATE)
#define gemmini_compute_preloaded(A, BD) \
gemmini_extended_compute_preloaded(A, BD, DIM, DIM, DIM, DIM)
#define gemmini_compute_accumulated(A, BD) \
gemmini_extended_compute_accumulated(A, BD, DIM, DIM, DIM, DIM)
// preload
#define gemmini_extended_preload(BD, C, BD_cols, BD_rows, C_cols, C_rows) \
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(BD_rows) << (ADDR_LEN + 16)) | ((uint64_t)(BD_cols) << ADDR_LEN) | (uint64_t)(BD), ((uint64_t)(C_rows) << (ADDR_LEN + 16)) | ((uint64_t)(C_cols) << ADDR_LEN) | (uint64_t)(C), k_PRELOAD)
#define gemmini_preload(BD, C) \
gemmini_extended_preload(BD, C, DIM, DIM, DIM, DIM)
#define gemmini_preload_zeros(C) \
gemmini_preload(GARBAGE_ADDR, C)
// config
#define gemmini_extended3_config_ex(dataflow, sys_act, sys_shift, sys_acc_scale, C_stride, A_stride, A_transpose, B_transpose, set_only_strides) \
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)acc_scale_t_to_acc_scale_t_bits((acc_scale_t)sys_acc_scale) << 32) | ((uint64_t)(A_stride) << 16) | (B_transpose << 9) | (A_transpose << 8) | ((set_only_strides) << 7) | ((sys_act) << 3) | ((dataflow) << 2) | CONFIG_EX, ((uint64_t)(C_stride) << 48) | (sys_shift), k_CONFIG); \
#define gemmini_extended2_config_ex(dataflow, sys_act, sys_shift, A_stride, A_transpose, B_transpose) \
gemmini_extended3_config_ex(dataflow, sys_act, sys_shift, ACC_SCALE_IDENTITY, 1, A_stride, A_transpose, B_transpose, false)
#define gemmini_extended_config_ex(dataflow, sys_act, sys_shift, A_stride, A_transpose, B_transpose) \
gemmini_extended2_config_ex(dataflow, sys_act, sys_shift, A_stride, A_transpose, B_transpose)
#define gemmini_config_ex(dataflow, sys_act, sys_shift) \
gemmini_extended_config_ex(dataflow, sys_act, sys_shift, 1, 0, 0)
// Note: The "pixel_repeats" parameter below is still experimental, andthere is
// a high chance that it will be removed in future releases.
#define gemmini_extended5_config_ld(stride, scale, shrunk, block_mvin_stride, pixel_repeats, id) \
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(scale_t_to_scale_t_bits(scale)) << 32) | ((uint64_t)(block_mvin_stride) << 16) | ((uint64_t)(pixel_repeats) << 8) | ((id) << 3) | ((shrunk) << 2) | CONFIG_LD, stride, k_CONFIG)
#define gemmini_extended4_config_ld(stride, scale, shrunk, block_mvin_stride, id) \
gemmini_extended5_config_ld(stride, scale, shrunk, block_mvin_stride, 1, id) \
#define gemmini_extended3_config_ld(stride, scale, shrunk, id) \
gemmini_extended4_config_ld(stride, scale, shrunk, DIM, id)
#define gemmini_extended2_config_ld(stride, scale, shrunk) \
gemmini_extended3_config_ld(stride, scale, shrunk, 0)
#define gemmini_extended_config_ld(stride, scale) \
gemmini_extended2_config_ld(stride, scale, false)
#define gemmini_config_ld(stride) \
gemmini_extended_config_ld(stride, MVIN_SCALE_IDENTITY)
#define gemmini_extended2_config_st(stride, acc_act, acc_scale, pool_stride, pool_size, pool_out_dim, porows, pocols, orows, ocols, upad, lpad) \
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(ocols) << 56) | ((uint64_t)(orows) << 48) | ((uint64_t)(pocols) << 40) | ((uint64_t)(porows) << 32) | ((uint64_t)(pool_out_dim) << 24) | ((uint64_t)(lpad) << 10) | ((uint64_t)(upad) << 8) | ((uint64_t)(pool_size) << 6) | ((uint64_t)(pool_stride) << 4) | ((uint64_t)(acc_act) << 2) | CONFIG_ST, ((uint64_t)acc_scale_t_to_acc_scale_t_bits((acc_scale_t)acc_scale) << 32) | ((uint32_t)stride), k_CONFIG)
#define gemmini_extended_config_st(stride, acc_act, acc_scale) \
gemmini_extended2_config_st(stride, acc_act, acc_scale, 0, 0, 0, 0, 0, 0, 0, 0, 0)
#define gemmini_config_st(stride) \
gemmini_extended_config_st(stride, NO_ACTIVATION, ACC_SCALE_IDENTITY)
#define gemmini_config_norm(q_const, q_const_type, set_stats_id_only, act_msb, stat_id, igelu_qb, igelu_qc) \
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (((uint64_t) ((uint32_t) q_const)) << 32) | ((q_const_type & 1) << 18) | ((set_stats_id_only & 1) << 17) | ((act_msb & 1) << 16) | ((uint64_t)stat_id << 8) | CONFIG_BERT, ((uint64_t)((uint32_t)(igelu_qc)) << 32) | ((uint64_t)((uint32_t)(igelu_qb))), k_CONFIG)
// flush
#define gemmini_flush(skip) \
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, skip, 0, k_FLUSH)
// fence
#define gemmini_fence() asm volatile("fence")
// Counter access
#define gemmini_counter_access(rd, config_reg) \
{ \
uint32_t _placeholder; \
ROCC_INSTRUCTION(XCUSTOM_ACC, rd, config_reg, _placeholder, k_COUNTER) \
}
// Read counter
static uint32_t counter_read(size_t index) {
uint32_t config_reg = (index & 0x7) << 4;
uint32_t res;
gemmini_counter_access(res, config_reg);
return res;
}
// Configure counter to take a new signal
static void counter_configure(size_t index, size_t counter_code) {
int non_incremental = counter_code > INCREMENTAL_COUNTERS;
if (non_incremental) {
counter_code -= INCREMENTAL_COUNTERS;
}
uint32_t config_reg = (index & 0x7) << 4 | 0x8 | (counter_code & 0x3f) << 12 | non_incremental << 31;
uint32_t placeholder;
gemmini_counter_access(placeholder, config_reg);
}
// Take a snapshot
static void counter_snapshot_take() {
uint32_t config_reg = 0x4;
uint32_t placeholder;
gemmini_counter_access(placeholder, config_reg);
}
// Counter snapshot reset
static void counter_snapshot_reset() {
uint32_t config_reg = 0x2;
uint32_t placeholder;
gemmini_counter_access(placeholder, config_reg);
}
// Counter module reset
static void counter_reset() {
uint32_t config_reg = 0x1;
uint32_t placeholder;
gemmini_counter_access(placeholder, config_reg);
}
int ceil_divide_int(int a, int b){
int c = (a % b == 0) ? ((int)(a/b)) :(((int)(a/b)) + 1);
if(a < b) c = 1;
return c;
}
// weight-stationary matmul loop
#define gemmini_loop_ws(I, J, K, pad_I, pad_J, pad_K, A, B, D, C, A_stride, B_stride, D_stride, C_stride, A_transpose, B_transpose, full_C, low_D, ex_accumulate, act, a_spad_id, b_spad_id, is_resadd) \
{ \
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(pad_K) << 32) | ((uint64_t)(pad_J) << 16) | (uint64_t)(pad_I), ((uint64_t)(K) << 32) | ((uint64_t)(J) << 16) | (uint64_t)(I), k_LOOP_WS_CONFIG_BOUNDS) \
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, A, B, k_LOOP_WS_CONFIG_ADDRS_AB) \
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, D, C, k_LOOP_WS_CONFIG_ADDRS_DC) \
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, A_stride, B_stride, k_LOOP_WS_CONFIG_STRIDES_AB) \
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, D_stride, C_stride, k_LOOP_WS_CONFIG_STRIDES_DC) \
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(a_spad_id) << 18) | ((uint64_t)(b_spad_id) << 16) | ((uint64_t)(act) << 8) | ((low_D) << 2) | ((full_C) << 1) | (ex_accumulate), ((is_resadd) << 2) | ((B_transpose) << 1) | (A_transpose), k_LOOP_WS) \
}
#define gemmini_loop_ws_spad(I, J, K, pad_I, pad_J, pad_K, A, B, D, C, A_transpose, B_transpose, full_C, low_D, ex_accumulate, act, a_spad_id, b_spad_id, is_resadd, skips) \
{ \
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(pad_K) << 32) | ((uint64_t)(pad_J) << 16) | (uint64_t)(pad_I), ((uint64_t)(K) << 32) | ((uint64_t)(J) << 16) | (uint64_t)(I), k_LOOP_WS_CONFIG_BOUNDS) \
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, A, B, k_LOOP_WS_CONFIG_SPAD_AB) \
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(a_spad_id) << 18) | ((uint64_t)(b_spad_id) << 16) | ((uint64_t)(act) << 8) | ((low_D) << 2) | ((full_C) << 1) | (ex_accumulate), ((uint64_t)(C) << 32) | 0x200U | (skips) | ((is_resadd) << 2) | ((B_transpose) << 1) | (A_transpose), k_LOOP_WS) \
}
// weight-stationary conv loop
#define gemmini_loop_conv_ws(batch_size, in_row_dim, in_col_dim, in_channels, out_channels, out_row_dim, out_col_dim, pool_out_row_dim, pool_out_col_dim, stride, padding, kernel_dim, kernel_dilation, pool_size, pool_stride, pool_padding, batches, porows, pocols, pochs, krows, kcols, kchs, lpad, rpad, upad, dpad, plpad, prpad, pupad, pdpad, orows, ocols, weights, output, bias, input, no_bias, no_pool, downsample, wrot180, input_dilated, activation, trans_output_1203, trans_weight_1203, trans_weight_0132, trans_input_3120, max_pixels_per_row, in_stride, weight_stride, out_stride, dw, a_spad_id, b_spad_id) \
{ \
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(out_channels) << 48) | ((uint64_t)(in_channels) << 32) | ((uint64_t)(in_row_dim) << 16) | (uint64_t)(batch_size), \
((uint64_t)(padding) << 56) | ((uint64_t)(stride) << 48) | ((uint64_t)(out_col_dim) << 32) | ((uint64_t)(pool_out_row_dim) << 16) | (uint64_t)(out_row_dim), k_LOOP_CONV_WS_CONFIG_1) \
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(kernel_dim) << 48) | ((uint64_t)(pool_out_col_dim) << 32) | ((uint64_t)(pool_size) << 16) | ((uint64_t)(pool_stride) << 8) | (uint64_t)(pool_padding), \
((uint64_t)(batches) << 48) | ((uint64_t)(porows) << 32) | ((uint64_t)(pocols) << 16) | (uint64_t)(pochs), k_LOOP_CONV_WS_CONFIG_2) \
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(krows) << 48) | ((uint64_t)(kcols) << 32) | ((uint64_t)(kchs) << 16) | (uint64_t)(lpad), \
((uint64_t)(rpad) << 48) | ((uint64_t)(upad) << 32) | ((uint64_t)(dpad) << 24) | ((uint64_t)(plpad) << 16) | ((uint64_t)(in_col_dim)), k_LOOP_CONV_WS_CONFIG_3) \
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(orows) << 48) | ((uint64_t)(prpad) << 32) | ((uint64_t)(pupad) << 21) | ((uint64_t)(pdpad) << 10) | (uint64_t)(kernel_dilation), \
((uint64_t)(in_stride) << 48) | ((uint64_t)(weight_stride) << 32) | ((uint64_t)(out_stride) << 16) | (uint64_t)(ocols), k_LOOP_CONV_WS_CONFIG_4) \
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, weights, \
output, k_LOOP_CONV_WS_CONFIG_5) \
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, bias, \
input, k_LOOP_CONV_WS_CONFIG_6) \
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(a_spad_id) << 18) | ((uint64_t)(b_spad_id) << 16) | ((uint64_t)(max_pixels_per_row) << 8) | ((dw) << 6) | ((trans_input_3120) << 5) | ((trans_weight_0132) << 4) | ((trans_weight_1203) << 3) | ((trans_output_1203) << 2) | ((wrot180) << 1) | (no_bias), \
((activation) << 3)| ((input_dilated) << 2) | ((downsample) << 1) | (no_pool), \
k_LOOP_CONV_WS) \
}
// Tiling functions
static void sp_tiled_matmul_os(const elem_t * A, const elem_t * B, const void * D, void * C,
scale_t A_scale_factor, scale_t B_scale_factor, scale_acc_t D_scale_factor,
size_t I, size_t J, size_t K, size_t pad_I, size_t pad_J, size_t pad_K,
size_t A_row_stride, size_t B_row_stride, size_t D_row_stride, size_t C_row_stride,
bool a_transpose, bool b_transpose,
bool full_C, bool low_D,
bool no_bias, bool repeating_bias,
int act,
int a_spad_id, int b_spad_id) {
const uint32_t A_sp_addr_start = 0;
const uint32_t B_sp_addr_start = BANK_NUM * BANK_ROWS - K * J * DIM;
const uint32_t D_sp_addr_start = 1 << (ADDR_LEN-1);
const uint32_t C_sp_addr_start = (3 << (ADDR_LEN-2)) | (full_C << (ADDR_LEN-3));
const int A_blocks = K <= MAX_BLOCK_LEN ? K : MAX_BLOCK_LEN;
const int B_blocks = J <= MAX_BLOCK_LEN ? J : MAX_BLOCK_LEN;
const int D_blocks = J <= MAX_BLOCK_LEN_ACC ? J : MAX_BLOCK_LEN_ACC;
// Move-in D
if (D != NULL && !no_bias) {
const size_t D_stride = repeating_bias ? 0 : D_row_stride * sizeof(acc_t);
gemmini_extended_config_ld(D_stride, D_scale_factor);
for (size_t i = 0; i < I; i++) {
for (size_t j = 0; j < J; j += D_blocks) {
const size_t bias_row = repeating_bias ? 0 : i;
const acc_t * const D_dram_addr = (acc_t *)D + (bias_row * D_row_stride + j)*DIM;
const uint32_t D_sp_addr_acc = D_sp_addr_start + (i*J + j)*DIM;
const size_t blocks = j + D_blocks <= J ? D_blocks : J-j;
const size_t cols = blocks * DIM - (j + blocks >= J ? pad_J : 0);
const size_t rows = DIM - (i == I-1 ? pad_I : 0);
gemmini_extended_mvin(D_dram_addr, D_sp_addr_acc, cols, rows);
}
}
}
// Move-in B
gemmini_extended_config_ld(B_row_stride * sizeof(elem_t), B_scale_factor);
for (size_t j = 0; j < J; j += B_blocks) {
for (size_t k = 0; k < K; k++) {
const elem_t * const B_dram_addr = B + (k*B_row_stride + j)*DIM;
const uint32_t B_sp_addr = B_sp_addr_start + (k*J + j)*DIM;
const size_t blocks = j + B_blocks <= J ? B_blocks : J-j;
const size_t cols = blocks * DIM - (j + blocks >= J ? pad_J : 0);
const size_t rows = DIM - (k == K-1 ? pad_K : 0);
gemmini_extended_mvin(B_dram_addr, B_sp_addr, cols, rows);
}
}
// Move-in A
gemmini_extended_config_ld(A_row_stride * sizeof(elem_t), A_scale_factor);
for (size_t i = 0; i < I; i++) {
for (size_t k = 0; k < K; k += A_blocks) {
const elem_t * const A_dram_addr = A + (i*A_row_stride + k)*DIM;
const uint32_t A_sp_addr = A_sp_addr_start + (i*K + k)*DIM;
const size_t blocks = k + A_blocks <= K ? A_blocks : K-k;
const size_t cols = blocks * DIM - (k + blocks >= K ? pad_K : 0);
const size_t rows = DIM - (i == I-1 ? pad_I : 0);
gemmini_extended_mvin(A_dram_addr, A_sp_addr, cols, rows);
}
}
for (size_t i = 0; i < I; i++) {
for (size_t j = 0; j < J; j++) {
const uint32_t C_sp_addr = C_sp_addr_start + (i*J + j)*DIM;
for (size_t k = 0; k < K; k++) {
const uint32_t A_sp_addr = A_sp_addr_start + (i*K + k)*DIM;
const uint32_t B_sp_addr = B_sp_addr_start + (k*J + j)*DIM;
uint32_t out_sp_addr = k == K-1 ? C_sp_addr : GARBAGE_ADDR;
// If we're not using a bias, then we want to overwrite what's in the
// accumulator, rather than writing over it
int no_bias_new_matrix = no_bias && D != NULL && k == K-1;
if (no_bias_new_matrix) {
out_sp_addr &= ~(1 << (ADDR_LEN-2));
}
const size_t A_cols = DIM - (k == K - 1 ? pad_K : 0);
const size_t A_rows = DIM - (i == I - 1 ? pad_I : 0);
const size_t B_cols = DIM - (j == J - 1 ? pad_J : 0);
const size_t B_rows = DIM - (k == K - 1 ? pad_K : 0);
const size_t C_cols = DIM - (j == J - 1 ? pad_J : 0);
const size_t C_rows = DIM - (i == I - 1 ? pad_I : 0);
gemmini_extended_preload(GARBAGE_ADDR, out_sp_addr, DIM, DIM, C_cols, C_rows);
if (k == 0) { // First iteration
gemmini_extended_compute_preloaded(A_sp_addr, B_sp_addr, A_cols, A_rows, B_cols, B_rows);
} else { // All other iterations
gemmini_extended_compute_accumulated(A_sp_addr, B_sp_addr, A_cols, A_rows, B_cols, B_rows);
}
}
}
}
// Move-out C
if (C != NULL) {
const size_t sizeof_C = full_C ? sizeof(acc_t) : sizeof(elem_t);
for (size_t i = 0; i < I; i++) {
for (size_t j = 0; j < J; j++) {
void * const C_dram_addr = (int8_t*)C + (i*C_row_stride + j)*DIM*sizeof_C;
const uint32_t C_sp_addr = C_sp_addr_start + (i*J + j)*DIM;
const size_t C_cols = DIM - (j == J - 1 ? pad_J : 0);
const size_t C_rows = DIM - (i == I - 1 ? pad_I : 0);
gemmini_extended_mvout(C_dram_addr, C_sp_addr, C_cols, C_rows);
}
}
}
}
static void sp_tiled_matmul_ws(const elem_t * A, const elem_t * B,
const void * D, void * C,
scale_t A_scale_factor, scale_t B_scale_factor, scale_acc_t D_scale_factor,
size_t I, size_t J, size_t K, size_t pad_I, size_t pad_J, size_t pad_K,
size_t A_row_stride, size_t B_row_stride, size_t D_row_stride, size_t C_row_stride,
bool a_transpose, bool b_transpose,
bool full_C, bool low_D,
bool no_bias, bool repeating_bias,
int act,
int a_spad_id, int b_spad_id) {
/*
const uint32_t A_sp_addr_start = 0;
const uint32_t B_sp_addr_start = BANK_NUM * BANK_ROWS - K * J * DIM;
const uint32_t D_sp_addr_start = 1 << (ADDR_LEN-1);
const uint32_t C_sp_addr_start = 3 << (ADDR_LEN-2) | (full_C << (ADDR_LEN-3));
const int A_blocks = a_transpose ? (I <= MAX_BLOCK_LEN ? I : MAX_BLOCK_LEN) :
(K <= MAX_BLOCK_LEN ? K : MAX_BLOCK_LEN);
const int B_blocks = b_transpose ? (K <= MAX_BLOCK_LEN ? K : MAX_BLOCK_LEN) :
(J <= MAX_BLOCK_LEN ? J : MAX_BLOCK_LEN);
const int D_blocks = low_D ? (J <= MAX_BLOCK_LEN ? J : MAX_BLOCK_LEN) :
(J <= MAX_BLOCK_LEN_ACC ? J : MAX_BLOCK_LEN_ACC);
const int C_blocks = full_C ? 1 : (J <= MAX_BLOCK_LEN ? J : MAX_BLOCK_LEN);
const size_t sizeof_D = low_D ? sizeof(elem_t) : sizeof(acc_t);
const size_t sizeof_C = full_C ? sizeof(acc_t) : sizeof(elem_t);
// Move-in D
if (D != NULL && !no_bias) {
for (size_t i = 0; i < I; i++) {
const size_t rows = DIM - (i == I-1 ? pad_I : 0);
for (size_t j = 0; j < J; j += D_blocks) {
const size_t bias_row = repeating_bias ? 0 : i;
const void * const D_dram_addr = (int8_t *)D + (bias_row * D_row_stride + j)*DIM*sizeof_D;
const uint32_t D_sp_addr_acc = D_sp_addr_start + (i*J + j)*DIM;
size_t blocks = j + D_blocks <= J ? D_blocks : J-j;
const size_t cols = blocks * DIM - (j + blocks >= J ? pad_J : 0);
gemmini_extended_mvin3(D_dram_addr, D_sp_addr_acc, cols, rows);
}
}
}
for (size_t k = 0; k < K; k++) {
for (size_t j = 0; j < J; j++) {
for (size_t i = 0; i < I; i++) {
const uint32_t A_sp_addr = a_transpose ? (A_sp_addr_start + (k*I + i)*DIM) :
(A_sp_addr_start + (i*K + k)*DIM);
const uint32_t B_sp_addr = b_transpose ? (B_sp_addr_start + (j*K + k)*DIM) :
(B_sp_addr_start + (k*J + j)*DIM);
const uint32_t C_sp_addr = C_sp_addr_start + (i*J + j)*DIM;
// Mvin A
if (a_transpose) {
if (j == 0 && i % A_blocks == 0) {
const elem_t * const A_dram_addr = A + (k*A_row_stride + i)*DIM;
const size_t blocks = i + A_blocks <= I ? A_blocks : I-i;
const size_t cols = blocks * DIM - (i + blocks >= I ? pad_I : 0);
const size_t rows = DIM - (k == K-1 ? pad_K : 0);
gemmini_extended_mvin(A_dram_addr, A_sp_addr, cols, rows);
}
} else {
if (j == 0 && k % A_blocks == 0) {
const elem_t * const A_dram_addr = A + (i*A_row_stride + k)*DIM;
const size_t blocks = k + A_blocks <= K ? A_blocks : K-k;
const size_t cols = blocks * DIM - (k + blocks >= K ? pad_K : 0);
const size_t rows = DIM - (i == I-1 ? pad_I : 0);
gemmini_extended_mvin(A_dram_addr, A_sp_addr, cols, rows);
}
}
// Mvin B
if (b_transpose) {
if (i == 0 && k % B_blocks == 0) {
const elem_t * const B_dram_addr = B + (j*B_row_stride + k)*DIM;
const size_t blocks = k + B_blocks <= K ? B_blocks : K-k;
const size_t cols = blocks * DIM - (k + blocks >= K ? pad_K : 0);
const size_t rows = DIM - (j == J-1 ? pad_J : 0);
gemmini_extended_mvin2(B_dram_addr, B_sp_addr, cols, rows);
}
} else {
if (i == 0 && j % B_blocks == 0) {
const elem_t * const B_dram_addr = B + (k*B_row_stride + j)*DIM;
const size_t blocks = j + B_blocks <= J ? B_blocks : J-j;
const size_t cols = blocks * DIM - (j + blocks >= J ? pad_J : 0);
const size_t rows = DIM - (k == K-1 ? pad_K : 0);
gemmini_extended_mvin2(B_dram_addr, B_sp_addr, cols, rows);
}
}
// Compute
{
uint32_t pre_sp_addr = i == 0 ? B_sp_addr : GARBAGE_ADDR;
uint32_t out_sp_addr = C_sp_addr;
// If we're not using a bias, then we want to overwrite what's in the
// accumulator, rather than writing over it
int no_bias_new_matrix = no_bias && D != NULL && k == 0;
if (no_bias_new_matrix) {
out_sp_addr &= ~(1 << (ADDR_LEN-2));
}
const size_t A_cols = DIM - (k == K - 1 ? pad_K : 0);
const size_t A_rows = DIM - (i == I - 1 ? pad_I : 0);
const size_t B_cols = DIM - (j == J - 1 ? pad_J : 0);
const size_t B_rows = DIM - (k == K - 1 ? pad_K : 0);
const size_t C_cols = DIM - (j == J - 1 ? pad_J : 0);
const size_t C_rows = DIM - (i == I - 1 ? pad_I : 0);
gemmini_extended_preload(pre_sp_addr, out_sp_addr, B_cols, B_rows, C_cols, C_rows);
if (i == 0) { // First iteration
gemmini_extended_compute_preloaded(A_sp_addr, GARBAGE_ADDR, A_cols, A_rows, DIM, DIM);
} else { // All other iterations
gemmini_extended_compute_accumulated(A_sp_addr, GARBAGE_ADDR, A_cols, A_rows, DIM, DIM);
}
}
if (C != NULL && k == K-1) {
// Move-out C (if not normalizing)
if (((act != LAYERNORM) && (act != SOFTMAX)) && (j == J-1 || j % C_blocks == C_blocks-1)) {
const size_t rounded_j = (j / C_blocks) * C_blocks;
const uint32_t rounded_C_sp_addr = C_sp_addr_start + (i*J + rounded_j)*DIM;
void * const C_dram_addr = (int8_t*)C + (i*C_row_stride + rounded_j)*DIM*sizeof_C;
const size_t blocks = rounded_j + C_blocks <= J ? C_blocks : J-rounded_j;
const size_t cols = blocks * DIM - (rounded_j + blocks >= J ? pad_J : 0);
const size_t rows = DIM - (i == I - 1 ? pad_I : 0);
gemmini_extended_mvout(C_dram_addr, rounded_C_sp_addr, cols, rows);
}
// Move-out C (if normalizing)
if (act == LAYERNORM && j == J - 1) {
uint32_t norm_cmds[][2] = {{1,2},{3,4},{0,0}};
const int norm_cmds_size = sizeof(norm_cmds) / sizeof(norm_cmds[0]);
const size_t rows = DIM - (i == I-1 ? pad_I : 0);
for (size_t row = 0; row < rows; row += NORM_STAT_IDS) {
const size_t stat_ids = rows - row > NORM_STAT_IDS ?
NORM_STAT_IDS : rows - row;
for (int cmd = 0; cmd < norm_cmds_size; cmd++) {
for (size_t stat_id = 0; stat_id < stat_ids; stat_id++) {
gemmini_config_norm(0, 0, 0, 0, stat_id, 0, 0);
const size_t r = row + stat_id;
for (size_t jj = 0; jj < J; jj += C_blocks) {
uint32_t norm_C_sp_addr = C_sp_addr_start + (i*J + jj)*DIM + r;
if (jj + C_blocks >= J) {
norm_C_sp_addr |= (norm_cmds[cmd][1] << 26); // Final mean/inv-std-dev calculation
} else {
norm_C_sp_addr |= (norm_cmds[cmd][0] << 26); // Accumulate sum/variance
}
void * const C_dram_addr = (int8_t*)C +
(i*C_row_stride + jj) * DIM * sizeof_C +
r * C_row_stride * sizeof_C;
const size_t blocks = jj + C_blocks <= J ? C_blocks : J-jj;
const size_t cols = blocks * DIM - (jj + blocks >= J ? pad_J : 0);
gemmini_extended_mvout(C_dram_addr, norm_C_sp_addr, cols, 1);
}
}
}
}
} else if (act == SOFTMAX && j == J - 1) {
uint32_t norm_cmds[][2] = {{5,5},{6,7},{0,0}};
const int norm_cmds_size = sizeof(norm_cmds) / sizeof(norm_cmds[0]);
const size_t rows = DIM - (i == I-1 ? pad_I : 0);
for (size_t row = 0; row < rows; row += NORM_STAT_IDS) {
const size_t stat_ids = rows - row > NORM_STAT_IDS ?
NORM_STAT_IDS : rows - row;
for (int cmd = 0; cmd < norm_cmds_size; cmd++) {
for (size_t stat_id = 0; stat_id < stat_ids; stat_id++) {
// set stat id only
gemmini_config_norm(0, 0, 1, 0, stat_id, 0, 0);
const size_t r = row + stat_id;
for (size_t jj = 0; jj < J; jj += C_blocks) {
uint32_t norm_C_sp_addr = C_sp_addr_start + (i*J + jj)*DIM + r;
if (jj + C_blocks >= J) {
norm_C_sp_addr |= (norm_cmds[cmd][1] << 26); // Final mean/inv-std-dev calculation
} else {
norm_C_sp_addr |= (norm_cmds[cmd][0] << 26); // Accumulate sum/variance
}
void * const C_dram_addr = (int8_t*)C +
(i*C_row_stride + jj) * DIM * sizeof_C +
r * C_row_stride * sizeof_C;
const size_t blocks = jj + C_blocks <= J ? C_blocks : J-jj;
const size_t cols = blocks * DIM - (jj + blocks >= J ? pad_J : 0);
gemmini_extended_mvout(C_dram_addr, norm_C_sp_addr, cols, 1);
}
}
}
}
}
}
}
}
}
*/
// Combined loop
gemmini_loop_ws(I, J, K, pad_I, pad_J, pad_K, A, B, no_bias ? NULL : D, C,
A_row_stride, B_row_stride, repeating_bias ? 0 : D_row_stride, C_row_stride,
a_transpose, b_transpose,
full_C, low_D, !no_bias || D == NULL,
act, a_spad_id, b_spad_id, false);
}
static void tiled_matmul_outer(size_t dim_I, size_t dim_J, size_t dim_K,
const elem_t* A, const elem_t* B,
const void * D, void * C,
size_t stride_A, size_t stride_B, size_t stride_D, size_t stride_C,
scale_t A_scale_factor, scale_t B_scale_factor, scale_acc_t D_scale_factor,
size_t tile_I, size_t tile_J, size_t tile_K,
int act, acc_scale_t scale, acc_scale_t bert_scale,
bool repeating_bias,
bool a_transpose, bool b_transpose,
bool full_C, bool low_D,
uint8_t weightA,
int dataflow) {
const size_t dim_I_padded = (dim_I / DIM + (dim_I % DIM != 0)) * DIM;
const size_t dim_J_padded = (dim_J / DIM + (dim_J % DIM != 0)) * DIM;
const size_t dim_K_padded = (dim_K / DIM + (dim_K % DIM != 0)) * DIM;
const size_t I0 = dim_I_padded / (tile_I*DIM) + (dim_I_padded % (tile_I*DIM) != 0);
const size_t J0 = dim_J_padded / (tile_J*DIM) + (dim_J_padded % (tile_J*DIM) != 0);
const size_t K0 = dim_K_padded / (tile_K*DIM) + (dim_K_padded % (tile_K*DIM) != 0);
// These lines here are supposed to help us deal with when the dimensions of
// the systolic array aren't divisible by the tiling factors
const size_t last_I = dim_I_padded % (tile_I*DIM) == 0 ? tile_I : (dim_I_padded/DIM) % tile_I;
const size_t last_J = dim_J_padded % (tile_J*DIM) == 0 ? tile_J : (dim_J_padded/DIM) % tile_J;
const size_t last_K = dim_K_padded % (tile_K*DIM) == 0 ? tile_K : (dim_K_padded/DIM) % tile_K;
// These lines are supposed to figure out how much padding the hardware is
// supposed to add for the final tile
const size_t padding_I = dim_I_padded - dim_I;
const size_t padding_J = dim_J_padded - dim_J;
const size_t padding_K = dim_K_padded - dim_K;
const bool no_bias = D == NULL;
if (no_bias) {
D = (void*) 1; // Dummy address which isn't NULL
}
const size_t sizeof_D = low_D ? sizeof(elem_t) : sizeof(acc_t) ;
const size_t sizeof_C = full_C ? sizeof(acc_t) : sizeof(elem_t);
gemmini_extended_config_ex(dataflow, act & 3, 0, 1, a_transpose, b_transpose);
gemmini_extended_config_st(stride_C * sizeof_C, act & 3, scale);
gemmini_extended3_config_ld(stride_A * sizeof(elem_t), A_scale_factor, false, 0);
gemmini_extended3_config_ld(stride_B * sizeof(elem_t), B_scale_factor, false, 1)
gemmini_extended3_config_ld(repeating_bias ? 0 : (stride_D * sizeof_D), D_scale_factor, low_D, 2);
if (act == IGELU) {
const acc_scale_t sqrt_2 = 1.41421356237;
const acc_scale_t S = bert_scale;
const acc_scale_t S_erf = (-0.2888 * ((S*S) / 2));
const acc_t qb = -1.769 / (S / sqrt_2);
const acc_t qc = 1.0 / S_erf;
gemmini_config_norm(0, 0, 0, 0, 0, qb, qc);
}
if (act == SOFTMAX) {
const scale_t a = 0.3585;
const scale_t b = 1.353;
const scale_t c = 0.344;
const acc_t qln2 = (int) (0.693147 / bert_scale);
const acc_t qln2_inv = 65536 / qln2;
const acc_t qb = b / bert_scale;
const acc_t qc = c / (a*bert_scale*bert_scale);
gemmini_config_norm(qln2, 0, 0, 1, 0, qb, qc);
gemmini_config_norm(qln2_inv, 1, 0, 1, 0, qb, qc);
}
void (*inner)(const elem_t *, const elem_t *, const void *, void *,
scale_t, scale_t, scale_acc_t,
size_t, size_t, size_t, size_t, size_t, size_t,
size_t, size_t, size_t, size_t,
bool, bool,
bool, bool,
bool, bool,
int, int, int);
if (dataflow == OUTPUT_STATIONARY) {
inner = &sp_tiled_matmul_os;
} else /* if (dataflow == WEIGHT_STATIONARY) */ {
inner = &sp_tiled_matmul_ws;
}
// reuse operand if it fits scratchpad
int a_spad_id = 0;
int b_spad_id = 0;
bool b_reuse = (J0 * K0 <= 2) && (dataflow == WEIGHT_STATIONARY);
bool a_reuse = (I0 * K0 <= 2) && (dataflow == WEIGHT_STATIONARY);
for (size_t i0 = 0; i0 < I0; i0++)
for (size_t j0 = 0; j0 < J0; j0++)
for (size_t k0 = 0; k0 < K0; k0++) {
if(a_reuse)
a_spad_id = ((i0+k0) == 0) ? 1 : 2;
if(b_reuse)
b_spad_id = ((j0+k0) == 0) ? 1 : 2;
const void * pre;
if (k0 != 0) {
pre = NULL;
} else {
size_t bias_row = repeating_bias ? 0 : i0*tile_I*DIM;
// pre = &(((acc_t*)D)[bias_row * stride_D + j0 * tile_J * DIM]);
pre = (int8_t*)D + (bias_row * stride_D + j0 * tile_J * DIM)*sizeof_D;
}
void * out = k0 == K0-1 ? (int8_t*)C + (i0*tile_I*DIM*stride_C + j0*tile_J*DIM)*sizeof_C : NULL;
const size_t I = i0 < I0-1 ? tile_I : last_I;
const size_t J = j0 < J0-1 ? tile_J : last_J;
const size_t K = k0 < K0-1 ? tile_K : last_K;
const size_t pad_I = i0 == I0-1 ? padding_I : 0;
const size_t pad_J = j0 == J0-1 ? padding_J : 0;
const size_t pad_K = k0 == K0-1 ? padding_K : 0;
const elem_t * a = a_transpose ? (A + k0*tile_K*DIM*stride_A + i0*tile_I*DIM)
: (A + i0*tile_I*DIM*stride_A + k0*tile_K*DIM);
const elem_t * b = b_transpose ? (B + j0*tile_J*DIM*stride_B + k0*tile_K*DIM)
: (B + k0*tile_K*DIM*stride_B + j0*tile_J*DIM);
if(a_reuse && j0 >= 1) a = NULL;
if(b_reuse && i0 >= 1) b = NULL;
//printf("a_reuse: %d, b_reuse: %d, a_spad_id: %d, b_spad_id: %d, a: %llu, b: %llu \n", a_reuse, b_reuse, a_spad_id, b_spad_id, a, b);
(*inner)(a, b, pre, out,
A_scale_factor, B_scale_factor, D_scale_factor,
I, J, K,
pad_I, pad_J, pad_K,
stride_A, stride_B, stride_D, stride_C,
a_transpose, b_transpose,
full_C, low_D,
no_bias, repeating_bias,
act, a_spad_id, b_spad_id);
}
gemmini_fence();
}
static acc_t int_sqrt(acc_t n) {
if (n == 0) return 0;
int bits = 0;
for (acc_t x = n; x > 0; x /= 2)
bits++;
acc_t x_prev = 1 << ((bits + 1) / 2);
while (1) {
acc_t x_next = (x_prev + n / x_prev) / 2;
if (x_next >= x_prev) return x_prev;
x_prev = x_next;
};
}
static elem_t scale_and_sat(acc_t x, int act, acc_scale_t scale, acc_scale_t bert_scale) {
// Apply I-GELU if needed
if (act == IGELU) {
const acc_scale_t sqrt_2 = 1.41421356237;
const acc_scale_t S = bert_scale;
const acc_scale_t S_erf = (-0.2888 * (S/sqrt_2)*(S/sqrt_2));
const acc_t q1 = 1 / S_erf;
const acc_t qb = -1.769 / (S / sqrt_2);
const acc_t qc = 1.0 / (-0.2888 * (S / sqrt_2) * (S / sqrt_2));
const acc_t q = x;
const acc_t q_sign = q < 0 ? -1 : 1;
const acc_t q_clipped = abs(q) > (-qb) ? (-qb) : abs(q);
const acc_t q_poly = (q_clipped + qb)*(q_clipped + qb) + qc;
const acc_t q_erf = q_sign * q_poly;
x = q * (q_erf + q1);
}
// Scale value down and round it
x = ACC_SCALE(x, scale);
// Clip result
x = x > elem_t_max ? elem_t_max : (x < elem_t_min ? elem_t_min : x);
// Apply activation function
if (act == RELU) {
x = x < 0 ? 0 : x;
}
return x;
}
#ifdef HAS_MVIN_SCALE
#define GEMMINI_SCALE(x, scale) MVIN_SCALE((x), (scale))
#else
#define GEMMINI_SCALE(x, scale) (x)
#endif
#ifdef HAS_MVIN_ACC_SCALE
#define GEMMINI_ACC_SCALE(x, scale) MVIN_SCALE_ACC((x), (scale))
#else
#define GEMMINI_ACC_SCALE(x, scale) (x)
#endif
static void matmul_cpu(bool transA, bool transB, size_t DIM_I, size_t DIM_J, size_t DIM_K,
const elem_t* A, const elem_t* B, const acc_t * D,
elem_t* C,
size_t stride_A, size_t stride_B, size_t stride_D, size_t stride_C,
scale_t A_scale_factor, scale_t B_scale_factor, scale_acc_t D_scale_factor,
int act, acc_scale_t scale, acc_scale_t bert_scale, bool repeating_bias) {
const int no_bias = D == NULL;
if (act != LAYERNORM && act != SOFTMAX && !transA && !transB && DIM_I % 4 == 0 && DIM_J % 4 == 0) {
for (size_t i = 0; i < DIM_I; i += 4) {
for (size_t j = 0; j < DIM_J; j += 4) {
acc_t result[4][4]; // = {{0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}};
for (size_t ii = 0; ii < 4; ii++)
for (size_t jj = 0; jj < 4; jj++) {
const size_t bias_row = repeating_bias ? 0 : i + ii;
result[ii][jj] = no_bias ? 0 :
GEMMINI_ACC_SCALE(*(D + bias_row*stride_D + j + jj), D_scale_factor);
}
for (size_t k = 0; k < DIM_K; k++) {
result[0][0] +=
GEMMINI_SCALE(*(A + i*stride_A + k), A_scale_factor) *
GEMMINI_SCALE(*(B + k*stride_B + j), B_scale_factor);
result[0][1] +=
GEMMINI_SCALE(*(A + i*stride_A + k), A_scale_factor) *
GEMMINI_SCALE(*(B + k*stride_B + j+1), B_scale_factor);
result[0][2] +=
GEMMINI_SCALE(*(A + i*stride_A + k), A_scale_factor) *
GEMMINI_SCALE(*(B + k*stride_B + j+2), B_scale_factor);
result[0][3] +=
GEMMINI_SCALE(*(A + i*stride_A + k), A_scale_factor) *
GEMMINI_SCALE(*(B + k*stride_B + j+3), B_scale_factor);
result[1][0] +=
GEMMINI_SCALE(*(A + (i+1)*stride_A + k), A_scale_factor) *
GEMMINI_SCALE(*(B + k*stride_B + j), B_scale_factor);
result[1][1] +=
GEMMINI_SCALE(*(A + (i+1)*stride_A + k), A_scale_factor) *
GEMMINI_SCALE(*(B + k*stride_B + j+1), B_scale_factor);
result[1][2] +=
GEMMINI_SCALE(*(A + (i+1)*stride_A + k), A_scale_factor) *
GEMMINI_SCALE(*(B + k*stride_B + j+2), B_scale_factor);
result[1][3] +=
GEMMINI_SCALE(*(A + (i+1)*stride_A + k), A_scale_factor) *
GEMMINI_SCALE(*(B + k*stride_B + j+3), B_scale_factor);
result[2][0] +=
GEMMINI_SCALE(*(A + (i+2)*stride_A + k), A_scale_factor) *
GEMMINI_SCALE(*(B + k*stride_B + j), B_scale_factor);
result[2][1] +=
GEMMINI_SCALE(*(A + (i+2)*stride_A + k), A_scale_factor) *
GEMMINI_SCALE(*(B + k*stride_B + j+1), B_scale_factor);
result[2][2] +=
GEMMINI_SCALE(*(A + (i+2)*stride_A + k), A_scale_factor) *
GEMMINI_SCALE(*(B + k*stride_B + j+2), B_scale_factor);
result[2][3] +=
GEMMINI_SCALE(*(A + (i+2)*stride_A + k), A_scale_factor) *
GEMMINI_SCALE(*(B + k*stride_B + j+3), B_scale_factor);
result[3][0] +=
GEMMINI_SCALE(*(A + (i+3)*stride_A + k), A_scale_factor) *
GEMMINI_SCALE(*(B + k*stride_B + j), B_scale_factor);
result[3][1] +=
GEMMINI_SCALE(*(A + (i+3)*stride_A + k), A_scale_factor) *
GEMMINI_SCALE(*(B + k*stride_B + j+1), B_scale_factor);
result[3][2] +=
GEMMINI_SCALE(*(A + (i+3)*stride_A + k), A_scale_factor) *
GEMMINI_SCALE(*(B + k*stride_B + j+2), B_scale_factor);
result[3][3] +=
GEMMINI_SCALE(*(A + (i+3)*stride_A + k), A_scale_factor) *
GEMMINI_SCALE(*(B + k*stride_B + j+3), B_scale_factor);
}
*(C + i*stride_C + j) =
scale_and_sat(result[0][0], act, scale, bert_scale);
*(C + i*stride_C + j+1) =
scale_and_sat(result[0][1], act, scale, bert_scale);
*(C + i*stride_C + j+2) =
scale_and_sat(result[0][2], act, scale, bert_scale);
*(C + i*stride_C + j+3) =
scale_and_sat(result[0][3], act, scale, bert_scale);
*(C + (i+1)*stride_C + j) =
scale_and_sat(result[1][0], act, scale, bert_scale);
*(C + (i+1)*stride_C + j+1) =
scale_and_sat(result[1][1], act, scale, bert_scale);
*(C + (i+1)*stride_C + j+2) =
scale_and_sat(result[1][2], act, scale, bert_scale);
*(C + (i+1)*stride_C + j+3) =
scale_and_sat(result[1][3], act, scale, bert_scale);
*(C + (i+2)*stride_C + j) =
scale_and_sat(result[2][0], act, scale, bert_scale);
*(C + (i+2)*stride_C + j+1) =
scale_and_sat(result[2][1], act, scale, bert_scale);
*(C + (i+2)*stride_C + j+2) =
scale_and_sat(result[2][2], act, scale, bert_scale);
*(C + (i+2)*stride_C + j+3) =
scale_and_sat(result[2][3], act, scale, bert_scale);
*(C + (i+3)*stride_C + j) =
scale_and_sat(result[3][0], act, scale, bert_scale);
*(C + (i+3)*stride_C + j+1) =
scale_and_sat(result[3][1], act, scale, bert_scale);
*(C + (i+3)*stride_C + j+2) =
scale_and_sat(result[3][2], act, scale, bert_scale);
*(C + (i+3)*stride_C + j+3) =
scale_and_sat(result[3][3], act, scale, bert_scale);
}
}
} else {
size_t A_dim_strides[2] = {!transA ? stride_A : 1, !transA ? 1 : stride_A}; // i, j stride
size_t B_dim_strides[2] = {!transB ? 1 : stride_B, !transB ? stride_B : 1}; // j, k stride
// We also create a buffer that we can use for layernorms and softmaxes
static acc_t c_buffer[1024];
const size_t c_buffer_sz = sizeof(c_buffer)/sizeof(c_buffer[0]);
if ((act == LAYERNORM || act == SOFTMAX) && DIM_J > c_buffer_sz) {
printf("Matmul is too large to normalize\n");
exit(1);
}
for (size_t i = 0; i < DIM_I; i++) {
for (size_t j = 0; j < DIM_J; j++) {
elem_t* c = C + (i * stride_C) + j;
const size_t bias_row = repeating_bias ? 0 : i;
acc_t sum = no_bias ? 0 : GEMMINI_ACC_SCALE(*(D + bias_row * stride_D + j), D_scale_factor);
for (size_t k = 0; k < DIM_K; k++) {
const elem_t* a = A + i * A_dim_strides[0] + k * A_dim_strides[1];
const elem_t* b = B + j * B_dim_strides[0] + k * B_dim_strides[1];
sum += (GEMMINI_SCALE(*a, A_scale_factor) * GEMMINI_SCALE(*b, B_scale_factor));
}
if (act == LAYERNORM || act == SOFTMAX)
c_buffer[j] = sum;
else
*c = scale_and_sat(sum, act, scale, bert_scale);
}
#ifdef HAS_NORMALIZATIONS
if (act == LAYERNORM) {
acc_t sum = 0;
for (size_t j = 0; j < DIM_J; j++)
sum += c_buffer[j];
acc_t mean = sum / (acc_t)DIM_J;
acc_t total_err_sq = 0;
for (size_t j = 0; j < DIM_J; j++)
total_err_sq += (c_buffer[j] - mean)*(c_buffer[j] - mean);
acc_t variance = total_err_sq / (acc_t)DIM_J;
acc_t stddev = int_sqrt(variance);
if (variance == 0) stddev = 1;
for (size_t j = 0; j < DIM_J; j++) {
c_buffer[j] -= mean;
// c_buffer[j] /= stddev;
c_buffer[j] = ROUND_NEAR_EVEN((double)c_buffer[j] / stddev); // TODO I don't think I-BERT uses round-near-even, so we shouldn't either. We just use this rounding mode here in order to match the hardware.
elem_t* c = C + (i * stride_C) + j;
*c = scale_and_sat(c_buffer[j], act, scale, bert_scale);
}
} else if (act == SOFTMAX) {
const scale_t a = 0.3585;
const scale_t b = 1.353;
const scale_t c = 0.344;
// is SCALE supposed to be input scale?
const acc_t qln2 = (acc_t) (0.693147 / bert_scale);
const acc_t qln2_inv = 65536 / qln2;
const acc_t qb = b / bert_scale;
const acc_t qc = c / (a*bert_scale*bert_scale);
// pass 1: get max_q
acc_t max_q = -2147483648;
for (size_t j = 0; j < DIM_J; j++) {
if (c_buffer[j] > max_q) max_q = c_buffer[j];
}
// pass 2: calculate iexp(q_tilde) and sum(q_tilde)
acc_t sum_exp = 0;
for (size_t j = 0; j < DIM_J; j++) {
acc_t q = c_buffer[j] - max_q;
acc_t z = (acc_t) (-q * qln2_inv) >> 16;
acc_t qp = q + z * qln2;
acc_t q_exp = (qp + qb)*(qp + qb) + qc;
c_buffer[j] = q_exp >> z;
sum_exp += c_buffer[j];
}
// pass 3: divide by sum
scale_t factor = (127.f) / (float) sum_exp; // what corresponds to 1 in output?
for (size_t j = 0; j < DIM_J; j++) {
elem_t* c = C + (i * stride_C) + j;
*c = scale_and_sat(c_buffer[j], act, factor, bert_scale);
}
}
#endif
}
}
}
#undef GEMMINI_SCALE
// General matmul which can be run with different dataflows, or on the CPU
enum tiled_matmul_type_t {OS, WS, CPU}; // TODO rename this so it's name also applies to convs
// This function runs a tiled matrix mulctiplication, with hardcoded tiling
// factors
static void tiled_matmul(size_t dim_I, size_t dim_J, size_t dim_K,
const elem_t* A, const elem_t* B,
const void * D, void* C,
size_t stride_A, size_t stride_B, size_t stride_D, size_t stride_C,
scale_t A_scale_factor, scale_t B_scale_factor, scale_acc_t D_scale_factor,
int act, acc_scale_t scale, acc_scale_t bert_scale,
bool repeating_bias,
size_t tile_I, size_t tile_J, size_t tile_K,
bool transpose_A, bool transpose_B,
bool full_C, bool low_D,
uint8_t weightA,
enum tiled_matmul_type_t tiled_matmul_type) {
#ifdef GEMMINI_ASSERTIONS
// Make sure that the tiling factors make sense
if (tile_I <= 0) {
printf("tile_I is non-positive\n");
exit(1);
} else if (tile_J <= 0) {
printf("tile_J is non-positive\n");
exit(1);
} else if (tile_K <= 0) {
printf("tile_K is non-positive\n");
exit(1);
}
const size_t dim_I_padded = (dim_I / DIM + (dim_I % DIM != 0)) * DIM;
const size_t dim_J_padded = (dim_J / DIM + (dim_J % DIM != 0)) * DIM;
const size_t dim_K_padded = (dim_K / DIM + (dim_K % DIM != 0)) * DIM;
if (tile_I * DIM > dim_I_padded) {
printf("tile_I is too large (tile_I * DIM > dim_I_padded)\n");
exit(1);
} else if (tile_J * DIM > dim_J_padded) {
printf("tile_J is too large (tile_J * DIM > dim_J_padded)\n");
exit(1);
} else if (tile_K * DIM > dim_K_padded) {
printf("tile_K is too large (tile_K * DIM > dim_K_padded)\n");
exit(1);
}
const bool double_buffered = tiled_matmul_type == WS;
const size_t total_spad_size = double_buffered ? BANK_NUM * BANK_ROWS / 2 :
BANK_NUM * BANK_ROWS;
const size_t total_acc_size = double_buffered ? ACC_ROWS / 2 : ACC_ROWS;
const size_t total_spad_rows =
(tile_I * tile_K * DIM) + // Rows to store A
(tile_K * tile_J * DIM); // Rows to store B
if (total_spad_rows > total_spad_size) {
printf("Not enough space in scratchpad to store A and B matrices\n");
exit(1);
}
const size_t total_acc_rows =
tile_I * tile_J * DIM; // Rows to store C
if (total_acc_rows > total_acc_size) {
printf("Not enough space in accumulator to store C\n");
exit(1);
}
if (tile_I > 65535 || tile_J > 65535 || tile_K > 65535) {
printf("I, J, and K tiling factors must be less than 65535, to fit within the bounds of the LOOP_WS function");
exit(1);
}
char matmul_type_str[][4] = {"OS", "WS", "CPU"};
// Check if transpose options are correct
if (((tiled_matmul_type == OS) && (transpose_A || transpose_B)) ||
(tiled_matmul_type == WS && transpose_A && transpose_B)) {
printf("Not implemented: %s matmul, a_transpose=%d, b_transpose=%d\n", matmul_type_str[tiled_matmul_type], transpose_A, transpose_B);
exit(1);
}
// Check if full_C options are correct
if ((tiled_matmul_type == CPU && (full_C || low_D)) ||
(tiled_matmul_type == OS && low_D)) {
printf("Not implemented: %s matmul, full_C=%d, low_D=%d\n", matmul_type_str[tiled_matmul_type], full_C, low_D);
}
if (act == LAYERNORM || act == SOFTMAX) {
if (tiled_matmul_type == OS) {
printf("Not implemented: %s matmul, act=%d\n", matmul_type_str[tiled_matmul_type], act);
}
if (tile_J * DIM < dim_J) {
printf("When doing layernorm or softmax, the full J dimension of the matrix must fit in the accumulator\n");
}
}
#endif
// Run a tiled matrix multiplication on either Gemmini or the CPU
if (tiled_matmul_type == OS || tiled_matmul_type == WS) {
tiled_matmul_outer(dim_I, dim_J, dim_K,
A, B, D, C,
stride_A, stride_B, stride_D, stride_C,
A_scale_factor, B_scale_factor, D_scale_factor,
tile_I, tile_J, tile_K,
act, scale, bert_scale, repeating_bias,
transpose_A, transpose_B,
full_C, low_D,
weightA,
(int)tiled_matmul_type);
} else /*if (tiled_matmul_type == CPU)*/ {
matmul_cpu(transpose_A, transpose_B, dim_I, dim_J, dim_K,
A, B, (const acc_t*) D, (elem_t*)C,
stride_A, stride_B, stride_D, stride_C,
A_scale_factor, B_scale_factor, D_scale_factor,
act, scale, bert_scale, repeating_bias);
}
}
static size_t tiled_matmul_total_spad_rows(size_t I, size_t J, size_t K) {
return (I * K + K * J) * DIM;
}
static size_t tiled_matmul_total_acc_rows(size_t I, size_t J) {
return (I * J) * DIM;
}
// This function runs a tiled matrix multiplication, with automatically
// calculated tiling factors
static void tiled_matmul_auto(size_t dim_I, size_t dim_J, size_t dim_K,
const elem_t* A, const elem_t* B,
const void * D, void * C,
size_t stride_A, size_t stride_B, size_t stride_D, size_t stride_C,
scale_t A_scale_factor, scale_t B_scale_factor, scale_acc_t D_scale_factor,
int act, acc_scale_t scale, acc_scale_t bert_scale,
bool repeating_bias,
bool transpose_A, bool transpose_B,
bool full_C, bool low_D,
uint8_t weightA,
enum tiled_matmul_type_t tiled_matmul_type) {
#define partition_rows (BANK_NUM * BANK_ROWS / 2)
#define mats_in_partition (partition_rows / DIM)
#define mats_in_acc (ACC_ROWS / DIM)
#define max_tile_i_j ((size_t)sqrt(mats_in_acc))
#define max_tile_k (mats_in_partition / max_tile_i_j)
// "db_" means "double-buffered"
#define db_partition_rows ((BANK_NUM * BANK_ROWS / 2) / 2)
#define db_mats_in_partition (db_partition_rows / DIM)
#define db_mats_in_acc ((ACC_ROWS / 2) / DIM)
#define db_max_tile_i_j ((size_t)sqrt(db_mats_in_acc))
#define db_max_tile_k (db_mats_in_partition / db_max_tile_i_j)
const size_t dim_I_padded = (dim_I / DIM + (dim_I % DIM != 0)) * DIM;
const size_t dim_J_padded = (dim_J / DIM + (dim_J % DIM != 0)) * DIM;
const size_t dim_K_padded = (dim_K / DIM + (dim_K % DIM != 0)) * DIM;
const bool double_buffered = tiled_matmul_type == WS;
const size_t max_spad_rows = double_buffered ? BANK_NUM * BANK_ROWS / 2 :
BANK_NUM * BANK_ROWS;
const size_t max_acc_rows = double_buffered ? ACC_ROWS / 2 : ACC_ROWS;
size_t tile_I, tile_J, tile_K;
if (act == LAYERNORM || act == SOFTMAX) {
tile_I = 1;
tile_J = dim_J_padded/DIM;
tile_K = 1;
} else if (double_buffered) {
tile_I = dim_I_padded/DIM < db_max_tile_i_j ? dim_I_padded/DIM : db_max_tile_i_j;
tile_J = dim_J_padded/DIM < db_max_tile_i_j ? dim_J_padded/DIM : db_max_tile_i_j;
tile_K = dim_K_padded/DIM < db_max_tile_k ? dim_K_padded/DIM : db_max_tile_k;
} else {
tile_I = dim_I_padded/DIM < max_tile_i_j ? dim_I_padded/DIM : max_tile_i_j;
tile_J = dim_J_padded/DIM < max_tile_i_j ? dim_J_padded/DIM : max_tile_i_j;
tile_K = dim_K_padded/DIM < max_tile_k ? dim_K_padded/DIM : max_tile_k;
}
// Fill scratchpad as much as possible
while (true) {
bool increased = false;
if (tiled_matmul_total_spad_rows(tile_I, tile_J+1, tile_K) <= max_spad_rows &&
tiled_matmul_total_acc_rows(tile_I, tile_J+1) <= max_acc_rows &&
(tile_J+1) * DIM <= dim_J_padded) {
tile_J++;
increased = true;
}
if (tiled_matmul_total_spad_rows(tile_I+1, tile_J, tile_K) <= max_spad_rows &&
tiled_matmul_total_acc_rows(tile_I+1, tile_J) <= max_acc_rows &&
(tile_I+1) * DIM <= dim_I_padded) {
tile_I++;
increased = true;
}
if (tiled_matmul_total_spad_rows(tile_I, tile_J, tile_K+1) <= max_spad_rows &&
(tile_K+1) * DIM <= dim_K_padded) {
tile_K++;
increased = true;
}
if (!increased)
break;
}
#ifdef PRINT_TILE
#if PRINT_TILE
const int spad_rows = tiled_matmul_total_spad_rows(tile_I, tile_J, tile_K);
const int acc_rows = tiled_matmul_total_acc_rows(tile_I, tile_J);
printf("tile_I: %d\n", tile_I);
printf("tile_J: %d\n", tile_J);
printf("tile_K: %d\n\n", tile_K);
printf("spad_rows: %d\n", spad_rows);
printf("acc_rows: %d\n\n", acc_rows);
printf("spad_row utilization: %d%%\n", (spad_rows * 100) / max_spad_rows);
printf("acc_row utilization: %d%%\n\n", (acc_rows * 100) / max_acc_rows);
exit(EXIT_SUCCESS);
#endif
#endif
tiled_matmul(dim_I, dim_J, dim_K,
A, B, D, C,
stride_A, stride_B, stride_D, stride_C,
A_scale_factor, B_scale_factor, D_scale_factor,
act, scale, bert_scale, repeating_bias,
tile_I, tile_J, tile_K,
transpose_A, transpose_B,
full_C, low_D,
weightA,
tiled_matmul_type);
#undef partition_rows
#undef mats_in_partition
#undef mats_in_acc
#undef max_tile_i_j
#undef max_tile_k
}
static void sp_tiled_conv(
int batch_size, int in_row_dim, int in_col_dim, int in_channels,
int out_channels, int out_row_dim, int out_col_dim,
int pool_out_row_dim, int pool_out_col_dim,
int stride, int padding, int kernel_dim, int kernel_dilation,
int in_stride, int weight_stride, int out_stride,
int pool_size, int pool_stride, int pool_padding,
int batches,
int porows, int pocols, int pochs,
int krows, int kcols, int kchs,
int lpad, int rpad, int upad, int dpad,
int plpad, int prpad, int pupad, int pdpad,
const elem_t * input,
const elem_t * weights,
elem_t * output,
const acc_t * bias,
int act, acc_scale_t scale,
bool wrot180, bool trans_output_1203, bool trans_input_3120,
bool trans_weight_1203, bool trans_weight_0132,
bool no_bias, bool no_pool, bool downsample, bool input_dilated,
bool dw, int a_spad_id, int b_spad_id) {
// When dw convs are true, we assume that kchs and ochs are 1
if (dw) { kchs = 1; pochs = 1; }
const int orows = porows * pool_stride + pool_size - 1 - pupad - pdpad;
const int ocols = pocols * pool_stride + pool_size - 1 - plpad - prpad;
const int ochs = pochs;
// Calculate image dimensions
// Note: "irows" and "icols" includes padding
const int dilated_krows = krows + (kernel_dilation - 1)*(krows - 1);
const int dilated_kcols = kcols + (kernel_dilation - 1)*(kcols - 1);
int irows = orows * stride + dilated_krows - 1;
int icols = ocols * stride + dilated_kcols - 1;
int irows_unpadded = irows - upad - dpad;
int icols_unpadded = icols - lpad - rpad;
const int ichs = kchs;
#define UNDILATED(x) ((input_dilated) ? (((x)+1)/2) : (x))
if (input_dilated) {
irows_unpadded = (irows_unpadded+1)/2;
icols_unpadded = (icols_unpadded+1)/2;
irows = irows_unpadded + UNDILATED(upad) + UNDILATED(dpad);
icols = icols_unpadded + UNDILATED(lpad) + UNDILATED(rpad);
}
#ifdef HAS_FIRST_LAYER_OPTIMIZATIONS
const bool transposed = trans_output_1203 || trans_input_3120 ||
trans_weight_1203 || trans_weight_0132;
int max_pixels_per_row = transposed || wrot180 || downsample ||
input_dilated || kernel_dilation > 1 ||
ichs > DIM ? 1 : DIM/ichs;
if (max_pixels_per_row > kcols) max_pixels_per_row = kcols;
#else
const int max_pixels_per_row = 1;
#endif
// Calculate spad address offsets
const int out_channels_per_bank = ochs / DIM + (ochs % DIM != 0);
const int in_channels_per_bank = kchs / DIM + (kchs % DIM != 0);
const int B_rows = trans_weight_0132 ?
in_channels_per_bank * kcols * krows * ochs :
out_channels_per_bank * kcols * krows * kchs;
static uint32_t D_sp_addr_row = 0;
static uint32_t C_sp_addr_row = 0;
const uint32_t A_sp_addr_start = 0;
const uint32_t B_sp_addr_start = BANK_NUM * BANK_ROWS - B_rows;
const uint32_t D_sp_addr_start = (1 << (ADDR_LEN - 1)) + D_sp_addr_row;
const uint32_t C_sp_addr_start = (3 << (ADDR_LEN - 2)) + C_sp_addr_row;
if (bias != 0) {
D_sp_addr_row = (D_sp_addr_row + ACC_ROWS / 2) % ACC_ROWS;
}
if (output != 0) {
C_sp_addr_row = (C_sp_addr_row + ACC_ROWS / 2) % ACC_ROWS;
}
gemmini_loop_conv_ws(batch_size, in_row_dim, in_col_dim, in_channels, out_channels, out_row_dim, out_col_dim, pool_out_row_dim, pool_out_col_dim, stride, padding, kernel_dim, kernel_dilation, pool_size, pool_stride, pool_padding, batches, porows, pocols, pochs, krows, kcols, kchs, lpad, rpad, upad, dpad, plpad, prpad, pupad, pdpad, orows, ocols, weights, output, bias, input, no_bias, no_pool, downsample, wrot180, input_dilated, act, trans_output_1203, trans_weight_1203, trans_weight_0132, trans_input_3120, max_pixels_per_row, in_stride, weight_stride, out_stride, dw, a_spad_id, b_spad_id);
/*
if (!no_pool) {
printf("Pooling with rectangular convolutions is currently not supported.\n");
exit(1);
}
// Only rectangular convolutions will use the following C code
// mvin bias
if (bias != NULL) {
// TODO we probably don't need quite this many nested loops for this part
const int max_ochs_per_mvin = ochs < MAX_BLOCK_LEN_ACC * DIM ? ochs :
MAX_BLOCK_LEN_ACC * DIM;
gemmini_extended4_config_ld(0, MVIN_SCALE_IDENTITY, false, batches * orows * ocols, 2);
for (int b = 0; b < batches; b++)
for (int orow = 0; orow < orows; orow++)
for (int ocol = 0; ocol < ocols; ocol += DIM) {
const int I = ocols - ocol > DIM ? DIM : ocols - ocol;
for (int och = 0; och < ochs; och += max_ochs_per_mvin) {
const int J = ochs - och > max_ochs_per_mvin ? max_ochs_per_mvin : ochs - och;
const uint32_t D_sp_addr = D_sp_addr_start + (och / DIM) * batches * orows * ocols + b * orows * ocols + orow * ocols + ocol;
const acc_t * bias_dram_addr = no_bias ? NULL : bias + och;
gemmini_extended_mvin3(bias_dram_addr,
D_sp_addr,
J, I);
}
}
}
// mvin input
if (input != NULL){
int max_chs_per_mvin = ichs < MAX_BLOCK_LEN * DIM ? ichs :
MAX_BLOCK_LEN * DIM;
if (trans_input_3120) {
max_chs_per_mvin = batches < MAX_BLOCK_LEN * DIM ? batches :
MAX_BLOCK_LEN * DIM;
}
const int dram_stride = trans_input_3120 ?
batch_size * sizeof(elem_t) :
in_channels * sizeof(elem_t);
const int spad_stride = trans_input_3120 ?
ichs * (irows >> downsample) * (icols >> downsample) :
batches * (irows >> downsample) * (icols >> downsample);
gemmini_extended5_config_ld(dram_stride << downsample, MVIN_SCALE_IDENTITY, false, spad_stride, max_pixels_per_row, 0);
const int b_it = trans_input_3120 ? max_chs_per_mvin : 1;
const int ich_it = trans_input_3120 ? 1 : max_chs_per_mvin;
for (int b = 0; b < batches; b += b_it)
for (int irow = -UNDILATED(upad); irow < irows_unpadded + UNDILATED(dpad); irow += 1 + downsample) {
const int irow_padded = irow + UNDILATED(upad);
for (int icol = -UNDILATED(lpad); icol < icols_unpadded + UNDILATED(rpad);) {
// TODO There might be some unnecessary mvins here at the edge of the image
int I = icols_unpadded - icol > (DIM << downsample) ?
(DIM << downsample) : icols_unpadded - icol;
if (icol < 0) {
I = -icol > DIM ? DIM : -icol;
} else if (icol >= icols_unpadded) {
I = icols_unpadded + UNDILATED(rpad) - icol > DIM ? DIM : icols_unpadded + UNDILATED(rpad) - icol;
}
const int icol_padded = icol + UNDILATED(lpad);
for (int ich = 0; ich < ichs; ich += ich_it) {
int K = ichs - ich > max_chs_per_mvin ?
max_chs_per_mvin : ichs - ich;
if (trans_input_3120) {
K = batches - b > max_chs_per_mvin ?
max_chs_per_mvin : batches - b;
}
#define DS(x) ((x) >> (downsample))
uint32_t A_sp_addr = A_sp_addr_start + (ich / DIM) * batches * DS(irows) * DS(icols) + b * DS(irows) * DS(icols) + DS(irow_padded) * DS(icols) + DS(icol_padded);
if (trans_input_3120) {
A_sp_addr = A_sp_addr_start + (b / DIM) * ichs * DS(irows) * DS(icols) + ich * DS(irows) * DS(icols) + DS(irow_padded) * DS(icols) + DS(icol_padded);
}
const bool is_zeros = irow < 0 || irow >= irows_unpadded || icol < 0 || icol >= icols_unpadded;
const elem_t * in = input + (b*in_row_dim*in_col_dim + irow*in_col_dim + icol) * in_stride + ich;
if (is_zeros) {
in = NULL;
} else if (trans_input_3120) {
in = input + (ich*in_row_dim*in_col_dim + irow*in_col_dim + icol) * batch_size + b;
}
gemmini_extended_mvin(in,
A_sp_addr,
K, I >> downsample);
}
icol += I;
}
}
}
// mvin weights
if (weights != NULL) {
int max_chs_per_mvin = ochs < MAX_BLOCK_LEN * DIM ? ochs :
MAX_BLOCK_LEN * DIM;
if (trans_weight_0132) {
max_chs_per_mvin = kchs < MAX_BLOCK_LEN * DIM ? kchs :
MAX_BLOCK_LEN * DIM;
}
size_t dram_stride = weight_stride * sizeof(elem_t);
if (dw) {
dram_stride = sizeof(elem_t);
} else if (trans_weight_1203) {
dram_stride = kernel_dim * kernel_dim * out_channels * sizeof(elem_t);
} else if (trans_weight_0132) {
dram_stride = in_channels * sizeof(elem_t);
}
const size_t spad_block_stride = trans_weight_0132 ?
krows * kcols * ochs : krows * kcols * kchs;
gemmini_extended4_config_ld(dram_stride, MVIN_SCALE_IDENTITY, false, spad_block_stride, 1);
const size_t och_it = trans_weight_0132 ? DIM : max_chs_per_mvin;
const size_t kch_it = trans_weight_0132 ? max_chs_per_mvin : DIM;
for (int och = 0; och < ochs; och += och_it) {
for (int krow = 0; krow < krows; krow++)
for (int kcol = 0; kcol < kcols; kcol++)
for (int kch = 0; kch < kchs; kch += kch_it) {
int K = kchs - kch > DIM ? DIM : kchs - kch;
int J = ochs - och > max_chs_per_mvin ? max_chs_per_mvin : ochs - och;
if (trans_weight_0132) {
K = ochs - och > DIM ? DIM : ochs - och;
J = kchs - kch > max_chs_per_mvin ? max_chs_per_mvin : kchs - kch;
}
uint32_t B_sp_addr = B_sp_addr_start + (och / DIM) * krows * kcols * kchs + krow * kcols * kchs + kcol * kchs + kch;
if (trans_weight_0132) {
B_sp_addr = B_sp_addr_start + (kch / DIM) * krows * kcols * ochs + krow * kcols * ochs + kcol * ochs + och;
}
const elem_t * w = weights + (krow*kernel_dim*in_channels + kcol*in_channels + kch) * weight_stride + och;
if (dw) {
w = weights + krow * kernel_dim + kcol;
} else if (trans_weight_1203) {
w = weights + (kch * kernel_dim * kernel_dim + krow * kernel_dim + kcol) * out_channels + och;
} else if (trans_weight_0132) {
w = weights + (krow * kernel_dim * out_channels + kcol * out_channels + och) * in_channels + kch;
}
gemmini_extended_mvin2(w, B_sp_addr, J, K);
}
}
}
// Compute
{
const int b_it = trans_input_3120 ? DIM : 1;
const int ocol_it = trans_input_3120 ? 1 : (DIM << input_dilated);
if (trans_input_3120) {
gemmini_extended3_config_ex(0, 0, 0, 0, orows * ocols, irows * icols, 0, 0, true);
}
for (int och = 0; och < ochs; och += DIM) {
for (int krow = 0; krow < krows; krow++) {
for (int kcol = 0; kcol < kcols; kcol += max_pixels_per_row) {
for (int kch = 0; kch < kchs; kch += DIM) {
bool new_weights = true;
for (int b = 0; b < batches; b += b_it) {
for (int orow = 0; orow < orows; orow++) {
// Skip some kernel rows due to input-dilation
if (input_dilated && ((krow * kernel_dilation + orow * stride - upad) % 2 != 0)) {
continue;
}
for (int ocol = 0; ocol < ocols;) {
// Skip some cols dimensions due to input-dilation
if (input_dilated && ((kcol + ocol * stride - lpad) % 2 != 0)) {
ocol++;
continue;
}
int irow = orow * stride + krow * kernel_dilation;
int icol = ocol * stride + kcol * kernel_dilation;
if (input_dilated) {
irow = (irow + 1) / 2;
icol = (icol + 1) / 2;
}
const int pixels = kcols - kcol > max_pixels_per_row ?
max_pixels_per_row : kcols - kcol;
const uint32_t C_sp_addr = C_sp_addr_start + (och / DIM) * batches * orows * ocols + b * orows * ocols + orow * ocols + ocol;
// Over here, construct a new matrix
//
// Let us assume that we only ever operate on
// one pixel in one row.
// Thus, krows == kcols == 1
//
// Then, for every set of I, J, and K values
// - I = ocols
// - J = ochs
// - K = kchs
int I = UNDILATED(ocols - ocol > (DIM << input_dilated) ? (DIM << input_dilated) : ocols - ocol);
const int J = ochs - och > DIM ? DIM : ochs - och;
const int K = pixels * (kchs - kch > DIM ? DIM : kchs - kch);
if (trans_input_3120) {
I = batches - b > DIM ? DIM : batches - b;
}
uint32_t A_sp_addr = A_sp_addr_start + (kch / DIM) * batches * DS(irows) * DS(icols) + b * DS(irows) * DS(icols) + DS(irow) * DS(icols) + DS(icol);
if (trans_input_3120) {
A_sp_addr = A_sp_addr_start + (b / DIM) * kchs * DS(irows) * DS(icols) + kch * DS(irows) * DS(icols) + DS(irow) * DS(icols) + DS(icol);
}
const int krow_ = wrot180 ? krows - krow - 1 : krow;
const int kcol_ = wrot180 ? kcols - kcol - 1 : kcol;
uint32_t B_sp_addr = B_sp_addr_start + (och / DIM) * krows * kcols * kchs + krow_ * kcols * kchs + kcol_ * kchs + kch;
if (trans_weight_0132) {
B_sp_addr = B_sp_addr_start + (kch / DIM) * krows * kcols * ochs + krow_ * kcols * ochs + kcol_ * ochs + och;
}
const uint32_t pre_sp_addr = new_weights ?
B_sp_addr : GARBAGE_ADDR;
// perform matmul
gemmini_extended_preload(pre_sp_addr, C_sp_addr, J, K, J, I);
if (new_weights) {
gemmini_extended_compute_preloaded(A_sp_addr, GARBAGE_ADDR, K, I, J, I);
} else {
gemmini_extended_compute_accumulated(A_sp_addr, GARBAGE_ADDR, K, I, J, I);
}
ocol += ocol_it;
new_weights = false;
}
}
}
}
}
}
}
}
#undef DS
#undef UNDILATED
// mvout output
if (output != NULL) {
if (no_pool) {
for (int b = 0; b < batches; b++)
for (int orow = 0; orow < orows; orow++)
for (int ocol = 0; ocol < ocols; ocol += DIM) {
const int I = ocols - ocol > DIM ? DIM : ocols - ocol;
for (int och = 0; och < ochs; och += DIM) {
const int J = ochs - och > DIM ? DIM : ochs - och;
const uint32_t C_sp_addr = C_sp_addr_start + (och / DIM) * batches * orows * ocols + b * orows * ocols + orow * ocols + ocol;
elem_t * out = output + (b*out_row_dim*out_col_dim + orow*out_col_dim + ocol) * out_stride + och;
if (trans_output_1203) {
out = output + (orow*out_col_dim*batch_size + ocol*batch_size + b) * out_channels + och;
}
gemmini_extended_mvout(out,
C_sp_addr,
J, I);
}
}
} else {
printf("Pooling with rectangular convolutions is currently not supported.\n");
exit(1);
*/
/*
gemmini_extended2_config_st(out_channels * sizeof(elem_t), act, scale, pool_stride, pool_size, pool_out_row_dim, porows, pocols, orows, ocols, pupad, plpad);
for (int b = 0; b < batches; b++) {
for (int poch = 0; poch < pochs; poch += DIM) {
const int channels = poch + DIM >= pochs ? pochs - poch : DIM;
elem_t * pout = output + (b * pool_out_row_dim * pool_out_col_dim)*out_channels + poch;
const uint32_t C_sp_addr = C_sp_addr_start + (poch / DIM) * batches * orows * ocols + b * orows * ocols;
gemmini_extended_mvout(pout,
C_sp_addr,
channels, 0);
}
}
gemmini_extended_config_st(out_channels * sizeof(elem_t), act, scale);
<<<<<<< HEAD
*/
// }
// }
// }
//}
}
static int tiled_conv_total_spad_rows_dw(bool acc, bool weight,
int stride,
int batches,
int porows, int pocols, int ochs,
int krows, int kcols, int kchs,
int pool_size, int pool_stride) {
const int orows = porows * pool_stride + pool_size - 1;
const int ocols = pocols * pool_stride + pool_size - 1;
const int irows = orows * stride + krows - 1; // - 2 * padding;
const int icols = ocols * stride + kcols - 1; // - 2 * padding;
const int ichs = kchs;
const int in_channels_per_bank = ichs / DIM + (ichs % DIM != 0);
const int out_channels_per_bank = ochs / DIM + (ochs % DIM != 0);
const int A_rows = in_channels_per_bank * batches * irows * icols;
const int B_rows = out_channels_per_bank * kcols * krows * kchs;
const int C_rows = out_channels_per_bank * batches * orows * ocols;
if (acc)
return C_rows;
else if(weight)
return B_rows;
else
return A_rows;
}
static int tiled_conv_total_spad_rows(bool acc,
int stride,
int input_dilation,
int kernel_dilation,
bool downsample,
bool trans_weight_0132,
bool trans_input_3120,
int batches,
int porows, int pocols, int ochs,
int krows, int kcols, int kchs,
int pool_size, int pool_stride) {
const int orows = porows * pool_stride + pool_size - 1;
const int ocols = pocols * pool_stride + pool_size - 1;
const int krows_dilated = krows + (kernel_dilation - 1)*(krows - 1);
const int kcols_dilated = kcols + (kernel_dilation - 1)*(kcols - 1);
int irows = orows * stride + krows_dilated - 1; // - 2 * padding;
int icols = ocols * stride + kcols_dilated - 1; // - 2 * padding;
const int ichs = kchs;
irows = irows / input_dilation + (irows % input_dilation != 0);
icols = icols / input_dilation + (icols % input_dilation != 0);
const int in_channels_per_bank = ichs / DIM + (ichs % DIM != 0);
const int out_channels_per_bank = ochs / DIM + (ochs % DIM != 0);
const int batches_per_bank = batches / DIM + (batches % DIM != 0);
const int A_rows = trans_input_3120 ?
(batches_per_bank * ichs * (irows >> downsample) * (icols >> downsample)) :
(in_channels_per_bank * batches * (irows >> downsample) * (icols >> downsample));
const int B_rows = trans_weight_0132 ?
in_channels_per_bank * kcols * krows * ochs :
out_channels_per_bank * kcols * krows * kchs;
const int C_rows = out_channels_per_bank * batches * orows * ocols;
return acc ? C_rows : A_rows + B_rows;
}
static void conv_cpu_without_pool(
int batch_size, int in_row_dim, int in_col_dim, int in_channels,
int out_channels, int out_row_dim, int out_col_dim,
int stride, int input_dilation, int kernel_dilation, int padding, int kernel_dim,
int in_stride, int weight_stride, int out_stride,
bool wrot180, bool trans_output_1203, bool trans_input_3120,
bool trans_weight_1203, bool trans_weight_0132,
const elem_t * input,
const elem_t * weights,
const acc_t * bias,
elem_t * output,
int act, acc_scale_t scale) {
bool no_bias = bias == NULL;
for (int b = 0; b < batch_size; b++) {
for (int orow = 0; orow < out_row_dim; orow++) {
for (int ocol = 0; ocol < out_col_dim; ocol++) {
for (int och = 0; och < out_channels; och++) {
acc_t opixel = no_bias ? 0 : bias[och];
for (int krow = 0; krow < kernel_dim; krow++) {
if ((orow * stride + krow * kernel_dilation - padding) % input_dilation != 0)
continue;
const int irow = (orow * stride + krow * kernel_dilation - padding) / input_dilation;
for (int kcol = 0; kcol < kernel_dim; kcol++) {
if ((ocol * stride + kcol * kernel_dilation - padding) % input_dilation != 0)
continue;
const int icol = (ocol * stride + kcol * kernel_dilation - padding) / input_dilation;
for (int kch = 0; kch < in_channels; kch++) {
const elem_t *in = input + (b * in_row_dim * in_col_dim + irow * in_col_dim + icol) * in_stride + kch;
if (trans_input_3120) {
// NHWC to CHWN
in = input + (kch * in_row_dim * in_col_dim + irow * in_col_dim + icol) * batch_size + b;
}
elem_t ipixel = irow < 0 || irow >= in_row_dim || icol < 0 || icol >= in_col_dim ?
0 : *in;
const int krow_ = wrot180 ? kernel_dim - krow - 1 : krow;
const int kcol_ = wrot180 ? kernel_dim - kcol - 1 : kcol;
elem_t weight = *(weights + (krow_ * kernel_dim * in_channels + kcol_ * in_channels + kch) * weight_stride + och);
if (trans_weight_1203) {
// HWIO to WIHO
weight = *(weights + (kch * kernel_dim * kernel_dim + krow_ * kernel_dim + kcol_) * out_channels + och);
} else if (trans_weight_0132) {
// HWIO to HWOI
weight = *(weights + (krow_ * kernel_dim * out_channels + kcol_ * out_channels + och) * in_channels + kch);
}
opixel += weight * ipixel;
}
}
}
elem_t *out = output + (b * out_row_dim * out_col_dim + orow * out_col_dim + ocol) * out_stride + och;
if (trans_output_1203) {
// NHWC to HWNC
out = output + (orow * out_col_dim * batch_size + ocol * batch_size + b) * out_channels + och;
}
*out = scale_and_sat(opixel, act, scale, 0);
}
}
}
}
}
static void conv_dw_cpu_without_pool(
int batch_size, int in_row_dim, int in_col_dim,
int channels, int out_row_dim, int out_col_dim,
int stride, int padding, int kernel_dim,
const elem_t * input,
const elem_t * weights,
const acc_t * bias,
elem_t * output,
int act, acc_scale_t scale) {
bool no_bias = bias == NULL;
for (int b = 0; b < batch_size; b++) {
for (int orow = 0; orow < out_row_dim; orow++) {
for (int ocol = 0; ocol < out_col_dim; ocol++) {
for (int ch = 0; ch < channels; ch++) {
acc_t opixel = no_bias ? 0 : bias[ch];
for (int krow = 0; krow < kernel_dim; krow++) {
const int irow = orow * stride + krow - padding;
for (int kcol = 0; kcol < kernel_dim; kcol++) {
const int icol = ocol * stride + kcol - padding;
const elem_t * in = input + (b * in_row_dim * in_col_dim + irow * in_col_dim + icol) * channels + ch;
const elem_t ipixel = irow < 0 || irow >= in_row_dim || icol < 0 || icol >= in_col_dim ?
0 : *in;
const elem_t weight = *(weights + (ch * kernel_dim + krow) * kernel_dim + kcol);
opixel += weight * ipixel;
}
}
elem_t *out = output + (b * out_row_dim * out_col_dim + orow * out_col_dim + ocol) * channels + ch;
*out = scale_and_sat(opixel, act, scale, 0);
}
}
}
}
}
static void conv_cpu(
int batch_size, int in_row_dim, int in_col_dim, int in_channels,
int out_channels, int out_row_dim, int out_col_dim,
int stride, int input_dilation, int kernel_dilation, int padding, int kernel_dim,
int in_stride, int weight_stride, int out_stride,
bool wrot180, bool trans_output_1203, bool trans_input_3120,
bool trans_weight_1203, bool trans_weight_0132,
const elem_t * input,
const elem_t * weights,
const acc_t * bias,
elem_t * output,
int act, acc_scale_t scale,
int pool_size, int pool_stride, int pool_padding) {
const bool no_pool = pool_stride == 0;
if (no_pool) {
conv_cpu_without_pool(
batch_size, in_row_dim, in_col_dim, in_channels,
out_channels, out_row_dim, out_col_dim,
stride, input_dilation, kernel_dilation, padding, kernel_dim,
in_stride, weight_stride, out_stride,
wrot180, trans_output_1203, trans_input_3120,
trans_weight_1203, trans_weight_0132,
input, weights, bias, output,
act, scale);
return;
}
const bool no_bias = bias == NULL;
const int pool_out_row_dim = (out_row_dim + 2 * pool_padding - pool_size) / pool_stride + 1;
const int pool_out_col_dim = (out_col_dim + 2 * pool_padding - pool_size) / pool_stride + 1;
for (int b = 0; b < batch_size; b++) {
for (int porow = 0; porow < pool_out_row_dim; porow++) {
for (int pocol = 0; pocol < pool_out_col_dim; pocol++) {
for (int poch = 0; poch < out_channels; poch++) {
elem_t running_max = 0;
bool running_max_initialized = false;
for (int pwrow = 0; pwrow < pool_size; pwrow++) {
const int orow = porow * pool_stride + pwrow - pool_padding;
for (int pwcol = 0; pwcol < pool_size; pwcol++) {
const int ocol = pocol * pool_stride + pwcol - pool_padding;
if (orow < 0 || orow >= out_row_dim || ocol < 0 || ocol >= out_col_dim) {
if (!running_max_initialized || running_max < 0) {
running_max = 0;
running_max_initialized = true;
}
} else {
acc_t opixel = no_bias ? 0 : bias[poch];
for (int krow = 0; krow < kernel_dim; krow++) {
if ((orow * stride + krow * kernel_dilation - padding) % input_dilation != 0)
continue;
const int irow = (orow * stride + krow * kernel_dilation - padding) / input_dilation;
for (int kcol = 0; kcol < kernel_dim; kcol++) {
if ((ocol * stride + kcol * kernel_dilation - padding) % input_dilation != 0)
continue;
const int icol = (ocol * stride + kcol * kernel_dilation - padding) / input_dilation;
for (int kch = 0; kch < in_channels; kch++) {
const elem_t * in = input + (b * in_row_dim * in_col_dim + irow * in_col_dim + icol) * in_stride + kch;
if (trans_input_3120) {
// NHWC to CHWN
in = input + (kch * in_row_dim * in_col_dim + irow * in_col_dim + icol) * batch_size + b;
}
elem_t ipixel = irow < 0 || irow >= in_row_dim || icol < 0 || icol >= in_col_dim ?
0 : *in;
const int krow_ = wrot180 ? kernel_dim - krow - 1 : krow;
const int kcol_ = wrot180 ? kernel_dim - kcol - 1 : kcol;
elem_t weight = *(weights + (krow_ * kernel_dim * in_channels + kcol_ * in_channels + kch) * weight_stride + poch);
if (trans_weight_1203) {
// HWIO to WIHO
weight = *(weights + (kch * kernel_dim * kernel_dim + krow_ * kernel_dim + kcol_) * out_channels + poch);
} else if (trans_weight_0132) {
// HWIO to HWOI
weight = *(weights + (krow_ * kernel_dim * out_channels + kcol_ * out_channels + poch) * in_channels + kch);
}
opixel += weight * ipixel;
}
}
}
opixel = scale_and_sat(opixel, act, scale, 0);
if (!running_max_initialized || opixel > running_max) {
running_max = opixel;
running_max_initialized = true;
}
}
if (pwrow == pool_size - 1 && pwcol == pool_size - 1) {
elem_t * out = output + (b * pool_out_row_dim * pool_out_col_dim + porow * pool_out_col_dim + pocol) * out_stride + poch;
if (trans_output_1203) {
// NHWC to HWNC
out = output + (porow * pool_out_col_dim * batch_size + pocol * batch_size + b) * out_channels + poch;
}
*out = running_max;
}
}
}
}
}
}
}
}
static void conv_dw_cpu(
int batch_size, int in_row_dim, int in_col_dim,
int channels, int out_row_dim, int out_col_dim,
int stride, int padding, int kernel_dim,
const elem_t * input,
const elem_t * weights,
const acc_t * bias,
elem_t * output,
int act, acc_scale_t scale,
int pool_size, int pool_stride, int pool_padding) {
const bool no_pool = pool_stride == 0;
if (no_pool) {
conv_dw_cpu_without_pool(
batch_size, in_row_dim, in_col_dim,
channels, out_row_dim, out_col_dim,
stride, padding, kernel_dim,
input, weights, bias, output,
act, scale);
return;
}
const bool no_bias = bias == NULL;
const int pool_out_row_dim = (out_row_dim + 2 * pool_padding - pool_size) / pool_stride + 1;
const int pool_out_col_dim = (out_col_dim + 2 * pool_padding - pool_size) / pool_stride + 1;
for (int b = 0; b < batch_size; b++) {
for (int porow = 0; porow < pool_out_row_dim; porow++) {
for (int pocol = 0; pocol < pool_out_col_dim; pocol++) {
for (int ch = 0; ch < channels; ch++) {
elem_t running_max = 0;
bool running_max_initialized = false;
for (int pwrow = 0; pwrow < pool_size; pwrow++) {
const int orow = porow * pool_stride + pwrow - pool_padding;
for (int pwcol = 0; pwcol < pool_size; pwcol++) {
const int ocol = pocol * pool_stride + pwcol - pool_padding;
if (orow < 0 || orow >= out_row_dim || ocol < 0 || ocol >= out_col_dim) {
if (!running_max_initialized || running_max < 0) {
running_max = 0;
running_max_initialized = true;
}
} else {
acc_t opixel = no_bias ? 0 : bias[ch];
for (int krow = 0; krow < kernel_dim; krow++) {
const int irow = orow * stride + krow - padding;
for (int kcol = 0; kcol < kernel_dim; kcol++) {
const int icol = ocol * stride + kcol - padding;
const elem_t * in = input + (b * in_row_dim * in_col_dim + irow * in_col_dim + icol) * channels + ch;
elem_t ipixel = irow < 0 || irow >= in_row_dim || icol < 0 || icol >= in_col_dim ?
0 : *in;
const elem_t weight = *(weights + (ch * kernel_dim + krow) * kernel_dim + kcol);
opixel += weight * ipixel;
}
}
opixel = scale_and_sat(opixel, act, scale, 0);
if (!running_max_initialized || opixel > running_max) {
running_max = opixel;
running_max_initialized = true;
}
}
if (pwrow == pool_size - 1 && pwcol == pool_size - 1) {
elem_t * out = output + (b * pool_out_row_dim * pool_out_col_dim + porow * pool_out_col_dim + pocol) * channels + ch;
*out = running_max;
}
}
}
}
}
}
}
}
static void tiled_conv(
int batch_size,
int in_row_dim, int in_col_dim, int in_channels,
int out_channels, int out_row_dim, int out_col_dim,
int stride, int input_dilation, int kernel_dilation, int padding, int kernel_dim,
int in_stride, int weight_stride, int out_stride,
bool wrot180, bool trans_output_1203, bool trans_input_3120,
bool trans_weight_1203, bool trans_weight_0132,
int batches,
int porows, int pocols, int pochs,
int krows, int kcols, int kchs,
const elem_t * input,
const elem_t * weights,
const acc_t * bias,
elem_t * output,
int act, acc_scale_t scale,
int pool_size, int pool_stride, int pool_padding,
enum tiled_matmul_type_t tiled_conv_type) {
#ifdef GEMMINI_ASSERTIONS
if (trans_weight_1203 && trans_weight_0132) {
printf("Only one weight transformation can be applied at a time\n");
exit(1);
}
#endif
if (tiled_conv_type == CPU) {
if (pool_size == 1 && pool_stride == 1 && pool_padding == 0) {
pool_stride = 0;
}
// assume in_dim_rows = in_dim_cols
// and out_dim_rows = out_dim_cols for now
conv_cpu(
batch_size, in_row_dim, in_col_dim, in_channels,
out_channels, out_row_dim, out_col_dim,
stride, input_dilation, kernel_dilation, padding, kernel_dim,
in_stride, weight_stride, out_stride,
wrot180, trans_output_1203, trans_input_3120,
trans_weight_1203, trans_weight_0132,
input, weights, bias, output,
act, scale,
pool_size, pool_stride, pool_padding);
return;
} else if (tiled_conv_type == OS) {
printf("Gemmini convs do not currently support OS\n");
exit(1);
}
// TODO move everything below this into a tiled_conv_outer function to match the tiled_matmul function
bool no_bias = false;
if (bias == NULL) {
bias = (acc_t*)1;
no_bias = true;
}
bool no_pool = pool_stride == 0;
if (no_pool) {
pool_size = 1;
pool_stride = 1;
pool_padding = 0;
}
const bool downsample = stride == 2 && kernel_dim == 1 && in_row_dim % 2 == 0 && in_col_dim % 2 == 0
&& padding == 0 && no_pool && input_dilation == 1 && !trans_input_3120;
const int input_dilated = input_dilation == 2;
#ifdef GEMMINI_ASSERTIONS
{
// const int orows = porows * pool_stride + pool_size - 1;
// const int ocols = pocols * pool_stride + pool_size - 1;
// Check that data will fit in scratchpad
const int spad_rows = tiled_conv_total_spad_rows(false,
stride, input_dilation, kernel_dilation, downsample, trans_weight_0132, trans_input_3120,
batches, porows, pocols, pochs, krows, kcols, kchs, pool_size, pool_stride);
const int acc_rows = tiled_conv_total_spad_rows(true,
stride, input_dilation, kernel_dilation, downsample, trans_weight_0132, trans_input_3120,
batches, porows, pocols, pochs, krows, kcols, kchs, pool_size, pool_stride);
if (spad_rows > BANK_NUM * BANK_ROWS / 2) {
printf("not enough scratchpad space to store inputs and weights, %d\n", spad_rows);
exit(1);
}
if (acc_rows > ACC_ROWS / 2) {
printf("not enough accumulator space to store outputs\n");
exit(1);
}
if (kernel_dim <= padding) {
printf("kernel_dim must be larger than padding\n");
exit(1);
}
if (input_dilation > 2) {
printf("input_dilation > 2 is only supported on CPU\n");
exit(1);
}
if (input_dilation > 1 && stride > 1) {
printf("input input_dilation is only supported when stride == 1\n");
exit(1);
}
if (trans_output_1203 && !no_pool) {
printf("Output can only be transposed when pooling is disabled\n");
exit(1);
}
if (trans_input_3120 && trans_weight_0132) {
printf("Cannot transpose innermost dimensions of both inputs and weights on WS.\n");
exit(1);
}
}
#endif
const size_t st_dram_stride = trans_output_1203 ?
batch_size * out_channels * sizeof(elem_t) :
out_stride * sizeof(elem_t);
gemmini_extended_config_st(st_dram_stride, act, scale);
gemmini_extended3_config_ex(WEIGHT_STATIONARY, 0, 0, 0, input_dilation, stride >> downsample, trans_input_3120, trans_weight_0132, false);
const int pool_out_row_dim = (out_row_dim + 2 * pool_padding - pool_size) / pool_stride + 1;
const int pool_out_col_dim = (out_col_dim + 2 * pool_padding - pool_size) / pool_stride + 1;
const int dilated_in_row_dim = in_row_dim + (input_dilation - 1) * (in_row_dim- 1);
const int dilated_in_col_dim = in_col_dim + (input_dilation - 1) * (in_col_dim- 1);
size_t a_spad_id = 0;
size_t b_spad_id = 0;
int porow_end = pool_out_row_dim;
int porow_start = 0;
bool a_reuse = false;
bool b_reuse = false;
size_t num_kch = ceil_divide_int(in_channels, kchs);
size_t num_poch = ceil_divide_int(out_channels, pochs);
size_t num_b = ceil_divide_int(batch_size, batches);
size_t num_porow = ceil_divide_int((porow_end - porow_start), porows);
size_t num_pocol = ceil_divide_int(pool_out_col_dim, pocols);
size_t num_krow = ceil_divide_int(kernel_dim, krows);
size_t num_kcol = ceil_divide_int(kernel_dim, kcols);
// printf("num_kch: %d, num_poch: %d, num_b: %d, num_porow: %d, num_pocol: %d, num_krow: %d, num_kcol: %d\n", num_kch, num_poch, num_b, num_porow, num_pocol, num_krow, num_kcol);
if(num_kch * num_poch * num_krow * num_kcol <= 2)
b_reuse = true;
if(num_kch * num_krow * num_kcol * num_b * num_porow * num_pocol <= 2)
a_reuse = true;
for (int b = 0; b < batch_size; b += batches) {
for (int porow = porow_start; porow < porow_end; porow += porows) {
const int orow = porow * pool_stride - pool_padding;
for (int pocol = 0; pocol < pool_out_col_dim; pocol += pocols) {
const int ocol = pocol * pool_stride - pool_padding;
for (int poch = 0; poch < out_channels; poch += pochs) {
for (int krow = 0; krow < kernel_dim; krow += krows) {
const int orow_floored = orow < 0 ? 0 : orow;
int irow = orow_floored * stride + krow * kernel_dilation - padding;
for (int kcol = 0; kcol < kernel_dim; kcol += kcols) {
const int ocol_floored = ocol < 0 ? 0 : ocol;
int icol = ocol_floored * stride + kcol * kernel_dilation - padding;
for (int kch = 0; kch < in_channels; kch += kchs) {
if(a_reuse)
a_spad_id = (kch + krow + kcol + b + (porow - porow_start) + pocol) == 0 ? 1 : 2;
if(b_reuse)
b_spad_id = (kch + poch + krow + kcol) == 0 ? 1 : 2;
elem_t * out = output + (b * pool_out_row_dim * pool_out_col_dim + porow * pool_out_col_dim + pocol) * out_stride + poch;
if (trans_output_1203) {
out = output + (porow * pool_out_col_dim * batch_size + pocol * batch_size + b) * out_channels + poch;
}
if (krow + krows < kernel_dim ||
kcol + kcols < kernel_dim ||
kch + kchs < in_channels) {
out = NULL;
}
const acc_t * bias_ = bias + poch;
if (krow > 0 ||
kcol > 0 ||
kch > 0) {
bias_ = NULL;
}
const int batches_ = batch_size - b > batches ? batches : batch_size - b;
const int porows_ = pool_out_row_dim - porow > porows ? porows : pool_out_row_dim - porow;
const int pocols_ = pool_out_col_dim - pocol > pocols ? pocols : pool_out_col_dim - pocol;
const int pochs_ = out_channels - poch > pochs ? pochs : out_channels - poch;
const int krows_ = kernel_dim - krow > krows ? krows : kernel_dim - krow;
const int kcols_ = kernel_dim - kcol > kcols ? kcols : kernel_dim - kcol;
const int kchs_ = in_channels - kch > kchs ? kchs : in_channels - kch;
const int ocols_ = pocols_ * pool_stride + pool_size - 1;
const int orows_ = porows_ * pool_stride + pool_size - 1;
const int plpad = ocol < 0 ? -ocol : 0;
const int prpad = ocol + ocols_ > out_col_dim ? ocol + ocols_ - out_col_dim : 0;
const int pupad = orow < 0 ? -orow : 0;
const int pdpad = orow + orows_ > out_row_dim ? orow + orows_ - out_row_dim : 0;
const int dilated_krows_ = krows_ + (kernel_dilation - 1)*(krows_ - 1);
const int dilated_kcols_ = kcols_ + (kernel_dilation - 1)*(kcols_ - 1);
const int icols_ = (ocols_ - plpad - prpad) * stride + dilated_kcols_ - 1;
const int irows_ = (orows_ - pupad - pdpad) * stride + dilated_krows_ - 1;
int lpad = icol < 0 ? -icol : 0;
int rpad = icol + icols_ > dilated_in_col_dim ? icol + icols_ - dilated_in_col_dim : 0;
int upad = irow < 0 ? -irow : 0;
int dpad = irow + irows_ > dilated_in_row_dim ? irow + irows_ - dilated_in_row_dim : 0;
if (input_dilated) {
lpad += lpad == 0 && icol % 2 != 0;
rpad += rpad == 0 && (icol + icols_) % 2 != 1;
upad += upad == 0 && irow % 2 != 0;
dpad += dpad == 0 && (irow + irows_) % 2 != 1;
}
int krow_ = krow;
int kcol_ = kcol;
if (wrot180) {
krow_ = kernel_dim - krow - krows_;
kcol_ = kernel_dim - kcol - kcols_;
}
const elem_t * weights_slice = weights + (krow_*kernel_dim*in_channels + kcol_*in_channels + kch) * weight_stride + poch;
if (trans_weight_1203) {
weights_slice = weights + (kch*kernel_dim*kernel_dim + krow_*kernel_dim+kcol_) * out_channels + poch;
} else if (trans_weight_0132) {
weights_slice = weights + (krow_*kernel_dim*out_channels + kcol_*out_channels + poch) * in_channels + kch;
}
const elem_t * in = input + (b *in_row_dim * in_col_dim + ((irow+upad)>>input_dilated) * in_col_dim + ((icol+lpad)>>input_dilated)) * in_stride + kch;
if (trans_input_3120) {
in = input + (kch * in_row_dim * in_col_dim + ((irow+upad)>>input_dilated) * in_col_dim + ((icol+lpad)>>input_dilated)) * batch_size + b;
}
if(b_reuse && (pocol + (porow - porow_start) + b > 0)) weights_slice = NULL;
if(a_reuse && (poch > 0)) in = NULL;
//printf("a_reuse: %d, b_reuse: %d, a_spad_id: %d, b_spad_id: %d, in: %llu, weight: %llu \n", a_reuse, b_reuse, a_spad_id, b_spad_id, in, weights_slice);
sp_tiled_conv(
batch_size, in_row_dim, in_col_dim, in_channels,
out_channels, out_row_dim, out_col_dim,
pool_out_row_dim, pool_out_col_dim,
stride, padding, kernel_dim, kernel_dilation,
in_stride, weight_stride, out_stride,
pool_size, pool_stride, pool_padding,
batches_,
porows_, pocols_, pochs_,
krows_, kcols_, kchs_,
lpad, rpad, upad, dpad,
plpad, prpad, pupad, pdpad,
in,
weights_slice,
out,
bias_,
act, scale,
wrot180, trans_output_1203, trans_input_3120,
trans_weight_1203, trans_weight_0132,
no_bias, no_pool, downsample, input_dilated,
false, a_spad_id, b_spad_id);
}
}
}
}
}
}
}
}
static void tiled_conv_dw(
int batch_size, int in_row_dim, int in_col_dim,
int channels, int out_row_dim, int out_col_dim,
int stride, int padding, int kernel_dim,
int batches,
int porows, int pocols,
int krows, int kcols,
const elem_t * input,
const elem_t * weights,
const acc_t * bias,
elem_t * output,
int act, acc_scale_t scale,
int pool_size, int pool_stride, int pool_padding,
enum tiled_matmul_type_t tiled_conv_type) {
if (tiled_conv_type == CPU) {
if (pool_size == 1 && pool_stride == 1 && pool_padding == 0) {
pool_stride = 0;
}
conv_dw_cpu(
batch_size, in_row_dim, in_col_dim,
channels, out_row_dim, out_col_dim,
stride, padding, kernel_dim,
input, weights, bias, output,
act, scale,
pool_size, pool_stride, pool_padding);
return;
} else if (tiled_conv_type == OS) {
printf("Gemmini convs do not currently support OS\n");
exit(1);
}
// TODO move everything below this into a tiled_conv_outer function to match the tiled_matmul function
bool no_bias = false;
if (bias == NULL) {
bias = (acc_t*)1;
no_bias = true;
}
bool no_pool = pool_stride == 0;
if (no_pool) {
pool_size = 1;
pool_stride = 1;
pool_padding = 0;
}
#ifdef GEMMINI_ASSERTIONS
{
// const int orows = porows * pool_stride + pool_size - 1;
// const int ocols = pocols * pool_stride + pool_size - 1;
// Check that data will fit in scratchpad
const int spad_rows = tiled_conv_total_spad_rows(false,
stride, 1, 1, false, false, false,
batches, porows, pocols, 1, krows, kcols, 1, pool_size, pool_stride);
const int acc_rows = tiled_conv_total_spad_rows(true,
stride, 1, 1, false, false, false,
batches, porows, pocols, 1, krows, kcols, 1, pool_size, pool_stride);
if (spad_rows > BANK_NUM * BANK_ROWS / 2) {
printf("not enough scratchpad space to store inputs and weights, %d\n", spad_rows);
exit(1);
}
if (acc_rows > ACC_ROWS / 2) {
printf("not enough accumulator space to store outputs\n");
exit(1);
}
if (kernel_dim <= padding) {
printf("kernel_dim must be larger than padding\n");
exit(1);
}
}
#endif
const size_t st_dram_stride = channels * sizeof(elem_t);
gemmini_extended_config_st(st_dram_stride, act, scale);
gemmini_extended3_config_ex(WEIGHT_STATIONARY, 0, 0, 0, 1, stride, false, false, false);
const int pool_out_row_dim = (out_row_dim + 2 * pool_padding - pool_size) / pool_stride + 1;
const int pool_out_col_dim = (out_col_dim + 2 * pool_padding - pool_size) / pool_stride + 1;
for (int b = 0; b < batch_size; b += batches) {
for (int porow = 0; porow < pool_out_row_dim; porow += porows) {
const int orow = porow * pool_stride - pool_padding;
for (int pocol = 0; pocol < pool_out_col_dim; pocol += pocols) {
const int ocol = pocol * pool_stride - pool_padding;
for (int ch = 0; ch < channels; ch++) {
for (int krow = 0; krow < kernel_dim; krow += krows) {
const int orow_floored = orow < 0 ? 0 : orow;
int irow = orow_floored * stride + krow - padding;
for (int kcol = 0; kcol < kernel_dim; kcol += kcols) {
const int ocol_floored = ocol < 0 ? 0 : ocol;
int icol = ocol_floored * stride + kcol - padding;
elem_t * out = output + (b * pool_out_row_dim * pool_out_col_dim + porow * pool_out_col_dim + pocol) * channels + ch;
if (krow + krows < kernel_dim ||
kcol + kcols < kernel_dim) {
out = NULL;
}
const acc_t * bias_ = bias + ch;
if (krow > 0 ||
kcol > 0) {
bias_ = NULL;
}
const int batches_ = batch_size - b > batches ? batches : batch_size - b;
const int porows_ = pool_out_row_dim - porow > porows ? porows : pool_out_row_dim - porow;
const int pocols_ = pool_out_col_dim - pocol > pocols ? pocols : pool_out_col_dim - pocol;
const int krows_ = kernel_dim - krow > krows ? krows : kernel_dim - krow;
const int kcols_ = kernel_dim - kcol > kcols ? kcols : kernel_dim - kcol;
const int ocols_ = pocols_ * pool_stride + pool_size - 1;
const int orows_ = porows_ * pool_stride + pool_size - 1;
const int plpad = ocol < 0 ? -ocol : 0;
const int prpad = ocol + ocols_ > out_col_dim ? ocol + ocols_ - out_col_dim : 0;
const int pupad = orow < 0 ? -orow : 0;
const int pdpad = orow + orows_ > out_row_dim ? orow + orows_ - out_row_dim : 0;
const int icols_ = (ocols_ - plpad - prpad) * stride + kcols_ - 1;
const int irows_ = (orows_ - pupad - pdpad) * stride + krows_ - 1;
int lpad = icol < 0 ? -icol : 0;
int rpad = icol + icols_ > in_col_dim ? icol + icols_ - in_col_dim : 0;
int upad = irow < 0 ? -irow : 0;
int dpad = irow + irows_ > in_row_dim ? irow + irows_ - in_row_dim : 0;
const elem_t * weights_slice = weights + (ch*kernel_dim + krow) * kernel_dim + kcol;
const elem_t *in = input + (b * in_row_dim * in_col_dim + (irow+upad) * in_col_dim + (icol+lpad)) * channels + ch;
sp_tiled_conv(
batch_size, in_row_dim, in_col_dim, channels,
channels, out_row_dim, out_col_dim,
pool_out_row_dim, pool_out_col_dim,
stride, padding, kernel_dim, 1,
channels, 1, channels,
pool_size, pool_stride, pool_padding,
batches_,
porows_, pocols_, 1,
krows_, kcols_, 1,
lpad, rpad, upad, dpad,
plpad, prpad, pupad, pdpad,
in,
weights_slice,
out,
bias_,
act, scale,
false, false, false,
false, false,
no_bias, no_pool, false, false,
true, 0, 0);
}
}
}
}
}
}
}
// need to specify each operand/output's stride
// stride only for trans == false, wrot == false
static void tiled_conv_stride_auto(
int batch_size, int in_row_dim, int in_col_dim, int in_channels,
int out_channels, int out_row_dim, int out_col_dim,
int stride, int input_dilation, int kernel_dilation, int padding, int kernel_dim,
int in_stride, int weight_stride, int out_stride, // specify in/output's stride
bool wrot180, bool trans_output_1203, bool trans_input_3120,
bool trans_weight_1203, bool trans_weight_0132,
const elem_t * input,
const elem_t * weights,
const acc_t * bias,
elem_t * output,
int act, acc_scale_t scale,
int pool_size, int pool_stride, int pool_padding,
enum tiled_matmul_type_t tiled_conv_type) {
const bool no_pool = pool_stride == 0;
if (no_pool) {
pool_size = 1;
pool_stride = 1;
pool_padding = 0;
}
const int pool_out_row_dim = (out_row_dim + 2 * pool_padding - pool_size) / pool_stride + 1;
const int pool_out_col_dim = (out_col_dim + 2 * pool_padding - pool_size) / pool_stride + 1;
const bool downsample = stride == 2 && kernel_dim == 1 && padding == 0 && no_pool && in_row_dim % 2 == 0 && in_col_dim % 2 == 0;
// Tile convolution params
// int args[] = {batch_size, porows, pocols, pochs, krows, kcols, kchs};
int args[] = {batch_size, pool_out_row_dim, pool_out_col_dim, out_channels, kernel_dim, kernel_dim, in_channels};
const int max_args[] = {batch_size, pool_out_row_dim, pool_out_col_dim, out_channels, kernel_dim, kernel_dim, in_channels};
const int orows_idx = 1;
const int ocols_idx = 2;
const int out_channels_idx = 3;
const int in_channels_idx = 6;
// We divide by 2 for the sake of double-buffering
const int max_spad_rows = (BANK_NUM*BANK_ROWS / 2);
const int max_acc_rows = (ACC_ROWS / 2);
int spad_rows = tiled_conv_total_spad_rows(false,
stride, input_dilation, kernel_dilation, downsample, trans_weight_0132, trans_input_3120,
args[0], args[1], args[2], args[3], args[4], args[5], args[6], pool_size, pool_stride);
int acc_rows = tiled_conv_total_spad_rows(true,
stride, input_dilation, kernel_dilation, downsample, trans_weight_0132, trans_input_3120,
args[0], args[1], args[2], args[3], args[4], args[5], args[6], pool_size, pool_stride);
while (spad_rows > max_spad_rows || acc_rows > max_acc_rows) {
int max_val = -1;
int max_idx = -1;
for (size_t i = 0; i < sizeof(args)/sizeof(args[0]); i++) {
// We avoid reducing ocols when possible to keep the spatial array fully utilized
if (!(i == ocols_idx && args[i] <= DIM && args[orows_idx] > 1)
&& args[i] > max_val) {
max_val = args[i];
max_idx = i;
}
}
if (max_idx == out_channels_idx || max_idx == in_channels_idx) {
// For input and output channels, there's no point in subtracting by just one
if (args[max_idx] % DIM != 0) {
args[max_idx] = (args[max_idx] / DIM) * DIM;
} else {
args[max_idx] -= DIM;
}
args[max_idx] = args[max_idx] == 0 ? 1 : args[max_idx];
} else {
args[max_idx]--;
}
spad_rows = tiled_conv_total_spad_rows(false,
stride, input_dilation, kernel_dilation, downsample, trans_weight_0132, trans_input_3120,
args[0], args[1], args[2], args[3], args[4], args[5], args[6], pool_size, pool_stride);
acc_rows = tiled_conv_total_spad_rows(true,
stride, input_dilation, kernel_dilation, downsample, trans_weight_0132, trans_input_3120,
args[0], args[1], args[2], args[3], args[4], args[5], args[6], pool_size, pool_stride);
}
// Check if we can increase ocols
bool not_increased = false;
while (!not_increased) {
not_increased = true;
int args_candidate[] = {args[0], args[1], args[2], args[3], args[4], args[5], args[6]};
args_candidate[ocols_idx]++;
if (args_candidate[ocols_idx] > max_args[ocols_idx])
continue;
spad_rows = tiled_conv_total_spad_rows(false,
stride, input_dilation, kernel_dilation, downsample, trans_weight_0132, trans_input_3120,
args_candidate[0], args_candidate[1], args_candidate[2], args_candidate[3], args_candidate[4], args_candidate[5], args_candidate[6], pool_size, pool_stride);
acc_rows = tiled_conv_total_spad_rows(true,
stride, input_dilation, kernel_dilation, downsample, trans_weight_0132, trans_input_3120,
args_candidate[0], args_candidate[1], args_candidate[2], args_candidate[3], args_candidate[4], args_candidate[5], args_candidate[6], pool_size, pool_stride);
if (spad_rows <= max_spad_rows && acc_rows <= max_acc_rows) {
args[ocols_idx] = args_candidate[ocols_idx];
not_increased = false;
}
}
// Check if there are any parameters that we can currently still increase
bool nothing_increased = false;
while (!nothing_increased) {
nothing_increased = true;
for (size_t i = 0; i < sizeof(args)/sizeof(args[0]); i++) {
int args_candidate[] = {args[0], args[1], args[2], args[3], args[4], args[5], args[6]};
args_candidate[i]++;
if (args_candidate[i] > max_args[i])
continue;
spad_rows = tiled_conv_total_spad_rows(false,
stride, input_dilation, kernel_dilation, downsample, trans_weight_0132, trans_input_3120,
args_candidate[0], args_candidate[1], args_candidate[2], args_candidate[3], args_candidate[4], args_candidate[5], args_candidate[6], pool_size, pool_stride);
acc_rows = tiled_conv_total_spad_rows(true,
stride, input_dilation, kernel_dilation, downsample, trans_weight_0132, trans_input_3120,
args_candidate[0], args_candidate[1], args_candidate[2], args_candidate[3], args_candidate[4], args_candidate[5], args_candidate[6], pool_size, pool_stride);
if (spad_rows <= max_spad_rows && acc_rows <= max_acc_rows) {
args[i] = args_candidate[i];
nothing_increased = false;
}
}
}
const int batches = args[0];
const int orows = args[1];
const int ocols = args[2];
const int ochs = args[3];
const int krows = args[4];
const int kcols = args[5];
const int kchs = args[6];
/*
spad_rows = tiled_conv_total_spad_rows(false,
stride, input_dilation, kernel_dilation, downsample, trans_weight_0132, trans_input_3120,
args[0], args[1], args[2], args[3], args[4], args[5], args[6], pool_size, pool_stride);
acc_rows = tiled_conv_total_spad_rows(true,
stride, input_dilation, kernel_dilation, downsample, trans_weight_0132, trans_input_3120,
args[0], args[1], args[2], args[3], args[4], args[5], args[6], pool_size, pool_stride);
*/
#ifdef PRINT_TILE
#if PRINT_TILE
printf("batches = %d\n", batches);
printf("orows = %d\n", orows);
printf("ocols = %d\n", ocols);
printf("ochs = %d\n", ochs);
printf("krows = %d\n", krows);
printf("kcols = %d\n", kcols);
printf("kchs = %d\n\n", kchs);
printf("total spad_rows reserved: %d\n", spad_rows);
printf("total acc_rows reserved: %d\n\n", acc_rows);
printf("scratchpad row utilization: %d%%\n", (spad_rows*100) / max_spad_rows);
printf("accumulator row utilization: %d%%\n\n", (acc_rows*100) / max_acc_rows);
printf("inner matmul size: i=%d, j=%d, k=%d\n\n", ocols, ochs, kchs);
#endif
#endif
tiled_conv(
batch_size, in_row_dim, in_col_dim, in_channels,
out_channels, out_row_dim, out_col_dim,
stride, input_dilation, kernel_dilation, padding, kernel_dim,
in_stride, weight_stride, out_stride,
wrot180, trans_output_1203, trans_input_3120,
trans_weight_1203, trans_weight_0132,
batches,
orows, ocols, ochs,
krows, kcols, kchs,
input,
weights,
bias,
output,
act, scale,
pool_size, no_pool ? 0 : pool_stride, pool_padding,
tiled_conv_type);
}
static void tiled_conv_auto(
int batch_size, int in_row_dim, int in_col_dim, int in_channels,
int out_channels, int out_row_dim, int out_col_dim,
int stride, int input_dilation, int kernel_dilation, int padding, int kernel_dim,
bool wrot180, bool trans_output_1203, bool trans_input_3120,
bool trans_weight_1203, bool trans_weight_0132,
const elem_t * input,
const elem_t * weights,
const acc_t * bias,
elem_t * output,
int act, acc_scale_t scale,
int pool_size, int pool_stride, int pool_padding,
enum tiled_matmul_type_t tiled_conv_type) {
int in_stride = in_channels;
int out_stride = out_channels;
int weight_stride = out_channels;
tiled_conv_stride_auto(
batch_size, in_row_dim, in_col_dim, in_channels,
out_channels, out_row_dim, out_col_dim,
stride, input_dilation, kernel_dilation, padding, kernel_dim,
in_stride, weight_stride, out_stride,
wrot180, trans_output_1203, trans_input_3120,
trans_weight_1203, trans_weight_0132,
input, weights, bias, output,
act, scale, pool_size, pool_stride, pool_padding,
tiled_conv_type);
}
// This function is for a convolution with kernel_dim=1, stride==2, padding=0, and no pooling
static void tiled_conv_downsample(
int batch_size, int in_row_dim, int in_col_dim, int in_channels,
int out_channels, int out_row_dim, int out_col_dim,
int in_stride, int weight_stride, int out_stride,
const elem_t * input,
const elem_t * weights,
const acc_t * bias,
elem_t * output,
int act, acc_scale_t scale,
enum tiled_matmul_type_t tiled_conv_type) {
// Rectangular dimensions for this function are currently not supported
if (in_row_dim != in_col_dim || out_row_dim != out_col_dim) {
printf("Rectangular convolutions for tiled_conv_downsample are currently not supported.\n");
exit(1);
}
const int in_dim = in_row_dim;
const int out_dim = out_row_dim;
const int stride = 2;
for (int b = 0; b < batch_size; b++) {
for (int irow = 0; irow < in_row_dim; irow += stride) {
const int orow = irow / stride;
const int I = in_col_dim / stride; // number of columns in row
const int J = out_channels;
const int K = in_channels;
const elem_t * A = input + (b * in_dim + irow) * in_dim * in_stride;
const elem_t * B = weights;
const acc_t * D = bias;
elem_t * C = output + (b * out_dim + orow) * out_dim * out_stride;
const int A_stride = in_stride * 2;
const int B_stride = weight_stride;
const int D_stride = out_stride;
const int C_stride = out_stride;
tiled_matmul_auto(I, J, K, A, B, (void*)D, (void*)C,
A_stride, B_stride, D_stride, C_stride,
MVIN_SCALE_IDENTITY, MVIN_SCALE_IDENTITY,
MVIN_SCALE_IDENTITY, act, scale, 0,
true, false, false, false, false, 0, tiled_conv_type);
}
}
}
//for mobilenet's depthwise convs
static void tiled_conv_dw_auto(
int batch_size, int in_row_dim, int in_col_dim,
int channels, int out_row_dim, int out_col_dim,
int stride, int padding, int kernel_dim,
elem_t * input,
elem_t * weights,
acc_t * bias,
elem_t * output,
int act, acc_scale_t scale,
int pool_size, int pool_stride, int pool_padding,
enum tiled_matmul_type_t tiled_conv_type) {
const bool no_pool = pool_stride == 0;
if (no_pool) {
pool_size = 1;
pool_stride = 1;
pool_padding = 0;
}
const int pool_out_row_dim = (out_row_dim + 2 * pool_padding - pool_size) / pool_stride + 1;
const int pool_out_col_dim = (out_col_dim + 2 * pool_padding - pool_size) / pool_stride + 1;
// Tile convolution params
// int args[] = {batch_size, porows, pocols, pochs, krows, kcols, kchs};
int args[] = {batch_size, pool_out_row_dim, pool_out_col_dim, 1, kernel_dim, kernel_dim, 1};
const int max_args[] = {batch_size, pool_out_row_dim, pool_out_col_dim, 1, kernel_dim, kernel_dim, 1};
const int orows_idx = 1;
const int ocols_idx = 2;
const int out_channels_idx = 3;
// We divide by 2 for the sake of double-buffering
const int max_spad_rows = (BANK_NUM*BANK_ROWS / 2);
const int max_acc_rows = (ACC_ROWS / 2);
int spad_rows = tiled_conv_total_spad_rows(false,
stride, 1, 1, false, false, false,
args[0], args[1], args[2], args[3], args[4], args[5], args[6], pool_size, pool_stride);
int acc_rows = tiled_conv_total_spad_rows(true,
stride, 1, 1, false, false, false,
args[0], args[1], args[2], args[3], args[4], args[5], args[6], pool_size, pool_stride);
while (spad_rows > max_spad_rows || acc_rows > max_acc_rows) {
int max_val = -1;
int max_idx = -1;
for (size_t i = 0; i < sizeof(args)/sizeof(args[0]); i++) {
// We avoid reducing ocols when possible to keep the spatial array fully utilized
if (!(i == ocols_idx && args[i] <= DIM && args[orows_idx] > 1)
&& args[i] > max_val) {
max_val = args[i];
max_idx = i;
}
}
if (max_idx == out_channels_idx) {
// For input and output channels, there's no point in subtracting by just one
if (args[max_idx] % DIM != 0) {
args[max_idx] = (args[max_idx] / DIM) * DIM;
} else {
args[max_idx] -= DIM;
}
args[max_idx] = args[max_idx] == 0 ? 1 : args[max_idx];
} else {
args[max_idx]--;
}
spad_rows = tiled_conv_total_spad_rows(false,
stride, 1, 1, false, false, false,
args[0], args[1], args[2], args[3], args[4], args[5], args[6], pool_size, pool_stride);
acc_rows = tiled_conv_total_spad_rows(true,
stride, 1, 1, false, false, false,
args[0], args[1], args[2], args[3], args[4], args[5], args[6], pool_size, pool_stride);
}
// Check if we can increase ocols
bool not_increased = false;
while (!not_increased) {
not_increased = true;
int args_candidate[] = {args[0], args[1], args[2], args[3], args[4], args[5], args[6]};
args_candidate[ocols_idx]++;
if (args_candidate[ocols_idx] > max_args[ocols_idx])
continue;
spad_rows = tiled_conv_total_spad_rows(false,
stride, 1, 1, false, false, false,
args_candidate[0], args_candidate[1], args_candidate[2], args_candidate[3], args_candidate[4], args_candidate[5], args_candidate[6], pool_size, pool_stride);
acc_rows = tiled_conv_total_spad_rows(true,
stride, 1, 1, false, false, false,
args_candidate[0], args_candidate[1], args_candidate[2], args_candidate[3], args_candidate[4], args_candidate[5], args_candidate[6], pool_size, pool_stride);
if (spad_rows <= max_spad_rows && acc_rows <= max_acc_rows) {
args[ocols_idx] = args_candidate[ocols_idx];
not_increased = false;
}
}
// Check if there are any parameters that we can currently still increase
bool nothing_increased = false;
while (!nothing_increased) {
nothing_increased = true;
for (size_t i = 0; i < sizeof(args)/sizeof(args[0]); i++) {
int args_candidate[] = {args[0], args[1], args[2], args[3], args[4], args[5], args[6]};
args_candidate[i]++;
if (args_candidate[i] > max_args[i])
continue;
spad_rows = tiled_conv_total_spad_rows(false,
stride, 1, 1, false, false, false,
args_candidate[0], args_candidate[1], args_candidate[2], args_candidate[3], args_candidate[4], args_candidate[5], args_candidate[6], pool_size, pool_stride);
acc_rows = tiled_conv_total_spad_rows(true,
stride, 1, 1, false, false, false,
args_candidate[0], args_candidate[1], args_candidate[2], args_candidate[3], args_candidate[4], args_candidate[5], args_candidate[6], pool_size, pool_stride);
if (spad_rows <= max_spad_rows && acc_rows <= max_acc_rows) {
args[i] = args_candidate[i];
nothing_increased = false;
}
}
}
const int batches = args[0];
const int orows = args[1];
const int ocols = args[2];
const int ochs = 1; // args[3];
const int krows = args[4];
const int kcols = args[5];
const int kchs = 1; // args[6];
/*
spad_rows = tiled_conv_total_spad_rows(false,
stride, 1, 1, false, false, false,
args[0], args[1], args[2], args[3], args[4], args[5], args[6], pool_size, pool_stride);
acc_rows = tiled_conv_total_spad_rows(true,
stride, 1, 1, false, false, false,
args[0], args[1], args[2], args[3], args[4], args[5], args[6], pool_size, pool_stride);
printf("batches = %d\n", batches);
printf("orows = %d\n", orows);
printf("ocols = %d\n", ocols);
printf("ochs = %d\n", ochs);
printf("krows = %d\n", krows);
printf("kcols = %d\n", kcols);
printf("kchs = %d\n\n", kchs);
printf("total spad_rows reserved: %d\n", spad_rows);
printf("total acc_rows reserved: %d\n\n", acc_rows);
printf("scratchpad row utilization: %d%%\n", (spad_rows*100) / max_spad_rows);
printf("accumulator row utilization: %d%%\n\n", (acc_rows*100) / max_acc_rows);
printf("inner matmul size: i=%d, j=%d, k=%d\n\n", ocols, ochs, kchs);
*/
tiled_conv_dw(
batch_size, in_row_dim, in_col_dim,
channels, out_row_dim, out_col_dim,
stride, padding, kernel_dim,
batches,
orows, ocols,
krows, kcols,
input,
weights,
bias,
output,
act, scale,
pool_size, no_pool ? 0 : pool_stride, pool_padding,
tiled_conv_type);
}
static void resadd_cpu(const size_t I, const size_t J,
const size_t stride,
const scale_t A_scale,
const scale_t B_scale,
const acc_scale_t C_scale,
const elem_t * A,
const elem_t * B,
elem_t * C,
bool relu) {
const int minimum = relu ? 0 : elem_t_min;
for (size_t i = 0; i < I; i++) {
for (size_t j = 0; j < J; j++) {
const elem_t * a = A + i * stride + j;
const elem_t * b = B + i * stride + j;
elem_t * c = C + i * stride + j;
acc_t result = MVIN_SCALE(*a, A_scale) + MVIN_SCALE(*b, B_scale);
result = ACC_SCALE(result, C_scale);
result = result > elem_t_max ? elem_t_max :
(result < minimum ? minimum : result);
*c = result;
}
}
}
static void sp_tiled_resadd(const size_t I, const size_t J,
const scale_t A_scale,
const scale_t B_scale,
const elem_t * A, const elem_t * B, elem_t * C,
size_t A_row_stride, size_t B_row_stride, size_t C_row_stride,
bool relu) {
int pad_I = ((I%DIM) == 0) ? 0 : DIM - (I % DIM);
int pad_J = ((J%DIM) == 0) ? 0 : DIM - (J % DIM);
int tile_I = (I%DIM == 0) ? (int)(I/DIM) : (int)(I/DIM) + 1;
int tile_J = (J%DIM == 0) ? (int)(J/DIM) : (int)(J/DIM) + 1;
//printf("pad I: %d, pad_J: %d, tile_I: %d, tile_J: %d\n", pad_I, pad_J, tile_I, tile_J);
gemmini_loop_ws(tile_I, tile_J, 0, pad_I, pad_J, 0, A, B, NULL, C, A_row_stride, B_row_stride, 0, C_row_stride, false, false, false, false, false, relu, 0, 0, true);
/*
// Use the new mvin2 command to overlap mvin A, mvin B, and mvout C
size_t blocks = (J/DIM + (J % DIM != 0));
if (blocks > MAX_BLOCK_LEN) blocks = MAX_BLOCK_LEN;
const uint32_t D_sp_addr_start = 1 << (ADDR_LEN-1);
const uint32_t C_sp_addr_start = 3 << (ADDR_LEN-2);
const size_t rounded_up_J = (J / DIM + (J % DIM != 0)) * DIM;
// Mvin A
// printf("Mving A\n");
for (size_t i = 0; i < I; i += DIM) {
for (size_t j = 0; j < J; j += blocks * DIM) {
const size_t cols = j + blocks*DIM <= J ? blocks*DIM : J-j;
const size_t rows = i + DIM <= I ? DIM : I-i;
const elem_t * const A_dram_addr = A + i * A_row_stride + j;
const uint32_t A_sp_addr = D_sp_addr_start + i * (rounded_up_J/DIM) + j;
gemmini_extended_mvin(A_dram_addr, A_sp_addr, cols, rows);
}
}
// Mvin B
printf("Mving B\n");
for (size_t i = 0; i < I; i += DIM) {
for (size_t j = 0; j < J; j += blocks * DIM) {
const size_t cols = j + blocks*DIM <= J ? blocks*DIM : J-j;
const size_t rows = i + DIM <= I ? DIM : I-i;
const elem_t * const B_dram_addr = B + i * B_row_stride + j;
const uint32_t B_sp_addr = C_sp_addr_start + i * (rounded_up_J/DIM) + j;
gemmini_extended_mvin2(B_dram_addr, B_sp_addr, cols, rows);
}
}
// Mvout C from accumulator
// printf("Mvout C from accumulator\n");
for (size_t i = 0; i < I; i += DIM) {
for (size_t j = 0; j < J; j += blocks * DIM) {
const size_t cols = j + blocks*DIM <= J ? blocks*DIM : J-j;
const size_t rows = i + DIM <= I ? DIM : I-i;
elem_t * const C_dram_addr = C + i * C_row_stride + j;
const uint32_t C_sp_addr = D_sp_addr_start + i * (rounded_up_J/DIM) + j;
gemmini_extended_mvout(C_dram_addr, C_sp_addr, cols, rows);
}
}
*/
}
// Compute MVIN_SCALE(A, A_scale) + MVIN_SCALE(B, B_scale) = C
static void tiled_resadd(const size_t I, const size_t J,
const size_t stride,
const size_t tile_I, const size_t tile_J,
const scale_t A_scale,
const scale_t B_scale,
const acc_scale_t C_scale,
const elem_t * A,
const elem_t * B,
elem_t * C,
bool relu,
enum tiled_matmul_type_t matadd_type) {
gemmini_extended_config_st(stride * sizeof(elem_t), relu ? RELU : NO_ACTIVATION, C_scale);
gemmini_config_ex(WS, 0, 0);
gemmini_extended4_config_ld(stride * sizeof(elem_t), A_scale, true, DIM, 0);
gemmini_extended4_config_ld(stride * sizeof(elem_t), B_scale, true, DIM, 1);
for (size_t i = 0; i < I; i += tile_I) {
for (size_t j = 0; j < J; j += tile_J) {
const size_t I_tile = i + tile_I <= I ? tile_I : I - i;
const size_t J_tile = j + tile_J <= J ? tile_J : J - j;
const elem_t * a = A + i * stride + j;
const elem_t * b = B + i * stride + j;
elem_t * c = C + i * stride + j;
sp_tiled_resadd(I_tile, J_tile,
A_scale, B_scale, a, b, c,
stride, stride, stride,
relu);
}
}
gemmini_fence();
}
// Compute (A >> A_shift) + B = C
// specify stride
static void tiled_resadd_stride_auto(const size_t I, const size_t J,
const scale_t A_scale,
const scale_t B_scale,
const acc_scale_t C_scale,
const size_t stride,
const elem_t * A,
const elem_t * B,
elem_t * C,
bool relu,
enum tiled_matmul_type_t matadd_type) {
if (matadd_type == CPU) {
resadd_cpu(I, J, stride,
A_scale, B_scale, C_scale, A, B, C,
relu);
return;
}
size_t tile_I = I, tile_J = J;
// size_t total_spad_rows = 2 * (tile_I / DIM + (tile_I % DIM != 0))*DIM * (tile_J / DIM + (tile_J % DIM != 0));
size_t total_acc_rows = (tile_I / DIM + (tile_I % DIM != 0))*DIM * (tile_J / DIM + (tile_J % DIM != 0));
// TODO this is a very inefficient way of doing this...
while (total_acc_rows > ACC_ROWS / 2) {
//if(tile_J > MAX_BLOCK_LEN * DIM)
// tile_J = MAX_BLOCK_LEN * DIM;
//else
if (tile_I >= tile_J || tile_J <= DIM)
tile_I /= 2;
else
tile_J -= DIM;
total_acc_rows = (tile_I / DIM + (tile_I % DIM != 0))*DIM * (tile_J / DIM + (tile_J % DIM != 0));
}
// printf("tile_I: %llu\n", tile_I);
// printf("tile_J: %llu\n", tile_J);
if (matadd_type == WS) {
tiled_resadd(I, J, stride, tile_I, tile_J,
A_scale, B_scale, C_scale, A, B, C,
relu, matadd_type);
}
else {
printf("Unsupported type\n");
exit(1);
}
}
static void tiled_resadd_auto(const size_t I, const size_t J,
const scale_t A_scale,
const scale_t B_scale,
const acc_scale_t C_scale,
const elem_t * A,
const elem_t * B,
elem_t * C,
bool relu,
enum tiled_matmul_type_t matadd_type) {
tiled_resadd_stride_auto(I, J,
A_scale, B_scale, C_scale,
J,
A, B, C,
relu, matadd_type);
}
static void global_average_cpu(const elem_t * input, elem_t * output,
int batches, int channels, int dim) {
const int count = dim * dim;
for (int batch = 0; batch < batches; batch++) {
for (int channel = 0; channel < channels; channel++) {
acc_t sum = 0;
for (int row = 0; row < dim; row++) {
for (int col = 0; col < dim; col++) {
size_t pixel = batch * dim * dim + row * dim + col;
sum += input[pixel * channels + channel];
}
}
#ifdef ELEM_T_IS_FLOAT
output[batch * channels + channel] = sum / count;
#else
output[batch * channels + channel] = (sum + count/2) / count;
#endif
}
}
}
static void sp_tiled_global_average(const elem_t * input, elem_t * output,
int batches, int channels, int dim, int channel_tile_size) {
const uint32_t C_acc_addr_start = ((uint32_t)1 << 31);
size_t blocks = channel_tile_size/DIM + (channel_tile_size % DIM != 0);
if (blocks > MAX_BLOCK_LEN) blocks = MAX_BLOCK_LEN;
for (int channel = 0; channel < channel_tile_size; channel += blocks*DIM) {
for (int row = 0; row < dim; row++) {
for (int col = 0; col < dim; col++) {
const elem_t * in = input +
(row * dim + col) * channels +
channel;
const uint32_t acc_addr_start = C_acc_addr_start |
((row != 0 || col != 0) << 30);
const uint32_t acc_addr = acc_addr_start + channel / DIM;
const size_t cols = channel + blocks*DIM <= channel_tile_size ?
blocks*DIM : channel_tile_size - channel;
const size_t rows = 1;
gemmini_extended_mvin(in, acc_addr, cols, rows);
}
}
}
for (int channel = 0; channel < channel_tile_size; channel += DIM) {
elem_t * out = output + channel;
const uint32_t acc_addr = C_acc_addr_start + channel / DIM;
const size_t cols = channel + DIM <= channel_tile_size ?
DIM : channel_tile_size - channel;
const size_t rows = 1; // TODO we should move out more than just one row here
gemmini_extended_mvout(out, acc_addr, cols, rows);
}
}
static void tiled_global_average(const elem_t * input, elem_t * output,
int batches, int channels, int dim,
int channel_tile_size) {
gemmini_extended4_config_ld(DIM*sizeof(elem_t), MVIN_SCALE_IDENTITY, true, 1, 0);
gemmini_config_ex(0, NO_ACTIVATION, 0);
gemmini_extended_config_st(0, NO_ACTIVATION, 1.0 / (dim*dim));
for (int batch = 0; batch < batches; batch++) {
for (int channel = 0; channel < channels; channel += channel_tile_size) {
const int tile_size = channel + channel_tile_size <= channels ?
channel_tile_size : channels - channel;
sp_tiled_global_average(input + batch * dim * dim * channels + channel,
output + batch * channels + channel,
batches, channels, dim, tile_size);
}
}
}
static void tiled_global_average_auto(const elem_t * input, elem_t * output,
int batches, int channels, int dim,
enum tiled_matmul_type_t type) {
if (type == CPU) {
return global_average_cpu(input, output, batches, channels, dim);
}
int channel_tile_size = channels;
int acc_rows = channel_tile_size / DIM + (channel_tile_size % DIM != 0);
while (acc_rows > ACC_ROWS) {
channel_tile_size--;
acc_rows = channel_tile_size / DIM + (channel_tile_size % DIM != 0);
}
tiled_global_average(input, output, batches, channels, dim,
channel_tile_size);
}
static void sp_tiled_norm(const size_t I, const size_t J,
const acc_t * in, elem_t * out,
size_t A_row_stride, size_t C_row_stride,
int act) {
#ifdef HAS_NORMALIZATIONS
size_t A_blocks = (J/DIM + (J % DIM != 0));
if (A_blocks > MAX_BLOCK_LEN_ACC) A_blocks = MAX_BLOCK_LEN_ACC;
size_t C_blocks = (J/DIM + (J % DIM != 0));
if (C_blocks > MAX_BLOCK_LEN) C_blocks = MAX_BLOCK_LEN;
const uint32_t D_sp_addr_start = 1 << (ADDR_LEN-1);
const uint32_t C_sp_addr_start = 3 << (ADDR_LEN-2);
const size_t rounded_up_J = (J / DIM + (J % DIM != 0)) * DIM;
for (size_t i = 0; i < I; i += DIM) {
// Mvin
for (size_t j = 0; j < J; j += A_blocks * DIM) {
const size_t cols = j + A_blocks*DIM <= J ? A_blocks*DIM : J-j;
const size_t rows = i + DIM <= I ? DIM : I-i;
const acc_t * const A_dram_addr = in + i * A_row_stride + j;
const uint32_t A_sp_addr = D_sp_addr_start + i * (rounded_up_J/DIM) + j;
gemmini_extended_mvin(A_dram_addr, A_sp_addr, cols, rows);
}
// Mvout
if (act == LAYERNORM) {
uint32_t norm_cmds[][2] = {{1,2},{3,4},{0,0}};
const int norm_cmds_size = sizeof(norm_cmds) / sizeof(norm_cmds[0]);
const size_t rows = I - i < DIM ? I - i : DIM;
for (size_t row = 0; row < rows; row += NORM_STAT_IDS) {
const size_t stat_ids = rows - row > NORM_STAT_IDS ?
NORM_STAT_IDS : rows - row;
for (int cmd = 0; cmd < norm_cmds_size; cmd++) {
for (size_t stat_id = 0; stat_id < stat_ids; stat_id++) {
gemmini_config_norm(0, 0, 0, 0, stat_id, 0, 0);
const size_t r = row + stat_id;
for (size_t jj = 0; jj < J; jj += C_blocks * DIM) {
uint32_t norm_C_sp_addr = C_sp_addr_start + i * (rounded_up_J/DIM) + jj + r;
if (jj + C_blocks*DIM >= J) {
norm_C_sp_addr |= (norm_cmds[cmd][1] << 26); // Final mean/inv-std-dev calculation
} else {
norm_C_sp_addr |= (norm_cmds[cmd][0] << 26); // Accumulate sum/variance
}
void * const C_dram_addr = (int8_t*)out +
(i*C_row_stride + jj) * sizeof(elem_t) +
r * C_row_stride * sizeof(elem_t);
const size_t cols = J - jj < C_blocks * DIM ? J - jj : C_blocks * DIM;
gemmini_extended_mvout(C_dram_addr, norm_C_sp_addr, cols, 1);
}
}
}
}
} else if (act == SOFTMAX) {
uint32_t norm_cmds[][2] = {{5,5},{6,7},{0,0}};
const int norm_cmds_size = sizeof(norm_cmds) / sizeof(norm_cmds[0]);
const size_t rows = I - i < DIM ? I - i : DIM;
for (size_t row = 0; row < rows; row += NORM_STAT_IDS) {
const size_t stat_ids = rows - row > NORM_STAT_IDS ?
NORM_STAT_IDS : rows - row;
for (int cmd = 0; cmd < norm_cmds_size; cmd++) {
for (size_t stat_id = 0; stat_id < stat_ids; stat_id++) {
// set stat id only
gemmini_config_norm(0, 0, 1, 0, stat_id, 0, 0);
const size_t r = row + stat_id;
for (size_t jj = 0; jj < J; jj += C_blocks * DIM) {
uint32_t norm_C_sp_addr = C_sp_addr_start + i * (rounded_up_J/DIM) + jj + r;
if (jj + C_blocks*DIM >= J) {
norm_C_sp_addr |= (norm_cmds[cmd][1] << 26); // Final mean/inv-std-dev calculation
} else {
norm_C_sp_addr |= (norm_cmds[cmd][0] << 26); // Accumulate sum/variance
}
void * const C_dram_addr = (int8_t*)out +
(i*C_row_stride + jj) * sizeof(elem_t) +
r * C_row_stride * sizeof(elem_t);
const size_t cols = J - jj < C_blocks * DIM ? J - jj : C_blocks * DIM;
gemmini_extended_mvout(C_dram_addr, norm_C_sp_addr, cols, 1);
}
}
}
}
}
}
#else
printf("Normalizations not supported in this Gemmini config\n");
exit(1);
#endif
}
static void tiled_norm(const size_t I, const size_t J,
const size_t tile_I, const size_t tile_J,
const acc_t * in,
elem_t * out,
const acc_scale_t C_scale,
int act,
enum tiled_matmul_type_t norm_type) {
gemmini_extended_config_st(J * sizeof(elem_t), act & 3, C_scale);
gemmini_config_ex(WS, 0, 0); // TODO is this actually required?
gemmini_extended4_config_ld(J * sizeof(acc_t), MVIN_SCALE_IDENTITY, false, DIM, 0);
gemmini_extended4_config_ld(J * sizeof(acc_t), MVIN_SCALE_IDENTITY, false, DIM, 1);
if (act == SOFTMAX) {
const scale_t a = 0.3585;
const scale_t b = 1.353;
const scale_t c = 0.344;
// TODO let bert-scale be set by the programmer
acc_scale_t bert_scale = 0.05;
const acc_t qln2 = (int) (0.693147 / bert_scale);
const acc_t qln2_inv = 65536 / qln2;
const acc_t qb = b / bert_scale;
const acc_t qc = c / (a*bert_scale*bert_scale);
gemmini_config_norm(qln2, 0, 0, 1, 0, qb, qc);
gemmini_config_norm(qln2_inv, 1, 0, 1, 0, qb, qc);
}
for (size_t i = 0; i < I; i += tile_I) {
for (size_t j = 0; j < J; j += tile_J) {
const size_t I_tile = i + tile_I <= I ? tile_I : I - i;
const size_t J_tile = j + tile_J <= J ? tile_J : J - j;
const acc_t * in_ = in + i * J + j;
elem_t * out_ = out + i * J + j;
sp_tiled_norm(I_tile, J_tile,
in_, out_,
J, J,
act);
}
}
gemmini_fence();
}
static void tiled_norm_auto(const size_t I, const size_t J,
const acc_t * in,
elem_t * out,
const acc_scale_t C_scale,
int act,
enum tiled_matmul_type_t norm_type) {
size_t tile_I = I, tile_J = J;
size_t total_acc_rows = (tile_I / DIM + (tile_I % DIM != 0))*DIM * (tile_J / DIM + (tile_J % DIM != 0));
while (total_acc_rows > ACC_ROWS) {
if (tile_I > 1) {
tile_I--;
} else {
// TODO we should be able to tile over J as well to avoid this issue
printf("Can't fit pre-normalized tensor into accumulator");
exit(1);
}
total_acc_rows = (tile_I / DIM + (tile_I % DIM != 0))*DIM * (tile_J / DIM + (tile_J % DIM != 0));
}
if (norm_type) {
tiled_norm(I, J, tile_I, tile_J,
in, out,
C_scale, act, norm_type);
} else {
printf("Unsupported type\n");
exit(1);
}
}
#undef abs
#endif // SRC_MAIN_C_GEMMINI_H