tensor: Fix dimensions and makefile

This commit is contained in:
Hansung Kim
2024-08-19 17:37:26 -07:00
parent a98da9e3ca
commit 3f4abc542c
2 changed files with 41 additions and 40 deletions

View File

@@ -14,7 +14,9 @@ $(PROJECT).elf: $(SRCS) $(DEPS)
$(CC) $(CFLAGS) $(SRCS) $(LDFLAGS) -o $(PROJECT).elf
$(OBJCOPY) --set-section-flags .operand.a=$(OBJCOPY_FLAGS) $@
$(OBJCOPY) --set-section-flags .operand.b=$(OBJCOPY_FLAGS) $@
$(OBJCOPY) --set-section-flags .operand.c=$(OBJCOPY_FLAGS) $@
$(OBJCOPY) --set-section-flags .args=$(OBJCOPY_FLAGS) $@
$(OBJCOPY) --update-section .operand.a=input.a.bin $@ || true
$(OBJCOPY) --update-section .operand.b=input.b.bin $@ || true
$(OBJCOPY) --update-section .operand.c=input.c.bin $@ || true
$(OBJCOPY) --update-section .args=args.bin $@ || true

View File

@@ -12,7 +12,6 @@ constexpr int DIM_K = 8;
// #include "test_data.h"
const float *A = reinterpret_cast<const float *>(0xa0000000UL);
const float *B = reinterpret_cast<const float *>(0xa1000000UL);
// FIXME: C region is uninitialized
const float *C = reinterpret_cast<const float *>(0xa2000000UL);
// single "substep" wmma instruction
@@ -104,8 +103,8 @@ void vx_wmma_load() {
map_operand_8lanes(tid, row, col);
// load A
// A is stored row-major in the memory,
// loaded row-major into the RF.
// A is stored K-major in the memory,
// loaded K-major into the RF.
//
// For 32 lanes config, each operand element is read twice by two
// threadgroups (Sec. III-B); i.e. 8 regs * 32 lanes = 256 fp32 elements = 2
@@ -120,26 +119,26 @@ void vx_wmma_load() {
asm volatile("flw f7, %0" ::"m"(A[DIM_K * row + 7]));
// load B
// B is stored row-major in the memory,
// loaded column-major into the RF.
// asm volatile("flw f8 , %0" ::"m"(B[DIM_N * 0 + col]));
// asm volatile("flw f9 , %0" ::"m"(B[DIM_N * 1 + col]));
// asm volatile("flw f10, %0" ::"m"(B[DIM_N * 2 + col]));
// asm volatile("flw f11, %0" ::"m"(B[DIM_N * 3 + col]));
// asm volatile("flw f12, %0" ::"m"(B[DIM_N * 4 + col]));
// asm volatile("flw f13, %0" ::"m"(B[DIM_N * 5 + col]));
// asm volatile("flw f14, %0" ::"m"(B[DIM_N * 6 + col]));
// asm volatile("flw f15, %0" ::"m"(B[DIM_N * 7 + col]));
// B is stored column-major in the memory,
// loaded column-major into the RF.
asm volatile("flw f8 , %0" ::"m"(B[DIM_K * row + 0]));
asm volatile("flw f9 , %0" ::"m"(B[DIM_K * row + 1]));
asm volatile("flw f10, %0" ::"m"(B[DIM_K * row + 2]));
asm volatile("flw f11, %0" ::"m"(B[DIM_K * row + 3]));
asm volatile("flw f12, %0" ::"m"(B[DIM_K * row + 4]));
asm volatile("flw f13, %0" ::"m"(B[DIM_K * row + 5]));
asm volatile("flw f14, %0" ::"m"(B[DIM_K * row + 6]));
asm volatile("flw f15, %0" ::"m"(B[DIM_K * row + 7]));
// B is stored N-major in the memory,
// loaded K-major into the RF.
asm volatile("flw f8 , %0" ::"m"(B[DIM_N * 0 + col]));
asm volatile("flw f9 , %0" ::"m"(B[DIM_N * 1 + col]));
asm volatile("flw f10, %0" ::"m"(B[DIM_N * 2 + col]));
asm volatile("flw f11, %0" ::"m"(B[DIM_N * 3 + col]));
asm volatile("flw f12, %0" ::"m"(B[DIM_N * 4 + col]));
asm volatile("flw f13, %0" ::"m"(B[DIM_N * 5 + col]));
asm volatile("flw f14, %0" ::"m"(B[DIM_N * 6 + col]));
asm volatile("flw f15, %0" ::"m"(B[DIM_N * 7 + col]));
// B is stored K-major in the memory,
// loaded K-major into the RF.
// asm volatile("flw f8 , %0" ::"m"(B[DIM_K * row + 0]));
// asm volatile("flw f9 , %0" ::"m"(B[DIM_K * row + 1]));
// asm volatile("flw f10, %0" ::"m"(B[DIM_K * row + 2]));
// asm volatile("flw f11, %0" ::"m"(B[DIM_K * row + 3]));
// asm volatile("flw f12, %0" ::"m"(B[DIM_K * row + 4]));
// asm volatile("flw f13, %0" ::"m"(B[DIM_K * row + 5]));
// asm volatile("flw f14, %0" ::"m"(B[DIM_K * row + 6]));
// asm volatile("flw f15, %0" ::"m"(B[DIM_K * row + 7]));
map_c_8lanes(tid, row, col);
@@ -178,24 +177,24 @@ void store_wmma_result() {
map_c_8lanes(tid, row, col);
// store C
float *const results_wid = results + (DIM_M * DIM_M * wid);
float *const results_wid = results + (DIM_M * DIM_N * wid);
// uncomment to have two accum buffers in rf
// asm volatile("fsw f16, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 0)]));
// asm volatile("fsw f17, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 1)]));
// asm volatile("fsw f18, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 0)]));
// asm volatile("fsw f19, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 1)]));
// asm volatile("fsw f20, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 4)]));
// asm volatile("fsw f21, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 5)]));
// asm volatile("fsw f22, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 4)]));
// asm volatile("fsw f23, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 5)]));
asm volatile("fsw f24, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 0)]));
asm volatile("fsw f25, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 1)]));
asm volatile("fsw f26, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 0)]));
asm volatile("fsw f27, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 1)]));
asm volatile("fsw f28, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 4)]));
asm volatile("fsw f29, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 5)]));
asm volatile("fsw f30, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 4)]));
asm volatile("fsw f31, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 5)]));
// asm volatile("fsw f16, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 0)]));
// asm volatile("fsw f17, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 1)]));
// asm volatile("fsw f18, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 0)]));
// asm volatile("fsw f19, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 1)]));
// asm volatile("fsw f20, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 4)]));
// asm volatile("fsw f21, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 5)]));
// asm volatile("fsw f22, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 4)]));
// asm volatile("fsw f23, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 5)]));
asm volatile("fsw f24, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 0)]));
asm volatile("fsw f25, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 1)]));
asm volatile("fsw f26, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 0)]));
asm volatile("fsw f27, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 1)]));
asm volatile("fsw f28, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 4)]));
asm volatile("fsw f29, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 5)]));
asm volatile("fsw f30, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 4)]));
asm volatile("fsw f31, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 5)]));
}
void print_wmma_result() {