Files
kernels/lib/gemmini/include/gemmini_testutils.h
2025-01-29 21:22:41 -08:00

286 lines
8.1 KiB
C

// See LICENSE for license details.
#ifndef SRC_MAIN_C_GEMMINI_TESTUTILS_H
#define SRC_MAIN_C_GEMMINI_TESTUTILS_H
#undef abs
#include <stdint.h>
#include <stdlib.h>
#include <stdio.h>
#include <math.h>
#include <limits.h>
#include <stdbool.h>
#include "include/gemmini_params.h"
#include "include/gemmini.h"
#ifdef BAREMETAL
#undef assert
#define assert(expr) \
if (!(expr)) { \
printf("Failed assertion: " #expr "\n " __FILE__ ":%u\n", __LINE__); \
exit(1); \
}
#endif
// #define GEMMINI_ASSERTIONS
// Matmul utility functions
static void matmul(elem_t A[DIM][DIM], elem_t B[DIM][DIM], elem_t D[DIM][DIM], full_t C_full[DIM][DIM]) {
for (size_t r = 0; r < DIM; r++)
for (size_t c = 0; c < DIM; c++) {
C_full[r][c] = D[r][c];
for (size_t k = 0; k < DIM; k++)
C_full[r][c] += A[r][k]*B[k][c];
}
}
static void matmul_short(elem_t A[DIM][DIM], elem_t B[DIM][DIM], elem_t D[DIM][DIM], elem_t C[DIM][DIM]) {
for (size_t r = 0; r < DIM; r++)
for (size_t c = 0; c < DIM; c++) {
C[r][c] = D[r][c];
for (size_t k = 0; k < DIM; k++)
C[r][c] += A[r][k]*B[k][c];
}
}
static void matmul_full(elem_t A[DIM][DIM], elem_t B[DIM][DIM], full_t D[DIM][DIM], full_t C_full[DIM][DIM]) {
// Identical to the other matmul function, but with a 64-bit bias
for (size_t r = 0; r < DIM; r++)
for (size_t c = 0; c < DIM; c++) {
C_full[r][c] = D[r][c];
for (size_t k = 0; k < DIM; k++)
C_full[r][c] += A[r][k]*B[k][c];
}
}
static void matmul_A_transposed(elem_t A[DIM][DIM], elem_t B[DIM][DIM], elem_t D[DIM][DIM], full_t C_full[DIM][DIM]) {
for (size_t r = 0; r < DIM; r++)
for (size_t c = 0; c < DIM; c++) {
C_full[r][c] = D[r][c];
for (size_t k = 0; k < DIM; k++)
C_full[r][c] += A[k][r]*B[k][c];
}
}
static void matmul_short_A_transposed(elem_t A[DIM][DIM], elem_t B[DIM][DIM], elem_t D[DIM][DIM], elem_t C[DIM][DIM]) {
for (size_t r = 0; r < DIM; r++)
for (size_t c = 0; c < DIM; c++) {
C[r][c] = D[r][c];
for (size_t k = 0; k < DIM; k++)
C[r][c] += A[k][r]*B[k][c];
}
}
static void matmul_full_A_transposed(elem_t A[DIM][DIM], elem_t B[DIM][DIM], full_t D[DIM][DIM], full_t C_full[DIM][DIM]) {
for (size_t r = 0; r < DIM; r++)
for (size_t c = 0; c < DIM; c++) {
C_full[r][c] = D[r][c];
for (size_t k = 0; k < DIM; k++)
C_full[r][c] += A[k][r]*B[k][c];
}
}
static void matmul_B_transposed(elem_t A[DIM][DIM], elem_t B[DIM][DIM], elem_t D[DIM][DIM], full_t C_full[DIM][DIM]) {
for (size_t r = 0; r < DIM; r++)
for (size_t c = 0; c < DIM; c++) {
C_full[r][c] = D[r][c];
for (size_t k = 0; k < DIM; k++)
C_full[r][c] += A[r][k]*B[c][k];
}
}
static void matmul_short_B_transposed(elem_t A[DIM][DIM], elem_t B[DIM][DIM], elem_t D[DIM][DIM], elem_t C[DIM][DIM]) {
for (size_t r = 0; r < DIM; r++)
for (size_t c = 0; c < DIM; c++) {
C[r][c] = D[r][c];
for (size_t k = 0; k < DIM; k++)
C[r][c] += A[r][k]*B[c][k];
}
}
static void matmul_full_B_transposed(elem_t A[DIM][DIM], elem_t B[DIM][DIM], full_t D[DIM][DIM], full_t C_full[DIM][DIM]) {
for (size_t r = 0; r < DIM; r++)
for (size_t c = 0; c < DIM; c++) {
C_full[r][c] = D[r][c];
for (size_t k = 0; k < DIM; k++)
C_full[r][c] += A[r][k]*B[c][k];
}
}
static void matmul_AB_transposed(elem_t A[DIM][DIM], elem_t B[DIM][DIM], elem_t D[DIM][DIM], full_t C_full[DIM][DIM]) {
for (size_t r = 0; r < DIM; r++)
for (size_t c = 0; c < DIM; c++) {
C_full[r][c] = D[r][c];
for (size_t k = 0; k < DIM; k++)
C_full[r][c] += A[k][r]*B[c][k];
}
}
static void matmul_short_AB_transposed(elem_t A[DIM][DIM], elem_t B[DIM][DIM], elem_t D[DIM][DIM], elem_t C[DIM][DIM]) {
for (size_t r = 0; r < DIM; r++)
for (size_t c = 0; c < DIM; c++) {
C[r][c] = D[r][c];
for (size_t k = 0; k < DIM; k++)
C[r][c] += A[k][r]*B[c][k];
}
}
static void matmul_full_AB_transposed(elem_t A[DIM][DIM], elem_t B[DIM][DIM], full_t D[DIM][DIM], full_t C_full[DIM][DIM]) {
for (size_t r = 0; r < DIM; r++)
for (size_t c = 0; c < DIM; c++) {
C_full[r][c] = D[r][c];
for (size_t k = 0; k < DIM; k++)
C_full[r][c] += A[k][r]*B[c][k];
}
}
static void matadd(full_t sum[DIM][DIM], full_t m1[DIM][DIM], full_t m2[DIM][DIM]) {
for (size_t r = 0; r < DIM; r++)
for (size_t c = 0; c < DIM; c++)
sum[r][c] = m1[r][c] + m2[r][c];
}
// THIS IS A ROUNDING SHIFT! It also performs a saturating cast
static void matshift(full_t full[DIM][DIM], elem_t out[DIM][DIM], int shift) {
for (size_t r = 0; r < DIM; r++)
for (size_t c = 0; c < DIM; c++) {
// Bitshift and round element
full_t shifted = ROUNDING_RIGHT_SHIFT(full[r][c], shift);
// Saturate and cast element
#ifndef ELEM_T_IS_FLOAT
full_t elem = shifted > elem_t_max ? elem_t_max : (shifted < elem_t_min ? elem_t_min : shifted);
out[r][c] = elem;
#else
out[r][c] = shifted; // TODO should we also saturate when using floats?
#endif
}
}
static void matscale(full_t full[DIM][DIM], elem_t out[DIM][DIM], acc_scale_t scale) {
for (size_t r = 0; r < DIM; r++)
for (size_t c = 0; c < DIM; c++) {
// Bitshift and round element
full_t scaled = ACC_SCALE(full[r][c], scale);
// Saturate and cast element
#ifndef ELEM_T_IS_FLOAT
full_t elem = scaled > elem_t_max ? elem_t_max : (scaled < elem_t_min ? elem_t_min : scaled);
out[r][c] = elem;
#else
out[r][c] = scaled; // TODO should we also saturate when using floats?
#endif
}
}
static void matrelu(elem_t in[DIM][DIM], elem_t out[DIM][DIM]) {
for (size_t r = 0; r < DIM; r++)
for (size_t c = 0; c < DIM; c++)
out[r][c] = in[r][c] > 0 ? in[r][c] : 0;
}
static void transpose(elem_t in[DIM][DIM], elem_t out[DIM][DIM]) {
for (size_t r = 0; r < DIM; r++)
for (size_t c = 0; c < DIM; c++)
out[c][r] = in[r][c];
}
int rand() {
static uint32_t x = 777;
x = x * 1664525 + 1013904223;
return x >> 24;
}
#ifdef ELEM_T_IS_FLOAT
double rand_double() {
double a = (double)(rand() % 128) / (double)(1 + (rand() % 64));
double b = (double)(rand() % 128) / (double)(1 + (rand() % 64));
return a - b;
}
#endif
static void printMatrix(elem_t m[DIM][DIM]) {
for (size_t i = 0; i < DIM; ++i) {
for (size_t j = 0; j < DIM; ++j)
#ifndef ELEM_T_IS_FLOAT
printf("%d ", m[i][j]);
#else
printf("%x ", elem_t_to_elem_t_bits(m[i][j]));
#endif
printf("\n");
}
}
static void printMatrixAcc(acc_t m[DIM][DIM]) {
for (size_t i = 0; i < DIM; ++i) {
for (size_t j = 0; j < DIM; ++j)
#ifndef ELEM_T_IS_FLOAT
printf("%d ", m[i][j]);
#else
printf("%x ", acc_t_to_acc_t_bits(m[i][j]));
#endif
printf("\n");
}
}
static int is_equal(elem_t x[DIM][DIM], elem_t y[DIM][DIM]) {
for (size_t i = 0; i < DIM; ++i)
for (size_t j = 0; j < DIM; ++j) {
#ifndef ELEM_T_IS_FLOAT
if (x[i][j] != y[i][j])
#else
bool isnanx = elem_t_isnan(x[i][j]);
bool isnany = elem_t_isnan(y[i][j]);
if (x[i][j] != y[i][j] && !(isnanx && isnany))
#endif
return 0;
}
return 1;
}
static int is_equal_transposed(elem_t x[DIM][DIM], elem_t y[DIM][DIM]) {
for (size_t i = 0; i < DIM; ++i)
for (size_t j = 0; j < DIM; ++j) {
#ifndef ELEM_T_IS_FLOAT
if (x[i][j] != y[j][i])
#else
bool isnanx = elem_t_isnan(x[i][j]);
bool isnany = elem_t_isnan(y[j][i]);
if (x[i][j] != y[j][i] && !(isnanx && isnany))
#endif
return 0;
}
return 1;
}
// This is a GNU extension known as statment expressions
#define MAT_IS_EQUAL(dim_i, dim_j, x, y) \
({int result = 1; \
for (size_t i = 0; i < dim_i; i++) \
for (size_t j = 0; j < dim_j; ++j) { \
if (x[i][j] != y[i][j]) { \
result = 0; \
break; \
} \
} \
result;})
static uint64_t read_cycles() {
uint64_t cycles;
asm volatile ("rdcycle %0" : "=r" (cycles));
return cycles;
// const uint32_t * mtime = (uint32_t *)(33554432 + 0xbff8);
// const uint32_t * mtime = (uint32_t *)(33554432 + 0xbffc);
// return *mtime;
}
#undef abs
#endif // SRC_MAIN_C_GEMMINI_TESTUTILS_H