tensor: Embed binary instead of hardcoding literals

the C compiler doesn't support fp16
This commit is contained in:
Hansung Kim
2024-07-31 16:07:55 -07:00
parent 1b5daccac9
commit c1906ebb4f
2 changed files with 61 additions and 37 deletions

View File

@@ -6,3 +6,15 @@ DEPS += b_matrix.h
DEPS += c_matrix.h
include ../common.mk
OBJCOPY ?= $(RISCV_TOOLCHAIN_PATH)/bin/$(RISCV_PREFIX)-objcopy
OBJCOPY_FLAGS ?= "LOAD,ALLOC,DATA,CONTENTS"
BINFILES := args.bin input.a.bin input.b.bin
$(PROJECT).elf: $(SRCS) $(DEPS)
$(CC) $(CFLAGS) $(SRCS) $(LDFLAGS) -o $(PROJECT).elf
$(OBJCOPY) --set-section-flags .operand.a=$(OBJCOPY_FLAGS) $@
$(OBJCOPY) --set-section-flags .operand.b=$(OBJCOPY_FLAGS) $@
$(OBJCOPY) --set-section-flags .args=$(OBJCOPY_FLAGS) $@
$(OBJCOPY) --update-section .operand.a=input.a.bin $@ || true
$(OBJCOPY) --update-section .operand.b=input.b.bin $@ || true
$(OBJCOPY) --update-section .args=args.bin $@ || true

View File

