diff --git a/kernel/linker/vx_link32.ld b/kernel/linker/vx_link32.ld index e84c0342..eae88853 100644 --- a/kernel/linker/vx_link32.ld +++ b/kernel/linker/vx_link32.ld @@ -10,6 +10,7 @@ ENTRY(_start) MEMORY { DRAM0 (rwx): ORIGIN = 0x80000000, LENGTH = 512M + DRAMARG (rwx): ORIGIN = 0x9fff0000, LENGTH = 8K DRAM1 (rwx): ORIGIN = 0xa0000000, LENGTH = 16M DRAM2 (rwx): ORIGIN = 0xa1000000, LENGTH = 16M } @@ -275,6 +276,10 @@ SECTIONS .gnu.attributes 0 : { KEEP (*(.gnu.attributes)) } /DISCARD/ : { *(.note.GNU-stack) *(.gnu_debuglink) *(.gnu.lto_*) } + .args : { + *(.args) + . += 8K; + }> DRAMARG .operand.a : { *(.operand.a) . += 32K; diff --git a/kernel/src/vx_spawn.c b/kernel/src/vx_spawn.c index ffbbaccb..759b915c 100644 --- a/kernel/src/vx_spawn.c +++ b/kernel/src/vx_spawn.c @@ -170,7 +170,7 @@ static void __attribute__ ((noinline)) spawn_tasks_cluster_all_cb() { vx_tmc_zero(); } -static void __attribute__ ((noinline)) spawn_tasks_all_cb() { +static void __attribute__ ((noinline)) spawn_tasks_all_cb() { // activate all threads vx_tmc(-1); @@ -254,6 +254,7 @@ void vx_spawn_tasks_cluster(int num_tasks, vx_spawn_tasks_cb callback, void *arg vx_wspawn_wait(); } + // TODO: this is incomplete // TODO: Instead of launching an additional wave just to work on remaining // threads, handle this in the last wave amongst other full warps. if (rem_threads_in_last_warp != 0 && core_id_in_cluster == 0) { diff --git a/tests/kernel/Makefile b/tests/kernel/Makefile index ab4fdd07..f7c46754 100644 --- a/tests/kernel/Makefile +++ b/tests/kernel/Makefile @@ -1,19 +1,23 @@ all: $(MAKE) -C conform $(MAKE) -C hello - $(MAKE) -C fibonacci + $(MAKE) -C fibonacci + $(MAKE) -C reductions run-simx: $(MAKE) -C conform run-simx $(MAKE) -C hello run-simx $(MAKE) -C fibonacci run-simx + $(MAKE) -C reductions run-simx run-rtlsim: $(MAKE) -C conform run-rtlsim $(MAKE) -C hello run-rtlsim $(MAKE) -C fibonacci run-rtlsim + $(MAKE) -C reductions run-rtlsim clean: $(MAKE) -C conform clean $(MAKE) -C hello clean $(MAKE) -C fibonacci clean + $(MAKE) -C reductions clean diff --git a/tests/kernel/common.mk b/tests/kernel/common.mk index 7bf4b520..b58b3110 100644 --- a/tests/kernel/common.mk +++ b/tests/kernel/common.mk @@ -21,6 +21,9 @@ CP = $(RISCV_TOOLCHAIN_PATH)/bin/$(RISCV_PREFIX)-objcopy SIM_DIR = ../../../sim CFLAGS += -O3 -mcmodel=medany -fno-exceptions -nostartfiles -fdata-sections -ffunction-sections +CFLAGS += -ffixed-ft0 -ffixed-ft1 -ffixed-ft2 -ffixed-ft3 -ffixed-ft4 -ffixed-ft5 -ffixed-ft6 -ffixed-ft7 +CFLAGS += -ffixed-fs0 -ffixed-fs1 -ffixed-fs2 -ffixed-fs3 -ffixed-fs4 -ffixed-fs5 -ffixed-fs6 -ffixed-fs7 +CFLAGS += -ffixed-fa0 -ffixed-fa1 -ffixed-fa2 -ffixed-fa3 -ffixed-fa4 -ffixed-fa5 -ffixed-fa6 -ffixed-fa7 CFLAGS += -I$(VORTEX_KN_PATH)/include -I$(VORTEX_KN_PATH)/../hw LDFLAGS += -lm -Wl,-Bstatic,--gc-sections,-T,$(VORTEX_KN_PATH)/linker/vx_link$(XLEN).ld,--defsym=STARTUP_ADDR=0x80000000 $(VORTEX_KN_PATH)/libvortexrt.a @@ -33,7 +36,7 @@ $(PROJECT).dump: $(PROJECT).elf $(PROJECT).bin: $(PROJECT).elf $(CP) -O binary $(PROJECT).elf $(PROJECT).bin -$(PROJECT).elf: $(SRCS) +$(PROJECT).elf: $(SRCS) $(DEPS) $(CC) $(CFLAGS) $(SRCS) $(LDFLAGS) -o $(PROJECT).elf run-rtlsim: $(PROJECT).bin diff --git a/tests/kernel/reductions/Makefile b/tests/kernel/reductions/Makefile new file mode 100644 index 00000000..76e96c46 --- /dev/null +++ b/tests/kernel/reductions/Makefile @@ -0,0 +1,5 @@ +PROJECT = reductions + +SRCS = main.cpp + +include ../common.mk diff --git a/tests/kernel/reductions/main.cpp b/tests/kernel/reductions/main.cpp new file mode 100644 index 00000000..fcadddb6 --- /dev/null +++ b/tests/kernel/reductions/main.cpp @@ -0,0 +1,217 @@ +#define RISCV_CUSTOM2 0x5B +#define ADD_FUNC7 0b0000000 +#define ADDU_FUNC7 0b1000000 +#define MIN_FUNC7 0b0000001 +#define MINU_FUNC7 0b1000001 +#define MAX_FUNC7 0b0000010 +#define MAXU_FUNC7 0b1000010 +#define AND_FUNC7 0b0000011 +#define OR_FUNC7 0b0000100 +#define XOR_FUNC7 0b0000101 + +/* + 6'h0: begin + op_type = func7[6] ? `INST_RED_ADDU : `INST_RED_ADD; + end + 6'h1: begin + op_type = func7[6] ? `INST_RED_MINU : `INST_RED_MIN; + end + 6'h2: begin + op_type = func7[6] ? `INST_RED_MAXU : `INST_RED_MAX; + end + 6'h3: begin + op_type = `INST_RED_AND; + end + 6'h4: begin + op_type = `INST_RED_OR; + end + 6'h5: begin + op_type = `INST_RED_XOR; + end +*/ + +#include +#include +#include + +int x[4] = {3, 7, 2, 5}; +int y = -1; + +inline int vx_add_reduce(int v) { + int ret; + asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(ADD_FUNC7)); + return ret; +} + +inline int vx_min_reduce(int v) { + int ret; + asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(MIN_FUNC7)); + return ret; +} + +inline unsigned vx_minu_reduce(unsigned v) { + unsigned ret; + asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(MINU_FUNC7)); + return ret; +} + +inline int vx_max_reduce(int v) { + int ret; + asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(MAX_FUNC7)); + return ret; +} + +inline unsigned vx_maxu_reduce(unsigned v) { + unsigned ret; + asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(MAXU_FUNC7)); + return ret; +} + + +inline unsigned vx_and_reduce(unsigned v) { + unsigned ret; + asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(AND_FUNC7)); + return ret; +} + +inline unsigned vx_or_reduce(unsigned v) { + unsigned ret; + asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(OR_FUNC7)); + return ret; +} + +inline unsigned vx_xor_reduce(unsigned v) { + unsigned ret; + asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(XOR_FUNC7)); + return ret; +} + +void test_add_reduce() { + vx_tmc(-1); + int tid = vx_thread_id(); + int v = x[tid]; + int reduced = vx_add_reduce(v); + vx_tmc(1); + + y = reduced; +} + +unsigned unsigned_vector[4] = {(unsigned)-1, 0, (unsigned)-2, 5}; + +void test_min_reduce() { + vx_tmc(-1); + int tid = vx_thread_id(); + int v = unsigned_vector[tid]; + int reduced = vx_min_reduce(v); + vx_tmc(1); + + y = reduced; +} + +void test_max_reduce() { + vx_tmc(-1); + int tid = vx_thread_id(); + int v = unsigned_vector[tid]; + int reduced = vx_max_reduce(v); + vx_tmc(1); + + y = reduced; +} + +void test_minu_reduce() { + vx_tmc(-1); + int tid = vx_thread_id(); + unsigned v = unsigned_vector[tid]; + unsigned reduced = vx_minu_reduce(v); + vx_tmc(1); + + y = reduced; +} + +void test_maxu_reduce() { + vx_tmc(-1); + int tid = vx_thread_id(); + unsigned v = unsigned_vector[tid]; + unsigned reduced = vx_maxu_reduce(v); + vx_tmc(1); + + y = reduced; +} + +// assumes NUM_THREADS == 4 +unsigned bit_vectors[4] = {0b11010110000111001100010100100110, 0b10010100011010001010000000001110, 0b10001001010111110001110000000010, 0b00010011010100101101110111001111}; + +void test_and_reduce() { + vx_tmc(-1); + int tid = vx_thread_id(); + unsigned v = bit_vectors[tid]; + unsigned reduced = vx_and_reduce(v); + vx_tmc(1); + + y = reduced; +} + +void test_or_reduce() { + vx_tmc(-1); + int tid = vx_thread_id(); + unsigned v = bit_vectors[tid]; + unsigned reduced = vx_or_reduce(v); + vx_tmc(1); + + y = reduced; +} + +void test_xor_reduce() { + vx_tmc(-1); + int tid = vx_thread_id(); + unsigned v = bit_vectors[tid]; + unsigned reduced = vx_xor_reduce(v); + vx_tmc(1); + + y = reduced; +} + +int main() +{ + int expected; + + test_add_reduce(); + vx_printf("add reduce result: %d\n", y); + vx_printf("expected: %d\n", x[0] + x[1] + x[2] + x[3]); + + test_min_reduce(); + vx_printf("min reduce result: %d\n", y); + expected = MIN((int)unsigned_vector[0], MIN((int)unsigned_vector[1], MIN((int)unsigned_vector[2], (int)unsigned_vector[3]))); + vx_printf("expected: %d\n", expected); + + test_max_reduce(); + vx_printf("max reduce result: %d\n", y); + expected = MAX((int)unsigned_vector[0], MAX((int)unsigned_vector[1], MAX((int)unsigned_vector[2], (int)unsigned_vector[3]))); + vx_printf("expected: %d\n", expected); + + test_minu_reduce(); + vx_printf("minu reduce result: %d\n", y); + expected = MIN(unsigned_vector[0], MIN(unsigned_vector[1], MIN(unsigned_vector[2], unsigned_vector[3]))); + vx_printf("expected: %d\n", expected); + + test_maxu_reduce(); + vx_printf("maxu reduce result: %d\n", y); + expected = MAX(unsigned_vector[0], MAX(unsigned_vector[1], MAX(unsigned_vector[2], unsigned_vector[3]))); + vx_printf("expected: %d\n", expected); + + test_and_reduce(); + vx_printf("and reduce result: %d\n", y); + vx_printf("expected: %d\n", bit_vectors[0] & bit_vectors[1] & bit_vectors[2] & bit_vectors[3]); + + + test_or_reduce(); + vx_printf("or reduce result: %d\n", y); + vx_printf("expected: %d\n", bit_vectors[0] | bit_vectors[1] | bit_vectors[2] | bit_vectors[3]); + + test_xor_reduce(); + vx_printf("xor reduce result: %d\n", y); + vx_printf("expected: %d\n", bit_vectors[0] ^ bit_vectors[1] ^ bit_vectors[2] ^ bit_vectors[3]); + + + return 0; +} diff --git a/tests/kernel/tensor/Makefile b/tests/kernel/tensor/Makefile new file mode 100644 index 00000000..19cb340e --- /dev/null +++ b/tests/kernel/tensor/Makefile @@ -0,0 +1,8 @@ +PROJECT = tensor + +SRCS = main.cpp +DEPS = a_matrix.h +DEPS += b_matrix.h +DEPS += c_matrix.h + +include ../common.mk diff --git a/tests/kernel/tensor/check_correctness.py b/tests/kernel/tensor/check_correctness.py new file mode 100644 index 00000000..13e28891 --- /dev/null +++ b/tests/kernel/tensor/check_correctness.py @@ -0,0 +1,107 @@ +import numpy as np +import struct + +A_array = np.zeros((16, 8)) +B_array = np.zeros((8, 16)) +C_array = np.zeros((16, 16)) + +file = input("simulator output filename: ") + +def hex2float(float_hex_str): + # print(float_hex_str.strip()) + return struct.unpack(">f",struct.pack(">i",int(float_hex_str,16)))[0] + +def C_index(threadgroup, thread, register): + """ + col = ((tg % 4) / 2) * 8; + row = (tg * 8) % 16; + row += (tg / 4) * 4; + + asm volatile ("flw f16, %0" :: "m"(C[row+0][col+0])); + asm volatile ("flw f17, %0" :: "m"(C[row+0][col+1])); + asm volatile ("flw f18, %0" :: "m"(C[row+2][col+0])); + asm volatile ("flw f19, %0" :: "m"(C[row+2][col+1])); + asm volatile ("flw f20, %0" :: "m"(C[row+0][col+4])); + asm volatile ("flw f21, %0" :: "m"(C[row+0][col+5])); + asm volatile ("flw f22, %0" :: "m"(C[row+2][col+4])); + asm volatile ("flw f23, %0" :: "m"(C[row+2][col+5])); + """ + + col = ((threadgroup % 4) // 2) * 8 + row = (threadgroup * 8) % 16 + row += (threadgroup // 4) * 4 + offsets = [(0, 0), (0, 1), (2, 0), (2, 1), (0, 4), (0, 5), (2, 4), (2, 5)] + offset = offsets[register-16] + row += offset[0] + col += offset[1] + thread_offsets = [(0, 0), (1, 0), (0, 2), (1, 2)] + thread_offset = thread_offsets[thread % 4] + row += thread_offset[0] + col += thread_offset[1] + if C_array[row, col] != 0: + print("bad") + return (row, col) + + +with open(file) as f: + for line in f.readlines(): + line = line.strip() + if "warp" in line: + a, b, c = line.split(',') + _, a = a.split(' ') + _, b = b.strip().split(' ') + c, d = c.strip().split(':') + _, c = c.split(' ') + warp = int(a) + thread = int(b) + register = int(c) + value = d.strip() + + if warp != 0: + continue + if not (32 <= register < 32+24): + continue + + register = register - 32 + + # threadgroups 0, 4, 1, 5 have all elements of A + threadgroup = thread // 4 + if threadgroup in [0, 4, 1, 5]: + row = [0, 4, 1, 5].index(threadgroup) * 4 + thread % 4 + if 0 <= register < 8: + A_array[row, register] = hex2float(value) + + if threadgroup in [0, 4, 2, 6]: + col = [0, 4, 2, 6].index(threadgroup) * 4 + thread % 4 + if 8 <= register < 16: + B_array[register-8, col] = hex2float(value) + + if 16 <= register < 24: + # print(value) + C_array[C_index(threadgroup, thread, register)] = hex2float(value) + + +expected = np.load("abc.npz") +# expected_A = expected['A_array'] +# expected_B = expected['B_array'] +# expected_C = expected['C_array'] +expected_A = expected['A_array'][0:8, 0:8] +expected_B = expected['B_array'][0:8, 0:8] +expected_C = expected['C_array'][0:8, 0:8] +expected_C = expected_C + expected_A @ expected_B +print('expected A:') +print(expected_A) +print('expected B:') +print(expected_B) +print('expected C:') +print(expected_C[0:8, 0:8]) +print('got C:') +print(C_array[0:8, 0:8]) +print('diff C:') +print(expected_C[0:8, 0:8] - C_array[0:8, 0:8]) + +expected_C.astype('float32').tofile("c_expected.bin") + +assert np.allclose(expected_A, A_array) +assert np.allclose(expected_B, B_array) +assert np.allclose(expected_C, C_array) diff --git a/tests/kernel/tensor/create_test_case.py b/tests/kernel/tensor/create_test_case.py new file mode 100644 index 00000000..35ad7d73 --- /dev/null +++ b/tests/kernel/tensor/create_test_case.py @@ -0,0 +1,32 @@ +import numpy as np +A_array = np.random.rand(16, 8) +B_array = np.random.rand(8, 16) +C_array = np.random.rand(16, 16) +# A_array = np.zeros((16, 8)) +# B_array = np.zeros((8, 16)) +# A_array[0,:] = 1.0 +# B_array[:,4] = 1.0 +# C_array = np.zeros((16, 16)) +# for i in range(16): +# for j in range(16): +# C_array[i,j] = i * 16 + j + +with open('a_matrix.h', 'w') as f: + for i in range(A_array.shape[0]): + for j in range(A_array.shape[1]): + f.write(f'{A_array[i,j]}f, ') + f.write('\n') + +with open('b_matrix.h', 'w') as f: + for i in range(B_array.shape[0]): + for j in range(B_array.shape[1]): + f.write(f'{B_array[i,j]}f, ') + f.write('\n') + +with open('c_matrix.h', 'w') as f: + for i in range(C_array.shape[0]): + for j in range(C_array.shape[1]): + f.write(f'{C_array[i,j]}f, ') + f.write('\n') + +np.savez("abc", A_array=A_array, B_array=B_array, C_array=C_array) \ No newline at end of file diff --git a/tests/kernel/tensor/main.cpp b/tests/kernel/tensor/main.cpp new file mode 100644 index 00000000..d90c38be --- /dev/null +++ b/tests/kernel/tensor/main.cpp @@ -0,0 +1,222 @@ +#define RISCV_CUSTOM3 0x7B + +#include +#include +#include +#include + +constexpr int DIM_M = 8; + +inline void vx_wmma() { + asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)); +} + +inline void vx_wmma_new() { + asm volatile (".insn r %0, 0, 0, x1, x0, x0" :: "i"(RISCV_CUSTOM3)); +} + +#include "test_data.h" + +inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // A (row major) + // Figure 7(a) in paper + // row 0~ 3: threadgroups 0 and 2 + // row 4~ 7: threadgroups 4 and 6 + // row 8~11: threadgroups 1 and 3 + // row 12~15: threadgroups 5 and 7 + row = tid % 4; + row += (tg * 8) % 16; + row += (tg / 4) * 4; + + // B (column major) + // NOTE: Matrix B mapping in Figure 7(a) is incorrect; below is the + // corrected mapping: + // col 0~ 3: threadgroups 0 and 1 + // col 4~ 7: threadgroups 4 and 5 + // col 8~11: threadgroups 2 and 3 + // col 12~15: threadgroups 6 and 7 + col = tid % 4; + col += ((tg % 4) / 2) * 8; + col += (tg / 4) * 4; +} + +inline constexpr void map_operand_8lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // A (row major) + // row 0~ 3: threadgroup 0 + // row 4~ 7: threadgroup 1 + row = tid % 4; + row += tg * 4; + + // B (column major) + // col 0~ 3: threadgroup 0 + // col 4~ 7: threadgroup 1 + col = tid % 4; + col += tg * 4; +} + +inline constexpr void map_c_32lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // C + // Figure 7(b), left + col = ((tg % 4) / 2) * 8; + row = (tg * 8) % 16; + row += (tg / 4) * 4; + + // Figure 7(b), right + row += (tid % 4) % 2; + col += ((tid % 4) / 2) * 2; +} + +inline constexpr void map_c_8lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // C + col = 0; + row = tg * 4; + + // Figure 7(b), right + row += (tid % 4) % 2; + col += ((tid % 4) / 2) * 2; +} + +void vx_wmma_load() { + int tid = vx_thread_id(); + int tg = tid / 4; + + int row = 0; + int col = 0; + + map_operand_8lanes(tid, row, col); + + // load A + // each operand element is read twice by two threadgroups (Sec. III-B); + // i.e. 8 regs * 32 lanes = 256 fp32 elements = 2 * (16 * 8) elements + asm volatile ("flw f0, %0" :: "m"(A[row][0])); + asm volatile ("flw f1, %0" :: "m"(A[row][1])); + asm volatile ("flw f2, %0" :: "m"(A[row][2])); + asm volatile ("flw f3, %0" :: "m"(A[row][3])); + asm volatile ("flw f4, %0" :: "m"(A[row][4])); + asm volatile ("flw f5, %0" :: "m"(A[row][5])); + asm volatile ("flw f6, %0" :: "m"(A[row][6])); + asm volatile ("flw f7, %0" :: "m"(A[row][7])); + + // load B + asm volatile ("flw f8 , %0" :: "m"(B[0][col])); + asm volatile ("flw f9 , %0" :: "m"(B[1][col])); + asm volatile ("flw f10, %0" :: "m"(B[2][col])); + asm volatile ("flw f11, %0" :: "m"(B[3][col])); + asm volatile ("flw f12, %0" :: "m"(B[4][col])); + asm volatile ("flw f13, %0" :: "m"(B[5][col])); + asm volatile ("flw f14, %0" :: "m"(B[6][col])); + asm volatile ("flw f15, %0" :: "m"(B[7][col])); + + map_c_8lanes(tid, row, col); + + // load C + asm volatile ("flw f16, %0" :: "m"(C[row+0][col+0])); + asm volatile ("flw f17, %0" :: "m"(C[row+0][col+1])); + asm volatile ("flw f18, %0" :: "m"(C[row+2][col+0])); + asm volatile ("flw f19, %0" :: "m"(C[row+2][col+1])); + asm volatile ("flw f20, %0" :: "m"(C[row+0][col+4])); + asm volatile ("flw f21, %0" :: "m"(C[row+0][col+5])); + asm volatile ("flw f22, %0" :: "m"(C[row+2][col+4])); + asm volatile ("flw f23, %0" :: "m"(C[row+2][col+5])); + asm volatile ("flw f24, %0" :: "m"(C[row+0][col+0])); + asm volatile ("flw f25, %0" :: "m"(C[row+0][col+1])); + asm volatile ("flw f26, %0" :: "m"(C[row+2][col+0])); + asm volatile ("flw f27, %0" :: "m"(C[row+2][col+1])); + asm volatile ("flw f28, %0" :: "m"(C[row+0][col+4])); + asm volatile ("flw f29, %0" :: "m"(C[row+0][col+5])); + asm volatile ("flw f30, %0" :: "m"(C[row+2][col+4])); + asm volatile ("flw f31, %0" :: "m"(C[row+2][col+5])); +} + +// float results[32*8]; +float *const results = reinterpret_cast(0xc0000000UL); + +void store_wmma_result() { + int wid = vx_warp_id(); + int tid = vx_thread_id(); + int tg = tid / 4; + + int row = 0; + int col = 0; + + map_c_8lanes(tid, row, col); + + // store C + // asm volatile ("fsw f16, %0" :: "m"(results[tid*8+0])); + // asm volatile ("fsw f17, %0" :: "m"(results[tid*8+1])); + // asm volatile ("fsw f18, %0" :: "m"(results[tid*8+2])); + // asm volatile ("fsw f19, %0" :: "m"(results[tid*8+3])); + // asm volatile ("fsw f20, %0" :: "m"(results[tid*8+4])); + // asm volatile ("fsw f21, %0" :: "m"(results[tid*8+5])); + // asm volatile ("fsw f22, %0" :: "m"(results[tid*8+6])); + // asm volatile ("fsw f23, %0" :: "m"(results[tid*8+7])); + + float *const results_wid = results + (DIM_M * DIM_M * wid); + + // asm volatile("fsw f16, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 0)])); + // asm volatile("fsw f17, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 1)])); + // asm volatile("fsw f18, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 0)])); + // asm volatile("fsw f19, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 1)])); + // asm volatile("fsw f20, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 4)])); + // asm volatile("fsw f21, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 5)])); + // asm volatile("fsw f22, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 4)])); + // asm volatile("fsw f23, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 5)])); + asm volatile("fsw f24, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 0)])); + asm volatile("fsw f25, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 1)])); + asm volatile("fsw f26, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 0)])); + asm volatile("fsw f27, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 1)])); + asm volatile("fsw f28, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 4)])); + asm volatile("fsw f29, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 5)])); + asm volatile("fsw f30, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 4)])); + asm volatile("fsw f31, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 5)])); +} + +void print_wmma_result() { + const int num_threads = vx_num_threads(); + + for (int tid = 0; tid < num_threads; tid += 1) { + for (int reg = 0; reg < 8; reg += 1) { + vx_printf("thread %d, f%d: %x\n", tid, 16 + reg, + *((int *)&results[tid * 8 + reg])); + } + } +} + +void wmma() { + vx_tmc(-1); + + // if (vx_warp_id() == 1) { + // for (int i = 0; i < 100; i++) { + // asm volatile ("nop"); + // } + // } + + vx_wmma_load(); + // #pragma GCC unroll 100 + // for (int i = 0; i < 100; i++) { + // vx_wmma(); + // } + vx_wmma_new(); + + store_wmma_result(); + // print_wmma_result(); + vx_tmc(1); +} + +int main() { + const int num_warps = vx_num_warps(); + + vx_wspawn(num_warps, wmma); + wmma(); + vx_wspawn_wait(); + + return 0; +} diff --git a/tests/kernel/tensor/test_data.h b/tests/kernel/tensor/test_data.h new file mode 100644 index 00000000..83b05157 --- /dev/null +++ b/tests/kernel/tensor/test_data.h @@ -0,0 +1,11 @@ +float A[16][8] = { + #include "a_matrix.h" +}; + +float B[8][16] = { + #include "b_matrix.h" +}; + +float C[16][16] = { + #include "c_matrix.h" +}; \ No newline at end of file diff --git a/tests/regression/common.mk b/tests/regression/common.mk index 09a57795..6c55d435 100644 --- a/tests/regression/common.mk +++ b/tests/regression/common.mk @@ -49,7 +49,7 @@ VX_CP = $(LLVM_VORTEX)/bin/llvm-objcopy #VX_CP = $(RISCV_TOOLCHAIN_PATH)/bin/$(RISCV_PREFIX)-objcopy VX_CFLAGS += -v -O3 -std=c++17 -VX_CFLAGS += -mcmodel=medany -fno-rtti -fno-exceptions -nostartfiles -fdata-sections -ffunction-sections +VX_CFLAGS += -mcmodel=medany -fno-rtti -fno-exceptions -nostartfiles -fdata-sections -ffunction-sections -mllvm -inline-threshold=8192 VX_CFLAGS += -I$(VORTEX_KN_PATH)/include -I$(VORTEX_KN_PATH)/../hw -I$(GEMMINI_SW_PATH) VX_CFLAGS += -DNDEBUG -DLLVM_VORTEX @@ -83,7 +83,7 @@ endif # CONFIG is supplied from the command line to differentiate ELF files with custom suffixes CONFIGEXT = $(if $(CONFIG),.$(CONFIG),) -all: $(PROJECT) kernel.bin kernel.dump kernel.radiance.dump kernel.radiance$(CONFIGEXT).dump +all: $(PROJECT) kernel.bin kernel.dump kernel.radiance.dump kernel$(CONFIGEXT).dump kernel.radiance$(CONFIGEXT).dump kernel.dump: kernel.elf $(VX_DP) -D kernel.elf > kernel.dump @@ -92,6 +92,9 @@ kernel.radiance.dump: kernel.radiance.elf $(VX_DP) -D kernel.radiance.elf > kernel.radiance.dump ifneq ($(CONFIG),) +kernel$(CONFIGEXT).dump: kernel$(CONFIGEXT).elf + $(VX_DP) -D kernel$(CONFIGEXT).elf > kernel$(CONFIGEXT).dump + kernel.radiance$(CONFIGEXT).dump: kernel.radiance$(CONFIGEXT).elf $(VX_DP) -D kernel.radiance$(CONFIGEXT).elf > kernel.radiance$(CONFIGEXT).dump endif @@ -99,19 +102,30 @@ endif kernel.bin: kernel.elf kernel.radiance.elf $(VX_CP) -O binary kernel.elf kernel.bin -kernel.elf: $(VX_SRCS) - $(VX_CXX) $(VX_CFLAGS) $(VX_SRCS) $(VX_LDFLAGS) -o kernel.elf - -OBJCOPY ?= "riscv32-unknown-elf-objcopy" +OBJCOPY ?= $(RISCV_TOOLCHAIN_PATH)/bin/$(RISCV_PREFIX)-objcopy OBJCOPY_FLAGS ?= "LOAD,ALLOC,DATA,CONTENTS" -kernel.radiance.elf: kernel.elf - $(VX_CXX) $(VX_CFLAGS) $(VX_SRCS) $(VX_LDFLAGS) -DRADIANCE -o kernel.radiance.elf - $(OBJCOPY) --set-section-flags .operand.a=$(OBJCOPY_FLAGS) kernel.radiance.elf - $(OBJCOPY) --set-section-flags .operand.b=$(OBJCOPY_FLAGS) kernel.radiance.elf - $(OBJCOPY) --update-section .operand.a=input.a.bin kernel.radiance.elf - $(OBJCOPY) --update-section .operand.b=input.b.bin kernel.radiance.elf +kernel.elf: $(VX_SRCS) + $(VX_CXX) $(VX_CFLAGS) $(VX_SRCS) $(VX_LDFLAGS) -o $@ + $(OBJCOPY) --set-section-flags .operand.a=$(OBJCOPY_FLAGS) $@ + $(OBJCOPY) --set-section-flags .operand.b=$(OBJCOPY_FLAGS) $@ + $(OBJCOPY) --set-section-flags .args=$(OBJCOPY_FLAGS) $@ + $(OBJCOPY) --update-section .operand.a=input.a.bin $@ + $(OBJCOPY) --update-section .operand.b=input.b.bin $@ + $(OBJCOPY) --update-section .args=args.bin $@ + +kernel.radiance.elf: $(VX_SRCS) + $(VX_CXX) $(VX_CFLAGS) $(VX_SRCS) $(VX_LDFLAGS) -DRADIANCE -o $@ + $(OBJCOPY) --set-section-flags .operand.a=$(OBJCOPY_FLAGS) $@ + $(OBJCOPY) --set-section-flags .operand.b=$(OBJCOPY_FLAGS) $@ + $(OBJCOPY) --set-section-flags .args=$(OBJCOPY_FLAGS) $@ + $(OBJCOPY) --update-section .operand.a=input.a.bin $@ + $(OBJCOPY) --update-section .operand.b=input.b.bin $@ + $(OBJCOPY) --update-section .args=args.bin $@ ifneq ($(CONFIG),) +kernel$(CONFIGEXT).elf: kernel.elf + cp $< $@ + kernel.radiance$(CONFIGEXT).elf: kernel.radiance.elf cp $< $@ endif diff --git a/tests/regression/sgemm_gemmini/common.h b/tests/regression/sgemm_gemmini/common.h index 74941562..5c84f3b7 100644 --- a/tests/regression/sgemm_gemmini/common.h +++ b/tests/regression/sgemm_gemmini/common.h @@ -3,7 +3,7 @@ #include -#define KERNEL_ARG_DEV_MEM_ADDR 0x7fff0000 +#define KERNEL_ARG_DEV_MEM_ADDR 0x9fff0000 #define DEV_SMEM_START_ADDR 0xff000000 typedef struct { diff --git a/tests/regression/sgemm_gemmini_dma/.gitignore b/tests/regression/sgemm_gemmini_dma/.gitignore index 7c35ba59..66b3e811 100644 --- a/tests/regression/sgemm_gemmini_dma/.gitignore +++ b/tests/regression/sgemm_gemmini_dma/.gitignore @@ -1,5 +1,5 @@ *.bin *.dump *.elf -sgemm_wg +sgemm_gemmini_dma .depend diff --git a/tests/regression/sgemm_gemmini_dma/common.h b/tests/regression/sgemm_gemmini_dma/common.h index 74941562..5c84f3b7 100644 --- a/tests/regression/sgemm_gemmini_dma/common.h +++ b/tests/regression/sgemm_gemmini_dma/common.h @@ -3,7 +3,7 @@ #include -#define KERNEL_ARG_DEV_MEM_ADDR 0x7fff0000 +#define KERNEL_ARG_DEV_MEM_ADDR 0x9fff0000 #define DEV_SMEM_START_ADDR 0xff000000 typedef struct { diff --git a/tests/regression/sgemm_gemmini_dma/kernel.cpp b/tests/regression/sgemm_gemmini_dma/kernel.cpp index 02c99077..049d1970 100644 --- a/tests/regression/sgemm_gemmini_dma/kernel.cpp +++ b/tests/regression/sgemm_gemmini_dma/kernel.cpp @@ -51,7 +51,7 @@ inline void threadblock_barrier(unsigned int barrier_id, unsigned int count) { void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg, const uint32_t threadblock_id, const uint32_t tid_in_threadblock) { - __asm__("matmul_start:"); + asm volatile ("matmul_start_%=:" :: ); const float * const A = (const float * const) arg->addr_a; const float * const B = (const float * const) arg->addr_b; float * const C = (float * const) arg->addr_c; @@ -178,4 +178,4 @@ int main() { vx_spawn_tasks_contiguous(grid_size, (vx_spawn_tasks_cb)kernel_body, arg); #endif return 0; -} \ No newline at end of file +} diff --git a/tests/regression/sgemm_gemmini_dma/main.cpp b/tests/regression/sgemm_gemmini_dma/main.cpp index 54531062..45548d91 100644 --- a/tests/regression/sgemm_gemmini_dma/main.cpp +++ b/tests/regression/sgemm_gemmini_dma/main.cpp @@ -193,8 +193,8 @@ int main(int argc, char *argv[]) { { std::cout << "upload kernel argument" << std::endl; auto buf_ptr = staging_buf.data(); - kernel_arg.addr_a = (uint64_t) 0x20000; - kernel_arg.addr_b = (uint64_t) 0x28000; + kernel_arg.addr_a = (uint64_t) 0xa0000000ULL; + kernel_arg.addr_b = (uint64_t) 0xa1000000ULL; kernel_arg.addr_c = (uint64_t) 0xc0000000ULL; memcpy(buf_ptr, &kernel_arg, sizeof(kernel_arg_t)); diff --git a/tests/regression/sgemm_gemmini_dma/sgemm_gemmini_dma b/tests/regression/sgemm_gemmini_dma/sgemm_gemmini_dma deleted file mode 100755 index 67ade61b..00000000 Binary files a/tests/regression/sgemm_gemmini_dma/sgemm_gemmini_dma and /dev/null differ diff --git a/tests/regression/sgemm_tcore/.gitignore b/tests/regression/sgemm_tcore/.gitignore new file mode 100644 index 00000000..6ef379cc --- /dev/null +++ b/tests/regression/sgemm_tcore/.gitignore @@ -0,0 +1 @@ +sgemm_tcore diff --git a/tests/regression/sgemm_tcore/Makefile b/tests/regression/sgemm_tcore/Makefile new file mode 100644 index 00000000..0c378af0 --- /dev/null +++ b/tests/regression/sgemm_tcore/Makefile @@ -0,0 +1,9 @@ +PROJECT = sgemm_tcore + +SRCS = main.cpp common.h + +VX_SRCS = kernel.cpp + +OPTS ?= -n16 + +include ../common.mk \ No newline at end of file diff --git a/tests/regression/sgemm_tcore/common.h b/tests/regression/sgemm_tcore/common.h new file mode 100644 index 00000000..5c84f3b7 --- /dev/null +++ b/tests/regression/sgemm_tcore/common.h @@ -0,0 +1,18 @@ +#ifndef _COMMON_H_ +#define _COMMON_H_ + +#include + +#define KERNEL_ARG_DEV_MEM_ADDR 0x9fff0000 +#define DEV_SMEM_START_ADDR 0xff000000 + +typedef struct { + uint32_t dim_m; + uint32_t dim_n; + uint32_t dim_k; + uint64_t addr_a; + uint64_t addr_b; + uint64_t addr_c; +} kernel_arg_t; + +#endif diff --git a/tests/regression/sgemm_tcore/kernel.4warps.cpp b/tests/regression/sgemm_tcore/kernel.4warps.cpp new file mode 100644 index 00000000..f498f57b --- /dev/null +++ b/tests/regression/sgemm_tcore/kernel.4warps.cpp @@ -0,0 +1,333 @@ +#define RISCV_CUSTOM3 0x7B + +#include +#include +#include +#include +#include "common.h" + +#define BM 16 +#define BN 16 +#define BK 8 + +inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // A (row major) + // Figure 7(a) in paper + // row 0~ 3: threadgroups 0 and 2 + // row 4~ 7: threadgroups 4 and 6 + // row 8~11: threadgroups 1 and 3 + // row 12~15: threadgroups 5 and 7 + row = tid % 4; + row += (tg * 8) % 16; + row += (tg / 4) * 4; + + // B (column major) + // NOTE: Matrix B mapping in Figure 7(a) is incorrect; below is the + // corrected mapping: + // col 0~ 3: threadgroups 0 and 1 + // col 4~ 7: threadgroups 4 and 5 + // col 8~11: threadgroups 2 and 3 + // col 12~15: threadgroups 6 and 7 + col = tid % 4; + col += ((tg % 4) / 2) * 8; + col += (tg / 4) * 4; +} + +inline constexpr void map_operand_8lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // A (row major) + // row 0~ 3: threadgroup 0 + // row 4~ 7: threadgroup 1 + row = tid % 4; + row += tg * 4; + + // B (column major) + // col 0~ 3: threadgroup 0 + // col 4~ 7: threadgroup 1 + col = tid % 4; + col += tg * 4; +} + +inline constexpr void map_c_32lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // C + // Figure 7(b), left + col = ((tg % 4) / 2) * 8; + row = (tg * 8) % 16; + row += (tg / 4) * 4; + + // Figure 7(b), right + row += (tid % 4) % 2; + col += ((tid % 4) / 2) * 2; +} + +inline constexpr void map_c_8lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // C + col = 0; + row = tg * 4; + + // Figure 7(b), right + row += (tid % 4) % 2; + col += ((tid % 4) / 2) * 2; +} + +inline void vx_wmma() { + asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)); +} + +void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x, + int warp_y, int thread_in_warp) { + int tid = thread_in_warp; + int tg = tid / 4; + + int row = 0; + int col = 0; + map_operand_32lanes(tid, row, col); + + int smem_A_m = 32; + int smem_A_n = 8; + int smem_B_m = 8; + int smem_B_n = 32; + + int A_offset = (row + BM * warp_y) * smem_A_n; + + asm volatile("flw f0, %0" ::"m"(smem_A[A_offset + 0])); + asm volatile("flw f1, %0" ::"m"(smem_A[A_offset + 1])); + asm volatile("flw f2, %0" ::"m"(smem_A[A_offset + 2])); + asm volatile("flw f3, %0" ::"m"(smem_A[A_offset + 3])); + asm volatile("flw f4, %0" ::"m"(smem_A[A_offset + 4])); + asm volatile("flw f5, %0" ::"m"(smem_A[A_offset + 5])); + asm volatile("flw f6, %0" ::"m"(smem_A[A_offset + 6])); + asm volatile("flw f7, %0" ::"m"(smem_A[A_offset + 7])); + + asm volatile("flw f8 , %0" ::"m"(smem_B[(0 * smem_B_n) + warp_x * BN + col])); + asm volatile("flw f9 , %0" ::"m"(smem_B[(1 * smem_B_n) + warp_x * BN + col])); + asm volatile("flw f10, %0" ::"m"(smem_B[(2 * smem_B_n) + warp_x * BN + col])); + asm volatile("flw f11, %0" ::"m"(smem_B[(3 * smem_B_n) + warp_x * BN + col])); + asm volatile("flw f12, %0" ::"m"(smem_B[(4 * smem_B_n) + warp_x * BN + col])); + asm volatile("flw f13, %0" ::"m"(smem_B[(5 * smem_B_n) + warp_x * BN + col])); + asm volatile("flw f14, %0" ::"m"(smem_B[(6 * smem_B_n) + warp_x * BN + col])); + asm volatile("flw f15, %0" ::"m"(smem_B[(7 * smem_B_n) + warp_x * BN + col])); +} + +inline void initialize_C() { + // initialize C to zeros + asm volatile("fmv.w.x f16, x0"); + asm volatile("fmv.w.x f17, x0"); + asm volatile("fmv.w.x f18, x0"); + asm volatile("fmv.w.x f19, x0"); + asm volatile("fmv.w.x f20, x0"); + asm volatile("fmv.w.x f21, x0"); + asm volatile("fmv.w.x f22, x0"); + asm volatile("fmv.w.x f23, x0"); +} + +inline void write_results(volatile float *local_warp_results, + int thread_in_warp, int warp_x, int warp_y, int dim_m, + int dim_n, float *C, int threadblock_id_x, + int threadblock_id_y) { + int tid = thread_in_warp; + int tg = tid / 4; + + asm volatile("fsw f16, %0" ::"m"(local_warp_results[tid * 8 + 0])); + asm volatile("fsw f17, %0" ::"m"(local_warp_results[tid * 8 + 1])); + asm volatile("fsw f18, %0" ::"m"(local_warp_results[tid * 8 + 2])); + asm volatile("fsw f19, %0" ::"m"(local_warp_results[tid * 8 + 3])); + asm volatile("fsw f20, %0" ::"m"(local_warp_results[tid * 8 + 4])); + asm volatile("fsw f21, %0" ::"m"(local_warp_results[tid * 8 + 5])); + asm volatile("fsw f22, %0" ::"m"(local_warp_results[tid * 8 + 6])); + asm volatile("fsw f23, %0" ::"m"(local_warp_results[tid * 8 + 7])); + + /* + col = ((threadgroup % 4) // 2) * 8 + row = (threadgroup * 8) % 16 + row += (threadgroup // 4) * 4 + offsets = [(0, 0), (0, 1), (2, 0), (2, 1), (0, 4), (0, 5), (2, 4), (2, 5)] + offset = offsets[register-16] + row += offset[0] + col += offset[1] + thread_offsets = [(0, 0), (1, 0), (0, 2), (1, 2)] + thread_offset = thread_offsets[thread % 4] + row += thread_offset[0] + col += thread_offset[1] + return (row, col) + */ + + int local_row = 0; + int local_col = 0; + map_c_32lanes(tid, local_row, local_col); + + float *global_offset_C = C + + (threadblock_id_y * BM * 2 + warp_y * BM) * dim_n + + threadblock_id_x * BN * 2 + warp_x * BM; + for (int i = 0; i < 8; i += 1) { + int row_offset = ((i / 2) % 2) * 2; + int col_offset = (i / 4) * 4 + i % 2; + + int adjusted_local_row = local_row + row_offset; + int adjusted_local_col = local_col + col_offset; + + float v = local_warp_results[tid * 8 + i]; + global_offset_C[adjusted_local_row * dim_n + adjusted_local_col] = v; + } +} + +void threadblock_barrier(unsigned int tid_in_threadblock, unsigned int barrier_id, unsigned int count) { + vx_fence(); + vx_barrier(barrier_id, count); +} + +void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, + const uint32_t tid_in_threadblock, + const uint32_t threadblock_dim_x, + const uint32_t threadblock_dim_y, + const uint32_t threadblock_id_x, + const uint32_t threadblock_id_y, + const uint32_t threadblock_id, + float *sharedmem_per_threadblock) { + const float *A = (const float *)arg->addr_a; + const float *B = (const float *)arg->addr_b; + float *C = (float *)arg->addr_c; + + const uint32_t dim_m = arg->dim_m; + const uint32_t dim_n = arg->dim_n; + const uint32_t dim_k = arg->dim_k; + + // FIXME: Output block size is assumed to be square, i.e. BM == BN + // const uint32_t BM = threadblock_dim_y; + // const uint32_t BN = threadblock_dim_y; + // const uint32_t BK = threadblock_dim_x; + // constexpr uint32_t BM = 8; + // constexpr uint32_t BN = 8; + // constexpr uint32_t BK = 2; + + const uint32_t warp_in_threadblock = tid_in_threadblock / 32; + const uint32_t tid_in_warp = tid_in_threadblock % 32; + const uint32_t warp_x = warp_in_threadblock % 2; + const uint32_t warp_y = warp_in_threadblock / 2; + + const uint32_t global_a_row = threadblock_dim_y * threadblock_id_y; + + // 32 * 8 block of A, being loaded by 4 warps + const uint32_t local_a_row = warp_y * BM + warp_x * (BM / 2) + (tid_in_warp / BK); + const uint32_t local_a_col = tid_in_warp % BK; + + // 8 * 32 block of B, being loaded by 4 warps + // do a fat coalesced load + const uint32_t global_b_col = threadblock_dim_x * threadblock_id_x; + const uint32_t local_b_row = warp_in_threadblock; + const uint32_t local_b_col = tid_in_warp; + + volatile float *local_a = sharedmem_per_threadblock; + const size_t local_a_elems = (threadblock_dim_y * BK); + volatile float *local_b = sharedmem_per_threadblock + local_a_elems; + const size_t local_b_elems = (threadblock_dim_x * BK); + volatile float *local_warp_results = local_b + local_b_elems + (warp_in_threadblock * BM * BN); + + // clear out C + initialize_C(); + + for (uint32_t k = 0; k < dim_k; k += BK) { + // Data move from GMEM to SMEM + // + // Make sure global offset values for A and B are contiguous between + // neighboring threads to ensure GMEM coalescing. (not possible for A here, but for B it's doable) + + // coalesced load from global memory -> unit-stride store into shared memory + uint32_t global_a_offset = + dim_k * (global_a_row + local_a_row) + (k + local_a_col); + uint32_t shared_a_offset = + BK * local_a_row + local_a_col; + + local_a[shared_a_offset] = A[global_a_offset]; + + global_a_offset += dim_k * (BM / 4); + shared_a_offset += BK * (BM / 4); + + local_a[shared_a_offset] = A[global_a_offset]; + + uint32_t global_b_offset = + dim_n * (k + local_b_row) + (global_b_col + local_b_col); + uint32_t shared_b_offset = + (BN * 2) * (local_b_row) + local_b_col; + + local_b[shared_b_offset] = B[global_b_offset]; + + global_b_offset += dim_n * (BK / 2); + shared_b_offset += (BN * 2) * (BK / 2); + + local_b[shared_b_offset] = B[global_b_offset]; + + // want all 4 warps to reach barrier before moving on (just use barrier 0 for now) + threadblock_barrier(tid_in_threadblock, 0, 4); + + // perform wmma + vx_wmma_load(local_a, local_b, warp_x, warp_y, tid_in_warp); + vx_wmma(); + + // same as above + threadblock_barrier(tid_in_threadblock, 0, 4); + } + + write_results( + local_warp_results, + tid_in_warp, + warp_x, + warp_y, + dim_m, + dim_n, + C, + threadblock_id_x, + threadblock_id_y + ); +} + +void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { + // @perf: All threads are running these compute whose result is mostly same + // across the threadblock + const int NT = 32; // vx_num_threads(); + const int NW = 4; // vx_num_warps(); + const uint32_t threads_per_threadblock = NT * NW; + + // matches 4 warp capacity + const uint32_t threadblock_dim_x = 2 * BN; + const uint32_t threadblock_dim_y = 2 * BM; + const int threadblock_id = task_id / threads_per_threadblock; + const int tid_in_threadblock = task_id % threads_per_threadblock; + + const uint32_t dim_m = arg->dim_m; + const uint32_t dim_n = arg->dim_n; + const uint32_t dim_n_in_blocks = dim_n / threadblock_dim_x; + const int threadblock_id_x = threadblock_id % dim_n_in_blocks; + const int threadblock_id_y = threadblock_id / dim_n_in_blocks; + + // "static" shared memory allocation. This would determine threadblock + // occupancy of a single cluster + // only 1 threadblock running at a time, so this is ok + float *sharedmem_per_threadblock = + (float *)DEV_SMEM_START_ADDR; // + (2 * BM * BK) + (2 * BN * BK) * threadblock_id; + + thread_block_gemm(arg, tid_in_threadblock, threadblock_dim_x, + threadblock_dim_y, threadblock_id_x, threadblock_id_y, + threadblock_id, sharedmem_per_threadblock); +} + +int main() { + kernel_arg_t *arg = (kernel_arg_t *)KERNEL_ARG_DEV_MEM_ADDR; + int NT = vx_num_threads(); + + // TODO: add support for edge-case (m, n not divisible by 16) + const uint32_t grid_size = arg->dim_m * arg->dim_n * NT / (BM * BN); + + // for now, simplifying assumption of just 1 core + // vx_spawn_tasks_contiguous first runs warps 1 through NW, then NW+1 through 2*NW, etc. + // we can thus treat 1 through NW as a single threadblock for the purposes of the optimization. + vx_spawn_tasks_contiguous(grid_size, (vx_spawn_tasks_cb)kernel_body, arg); + return 0; +} diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp new file mode 100644 index 00000000..a56203a4 --- /dev/null +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -0,0 +1,755 @@ +#define RISCV_CUSTOM3 0x7B + +#include +#include +#include +#include +#include "common.h" + +#define NUM_LANES 8 + +// number of loop around the inner 0..TCK..BK loop to simulate perfect-DRAM +// scenario +#define BK_LOOP 1 +#define TRANSPOSE_AS 1 +// GMEM_COALESCED sets bank conflict-free accesses for +// 1: GMEM loads of A matrix +// 0: SMEM stores of A matrix +#define GMEM_COALESCED_A 1 + +#define DOUBLE_BUFFER 1 + +// Constraints on parameters: +// * Memory: +// (BM + BN) * BK * sizeof(float) <= sharedmem size. +// BM * BK == BN * BK >= threadblock size >= NT * CORES_PER_CLUSTER +// When larger, the kernel runs a sequential loop to read into sharedmem; +// but smaller case is not handled. +// * Compute: +// ( M* N) / (TM*TN) == grid size >= NC*NW*NT +// (BM*BN) / (TM*TN) == threadblock size < NT * NW * CORES_PER_CLUSTER +// (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER +// * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields +// BM <= BK*TM*TN +#define BM 32 +#define BN 32 +#define BK 32 +#define WM 16 +#define WN 8 +#define TCM 8 +#define TCN 8 +#define TCK 8 +#define WMITER (WM / TCM) +#define WNITER (WN / TCN) +#define ELEM_PER_THREAD (WMITER * WNITER * ((TCM * TCN) / NUM_LANES) / (DOUBLE_BUFFER ? 2 : 1)) + +// FIXME: NUM_THREADS and NUM_WARPS hardcoded +#if ((BM * BN / ELEM_PER_THREAD) > (CORES_PER_CLUSTER * 8 * 8)) +#error "threadblock size too big for cluster" +#endif + +inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // A (row major) + // Figure 7(a) in paper + // row 0~ 3: threadgroups 0 and 2 + // row 4~ 7: threadgroups 4 and 6 + // row 8~11: threadgroups 1 and 3 + // row 12~15: threadgroups 5 and 7 + row = tid % 4; + row += (tg * 8) % 16; + row += (tg / 4) * 4; + + // B (column major) + // NOTE: Matrix B mapping in Figure 7(a) is incorrect; below is the + // corrected mapping: + // col 0~ 3: threadgroups 0 and 1 + // col 4~ 7: threadgroups 4 and 5 + // col 8~11: threadgroups 2 and 3 + // col 12~15: threadgroups 6 and 7 + col = tid % 4; + col += ((tg % 4) / 2) * 8; + col += (tg / 4) * 4; +} + +inline constexpr void map_operand_8lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // A (row major) + // row 0~ 3: threadgroup 0 + // row 4~ 7: threadgroup 1 + row = tid % 4; + row += tg * 4; + + // B (column major) + // col 0~ 3: threadgroup 0 + // col 4~ 7: threadgroup 1 + col = tid % 4; + col += tg * 4; +} + +inline constexpr void map_operand(const int tid, int &row, int &col) { + if constexpr (NUM_LANES == 32) { + map_operand_32lanes(tid, row, col); + } else if constexpr (NUM_LANES == 8) { + map_operand_8lanes(tid, row, col); + } else { + // FIXME: not allowed + } +} + +inline constexpr void map_c_32lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // C + // Figure 7(b), left + col = ((tg % 4) / 2) * 8; + row = (tg * 8) % 16; + row += (tg / 4) * 4; + + // Figure 7(b), right + row += (tid % 4) % 2; + col += ((tid % 4) / 2) * 2; +} + +inline constexpr void map_c_8lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // C + col = 0; + row = tg * 4; + + // Figure 7(b), right + row += (tid % 4) % 2; + col += ((tid % 4) / 2) * 2; +} + +inline constexpr void map_c(const int tid, int &row, int &col) { + if constexpr (NUM_LANES == 32) { + map_c_32lanes(tid, row, col); + } else if constexpr (NUM_LANES == 8) { + map_c_8lanes(tid, row, col); + } else { + // FIXME: not allowed + } +} + +inline void vx_wmma(const int dest_reg) { + if (dest_reg == 0) { + asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)); + } else { + asm volatile (".insn r %0, 0, 0, x1, x0, x0" :: "i"(RISCV_CUSTOM3)); + } +} + +// `local_k` is assumed to be multiple of TCK +inline void vx_wmma_load_a(volatile float *smem_A, const int local_k, + const int warp_row, const int wm_iter, const int thread_in_warp) { + const int tid = thread_in_warp; + const int tg = tid / 4; + + // TODO: this is duplicately computed between vx_wmma_load_a and vx_wmma_load_b + int row = 0; + int col = 0; + map_operand(tid, row, col); + + constexpr int smem_A_rows = BM; + constexpr int smem_A_cols = BK; + constexpr int smem_AS_rows = BK; + constexpr int smem_AS_cols = BM; + + if constexpr (!TRANSPOSE_AS) { + int A_offset = (WM * warp_row + TCM * wm_iter + row) * smem_A_cols; + + // @perf: bank conflicts + // f8-f15 stores a single row of A + asm volatile("flw f0, %0" ::"m"(smem_A[A_offset + (local_k + 0)])); + asm volatile("flw f1, %0" ::"m"(smem_A[A_offset + (local_k + 1)])); + asm volatile("flw f2, %0" ::"m"(smem_A[A_offset + (local_k + 2)])); + asm volatile("flw f3, %0" ::"m"(smem_A[A_offset + (local_k + 3)])); + asm volatile("flw f4, %0" ::"m"(smem_A[A_offset + (local_k + 4)])); + asm volatile("flw f5, %0" ::"m"(smem_A[A_offset + (local_k + 5)])); + asm volatile("flw f6, %0" ::"m"(smem_A[A_offset + (local_k + 6)])); + asm volatile("flw f7, %0" ::"m"(smem_A[A_offset + (local_k + 7)])); + } else { + // transposed A + // f8-f15 stores a single row of A + volatile float *smem_addr; + smem_addr = &smem_A[((local_k + 0) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]; + asm volatile("flw f0, %0(%1)" :: "i"(smem_AS_cols * 0 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f1, %0(%1)" :: "i"(smem_AS_cols * 1 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f2, %0(%1)" :: "i"(smem_AS_cols * 2 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f3, %0(%1)" :: "i"(smem_AS_cols * 3 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f4, %0(%1)" :: "i"(smem_AS_cols * 4 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f5, %0(%1)" :: "i"(smem_AS_cols * 5 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f6, %0(%1)" :: "i"(smem_AS_cols * 6 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f7, %0(%1)" :: "i"(smem_AS_cols * 7 * sizeof(float)), "r"(smem_addr)); + + // asm volatile("flw f0, %0" ::"m"(smem_A[((local_k + 0) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); + // asm volatile("flw f1, %0" ::"m"(smem_A[((local_k + 1) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); + // asm volatile("flw f2, %0" ::"m"(smem_A[((local_k + 2) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); + // asm volatile("flw f3, %0" ::"m"(smem_A[((local_k + 3) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); + // asm volatile("flw f4, %0" ::"m"(smem_A[((local_k + 4) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); + // asm volatile("flw f5, %0" ::"m"(smem_A[((local_k + 5) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); + // asm volatile("flw f6, %0" ::"m"(smem_A[((local_k + 6) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); + // asm volatile("flw f7, %0" ::"m"(smem_A[((local_k + 7) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); + } +} + +// `local_k` is assumed to be multiple of TCK +inline void vx_wmma_load_b(volatile float *smem_B, const int local_k, + const int warp_col, const int wn_iter, + const int thread_in_warp) { + const int tid = thread_in_warp; + const int tg = tid / 4; + + int row = 0; + int col = 0; + map_operand(tid, row, col); + + constexpr int smem_B_rows = BK; + constexpr int smem_B_cols = BN; + + // f8-f15 stores a single column of B + volatile float *smem_addr; + smem_addr = &smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]; + asm volatile("flw f8, %0(%1)" :: "i"(smem_B_cols * 0 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f9, %0(%1)" :: "i"(smem_B_cols * 1 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f10, %0(%1)" :: "i"(smem_B_cols * 2 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f11, %0(%1)" :: "i"(smem_B_cols * 3 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f12, %0(%1)" :: "i"(smem_B_cols * 4 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f13, %0(%1)" :: "i"(smem_B_cols * 5 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f14, %0(%1)" :: "i"(smem_B_cols * 6 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f15, %0(%1)" :: "i"(smem_B_cols * 7 * sizeof(float)), "r"(smem_addr)); + + // asm volatile("flw f8, %0" ::"m"(smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + // asm volatile("flw f9, %0" ::"m"(smem_B[((local_k + 1) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + // asm volatile("flw f10, %0" ::"m"(smem_B[((local_k + 2) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + // asm volatile("flw f11, %0" ::"m"(smem_B[((local_k + 3) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + // asm volatile("flw f12, %0" ::"m"(smem_B[((local_k + 4) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + // asm volatile("flw f13, %0" ::"m"(smem_B[((local_k + 5) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + // asm volatile("flw f14, %0" ::"m"(smem_B[((local_k + 6) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + // asm volatile("flw f15, %0" ::"m"(smem_B[((local_k + 7) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); +} + +inline void initialize_C(const int dest_reg) { + // initialize C to zeros + if (dest_reg == 0) { + asm volatile("fmv.w.x f16, x0"); + asm volatile("fmv.w.x f17, x0"); + asm volatile("fmv.w.x f18, x0"); + asm volatile("fmv.w.x f19, x0"); + asm volatile("fmv.w.x f20, x0"); + asm volatile("fmv.w.x f21, x0"); + asm volatile("fmv.w.x f22, x0"); + asm volatile("fmv.w.x f23, x0"); + } else { + asm volatile("fmv.w.x f24, x0"); + asm volatile("fmv.w.x f25, x0"); + asm volatile("fmv.w.x f26, x0"); + asm volatile("fmv.w.x f27, x0"); + asm volatile("fmv.w.x f28, x0"); + asm volatile("fmv.w.x f29, x0"); + asm volatile("fmv.w.x f30, x0"); + asm volatile("fmv.w.x f31, x0"); + } +} + +inline void write_results(const int thread_in_warp, const int warp_col, + const int warp_row, const int wn_iter, + const int wm_iter, const int dim_n, + float *C, const int threadblock_id_x, + const int threadblock_id_y) { + int tid = thread_in_warp; + int tg = tid / 4; + + // these are [0, TCM/TCN) + int tid_row = 0; + int tid_col = 0; + map_c(tid, tid_row, tid_col); + + int local_row = (WM * warp_row + TCM * wm_iter) + tid_row; + int local_col = (WN * warp_col + TCN * wn_iter) + tid_col; + + float *global_offset_C = C + + (BM * threadblock_id_y) * dim_n + + BN * threadblock_id_x; + + // @perf: this likely causes a lot of gmem bank conflicts + if (wm_iter == 0) { + volatile float *gmem_addr; + volatile float *gmem_addr_tmp; + gmem_addr = &global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]; + asm volatile ("fsw f16, %0" :: "m"(*(gmem_addr + 0))); + asm volatile ("fsw f17, %0" :: "m"(*(gmem_addr + 1))); + gmem_addr_tmp = gmem_addr + (2 * dim_n); + asm volatile ("fsw f18, %0" :: "m"(*(gmem_addr_tmp + 0))); + asm volatile ("fsw f19, %0" :: "m"(*(gmem_addr_tmp + 1))); + gmem_addr += 4; + asm volatile ("fsw f20, %0" :: "m"(*(gmem_addr + 0))); + asm volatile ("fsw f21, %0" :: "m"(*(gmem_addr + 1))); + gmem_addr_tmp = gmem_addr + (2 * dim_n); + asm volatile ("fsw f22, %0" :: "m"(*(gmem_addr_tmp + 0))); + asm volatile ("fsw f23, %0" :: "m"(*(gmem_addr_tmp + 1))); + // asm volatile ("fsw f16, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 0)])); + // asm volatile ("fsw f17, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 1)])); + // asm volatile ("fsw f18, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 0)])); + // asm volatile ("fsw f19, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 1)])); + // asm volatile ("fsw f20, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 4)])); + // asm volatile ("fsw f21, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 5)])); + // asm volatile ("fsw f22, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 4)])); + // asm volatile ("fsw f23, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)])); + } else { + volatile float *gmem_addr; + volatile float *gmem_addr_tmp; + gmem_addr = &global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]; + gmem_addr_tmp = gmem_addr + (2 * dim_n); + asm volatile ("fsw f24, %0" :: "m"(*(gmem_addr + 0))); + asm volatile ("fsw f25, %0" :: "m"(*(gmem_addr + 1))); + asm volatile ("fsw f26, %0" :: "m"(*(gmem_addr_tmp + 0))); + asm volatile ("fsw f27, %0" :: "m"(*(gmem_addr_tmp + 1))); + gmem_addr += 4; + gmem_addr_tmp = gmem_addr + (2 * dim_n); + asm volatile ("fsw f28, %0" :: "m"(*(gmem_addr + 0))); + asm volatile ("fsw f29, %0" :: "m"(*(gmem_addr + 1))); + asm volatile ("fsw f30, %0" :: "m"(*(gmem_addr_tmp + 0))); + asm volatile ("fsw f31, %0" :: "m"(*(gmem_addr_tmp + 1))); + } +} + +inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count) { + vx_fence(); + vx_barrier(barrier_id, count); + // vx_barrier(0, count); +} + +inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, + const uint32_t k, const float *A, const float *B, + volatile float *local_a, volatile float *local_b, + const uint32_t tid_in_threadblock, + const uint32_t threadblock_id_x, + const uint32_t threadblock_id_y) { + const uint32_t local_a_row = tid_in_threadblock / BK; + const uint32_t local_a_col = tid_in_threadblock % BK; + const uint32_t local_as_row = tid_in_threadblock / BM; + const uint32_t local_as_col = tid_in_threadblock % BM; + const uint32_t local_b_row = tid_in_threadblock / BN; + const uint32_t local_b_col = tid_in_threadblock % BN; + + constexpr uint32_t threads_in_warpgroup = + (BM * BN) / ELEM_PER_THREAD / (DOUBLE_BUFFER ? 2 : 1); // FIXME + + // Data move from GMEM to SMEM + // + // Make sure global offset values for A and B are contiguous between + // neighboring threads to ensure GMEM coalescing. + // + // TODO: Sharedmem swizzling is important here + if constexpr (!TRANSPOSE_AS) { + // FIXME: !TRANSPOSE_AS code is old + + const uint32_t global_a_row = BM * threadblock_id_y + local_a_row; + // number of rows a full TB can read at a time + constexpr uint32_t row_stride_a = threads_in_warpgroup / BK; + const float *global_a = A + dim_k * global_a_row + (k + local_a_col); + volatile float *local_a_tmp = local_a + BK * local_a_row + local_a_col; + +#pragma GCC unroll 1 + for (uint32_t local_row_offset = 0; local_row_offset < BM; + local_row_offset += row_stride_a) { + // const uint32_t global_a_offset = + // dim_k * (global_a_row + local_row_offset) + (k + local_a_col); + // local_a[BK * (local_a_row + local_row_offset) + local_a_col] = + // A[global_a_offset]; + *local_a_tmp = *global_a; + + global_a += dim_k * row_stride_a; + local_a_tmp += BK * row_stride_a; + } + } else { + if constexpr (!GMEM_COALESCED_A) { + constexpr uint32_t row_stride_as = threads_in_warpgroup / BM; + const uint32_t global_a_row = BM * threadblock_id_y + local_as_col; + const float *global_a = A + dim_k * global_a_row + (k + local_as_row); + // FIXME experimenting with global coalescing + // const uint32_t global_a_row = BM * threadblock_id_y + local_as_row; + // const float *global_a = A + dim_k * global_a_row + (k + local_as_col); + volatile float *local_a_tmp = local_a + BM * local_as_row + local_as_col; + + static_assert( + row_stride_as * 8 <= BK, + "manual loop unrolling condition not met; consider increasing BK"); + static_assert( + (BK % (row_stride_as * 8)) == 0, + "manual loop unrolling condition not met; BK should be power-of-two"); + +#pragma GCC unroll 1 + for (uint32_t local_row_offset = 0; local_row_offset < BK; + local_row_offset += row_stride_as * 8) { + // @perf: bank conflicts here + // const uint32_t global_a_offset = + // dim_k * (global_a_row) + (k + local_as_row + local_row_offset); + // FIXME experimenting with global coalescing + // const uint32_t global_a_offset = + // dim_k * (global_a_row + local_row_offset) + (k + local_as_col); + // local_a[BM * (local_as_row + local_row_offset) + local_as_col] = + // A[global_a_offset]; + + // *local_a_tmp = *global_a; + asm volatile ("flw ft0, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft1, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft2, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft3, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft4, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft5, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft6, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft7, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + + asm volatile ("fsw ft0, %0(%1)" :: "i"(BM * row_stride_as * 0 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft1, %0(%1)" :: "i"(BM * row_stride_as * 1 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft2, %0(%1)" :: "i"(BM * row_stride_as * 2 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft3, %0(%1)" :: "i"(BM * row_stride_as * 3 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft4, %0(%1)" :: "i"(BM * row_stride_as * 4 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft5, %0(%1)" :: "i"(BM * row_stride_as * 5 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft6, %0(%1)" :: "i"(BM * row_stride_as * 6 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft7, %0(%1)" :: "i"(BM * row_stride_as * 7 * sizeof(float)), "r"(local_a_tmp)); + local_a_tmp += BM * row_stride_as * 8; + } + } else { + constexpr uint32_t row_stride_a = threads_in_warpgroup / BK; + const uint32_t global_a_row = BM * threadblock_id_y + local_a_row; + const float *global_a = A + dim_k * global_a_row + (k + local_a_col); + // NOTE that SMEM writes are transposed + volatile float *local_a_tmp = local_a + BM * local_a_col + local_a_row; + + static_assert( + row_stride_a * 8 <= BM, + "manual loop unrolling condition not met; consider increasing BM"); + static_assert( + (BM % (row_stride_a * 8)) == 0, + "manual loop unrolling condition not met; BM should be power-of-two"); + +#pragma GCC unroll 1 + for (uint32_t local_row_offset = 0; local_row_offset < BM; + local_row_offset += row_stride_a * 8) { + // const uint32_t global_a_offset = + // dim_k * (global_a_row + local_row_offset) + (k + local_a_col); + // NOTE that SMEM writes are transposed + // local_a[BM * (local_a_col) + local_a_row + local_row_offset] = + // A[global_a_offset]; + + asm volatile ("flw ft0, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft1, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft2, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft3, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft4, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft5, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft6, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft7, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + + // stride along columns + asm volatile ("fsw ft0, %0(%1)" :: "i"(row_stride_a * 0 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft1, %0(%1)" :: "i"(row_stride_a * 1 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft2, %0(%1)" :: "i"(row_stride_a * 2 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft3, %0(%1)" :: "i"(row_stride_a * 3 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft4, %0(%1)" :: "i"(row_stride_a * 4 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft5, %0(%1)" :: "i"(row_stride_a * 5 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft6, %0(%1)" :: "i"(row_stride_a * 6 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft7, %0(%1)" :: "i"(row_stride_a * 7 * sizeof(float)), "r"(local_a_tmp)); + local_a_tmp += row_stride_a * 8; + } + } + } + + constexpr uint32_t row_stride_b = threads_in_warpgroup / BN; + const uint32_t global_b_col = BN * threadblock_id_x + local_b_col; + const float *global_b = B + dim_n * (k + local_b_row) + global_b_col; + volatile float *local_b_tmp = local_b + BN * local_b_row + local_b_col; + + static_assert( + row_stride_b * 8 <= BK, + "manual loop unrolling condition not met; consider increasing BK"); + static_assert( + (BK % (row_stride_b * 8)) == 0, + "manual loop unrolling condition not met; BK should be power-of-two"); + +#pragma GCC unroll 1 + for (uint32_t load_offset = 0; load_offset < BK; + load_offset += row_stride_b * 8) { + // const uint32_t global_b_offset = + // dim_n * (k + local_b_row + load_offset) + global_b_col; + // local_b[BN * (local_b_row + load_offset) + local_b_col] = + // B[global_b_offset]; + + // *local_b_tmp = *global_b; + + // global_b += dim_n * row_stride_b; + // local_b_tmp += BN * row_stride_b; + + asm volatile ("flw ft0, (%0)" :: "r"(global_b)); + global_b += dim_n * row_stride_b; + asm volatile ("flw ft1, (%0)" :: "r"(global_b)); + global_b += dim_n * row_stride_b; + asm volatile ("flw ft2, (%0)" :: "r"(global_b)); + global_b += dim_n * row_stride_b; + asm volatile ("flw ft3, (%0)" :: "r"(global_b)); + global_b += dim_n * row_stride_b; + asm volatile ("flw ft4, (%0)" :: "r"(global_b)); + global_b += dim_n * row_stride_b; + asm volatile ("flw ft5, (%0)" :: "r"(global_b)); + global_b += dim_n * row_stride_b; + asm volatile ("flw ft6, (%0)" :: "r"(global_b)); + global_b += dim_n * row_stride_b; + asm volatile ("flw ft7, (%0)" :: "r"(global_b)); + global_b += dim_n * row_stride_b; + + asm volatile ("fsw ft0, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft1, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft2, %0(%1)" :: "i"(BN * row_stride_b * 2 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft3, %0(%1)" :: "i"(BN * row_stride_b * 3 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft4, %0(%1)" :: "i"(BN * row_stride_b * 4 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft5, %0(%1)" :: "i"(BN * row_stride_b * 5 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft6, %0(%1)" :: "i"(BN * row_stride_b * 6 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft7, %0(%1)" :: "i"(BN * row_stride_b * 7 * sizeof(float)), "r"(local_b_tmp)); + local_b_tmp += BN * row_stride_b * 8; + } +} + +inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, + const uint32_t tid_in_threadblock, + const uint32_t threads_per_threadblock, + const uint32_t threadblock_dim_x, + const uint32_t threadblock_dim_y, + /*const uint32_t threadblock_id_x, + const uint32_t threadblock_id_y,*/ + // const uint32_t threadblock_id_in_cluster, + float *sharedmem_per_threadblock) { + const float *A = (const float *)arg->addr_a; + const float *B = (const float *)arg->addr_b; + float *C = (float *)arg->addr_c; + + const uint32_t dim_m = arg->dim_m; + const uint32_t dim_n = arg->dim_n; + const uint32_t dim_k = arg->dim_k; + + const uint32_t local_a_row = tid_in_threadblock / BK; + const uint32_t local_a_col = tid_in_threadblock % BK; + const uint32_t local_as_row = tid_in_threadblock / BM; + const uint32_t local_as_col = tid_in_threadblock % BM; + const uint32_t local_b_row = tid_in_threadblock / BN; + const uint32_t local_b_col = tid_in_threadblock % BN; + + const uint32_t threads_per_warpgroup = threads_per_threadblock / (DOUBLE_BUFFER ? 2 : 1); + const uint32_t warpgroup_id = tid_in_threadblock / threads_per_warpgroup; + const uint32_t tid_in_warpgroup = tid_in_threadblock % threads_per_warpgroup; // FIXME + const uint32_t warp_in_warpgroup = tid_in_warpgroup / NUM_LANES; + // FIXME: warp_row / BN should be warp-specialized? + const uint32_t warp_row = warp_in_warpgroup / (BN / WN); + const uint32_t warp_col = warp_in_warpgroup % (BN / WN); + const uint32_t tid_in_warp = tid_in_threadblock % NUM_LANES; + + volatile float *local_a = sharedmem_per_threadblock; + // const size_t local_a_elems = threadblock_dim_x * threadblock_dim_y; + constexpr size_t local_a_elems = (BM * BK); + volatile float *local_b = sharedmem_per_threadblock + local_a_elems; + constexpr size_t local_b_elems = (BK * BN); + + volatile float *local_a_buf = local_b + local_b_elems; + volatile float *local_b_buf = local_a_buf + local_a_elems; + + if (warpgroup_id == 0) { +#pragma GCC unroll 1 + for (uint32_t block_m = 0; (block_m * BM) < dim_m; block_m++) { +#pragma GCC unroll 1 + for (uint32_t block_n = 0; (block_n * BN) < dim_n; block_n++) { + if constexpr (DOUBLE_BUFFER) { + // initiate software pipeline + global_dmem_load(dim_n, dim_k, 0 /*k*/, A, B, local_a, local_b, + tid_in_warpgroup, block_n, block_m); + + threadblock_barrier(0/*threadblock_id_in_cluster*/, threadblock_dim_y); + } + + // NOTE: this *should* be signed integer to trigger arithmetic + // right-shift + int32_t k_index = 0; +#pragma GCC unroll 1 + for (uint32_t k = 0; k < (8 * dim_k) - BK; k += BK) { + volatile float *local_a_produce; + volatile float *local_b_produce; + if constexpr (DOUBLE_BUFFER) { + const uint32_t mask_odd = (k_index & 1) << 31 >> 31; + const uint32_t mask_even = ((k_index & 1) ^ 1) << 31 >> 31; + // local_a_produce = (k_index % 2) ? local_a : local_a_buf; + // local_b_produce = (k_index % 2) ? local_b : local_b_buf; + local_a_produce = reinterpret_cast( + (mask_odd & reinterpret_cast(local_a)) | + (mask_even & reinterpret_cast(local_a_buf))); + local_b_produce = reinterpret_cast( + (mask_odd & reinterpret_cast(local_b)) | + (mask_even & reinterpret_cast(local_b_buf))); + } else { + local_a_produce = local_a; + local_b_produce = local_b; + } + k_index++; + + global_dmem_load(dim_n, dim_k, k + BK /*runahead*/, A, B, + local_a_produce, local_b_produce, tid_in_warpgroup, + block_n, block_m); + + threadblock_barrier(0/*threadblock_id_in_cluster*/, threadblock_dim_y); + } + + // sync with final consumer stage in the k-loop + threadblock_barrier(0/*threadblock_id_in_cluster*/, threadblock_dim_y); + } + } + } else { +#pragma GCC unroll 1 + for (uint32_t block_m = 0; (block_m * BM) < dim_m; block_m++) { +#pragma GCC unroll 1 + for (uint32_t block_n = 0; (block_n * BN) < dim_n; block_n++) { + // clear out C + initialize_C(0); + initialize_C(1); + + // sync with initial producer stage in the k-loop + threadblock_barrier(0/*threadblock_id_in_cluster*/, threadblock_dim_y); + + // NOTE: this *should* be signed integer to trigger arithmetic + // right-shift + int32_t k_index = 0; +#pragma GCC unroll 1 + for (uint32_t k = 0; k < (8 * dim_k); k += BK) { + volatile float *local_a_consume; + volatile float *local_b_consume; + if constexpr (DOUBLE_BUFFER) { + // local_a_consume = (k_index % 2) ? local_a_buf : local_a; + // local_b_consume = (k_index % 2) ? local_b_buf : local_b; + // FIXME: swap multiply with bitshifts + const uint32_t mask_odd = (k_index & 1) << 31 >> 31; + const uint32_t mask_even = ((k_index & 1) ^ 1) << 31 >> 31; + local_a_consume = reinterpret_cast( + (mask_odd & reinterpret_cast(local_a_buf)) | + (mask_even & reinterpret_cast(local_a))); + local_b_consume = reinterpret_cast( + (mask_odd & reinterpret_cast(local_b_buf)) | + (mask_even & reinterpret_cast(local_b))); + } else { + local_a_consume = local_a; + local_b_consume = local_b; + } + k_index++; + + // @perf: this loop spills to stack a lot because of all the flws in +#pragma GCC unroll 1 + for (int i = 0; i < BK_LOOP; i++) { +#pragma GCC unroll 2 + for (uint32_t local_k = 0; local_k < BK; local_k += TCK) { +#pragma GCC unroll 2 + for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { + // SMEM -> RF + vx_wmma_load_b(local_b_consume, local_k, warp_col, wn_iter, + tid_in_warp); +#pragma GCC unroll 2 + for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { + // SMEM -> RF + vx_wmma_load_a(local_a_consume, local_k, warp_row, wm_iter, + tid_in_warp); + // perform mma + vx_wmma(wm_iter); + } + } + } + } + + threadblock_barrier(0/*threadblock_id_in_cluster*/, threadblock_dim_y); + } + +#pragma GCC unroll 1 + for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { +#pragma GCC unroll 1 + for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { + if (warpgroup_id == 1) { + write_results(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter, + dim_n, C, block_n, block_m); + } + } + } + } + } + } +} + +void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { + // @perf: All threads are running these compute whose result is mostly same + // across the threadblock + + const uint32_t threads_per_threadblock = (BM * BN) / (ELEM_PER_THREAD); +#ifdef RADIANCE + const uint32_t threadblocks_per_core = CORES_PER_CLUSTER * vx_num_threads() * + vx_num_warps() / + threads_per_threadblock; +#else + const uint32_t threadblocks_per_core = + vx_num_threads() * vx_num_warps() / threads_per_threadblock; +#endif + const uint32_t threadblock_dim_x = vx_num_threads(); + const uint32_t threadblock_dim_y = vx_num_warps() / threadblocks_per_core; + const int threadblock_id = task_id / threads_per_threadblock; + const int threadblock_id_in_cluster = threadblock_id % threadblocks_per_core; + const int tid_in_threadblock = task_id % threads_per_threadblock; + + const uint32_t dim_m = arg->dim_m; + const uint32_t dim_n = arg->dim_n; + const uint32_t dim_n_in_blocks = dim_n / BN; + const int threadblock_id_x = threadblock_id % dim_n_in_blocks; + const int threadblock_id_y = threadblock_id / dim_n_in_blocks; + + // "static" shared memory allocation. This would determine threadblock + // occupancy of a single cluster + float *sharedmem_per_threadblock = + (float *)DEV_SMEM_START_ADDR + (2 * BM * BK) * threadblock_id_in_cluster; + + const int warp_id = vx_warp_id(); + thread_block_gemm(arg, tid_in_threadblock, threads_per_threadblock, + threadblock_dim_x, threadblock_dim_y, /*threadblock_id_x, + threadblock_id_y,*/ /*threadblock_id_in_cluster, */ + sharedmem_per_threadblock); +} + +int main() { + kernel_arg_t *arg = (kernel_arg_t *)KERNEL_ARG_DEV_MEM_ADDR; + + const uint32_t threads_per_cluster = + CORES_PER_CLUSTER * vx_num_threads() * vx_num_warps(); + // const uint32_t grid_size = arg->dim_m * arg->dim_n / ELEM_PER_THREAD; + const uint32_t grid_size = threads_per_cluster; + +#ifdef RADIANCE + vx_spawn_tasks_cluster(grid_size, (vx_spawn_tasks_cb)kernel_body, arg); +#else + // NOTE: This kernel assumes contiguous thread scheduling for efficient shared + // memory allocation, and therefore does not work with original vx_spawn_tasks + vx_spawn_tasks_contiguous(grid_size, (vx_spawn_tasks_cb)kernel_body, arg); +#endif + return 0; +} diff --git a/tests/regression/sgemm_tcore/main.cpp b/tests/regression/sgemm_tcore/main.cpp new file mode 100644 index 00000000..84283992 --- /dev/null +++ b/tests/regression/sgemm_tcore/main.cpp @@ -0,0 +1,282 @@ +#include +#include +#include +#include +#include +#include +#include "common.h" + +#define RT_CHECK(_expr) \ + do { \ + int _ret = _expr; \ + if (0 == _ret) \ + break; \ + printf("Error: '%s' returned %d!\n", #_expr, (int)_ret); \ + cleanup(); \ + exit(-1); \ + } while (false) + +/////////////////////////////////////////////////////////////////////////////// + +const char* kernel_file = "kernel.bin"; +uint32_t count = 0; + +std::vector src_a_data; +std::vector src_b_data; +std::vector ref_data; + +vx_device_h device = nullptr; +std::vector staging_buf; +kernel_arg_t kernel_arg = {}; + +static void show_usage() { + std::cout << "Vortex Test." << std::endl; + std::cout << "Usage: [-k: kernel] [-n words] [-h: help]" << std::endl; +} + +static void parse_args(int argc, char **argv) { + int c; + while ((c = getopt(argc, argv, "n:k:h?")) != -1) { + switch (c) { + case 'n': + count = atoi(optarg); + break; + case 'k': + kernel_file = optarg; + break; + case 'h': + case '?': { + show_usage(); + exit(0); + } break; + default: + show_usage(); + exit(-1); + } + } +} + +void cleanup() { + if (device) { + // vx_mem_free(device, kernel_arg.addr_a); + // vx_mem_free(device, kernel_arg.addr_b); + // vx_mem_free(device, kernel_arg.addr_c); + vx_dev_close(device); + } +} + +void generate_source_matrix(uint32_t dim_m, uint32_t dim_n, uint32_t dim_k) { + src_a_data.resize(dim_m * dim_k); + src_b_data.resize(dim_k * dim_n); + + for (uint32_t i = 0; i < src_a_data.size(); ++i) { + src_a_data[i] = static_cast(i); + std::cout << "A: " << i << ": value=" << src_a_data[i] << std::endl; + } + for (uint32_t i = 0; i < src_b_data.size(); ++i) { + src_b_data[i] = static_cast(i); + std::cout << "B: " << i << ": value=" << src_b_data[i] << std::endl; + } +} + +void generate_reference_matmul(uint32_t dim_m, uint32_t dim_n, uint32_t dim_k) { + ref_data.resize(dim_m * dim_n); + + for (uint32_t i = 0; i < dim_m; ++i) { + for (uint32_t j = 0; j < dim_n; ++j) { + float ref = 0.0f; + for (uint32_t k = 0; k < dim_k; ++k) { + ref += src_a_data[dim_k * i + k] * src_b_data[dim_n * k + j]; + } + ref_data.at(dim_n * i + j) = ref; + } + } +} + +int run_test(const kernel_arg_t& kernel_arg, + uint32_t buf_size, + uint32_t dim_m, uint32_t dim_n) { + // start device + std::cout << "start device" << std::endl; + RT_CHECK(vx_start(device)); + + // wait for completion + std::cout << "wait for completion" << std::endl; + RT_CHECK(vx_ready_wait(device, VX_MAX_TIMEOUT)); + + // download destination buffer + std::cout << "download destination buffer" << std::endl; + RT_CHECK(vx_copy_from_dev(device, staging_buf.data(), kernel_arg.addr_c, buf_size)); + + // verify result + std::cout << "verify result" << std::endl; + { + int errors = 0; + auto buf_ptr = (float*)staging_buf.data(); + for (uint32_t i = 0; i < dim_m * dim_n; ++i) { + float ref = ref_data.at(i); + float cur = buf_ptr[i]; + if (std::abs((cur - ref) / ref) > 1e-6) { + std::cout << "error at result #" << std::dec << i + << std::hex << ": actual=" << cur << ", expected=" << ref << std::endl; + ++errors; + } + } + if (errors != 0) { + std::cout << "Found " << std::dec << errors << " errors!" << std::endl; + std::cout << "FAILED!" << std::endl; + return 1; + } + } + + return 0; +} + +int main(int argc, char *argv[]) { + // parse command arguments + parse_args(argc, argv); + + if (count == 0) { + count = 1; + } + + std::srand(50); + + // open device connection + std::cout << "open device connection" << std::endl; + RT_CHECK(vx_dev_open(&device)); + + // FIXME: hardcoded + uint32_t dim_m = 128; + uint32_t dim_n = 128; + uint32_t dim_k = 128; + + generate_source_matrix(dim_m, dim_n, dim_k); + generate_reference_matmul(dim_m, dim_n, dim_k); + + std::cout << "write reference output" << std::endl; + std::ofstream ref_file("reference.c.bin", std::ios::binary | std::ios::out); + if (!ref_file) { + std::cerr << "error: failed to open reference.c.bin for writing\n"; + exit(EXIT_FAILURE); + } + ref_file.write(reinterpret_cast(ref_data.data()), ref_data.size() * sizeof(ref_data[0])); + ref_file.close(); + + uint32_t src_a_buf_size = src_a_data.size() * sizeof(src_a_data[0]); + uint32_t src_b_buf_size = src_b_data.size() * sizeof(src_b_data[0]); + uint32_t dst_buf_size = ref_data.size() * sizeof(src_a_data[0]); + + std::cout << "buffer size: " << dst_buf_size << " bytes" << std::endl; + + // upload program + std::cout << "upload program" << std::endl; + RT_CHECK(vx_upload_kernel_file(device, kernel_file)); + + // allocate device memory + std::cout << "allocate device memory" << std::endl; + // RT_CHECK(vx_mem_alloc(device, src_a_buf_size, VX_MEM_TYPE_GLOBAL, &kernel_arg.addr_a)); + // RT_CHECK(vx_mem_alloc(device, src_b_buf_size, VX_MEM_TYPE_GLOBAL, &kernel_arg.addr_b)); + // RT_CHECK(vx_mem_alloc(device, dst_buf_size, VX_MEM_TYPE_GLOBAL, &kernel_arg.addr_c)); + kernel_arg.addr_a = 0xa0000000; + kernel_arg.addr_b = 0xa1000000; + kernel_arg.addr_c = 0xc0000000; + + kernel_arg.dim_m = dim_m; + kernel_arg.dim_n = dim_n; + kernel_arg.dim_k = dim_k; + + std::cout << "dev_addr_a=0x" << std::hex << kernel_arg.addr_a << std::endl; + std::cout << "dev_addr_b=0x" << std::hex << kernel_arg.addr_b << std::endl; + std::cout << "dev_addr_c=0x" << std::hex << kernel_arg.addr_c << std::endl; + + // allocate staging buffer + { + std::cout << "allocate staging buffer" << std::endl; + uint32_t staging_buf_size = std::max( + src_a_buf_size, + std::max( + src_b_buf_size, + std::max(dst_buf_size, sizeof(kernel_arg_t)))); + staging_buf.resize(staging_buf_size); + } + + // upload kernel argument + { + std::cout << "upload kernel argument" << std::endl; + auto buf_ptr = staging_buf.data(); + memcpy(buf_ptr, &kernel_arg, sizeof(kernel_arg_t)); + RT_CHECK(vx_copy_to_dev(device, KERNEL_ARG_DEV_MEM_ADDR, staging_buf.data(), sizeof(kernel_arg_t))); + + std::cout << "uploading argument buffer to device, device mem address=" + << std::hex << KERNEL_ARG_DEV_MEM_ADDR << ", size=" << std::dec + << sizeof(kernel_arg_t) << " bytes\n"; + std::ofstream file("args.bin", std::ios::binary | std::ios::out); + if (!file) { + std::cerr << "error: failed to open args.bin for writing\n"; + exit(EXIT_FAILURE); + } + file.write(reinterpret_cast(staging_buf.data()), + sizeof(kernel_arg_t)); + file.close(); + } + + // upload source buffer + { + { + auto buf_ptr = staging_buf.data(); + memcpy(buf_ptr, src_a_data.data(), src_a_data.size() * sizeof(float)); + RT_CHECK(vx_copy_to_dev(device, kernel_arg.addr_a, staging_buf.data(), + src_a_buf_size)); + + std::cout << "uploading source A matrix to device, device mem address=" + << std::hex << kernel_arg.addr_a << ", size=" << std::dec + << src_a_buf_size << " bytes\n"; + std::ofstream file("input.a.bin", std::ios::binary | std::ios::out); + if (!file) { + std::cerr << "error: failed to open args.bin for writing\n"; + exit(EXIT_FAILURE); + } + file.write(reinterpret_cast(buf_ptr), src_a_buf_size); + file.close(); + } + { + auto buf_ptr = staging_buf.data(); + memcpy(buf_ptr, src_b_data.data(), src_b_data.size() * sizeof(float)); + RT_CHECK(vx_copy_to_dev(device, kernel_arg.addr_b, staging_buf.data(), + src_b_buf_size)); + + std::cout << "uploading source B matrix to device, device mem address=" + << std::hex << kernel_arg.addr_b << ", size=" << std::dec + << src_b_buf_size << " bytes\n"; + std::ofstream file("input.b.bin", std::ios::binary | std::ios::out); + if (!file) { + std::cerr << "error: failed to open args.bin for writing\n"; + exit(EXIT_FAILURE); + } + file.write(reinterpret_cast(buf_ptr), src_b_buf_size); + file.close(); + } + } + + // clear destination buffer + { + std::cout << "clear destination buffer" << std::endl; + auto buf_ptr = (int32_t*)staging_buf.data(); + for (uint32_t i = 0; i < ref_data.size(); ++i) { + buf_ptr[i] = 0xdeadbeef; + } + RT_CHECK(vx_copy_to_dev(device, kernel_arg.addr_c, staging_buf.data(), dst_buf_size)); + } + + // run tests + std::cout << "run tests" << std::endl; + RT_CHECK(run_test(kernel_arg, dst_buf_size, kernel_arg.dim_m, kernel_arg.dim_n)); + std::cout << "PASSED!" << std::endl; + + // cleanup + std::cout << "cleanup" << std::endl; + cleanup(); + + return 0; +} diff --git a/tests/regression/sgemm_wg/common.h b/tests/regression/sgemm_wg/common.h index 74941562..5c84f3b7 100644 --- a/tests/regression/sgemm_wg/common.h +++ b/tests/regression/sgemm_wg/common.h @@ -3,7 +3,7 @@ #include -#define KERNEL_ARG_DEV_MEM_ADDR 0x7fff0000 +#define KERNEL_ARG_DEV_MEM_ADDR 0x9fff0000 #define DEV_SMEM_START_ADDR 0xff000000 typedef struct {