seemingly working fp32 implementation

This commit is contained in:
joshua
2024-03-19 17:56:59 -07:00
parent beb3dce46d
commit 978dd3bdfe
9 changed files with 4450 additions and 0 deletions

View File

@@ -0,0 +1,89 @@
DESTDIR ?= .
RTL_DIR = ../../rtl
DPI_DIR = $(abspath ../../dpi)
SIM_DIR = ../../../sim
THIRD_PARTY_DIR = $(abspath ../../../third_party)
CONFIGS +=
PARAMS +=
CXXFLAGS += -std=c++17 -Wall -Wextra -Wfatal-errors -Wno-array-bounds
CXXFLAGS += -fPIC -Wno-maybe-uninitialized
CXXFLAGS += -fcoroutines
CXXFLAGS += -I../../.. -I../../common -I../../../../sim/common
CXXFLAGS += -I/$(THIRD_PARTY_DIR)/softfloat/source/include
CXXFLAGS += -I/$(DPI_DIR)
CXXFLAGS += $(CONFIGS)
LDFLAGS += $(THIRD_PARTY_DIR)/softfloat/build/Linux-x86_64-GCC/softfloat.a
# control RTL debug tracing states
DBG_TRACE_FLAGS += -DDBG_TRACE_CACHE_BANK
DBG_TRACE_FLAGS += -DDBG_TRACE_CACHE_MSHR
DBG_TRACE_FLAGS += -DDBG_TRACE_CACHE_TAG
DBG_TRACE_FLAGS += -DDBG_TRACE_CACHE_DATA
DBG_FLAGS += -DDEBUG_LEVEL=$(DEBUG) -DVCD_OUTPUT $(DBG_TRACE_FLAGS)
RTL_PKGS = $(RTL_DIR)/VX_gpu_pkg.sv
RTL_INCLUDE = -I$(RTL_DIR) -I$(DPI_DIR) -I$(RTL_DIR)/libs -I$(RTL_DIR)/fpu
# SRCS = cachesim.cpp testbench.cpp
SRCS += $(DPI_DIR)/util_dpi.cpp
SRCS += $(DPI_DIR)/float_dpi.cpp
SRCS += $(SIM_DIR)/common/rvfloats.cpp
SRCS += ./main.cpp
RTL_SRCS += $(RTL_DIR)/fpu/VX_tensor_core.sv
RTL_SRCS += $(RTL_DIR)/fpu/VX_tensor_tb.sv
TOP = VX_tensor_tb
VL_FLAGS = --exe
VL_FLAGS += --language 1800-2009 # -Wall -Wpedantic # --assert
VL_FLAGS += -Wno-DECLFILENAME -Wno-REDEFMACRO
VL_FLAGS += --x-initial unique --x-assign unique
VL_FLAGS += -DSIMULATION -DSV_DPI
VL_FLAGS += $(CONFIGS)
VL_FLAGS += $(PARAMS)
VL_FLAGS += $(RTL_INCLUDE)
VL_FLAGS += $(RTL_PKGS)
VL_FLAGS += --cc $(TOP) --top-module $(TOP)
VL_FLAGS += --timing
# Enable Verilator multithreaded simulation
THREADS ?= $(shell python -c 'import multiprocessing as mp; print(mp.cpu_count())')
VL_FLAGS += -j $(THREADS)
#VL_FLAGS += --threads $(THREADS)
# Debugigng
ifdef DEBUG
VL_FLAGS += --trace --trace-structs $(DBG_FLAGS)
CXXFLAGS += -g -O0 $(DBG_FLAGS)
else
VL_FLAGS += -DNDEBUG
CXXFLAGS += -O2 -DNDEBUG
endif
# Enable perf counters
ifdef PERF
VL_FLAGS += -DPERF_ENABLE
CXXFLAGS += -DPERF_ENABLE
endif
PROJECT = tensor
all: $(DESTDIR)/$(PROJECT)
$(DESTDIR)/$(PROJECT): $(SRCS) $(RTL_SRCS)
verilator --build $(VL_FLAGS) $(SRCS) -CFLAGS '$(CXXFLAGS)' -LDFLAGS '$(LDFLAGS)' -o ../$@
run: $(DESTDIR)/$(PROJECT)
$(DESTDIR)/$(PROJECT)
waves: trace.vcd
gtkwave -o trace.vcd
clean:
rm -rf obj_dir $(DESTDIR)/$(PROJECT)

197
hw/unittest/tensor/main.cpp Normal file
View File

