seemingly working fp32 implementation
This commit is contained in:
89
hw/unittest/tensor/Makefile
Normal file
89
hw/unittest/tensor/Makefile
Normal file
@@ -0,0 +1,89 @@
|
||||
DESTDIR ?= .
|
||||
RTL_DIR = ../../rtl
|
||||
DPI_DIR = $(abspath ../../dpi)
|
||||
SIM_DIR = ../../../sim
|
||||
THIRD_PARTY_DIR = $(abspath ../../../third_party)
|
||||
|
||||
CONFIGS +=
|
||||
PARAMS +=
|
||||
|
||||
CXXFLAGS += -std=c++17 -Wall -Wextra -Wfatal-errors -Wno-array-bounds
|
||||
CXXFLAGS += -fPIC -Wno-maybe-uninitialized
|
||||
CXXFLAGS += -fcoroutines
|
||||
CXXFLAGS += -I../../.. -I../../common -I../../../../sim/common
|
||||
CXXFLAGS += -I/$(THIRD_PARTY_DIR)/softfloat/source/include
|
||||
CXXFLAGS += -I/$(DPI_DIR)
|
||||
CXXFLAGS += $(CONFIGS)
|
||||
|
||||
LDFLAGS += $(THIRD_PARTY_DIR)/softfloat/build/Linux-x86_64-GCC/softfloat.a
|
||||
|
||||
# control RTL debug tracing states
|
||||
DBG_TRACE_FLAGS += -DDBG_TRACE_CACHE_BANK
|
||||
DBG_TRACE_FLAGS += -DDBG_TRACE_CACHE_MSHR
|
||||
DBG_TRACE_FLAGS += -DDBG_TRACE_CACHE_TAG
|
||||
DBG_TRACE_FLAGS += -DDBG_TRACE_CACHE_DATA
|
||||
|
||||
DBG_FLAGS += -DDEBUG_LEVEL=$(DEBUG) -DVCD_OUTPUT $(DBG_TRACE_FLAGS)
|
||||
|
||||
RTL_PKGS = $(RTL_DIR)/VX_gpu_pkg.sv
|
||||
|
||||
RTL_INCLUDE = -I$(RTL_DIR) -I$(DPI_DIR) -I$(RTL_DIR)/libs -I$(RTL_DIR)/fpu
|
||||
|
||||
# SRCS = cachesim.cpp testbench.cpp
|
||||
SRCS += $(DPI_DIR)/util_dpi.cpp
|
||||
SRCS += $(DPI_DIR)/float_dpi.cpp
|
||||
SRCS += $(SIM_DIR)/common/rvfloats.cpp
|
||||
SRCS += ./main.cpp
|
||||
|
||||
RTL_SRCS += $(RTL_DIR)/fpu/VX_tensor_core.sv
|
||||
RTL_SRCS += $(RTL_DIR)/fpu/VX_tensor_tb.sv
|
||||
|
||||
TOP = VX_tensor_tb
|
||||
|
||||
VL_FLAGS = --exe
|
||||
VL_FLAGS += --language 1800-2009 # -Wall -Wpedantic # --assert
|
||||
VL_FLAGS += -Wno-DECLFILENAME -Wno-REDEFMACRO
|
||||
VL_FLAGS += --x-initial unique --x-assign unique
|
||||
VL_FLAGS += -DSIMULATION -DSV_DPI
|
||||
VL_FLAGS += $(CONFIGS)
|
||||
VL_FLAGS += $(PARAMS)
|
||||
VL_FLAGS += $(RTL_INCLUDE)
|
||||
VL_FLAGS += $(RTL_PKGS)
|
||||
VL_FLAGS += --cc $(TOP) --top-module $(TOP)
|
||||
VL_FLAGS += --timing
|
||||
|
||||
# Enable Verilator multithreaded simulation
|
||||
THREADS ?= $(shell python -c 'import multiprocessing as mp; print(mp.cpu_count())')
|
||||
VL_FLAGS += -j $(THREADS)
|
||||
#VL_FLAGS += --threads $(THREADS)
|
||||
|
||||
# Debugigng
|
||||
ifdef DEBUG
|
||||
VL_FLAGS += --trace --trace-structs $(DBG_FLAGS)
|
||||
CXXFLAGS += -g -O0 $(DBG_FLAGS)
|
||||
else
|
||||
VL_FLAGS += -DNDEBUG
|
||||
CXXFLAGS += -O2 -DNDEBUG
|
||||
endif
|
||||
|
||||
# Enable perf counters
|
||||
ifdef PERF
|
||||
VL_FLAGS += -DPERF_ENABLE
|
||||
CXXFLAGS += -DPERF_ENABLE
|
||||
endif
|
||||
|
||||
PROJECT = tensor
|
||||
|
||||
all: $(DESTDIR)/$(PROJECT)
|
||||
|
||||
$(DESTDIR)/$(PROJECT): $(SRCS) $(RTL_SRCS)
|
||||
verilator --build $(VL_FLAGS) $(SRCS) -CFLAGS '$(CXXFLAGS)' -LDFLAGS '$(LDFLAGS)' -o ../$@
|
||||
|
||||
run: $(DESTDIR)/$(PROJECT)
|
||||
$(DESTDIR)/$(PROJECT)
|
||||
|
||||
waves: trace.vcd
|
||||
gtkwave -o trace.vcd
|
||||
|
||||
clean:
|
||||
rm -rf obj_dir $(DESTDIR)/$(PROJECT)
|
||||
197
hw/unittest/tensor/main.cpp
Normal file
197
hw/unittest/tensor/main.cpp
Normal file
@@ -0,0 +1,197 @@
|
||||
// Copyright © 2019-2023
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "vl_simulator.h"
|
||||
#include "VVX_tensor_tb.h"
|
||||
#include <iostream>
|
||||
|
||||
#include <half.hpp>
|
||||
|
||||
#define MAX_TICKS 20
|
||||
|
||||
#ifndef TRACE_START_TIME
|
||||
#define TRACE_START_TIME 0ull
|
||||
#endif
|
||||
|
||||
#ifndef TRACE_STOP_TIME
|
||||
#define TRACE_STOP_TIME -1ull
|
||||
#endif
|
||||
|
||||
#define CHECK(x) \
|
||||
do { \
|
||||
if (x) \
|
||||
break; \
|
||||
std::cout << "FAILED: " << #x << std::endl; \
|
||||
std::abort(); \
|
||||
} while (false)
|
||||
|
||||
static uint64_t timestamp = 0;
|
||||
static bool trace_enabled = false;
|
||||
static uint64_t trace_start_time = TRACE_START_TIME;
|
||||
static uint64_t trace_stop_time = TRACE_STOP_TIME;
|
||||
|
||||
double sc_time_stamp() {
|
||||
return timestamp;
|
||||
}
|
||||
|
||||
bool sim_trace_enabled() {
|
||||
if (timestamp >= trace_start_time
|
||||
&& timestamp < trace_stop_time)
|
||||
return true;
|
||||
return trace_enabled;
|
||||
}
|
||||
|
||||
void sim_trace_enable(bool enable) {
|
||||
trace_enabled = enable;
|
||||
}
|
||||
|
||||
using Device = VVX_tensor_tb;
|
||||
using half_float::half;
|
||||
|
||||
static_assert(sizeof(half) == 2);
|
||||
uint32_t half2bits(half h) {
|
||||
uint16_t half_bits;
|
||||
memcpy(&half_bits, &h, sizeof(half));
|
||||
return half_bits;
|
||||
}
|
||||
|
||||
uint32_t float2bits(float f) {
|
||||
uint32_t float_bits;
|
||||
memcpy(&float_bits, &f, sizeof(f));
|
||||
return float_bits;
|
||||
}
|
||||
|
||||
float bits2float(uint32_t b) {
|
||||
float f;
|
||||
memcpy(&f, &b, sizeof(b));
|
||||
return f;
|
||||
}
|
||||
|
||||
// A is M * K, B is K * K * M, C is M * M, D is M * M
|
||||
#define M 4
|
||||
#define K 2
|
||||
|
||||
// row, column
|
||||
float A_tile[M][K];
|
||||
float B_tile[K][M];
|
||||
float C_tile[M][M];
|
||||
float D_tile[M][M];
|
||||
|
||||
void initialize_test_data() {
|
||||
for (int i = 0; i < M; i += 1) {
|
||||
for (int j = 0; j < K; j += 1) {
|
||||
A_tile[i][j] = (float) (i * K + j);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < K; i += 1) {
|
||||
for (int j = 0; j < M; j += 1) {
|
||||
B_tile[i][j] = (float) (j * K + i);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < M; i += 1) {
|
||||
for (int j = 0; j < M; j += 1) {
|
||||
C_tile[i][j] = (float) (i * j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void write_test_data(vl_simulator<Device>& sim) {
|
||||
for (int i = 0; i < M; i += 1) {
|
||||
for (int j = 0; j < K; j += 1) {
|
||||
int index = (i * K + j);
|
||||
uint32_t A_bits = float2bits(A_tile[i][j]);
|
||||
sim->A_tile[index] = A_bits;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < K; i += 1) {
|
||||
for (int j = 0; j < M; j += 1) {
|
||||
int index = (i * M + j);
|
||||
uint32_t B_bits = float2bits(B_tile[i][j]);
|
||||
sim->B_tile[index] = B_bits;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < M; i += 1) {
|
||||
for (int j = 0; j < M; j += 1) {
|
||||
int index = (i * M + j);
|
||||
uint32_t C_bits = float2bits(C_tile[i][j]);
|
||||
sim->C_tile[index] = C_bits;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void read_result(vl_simulator<Device>& sim) {
|
||||
for (int i = 0; i < M; i += 1) {
|
||||
for (int j = 0; j < M; j += 1) {
|
||||
int index = (i * M + j);
|
||||
|
||||
uint32_t D_bits = sim->D_tile[index];
|
||||
float f = bits2float(D_bits);
|
||||
D_tile[i][j] = f;
|
||||
std::cout << f << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void expected() {
|
||||
for (int i = 0; i < M; i += 1) {
|
||||
for (int j = 0; j < M; j += 1) {
|
||||
float accum = C_tile[i][j];
|
||||
for (int k = 0; k < K; k += 1) {
|
||||
accum += A_tile[i][k] * B_tile[k][j];
|
||||
}
|
||||
|
||||
std::cout << accum << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
// Initialize Verilators variables
|
||||
Verilated::commandArgs(argc, argv);
|
||||
|
||||
vl_simulator<Device> sim;
|
||||
|
||||
initialize_test_data();
|
||||
// run test
|
||||
timestamp = sim.reset(0);
|
||||
|
||||
|
||||
// advance clock
|
||||
timestamp = sim.step(timestamp, 10);
|
||||
sim->valid_in = 1;
|
||||
write_test_data(sim);
|
||||
timestamp = sim.step(timestamp, 2);
|
||||
CHECK(sim->valid_out == 0);
|
||||
sim->valid_in = 0;
|
||||
timestamp = sim.step(timestamp, 2);
|
||||
CHECK(sim->valid_out == 0);
|
||||
timestamp = sim.step(timestamp, 2);
|
||||
CHECK(sim->valid_out == 0);
|
||||
timestamp = sim.step(timestamp, 2);
|
||||
CHECK(sim->valid_out == 1);
|
||||
read_result(sim);
|
||||
timestamp = sim.step(timestamp, 2);
|
||||
CHECK(sim->valid_out == 0);
|
||||
|
||||
expected();
|
||||
|
||||
std::cout << "PASSED!" << std::endl;
|
||||
std::cout << "Simulation time: " << std::dec << timestamp/2 << " cycles" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
Reference in New Issue
Block a user