From 3f4abc542cee2c92d3ea1705382c1657498ee002 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Mon, 19 Aug 2024 17:37:26 -0700 Subject: [PATCH] tensor: Fix dimensions and makefile --- tests/kernel/tensor/Makefile | 2 + tests/kernel/tensor/main.cpp | 79 ++++++++++++++++++------------------ 2 files changed, 41 insertions(+), 40 deletions(-) diff --git a/tests/kernel/tensor/Makefile b/tests/kernel/tensor/Makefile index 840e5ca8..10548774 100644 --- a/tests/kernel/tensor/Makefile +++ b/tests/kernel/tensor/Makefile @@ -14,7 +14,9 @@ $(PROJECT).elf: $(SRCS) $(DEPS) $(CC) $(CFLAGS) $(SRCS) $(LDFLAGS) -o $(PROJECT).elf $(OBJCOPY) --set-section-flags .operand.a=$(OBJCOPY_FLAGS) $@ $(OBJCOPY) --set-section-flags .operand.b=$(OBJCOPY_FLAGS) $@ + $(OBJCOPY) --set-section-flags .operand.c=$(OBJCOPY_FLAGS) $@ $(OBJCOPY) --set-section-flags .args=$(OBJCOPY_FLAGS) $@ $(OBJCOPY) --update-section .operand.a=input.a.bin $@ || true $(OBJCOPY) --update-section .operand.b=input.b.bin $@ || true + $(OBJCOPY) --update-section .operand.c=input.c.bin $@ || true $(OBJCOPY) --update-section .args=args.bin $@ || true diff --git a/tests/kernel/tensor/main.cpp b/tests/kernel/tensor/main.cpp index ca06bb47..c373507a 100644 --- a/tests/kernel/tensor/main.cpp +++ b/tests/kernel/tensor/main.cpp @@ -12,7 +12,6 @@ constexpr int DIM_K = 8; // #include "test_data.h" const float *A = reinterpret_cast(0xa0000000UL); const float *B = reinterpret_cast(0xa1000000UL); -// FIXME: C region is uninitialized const float *C = reinterpret_cast(0xa2000000UL); // single "substep" wmma instruction @@ -104,8 +103,8 @@ void vx_wmma_load() { map_operand_8lanes(tid, row, col); // load A - // A is stored row-major in the memory, - // loaded row-major into the RF. + // A is stored K-major in the memory, + // loaded K-major into the RF. // // For 32 lanes config, each operand element is read twice by two // threadgroups (Sec. III-B); i.e. 8 regs * 32 lanes = 256 fp32 elements = 2 @@ -120,26 +119,26 @@ void vx_wmma_load() { asm volatile("flw f7, %0" ::"m"(A[DIM_K * row + 7])); // load B - // B is stored row-major in the memory, - // loaded column-major into the RF. - // asm volatile("flw f8 , %0" ::"m"(B[DIM_N * 0 + col])); - // asm volatile("flw f9 , %0" ::"m"(B[DIM_N * 1 + col])); - // asm volatile("flw f10, %0" ::"m"(B[DIM_N * 2 + col])); - // asm volatile("flw f11, %0" ::"m"(B[DIM_N * 3 + col])); - // asm volatile("flw f12, %0" ::"m"(B[DIM_N * 4 + col])); - // asm volatile("flw f13, %0" ::"m"(B[DIM_N * 5 + col])); - // asm volatile("flw f14, %0" ::"m"(B[DIM_N * 6 + col])); - // asm volatile("flw f15, %0" ::"m"(B[DIM_N * 7 + col])); - // B is stored column-major in the memory, - // loaded column-major into the RF. - asm volatile("flw f8 , %0" ::"m"(B[DIM_K * row + 0])); - asm volatile("flw f9 , %0" ::"m"(B[DIM_K * row + 1])); - asm volatile("flw f10, %0" ::"m"(B[DIM_K * row + 2])); - asm volatile("flw f11, %0" ::"m"(B[DIM_K * row + 3])); - asm volatile("flw f12, %0" ::"m"(B[DIM_K * row + 4])); - asm volatile("flw f13, %0" ::"m"(B[DIM_K * row + 5])); - asm volatile("flw f14, %0" ::"m"(B[DIM_K * row + 6])); - asm volatile("flw f15, %0" ::"m"(B[DIM_K * row + 7])); + // B is stored N-major in the memory, + // loaded K-major into the RF. + asm volatile("flw f8 , %0" ::"m"(B[DIM_N * 0 + col])); + asm volatile("flw f9 , %0" ::"m"(B[DIM_N * 1 + col])); + asm volatile("flw f10, %0" ::"m"(B[DIM_N * 2 + col])); + asm volatile("flw f11, %0" ::"m"(B[DIM_N * 3 + col])); + asm volatile("flw f12, %0" ::"m"(B[DIM_N * 4 + col])); + asm volatile("flw f13, %0" ::"m"(B[DIM_N * 5 + col])); + asm volatile("flw f14, %0" ::"m"(B[DIM_N * 6 + col])); + asm volatile("flw f15, %0" ::"m"(B[DIM_N * 7 + col])); + // B is stored K-major in the memory, + // loaded K-major into the RF. + // asm volatile("flw f8 , %0" ::"m"(B[DIM_K * row + 0])); + // asm volatile("flw f9 , %0" ::"m"(B[DIM_K * row + 1])); + // asm volatile("flw f10, %0" ::"m"(B[DIM_K * row + 2])); + // asm volatile("flw f11, %0" ::"m"(B[DIM_K * row + 3])); + // asm volatile("flw f12, %0" ::"m"(B[DIM_K * row + 4])); + // asm volatile("flw f13, %0" ::"m"(B[DIM_K * row + 5])); + // asm volatile("flw f14, %0" ::"m"(B[DIM_K * row + 6])); + // asm volatile("flw f15, %0" ::"m"(B[DIM_K * row + 7])); map_c_8lanes(tid, row, col); @@ -178,24 +177,24 @@ void store_wmma_result() { map_c_8lanes(tid, row, col); // store C - float *const results_wid = results + (DIM_M * DIM_M * wid); + float *const results_wid = results + (DIM_M * DIM_N * wid); // uncomment to have two accum buffers in rf - // asm volatile("fsw f16, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 0)])); - // asm volatile("fsw f17, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 1)])); - // asm volatile("fsw f18, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 0)])); - // asm volatile("fsw f19, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 1)])); - // asm volatile("fsw f20, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 4)])); - // asm volatile("fsw f21, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 5)])); - // asm volatile("fsw f22, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 4)])); - // asm volatile("fsw f23, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 5)])); - asm volatile("fsw f24, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 0)])); - asm volatile("fsw f25, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 1)])); - asm volatile("fsw f26, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 0)])); - asm volatile("fsw f27, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 1)])); - asm volatile("fsw f28, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 4)])); - asm volatile("fsw f29, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 5)])); - asm volatile("fsw f30, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 4)])); - asm volatile("fsw f31, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 5)])); + // asm volatile("fsw f16, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 0)])); + // asm volatile("fsw f17, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 1)])); + // asm volatile("fsw f18, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 0)])); + // asm volatile("fsw f19, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 1)])); + // asm volatile("fsw f20, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 4)])); + // asm volatile("fsw f21, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 5)])); + // asm volatile("fsw f22, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 4)])); + // asm volatile("fsw f23, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 5)])); + asm volatile("fsw f24, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 0)])); + asm volatile("fsw f25, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 1)])); + asm volatile("fsw f26, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 0)])); + asm volatile("fsw f27, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 1)])); + asm volatile("fsw f28, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 4)])); + asm volatile("fsw f29, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 5)])); + asm volatile("fsw f30, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 4)])); + asm volatile("fsw f31, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 5)])); } void print_wmma_result() {