@@ -0,0 +1,197 @@
// Copyright © 2019-2023
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "vl_simulator.h"
#include "VVX_tensor_tb.h"
#include <iostream>
#include <half.hpp>
#define MAX_TICKS 20
#ifndef TRACE_START_TIME
#define TRACE_START_TIME 0ull
#endif
#ifndef TRACE_STOP_TIME
#define TRACE_STOP_TIME -1ull
#endif
#define CHECK(x) \
do { \
if (x) \
break; \
std::cout << "FAILED: " << #x << std::endl; \
std::abort(); \
} while (false)
static uint64_t timestamp = 0;
static bool trace_enabled = false;
static uint64_t trace_start_time = TRACE_START_TIME;
static uint64_t trace_stop_time = TRACE_STOP_TIME;
double sc_time_stamp() {
return timestamp;
}
bool sim_trace_enabled() {
if (timestamp >= trace_start_time
&& timestamp < trace_stop_time)
return true;
return trace_enabled;
}
void sim_trace_enable(bool enable) {
trace_enabled = enable;
}
using Device = VVX_tensor_tb;
using half_float::half;
static_assert(sizeof(half) == 2);
uint32_t half2bits(half h) {
uint16_t half_bits;
memcpy(&half_bits, &h, sizeof(half));
return half_bits;
}
uint32_t float2bits(float f) {
uint32_t float_bits;
memcpy(&float_bits, &f, sizeof(f));
return float_bits;
}
float bits2float(uint32_t b) {
float f;
memcpy(&f, &b, sizeof(b));
return f;
}
// A is M * K, B is K * K * M, C is M * M, D is M * M
#define M 4
#define K 2
// row, column
float A_tile[M][K];
float B_tile[K][M];
float C_tile[M][M];
float D_tile[M][M];
void initialize_test_data() {
for (int i = 0; i < M; i += 1) {
for (int j = 0; j < K; j += 1) {
A_tile[i][j] = (float) (i * K + j);
}
}
for (int i = 0; i < K; i += 1) {
for (int j = 0; j < M; j += 1) {
B_tile[i][j] = (float) (j * K + i);
}
}
for (int i = 0; i < M; i += 1) {
for (int j = 0; j < M; j += 1) {
C_tile[i][j] = (float) (i * j);
}
}
}
void write_test_data(vl_simulator<Device>& sim) {
for (int i = 0; i < M; i += 1) {
for (int j = 0; j < K; j += 1) {
int index = (i * K + j);
uint32_t A_bits = float2bits(A_tile[i][j]);
sim->A_tile[index] = A_bits;
}
}
for (int i = 0; i < K; i += 1) {
for (int j = 0; j < M; j += 1) {
int index = (i * M + j);
uint32_t B_bits = float2bits(B_tile[i][j]);
sim->B_tile[index] = B_bits;
}
}
for (int i = 0; i < M; i += 1) {
for (int j = 0; j < M; j += 1) {
int index = (i * M + j);
uint32_t C_bits = float2bits(C_tile[i][j]);
sim->C_tile[index] = C_bits;
}
}
}
void read_result(vl_simulator<Device>& sim) {
for (int i = 0; i < M; i += 1) {
for (int j = 0; j < M; j += 1) {
int index = (i * M + j);
uint32_t D_bits = sim->D_tile[index];
float f = bits2float(D_bits);
D_tile[i][j] = f;
std::cout << f << " ";
}
std::cout << std::endl;
}
}
void expected() {
for (int i = 0; i < M; i += 1) {
for (int j = 0; j < M; j += 1) {
float accum = C_tile[i][j];
for (int k = 0; k < K; k += 1) {
accum += A_tile[i][k] * B_tile[k][j];
}
std::cout << accum << " ";
}
std::cout << std::endl;
}
}
int main(int argc, char **argv) {
// Initialize Verilators variables
Verilated::commandArgs(argc, argv);
vl_simulator<Device> sim;
initialize_test_data();
// run test
timestamp = sim.reset(0);
// advance clock
timestamp = sim.step(timestamp, 10);
sim->valid_in = 1;
write_test_data(sim);
timestamp = sim.step(timestamp, 2);
CHECK(sim->valid_out == 0);
sim->valid_in = 0;
timestamp = sim.step(timestamp, 2);
CHECK(sim->valid_out == 0);
timestamp = sim.step(timestamp, 2);
CHECK(sim->valid_out == 0);
timestamp = sim.step(timestamp, 2);
CHECK(sim->valid_out == 1);
read_result(sim);
timestamp = sim.step(timestamp, 2);
CHECK(sim->valid_out == 0);
expected();
std::cout << "PASSED!" << std::endl;
std::cout << "Simulation time: " << std::dec << timestamp/2 << " cycles" << std::endl;
return 0;
}