@@ -4,9 +4,16 @@
#include <vx_intrinsics.h>
#include <stdio.h>
#include <vx_print.h>
#include "test_data.h"
constexpr int DIM_M = 8;
constexpr int DIM_N = 8;
constexpr int DIM_K = 8;
// #include "test_data.h"
const float *A = reinterpret_cast<const float *>(0xa0000000UL);
const float *B = reinterpret_cast<const float *>(0xa1000000UL);
// FIXME: C region is uninitialized
const float *C = reinterpret_cast<const float *>(0xa2000000UL);
// single "substep" wmma instruction
// use accum buffer 0 (f16-f23)
@@ -97,48 +104,54 @@ void vx_wmma_load() {
map_operand_8lanes(tid, row, col);
// load A
// each operand element is read twice by two threadgroups (Sec. III-B);
// i.e. 8 regs * 32 lanes = 256 fp32 elements = 2 * (16 * 8) elements
asm volatile("flw f0, %0" ::"m"(A[row][0]));
asm volatile("flw f1, %0" ::"m"(A[row][1]));
asm volatile("flw f2, %0" ::"m"(A[row][2]));
asm volatile("flw f3, %0" ::"m"(A[row][3]));
asm volatile("flw f4, %0" ::"m"(A[row][4]));
asm volatile("flw f5, %0" ::"m"(A[row][5]));
asm volatile("flw f6, %0" ::"m"(A[row][6]));
asm volatile("flw f7, %0" ::"m"(A[row][7]));
// A is stored row-major in the memory,
// loaded row-major into the RF.
//
// For 32 lanes config, each operand element is read twice by two
// threadgroups (Sec. III-B); i.e. 8 regs * 32 lanes = 256 fp32 elements = 2
// * (16 * 8) elements
asm volatile("flw f0, %0" ::"m"(A[DIM_K * row + 0]));
asm volatile("flw f1, %0" ::"m"(A[DIM_K * row + 1]));
asm volatile("flw f2, %0" ::"m"(A[DIM_K * row + 2]));
asm volatile("flw f3, %0" ::"m"(A[DIM_K * row + 3]));
asm volatile("flw f4, %0" ::"m"(A[DIM_K * row + 4]));
asm volatile("flw f5, %0" ::"m"(A[DIM_K * row + 5]));
asm volatile("flw f6, %0" ::"m"(A[DIM_K * row + 6]));
asm volatile("flw f7, %0" ::"m"(A[DIM_K * row + 7]));
// load B
asm volatile("flw f8 , %0" ::"m"(B[0][col]));
asm volatile("flw f9 , %0" ::"m"(B[1][col]));
asm volatile("flw f10, %0" ::"m"(B[2][col]));
asm volatile("flw f11, %0" ::"m"(B[3][col]));
asm volatile("flw f12, %0" ::"m"(B[4][col]));
asm volatile("flw f13, %0" ::"m"(B[5][col]));
asm volatile("flw f14, %0" ::"m"(B[6][col]));
asm volatile("flw f15, %0" ::"m"(B[7][col]));
// B is stored row-major in the memory,
// loaded column-major into the RF.
asm volatile("flw f8 , %0" ::"m"(B[DIM_N * 0 + col]));
asm volatile("flw f9 , %0" ::"m"(B[DIM_N * 1 + col]));
asm volatile("flw f10, %0" ::"m"(B[DIM_N * 2 + col]));
asm volatile("flw f11, %0" ::"m"(B[DIM_N * 3 + col]));
asm volatile("flw f12, %0" ::"m"(B[DIM_N * 4 + col]));
asm volatile("flw f13, %0" ::"m"(B[DIM_N * 5 + col]));
asm volatile("flw f14, %0" ::"m"(B[DIM_N * 6 + col]));
asm volatile("flw f15, %0" ::"m"(B[DIM_N * 7 + col]));
map_c_8lanes(tid, row, col);
// load C
// accum buffer 0
asm volatile("flw f16, %0" ::"m"(C[row + 0][col + 0]));
asm volatile("flw f17, %0" ::"m"(C[row + 0][col + 1]));
asm volatile("flw f18, %0" ::"m"(C[row + 2][col + 0]));
asm volatile("flw f19, %0" ::"m"(C[row + 2][col + 1]));
asm volatile("flw f20, %0" ::"m"(C[row + 0][col + 4]));
asm volatile("flw f21, %0" ::"m"(C[row + 0][col + 5]));
asm volatile("flw f22, %0" ::"m"(C[row + 2][col + 4]));
asm volatile("flw f23, %0" ::"m"(C[row + 2][col + 5]));
asm volatile("flw f16, %0" ::"m"(C[DIM_N * (row + 0) + col + 0]));
asm volatile("flw f17, %0" ::"m"(C[DIM_N * (row + 0) + col + 1]));
asm volatile("flw f18, %0" ::"m"(C[DIM_N * (row + 2) + col + 0]));
asm volatile("flw f19, %0" ::"m"(C[DIM_N * (row + 2) + col + 1]));
asm volatile("flw f20, %0" ::"m"(C[DIM_N * (row + 0) + col + 4]));
asm volatile("flw f21, %0" ::"m"(C[DIM_N * (row + 0) + col + 5]));
asm volatile("flw f22, %0" ::"m"(C[DIM_N * (row + 2) + col + 4]));
asm volatile("flw f23, %0" ::"m"(C[DIM_N * (row + 2) + col + 5]));
// accum buffer 1
asm volatile("flw f24, %0" ::"m"(C[row + 0][col + 0]));
asm volatile("flw f25, %0" ::"m"(C[row + 0][col + 1]));
asm volatile("flw f26, %0" ::"m"(C[row + 2][col + 0]));
asm volatile("flw f27, %0" ::"m"(C[row + 2][col + 1]));
asm volatile("flw f28, %0" ::"m"(C[row + 0][col + 4]));
asm volatile("flw f29, %0" ::"m"(C[row + 0][col + 5]));
asm volatile("flw f30, %0" ::"m"(C[row + 2][col + 4]));
asm volatile("flw f31, %0" ::"m"(C[row + 2][col + 5]));
asm volatile("flw f24, %0" ::"m"(C[DIM_N * (row + 0) + col + 0]));
asm volatile("flw f25, %0" ::"m"(C[DIM_N * (row + 0) + col + 1]));
asm volatile("flw f26, %0" ::"m"(C[DIM_N * (row + 2) + col + 0]));
asm volatile("flw f27, %0" ::"m"(C[DIM_N * (row + 2) + col + 1]));
asm volatile("flw f28, %0" ::"m"(C[DIM_N * (row + 0) + col + 4]));
asm volatile("flw f29, %0" ::"m"(C[DIM_N * (row + 0) + col + 5]));
asm volatile("flw f30, %0" ::"m"(C[DIM_N * (row + 2) + col + 4]));
asm volatile("flw f31, %0" ::"m"(C[DIM_N * (row + 2) + col + 5]));
}
// hardcoded device address for result
@@ -211,9 +224,8 @@ int main() {
const int num_warps = vx_num_warps();
// vx_wspawn(num_warps, wmma);
vx_wspawn(1, wmma);
wmma();
vx_wspawn_wait();
// vx_wspawn_wait();
return 0;
}