From 5bd25985c6ecabad0eb1277c86474d0969c1d106 Mon Sep 17 00:00:00 2001 From: joshua Date: Sat, 4 May 2024 23:01:47 -0700 Subject: [PATCH] i kinda forgot most of changes --- hw/dpi/float_dpi.cpp | 152 ++++++++++++- hw/dpi/float_dpi.vh | 1 + hw/rtl/core/VX_commit.sv | 13 +- hw/rtl/core/VX_operands.sv | 16 +- hw/rtl/core/VX_tensor_core.sv | 9 +- hw/rtl/core/VX_tensor_ucode.vh | 2 +- hw/rtl/core/VX_uop_sequencer.sv | 8 +- hw/rtl/core/generate_ucode.py | 6 +- hw/rtl/fpu/VX_tensor_dpu.sv | 9 +- kernel/include/vx_spawn.h | 1 + kernel/src/vx_spawn.c | 105 +++++++++ tests/kernel/tensor/main.cpp | 2 +- tests/regression/sgemm_tcore/Makefile | 9 + tests/regression/sgemm_tcore/common.h | 18 ++ tests/regression/sgemm_tcore/kernel.cpp | 285 ++++++++++++++++++++++++ tests/regression/sgemm_tcore/main.cpp | 270 ++++++++++++++++++++++ 16 files changed, 882 insertions(+), 24 deletions(-) create mode 100644 tests/regression/sgemm_tcore/Makefile create mode 100644 tests/regression/sgemm_tcore/common.h create mode 100644 tests/regression/sgemm_tcore/kernel.cpp create mode 100644 tests/regression/sgemm_tcore/main.cpp diff --git a/hw/dpi/float_dpi.cpp b/hw/dpi/float_dpi.cpp index d5209bed..a9d9eaef 100644 --- a/hw/dpi/float_dpi.cpp +++ b/hw/dpi/float_dpi.cpp @@ -56,6 +56,7 @@ extern "C" { void dpi_fmax(bool enable, int dst_fmt, int64_t a, int64_t b, int64_t* result, svBitVecVal* fflags); void dpi_hmma(bool enable, const svBitVecVal* A_tile, const svBitVecVal* B_tile, const svBitVecVal* C_tile, svBitVecVal* D_tile); + void dpi_print_results(int wid, int octet, const svBitVecVal* A_tile, const svBitVecVal* B_tile, const svBitVecVal* C_tile, const svBitVecVal* D_tile); } inline uint64_t nan_box(uint32_t value) { @@ -413,4 +414,153 @@ void dpi_hmma(bool enable, const svBitVecVal* A_tile, const svBitVecVal* B_tile, } write_float_array(D_tile, &c_D_tile[0][0], M, M); -} \ No newline at end of file +} + +// 1 copy per warp +float A_tile_full[4][16][8]; +float B_tile_full[4][8][16]; +float C_tile_full[4][16][16]; +float D_tile_full[4][16][16]; +int steps[4]; + +void print_array(float* array, int rows, int cols) { + for (int i = 0; i < rows; i += 1) { + for (int j = 0; j < cols; j += 1) { + std::cout << array[i*cols+j] << " "; + } + std::cout << "\n"; + } + std::cout << std::endl; +} + +void dpi_print_results(int wid, int octet, const svBitVecVal* A_tile, const svBitVecVal* B_tile, const svBitVecVal* C_tile, const svBitVecVal* D_tile) { + // std::cout << "A: " << std::endl; + fill_float_array(A_tile, &c_A_tile[0][0], M, K); + // std::cout << "B: " << std::endl; + fill_float_array(B_tile, &c_B_tile[0][0], K, M); + // std::cout << "C: " << std::endl; + fill_float_array(C_tile, &c_C_tile[0][0], M, M); + // for some reason this still holds onto old value? very strange + // std::cout << "D: " << std::endl; + fill_float_array(D_tile, &c_D_tile[0][0], M, M); + + int octet_row_offset; + int octet_col_offset; + switch(octet) { + case 0: + octet_row_offset = 0; + octet_col_offset = 0; + break; + case 1: + octet_row_offset = 8; + octet_col_offset = 0; + break; + case 2: + octet_row_offset = 0; + octet_col_offset = 8; + break; + case 3: + octet_row_offset = 8; + octet_col_offset = 8; + break; + } + + int step_row_offset; + int step_col_offset; + int step = (steps[wid] % 16) / 4; + int set = (steps[wid] / 16); + switch(step) { + case 0: + step_row_offset = 0; + step_col_offset = 0; + break; + case 1: + step_row_offset = 2; + step_col_offset = 0; + break; + case 2: + step_row_offset = 0; + step_col_offset = 4; + break; + case 3: + step_row_offset = 2; + step_col_offset = 4; + break; + } + + if (steps[0] >= 48) { + // std::cout << "octet " << octet << " step " << steps[0] << "\n"; + // print_array(&c_D_tile[0][0], 4, 4); + } + + D_tile_full[wid][octet_row_offset+step_row_offset+0][octet_col_offset+step_col_offset+0] = c_D_tile[0][0]; + D_tile_full[wid][octet_row_offset+step_row_offset+0][octet_col_offset+step_col_offset+1] = c_D_tile[0][1]; + D_tile_full[wid][octet_row_offset+step_row_offset+0][octet_col_offset+step_col_offset+2] = c_D_tile[0][2]; + D_tile_full[wid][octet_row_offset+step_row_offset+0][octet_col_offset+step_col_offset+3] = c_D_tile[0][3]; + D_tile_full[wid][octet_row_offset+step_row_offset+1][octet_col_offset+step_col_offset+0] = c_D_tile[1][0]; + D_tile_full[wid][octet_row_offset+step_row_offset+1][octet_col_offset+step_col_offset+1] = c_D_tile[1][1]; + D_tile_full[wid][octet_row_offset+step_row_offset+1][octet_col_offset+step_col_offset+2] = c_D_tile[1][2]; + D_tile_full[wid][octet_row_offset+step_row_offset+1][octet_col_offset+step_col_offset+3] = c_D_tile[1][3]; + D_tile_full[wid][octet_row_offset+step_row_offset+4][octet_col_offset+step_col_offset+0] = c_D_tile[2][0]; + D_tile_full[wid][octet_row_offset+step_row_offset+4][octet_col_offset+step_col_offset+1] = c_D_tile[2][1]; + D_tile_full[wid][octet_row_offset+step_row_offset+4][octet_col_offset+step_col_offset+2] = c_D_tile[2][2]; + D_tile_full[wid][octet_row_offset+step_row_offset+4][octet_col_offset+step_col_offset+3] = c_D_tile[2][3]; + D_tile_full[wid][octet_row_offset+step_row_offset+5][octet_col_offset+step_col_offset+0] = c_D_tile[3][0]; + D_tile_full[wid][octet_row_offset+step_row_offset+5][octet_col_offset+step_col_offset+1] = c_D_tile[3][1]; + D_tile_full[wid][octet_row_offset+step_row_offset+5][octet_col_offset+step_col_offset+2] = c_D_tile[3][2]; + D_tile_full[wid][octet_row_offset+step_row_offset+5][octet_col_offset+step_col_offset+3] = c_D_tile[3][3]; + + if (octet == 0 || octet == 1) { + octet_row_offset = octet * 8; + if (step == 0) { + step_row_offset = 0; + } + if (step == 1) { + step_row_offset = 2; + } + if (step == 0 || step == 1) { + A_tile_full[wid][octet_row_offset+step_row_offset+0][set*2+0] = c_A_tile[0][0]; + A_tile_full[wid][octet_row_offset+step_row_offset+0][set*2+1] = c_A_tile[0][1]; + A_tile_full[wid][octet_row_offset+step_row_offset+1][set*2+0] = c_A_tile[1][0]; + A_tile_full[wid][octet_row_offset+step_row_offset+1][set*2+1] = c_A_tile[1][1]; + A_tile_full[wid][octet_row_offset+step_row_offset+4][set*2+0] = c_A_tile[2][0]; + A_tile_full[wid][octet_row_offset+step_row_offset+4][set*2+1] = c_A_tile[2][1]; + A_tile_full[wid][octet_row_offset+step_row_offset+5][set*2+0] = c_A_tile[3][0]; + A_tile_full[wid][octet_row_offset+step_row_offset+5][set*2+1] = c_A_tile[3][1]; + } + } + + if (octet == 0 || octet == 2) { + octet_col_offset = octet * 4; + if (step == 0) { + step_col_offset = 0; + } + else if (step == 2) { + step_col_offset = 4; + } + if (step == 0 || step == 2) { + B_tile_full[wid][set*2+0][octet_col_offset+step_col_offset+0] = c_B_tile[0][0]; + B_tile_full[wid][set*2+0][octet_col_offset+step_col_offset+1] = c_B_tile[0][1]; + B_tile_full[wid][set*2+0][octet_col_offset+step_col_offset+2] = c_B_tile[0][2]; + B_tile_full[wid][set*2+0][octet_col_offset+step_col_offset+3] = c_B_tile[0][3]; + B_tile_full[wid][set*2+1][octet_col_offset+step_col_offset+0] = c_B_tile[1][0]; + B_tile_full[wid][set*2+1][octet_col_offset+step_col_offset+1] = c_B_tile[1][1]; + B_tile_full[wid][set*2+1][octet_col_offset+step_col_offset+2] = c_B_tile[1][2]; + B_tile_full[wid][set*2+1][octet_col_offset+step_col_offset+3] = c_B_tile[1][3]; + } + } + + steps[wid] += 1; + if (steps[wid] % 64 == 0) { + steps[wid] = 0; + std::cout << "warp " << wid << " finished wmma\n"; + std::cout << "A tile" << "\n"; + print_array(&A_tile_full[wid][0][0], 16, 8); + std::cout << "B tile" << "\n"; + print_array(&B_tile_full[wid][0][0], 8, 16); + // std::cout << "C tile" << "\n"; + // print_array(&C_tile_full[wid][0][0], 16, 16); + std::cout << "D tile" << "\n"; + print_array(&D_tile_full[wid][0][0], 16, 16); + } +} diff --git a/hw/dpi/float_dpi.vh b/hw/dpi/float_dpi.vh index c8e7c9cb..08abf25b 100644 --- a/hw/dpi/float_dpi.vh +++ b/hw/dpi/float_dpi.vh @@ -45,5 +45,6 @@ import "DPI-C" function void dpi_fmin(input logic enable, input int dst_fmt, inp import "DPI-C" function void dpi_fmax(input logic enable, input int dst_fmt, input longint a, input longint b, output longint result, output bit[4:0] fflags); import "DPI-C" function void dpi_hmma(input logic enable, input bit[3:0][1:0][31:0] A_tile, input bit[1:0][3:0][31:0] B_tile, input bit[3:0][3:0][31:0] C_tile, output bit[3:0][3:0][31:0] D_tile); +import "DPI-C" function void dpi_print_results(input int wid, input int octet, input bit[3:0][1:0][31:0] A_tile, input bit[1:0][3:0][31:0] B_tile, input bit[3:0][3:0][31:0] C_tile, input bit[3:0][3:0][31:0] D_tile); `endif diff --git a/hw/rtl/core/VX_commit.sv b/hw/rtl/core/VX_commit.sv index f0d94925..5fda49f9 100644 --- a/hw/rtl/core/VX_commit.sv +++ b/hw/rtl/core/VX_commit.sv @@ -53,6 +53,7 @@ module VX_commit import VX_gpu_pkg::*; #( wire [`ISSUE_WIDTH-1:0][`NUM_THREADS-1:0] commit_tmask; wire [`ISSUE_WIDTH-1:0] commit_eop; wire [`ISSUE_WIDTH-1:0][`EX_BITS-1:0] commit_sel; + `UNUSED_VAR (commit_sel) for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin @@ -175,14 +176,17 @@ module VX_commit import VX_gpu_pkg::*; #( // relies on 1 cycle delay of arbiter and continuous issuing of tensor instructions, // so probably want to change this at some point // (i.e. pass a "don't count this towards pending instructions" signal down the pipeline) - logic [`ISSUE_WIDTH-1:0][4:0] hmma_ctr, hmma_ctr_n; + // logic [`ISSUE_WIDTH-1:0][4:0] hmma_ctr, hmma_ctr_n; wire [`ISSUE_WIDTH-1:0] final_hmma; `ifdef EXT_T_ENABLE for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin - assign hmma_ctr_n[i] = (tensor_commit_if[i].valid && tensor_commit_if[i].ready) ? hmma_ctr[i] + 5'b1 : hmma_ctr[i]; - assign final_hmma[i] = (commit_sel[i] != `EX_BITS'(2) || hmma_ctr == '0); + // assign hmma_ctr_n[i] = (tensor_commit_if[i].valid && tensor_commit_if[i].ready) ? hmma_ctr[i] + 5'b1 : hmma_ctr[i]; + // assign final_hmma[i] = (commit_sel[i] != `EX_BITS'(2) || hmma_ctr == '0); + // i suppose this is now a feature and not a bug + // if PC is 0, this means it is not final step of a wmma, shouldn't be committed + assign final_hmma[i] = (commit_if[i].data.PC != 32'b0); end - + /* always @(posedge clk) begin if (reset) begin hmma_ctr <= '0; @@ -191,6 +195,7 @@ module VX_commit import VX_gpu_pkg::*; #( hmma_ctr <= hmma_ctr_n; end end + */ `else assign final_hmma = '1; `endif diff --git a/hw/rtl/core/VX_operands.sv b/hw/rtl/core/VX_operands.sv index 3fe1d7f4..f4ff1edc 100644 --- a/hw/rtl/core/VX_operands.sv +++ b/hw/rtl/core/VX_operands.sv @@ -300,14 +300,14 @@ module VX_operands import VX_gpu_pkg::*; #( cycle <= cycle_n; end - if (cycle == 32'd25000) begin - for (integer k = 0; k < `NUM_REGS * ISSUE_RATIO; ++k) begin - integer warp = i * ISSUE_RATIO + (k / `NUM_REGS); - integer thread = j; - integer register = k % `NUM_REGS; - $display("warp %0d, thread %0d, register %0d: %0x", warp, thread, register, gpr_ram.ram[k]); - end - end + // if (cycle == 32'd25000) begin + // for (integer k = 0; k < `NUM_REGS * ISSUE_RATIO; ++k) begin + // integer warp = i * ISSUE_RATIO + (k / `NUM_REGS); + // integer thread = j; + // integer register = k % `NUM_REGS; + // $display("warp %0d, thread %0d, register %0d: %0x", warp, thread, register, gpr_ram.ram[k]); + // end + // end end end end diff --git a/hw/rtl/core/VX_tensor_core.sv b/hw/rtl/core/VX_tensor_core.sv index 1055351d..de27f757 100644 --- a/hw/rtl/core/VX_tensor_core.sv +++ b/hw/rtl/core/VX_tensor_core.sv @@ -55,7 +55,8 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( logic result_valid; logic result_ready; VX_tensor_octet #( - + .ISW(ISW), + .OCTET(i) ) octet ( .clk(clk), .reset(reset), @@ -177,7 +178,8 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( endmodule module VX_tensor_octet #( - + parameter ISW, + parameter OCTET ) ( input clk, input reset, @@ -282,7 +284,8 @@ module VX_tensor_octet #( wire do_hmma = (substep == 1'b1 && operands_valid && operands_ready); VX_tensor_dpu #( - + .ISW(ISW), + .OCTET(OCTET) ) dpu ( .clk(clk), .reset(reset), diff --git a/hw/rtl/core/VX_tensor_ucode.vh b/hw/rtl/core/VX_tensor_ucode.vh index 7603b4a1..8c3243de 100644 --- a/hw/rtl/core/VX_tensor_ucode.vh +++ b/hw/rtl/core/VX_tensor_ucode.vh @@ -92,5 +92,5 @@ HMMA_SET3_STEP3_0: begin uop = {NEXT, HMMA_SET3_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(6), `FREG(14), `FREG(22)}; end HMMA_SET3_STEP3_1: begin - uop = {FINISH, HMMA_SET0_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(23), `FREG(7), `FREG(15), `FREG(23)}; + uop = {FINISH, HMMA_SET0_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b1, 32'b1, `FREG(23), `FREG(7), `FREG(15), `FREG(23)}; end diff --git a/hw/rtl/core/VX_uop_sequencer.sv b/hw/rtl/core/VX_uop_sequencer.sv index 3d2bb233..8fe2f3d9 100644 --- a/hw/rtl/core/VX_uop_sequencer.sv +++ b/hw/rtl/core/VX_uop_sequencer.sv @@ -125,16 +125,16 @@ module VX_uop_sequencer import VX_gpu_pkg::*; ( always @(posedge clk) begin if (uop_start) begin - $display("UOP start @ %t", $time); - $display("use_uop=%0d, use_uop_1d=%0d, uop_start=%0d, ibuffer_if.valid=%0d, ibuffer_if.ready=%0d", use_uop, use_uop_1d, uop_start, ibuffer_if.valid, ibuffer_if.ready); + // $display("UOP start @ %t", $time); + // $display("use_uop=%0d, use_uop_1d=%0d, uop_start=%0d, ibuffer_if.valid=%0d, ibuffer_if.ready=%0d", use_uop, use_uop_1d, uop_start, ibuffer_if.valid, ibuffer_if.ready); end if (uop_fire) begin - $display("UOP fire @ %t", $time); + // $display("UOP fire @ %t", $time); end if (uop_finish) begin - $display("UOP finish @ %t", $time); + // $display("UOP finish @ %t", $time); end if (reset) begin diff --git a/hw/rtl/core/generate_ucode.py b/hw/rtl/core/generate_ucode.py index 4cda111f..671b6c7b 100644 --- a/hw/rtl/core/generate_ucode.py +++ b/hw/rtl/core/generate_ucode.py @@ -55,17 +55,21 @@ with open('VX_tensor_ucode.vh', 'w') as f: finish = (next_sequence_num == 0) name = "HMMA_SET{}_STEP{}_{}" - ucode = "{}, HMMA_SET{}_STEP{}_{}, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'({}), `INST_MOD_BITS'({}), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG({}), `FREG({}), `FREG({}), `FREG({})" + ucode = "{}, HMMA_SET{}_STEP{}_{}, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'({}), `INST_MOD_BITS'({}), 1'b1, 1'b0, 1'b0, 32'b{}, 32'b{}, `FREG({}), `FREG({}), `FREG({}), `FREG({})" name = name.format( set_num, step, substep, ) + + pc_imm = 1 if finish else 0 ucode = ucode.format( "FINISH" if finish else "NEXT", next_set_num, next_step, next_substep, step, substep, + pc_imm, + pc_imm, rs3_rd[(step, substep)], rs1[(set_num, substep)], rs2[(set_num, substep)], diff --git a/hw/rtl/fpu/VX_tensor_dpu.sv b/hw/rtl/fpu/VX_tensor_dpu.sv index 8108790c..3927e1e3 100644 --- a/hw/rtl/fpu/VX_tensor_dpu.sv +++ b/hw/rtl/fpu/VX_tensor_dpu.sv @@ -1,7 +1,8 @@ `include "VX_fpu_define.vh" module VX_tensor_dpu #( - + parameter ISW, + parameter OCTET ) ( input clk, input reset, @@ -21,6 +22,12 @@ module VX_tensor_dpu #( always @(*) begin dpi_hmma(valid_in, A_tile, B_tile, C_tile, result_hmma); end + + always @(posedge clk) begin + if (~reset && valid_in) begin + dpi_print_results(int'(ISW), int'(OCTET), A_tile, B_tile, C_tile, result_hmma); + end + end VX_shift_register #( diff --git a/kernel/include/vx_spawn.h b/kernel/include/vx_spawn.h index 2584b997..d8797945 100644 --- a/kernel/include/vx_spawn.h +++ b/kernel/include/vx_spawn.h @@ -48,6 +48,7 @@ void vx_wspawn_wait(); void vx_spawn_kernel(context_t * ctx, vx_spawn_kernel_cb callback, void * arg); void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback, void * arg); +void vx_spawn_tasks_contiguous(int num_tasks, vx_spawn_tasks_cb callback , void * arg); void vx_serial(vx_serial_cb callback, void * arg); diff --git a/kernel/src/vx_spawn.c b/kernel/src/vx_spawn.c index fd8258e1..b1ef7230 100644 --- a/kernel/src/vx_spawn.c +++ b/kernel/src/vx_spawn.c @@ -83,6 +83,38 @@ static void __attribute__ ((noinline)) spawn_tasks_rem_stub() { (p_wspawn_args->callback)(task_id, p_wspawn_args->arg); } +static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_stub() { + int NT = vx_num_threads(); + int NW = vx_num_warps(); + int cid = vx_core_id(); + int wid = vx_warp_id(); + int tid = vx_thread_id(); + + wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)g_wspawn_args[cid]; + + int waves = p_wspawn_args->NWs + (wid < p_wspawn_args->RWs); + int offset = p_wspawn_args->offset + (NT * wid + tid); + + vx_spawn_tasks_cb callback = p_wspawn_args->callback; + void* arg = p_wspawn_args->arg; + for (int wave_id = 0; wave_id < waves; ++wave_id) { + int task_id = offset + (wave_id * NT * NW); + callback(task_id, arg); + } +} + +static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_cb() { + // activate all threads + vx_tmc(-1); + + // call stub routine + spawn_tasks_contiguous_all_stub(); + + // disable warp + // deadlock here on warps 1, 2, 3 + vx_tmc_zero(); +} + static void __attribute__ ((noinline)) spawn_tasks_all_cb() { // activate all threads vx_tmc(-1); @@ -94,6 +126,79 @@ static void __attribute__ ((noinline)) spawn_tasks_all_cb() { vx_tmc_zero(); } +void vx_spawn_tasks_contiguous(int num_tasks, vx_spawn_tasks_cb callback , void * arg) { + // device specs + int NC = vx_num_cores(); + int NW = vx_num_warps(); + int NT = vx_num_threads(); + + // current core id + int core_id = vx_core_id(); + if (core_id >= NUM_CORES_MAX) + return; + + // calculate necessary active cores + int WT = NW * NT; + int nC = (num_tasks > WT) ? (num_tasks / WT) : 1; + int nc = MIN(nC, NC); + if (core_id >= nc) + return; // terminate extra cores + + // number of tasks per core + int tasks_per_core = num_tasks / nc; + int tasks_per_core_n1 = tasks_per_core; + if (core_id == (nc-1)) { + int rem = num_tasks - (nc * tasks_per_core); + tasks_per_core_n1 += rem; // last core also executes remaining tasks + } + + // number of tasks per warp + int TW = tasks_per_core_n1 / NT; // occupied warps + int rT = tasks_per_core_n1 - TW * NT; // remaining threads + int fW = 1, rW = 0; + if (TW >= NW) { + fW = TW / NW; // full warps iterations + rW = TW - fW * NW; // remaining warps + } + + wspawn_tasks_args_t wspawn_args = { callback, arg, core_id * tasks_per_core, fW, rW }; + g_wspawn_args[core_id] = &wspawn_args; + + if (TW >= 1) { + // execute callback on other warps + int nw = MIN(TW, NW); + vx_wspawn(nw, spawn_tasks_contiguous_all_cb); + + // activate all threads + vx_tmc(-1); + + // call stub routine + spawn_tasks_contiguous_all_stub(); + + // back to single-threaded + vx_tmc_one(); + + // wait for spawn warps to terminate + // deadlock here on warp 0! + vx_wspawn_wait(); + } + + if (rT != 0) { + // adjust offset + wspawn_args.offset += (tasks_per_core_n1 - rT); + + // activate remaining threads + int tmask = (1 << rT) - 1; + vx_tmc(tmask); + + // call stub routine + spawn_tasks_rem_stub(); + + // back to single-threaded + vx_tmc_one(); + } +} + void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) { // device specs int NC = vx_num_cores(); diff --git a/tests/kernel/tensor/main.cpp b/tests/kernel/tensor/main.cpp index 2c45fa9b..5fc222b2 100644 --- a/tests/kernel/tensor/main.cpp +++ b/tests/kernel/tensor/main.cpp @@ -90,7 +90,7 @@ int main() vx_wmma(); store_wmma_result(); vx_tmc(1); - print_wmma_result(); + // print_wmma_result(); return 0; } \ No newline at end of file diff --git a/tests/regression/sgemm_tcore/Makefile b/tests/regression/sgemm_tcore/Makefile new file mode 100644 index 00000000..0c378af0 --- /dev/null +++ b/tests/regression/sgemm_tcore/Makefile @@ -0,0 +1,9 @@ +PROJECT = sgemm_tcore + +SRCS = main.cpp common.h + +VX_SRCS = kernel.cpp + +OPTS ?= -n16 + +include ../common.mk \ No newline at end of file diff --git a/tests/regression/sgemm_tcore/common.h b/tests/regression/sgemm_tcore/common.h new file mode 100644 index 00000000..d94a270f --- /dev/null +++ b/tests/regression/sgemm_tcore/common.h @@ -0,0 +1,18 @@ +#ifndef _COMMON_H_ +#define _COMMON_H_ + +#include + +#define KERNEL_ARG_DEV_MEM_ADDR 0x7fff0000 +#define DEV_SMEM_START_ADDR 0xff000000 + +typedef struct { + uint32_t dim_m; + uint32_t dim_n; + uint32_t dim_k; + uint64_t addr_a; + uint64_t addr_b; + uint64_t addr_c; +} kernel_arg_t; + +#endif \ No newline at end of file diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp new file mode 100644 index 00000000..f4e467f4 --- /dev/null +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -0,0 +1,285 @@ +#define RISCV_CUSTOM3 0x7B + +#include +#include +#include +#include +#include "common.h" + +#define BM 16 +#define BN 16 +#define BK 8 + +inline void vx_wmma() { + asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)); +} + +void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x, int warp_y, int thread_in_warp) { + int tid = thread_in_warp; + int tg = tid / 4; + + // load A + int row = tid % 4; + row += (tg * 8) % 16; + row += (tg / 4) * 4; + + int smem_A_m = 32; + int smem_A_n = 8; + int smem_B_m = 8; + int smem_B_n = 32; + + int A_offset = (row + BM * warp_y) * smem_A_n; + + asm volatile ("flw f0, %0" :: "m"(smem_A[A_offset + 0])); + asm volatile ("flw f1, %0" :: "m"(smem_A[A_offset + 1])); + asm volatile ("flw f2, %0" :: "m"(smem_A[A_offset + 2])); + asm volatile ("flw f3, %0" :: "m"(smem_A[A_offset + 3])); + asm volatile ("flw f4, %0" :: "m"(smem_A[A_offset + 4])); + asm volatile ("flw f5, %0" :: "m"(smem_A[A_offset + 5])); + asm volatile ("flw f6, %0" :: "m"(smem_A[A_offset + 6])); + asm volatile ("flw f7, %0" :: "m"(smem_A[A_offset + 7])); + + // load B + int col = tid % 4; + col += ((tg % 4) / 2) * 8; + col += (tg / 4) * 4; + + asm volatile ("flw f8 , %0" :: "m"(smem_B[(0 * smem_B_n) + warp_x * BN + col])); + asm volatile ("flw f9 , %0" :: "m"(smem_B[(1 * smem_B_n) + warp_x * BN + col])); + asm volatile ("flw f10, %0" :: "m"(smem_B[(2 * smem_B_n) + warp_x * BN + col])); + asm volatile ("flw f11, %0" :: "m"(smem_B[(3 * smem_B_n) + warp_x * BN + col])); + asm volatile ("flw f12, %0" :: "m"(smem_B[(4 * smem_B_n) + warp_x * BN + col])); + asm volatile ("flw f13, %0" :: "m"(smem_B[(5 * smem_B_n) + warp_x * BN + col])); + asm volatile ("flw f14, %0" :: "m"(smem_B[(6 * smem_B_n) + warp_x * BN + col])); + asm volatile ("flw f15, %0" :: "m"(smem_B[(7 * smem_B_n) + warp_x * BN + col])); +} + +inline void initialize_C() { + // initialize C to zeros + asm volatile ("fmv.w.x f16, x0"); + asm volatile ("fmv.w.x f17, x0"); + asm volatile ("fmv.w.x f18, x0"); + asm volatile ("fmv.w.x f19, x0"); + asm volatile ("fmv.w.x f20, x0"); + asm volatile ("fmv.w.x f21, x0"); + asm volatile ("fmv.w.x f22, x0"); + asm volatile ("fmv.w.x f23, x0"); +} + +inline void write_results( + volatile float *local_warp_results, + int thread_in_warp, + int warp_x, + int warp_y, + int dim_m, + int dim_n, + float *C, + int threadblock_id_x, + int threadblock_id_y +) { + int tid = thread_in_warp; + int tg = tid / 4; + + asm volatile ("fsw f16, %0" :: "m"(local_warp_results[tid*8+0])); + asm volatile ("fsw f17, %0" :: "m"(local_warp_results[tid*8+1])); + asm volatile ("fsw f18, %0" :: "m"(local_warp_results[tid*8+2])); + asm volatile ("fsw f19, %0" :: "m"(local_warp_results[tid*8+3])); + asm volatile ("fsw f20, %0" :: "m"(local_warp_results[tid*8+4])); + asm volatile ("fsw f21, %0" :: "m"(local_warp_results[tid*8+5])); + asm volatile ("fsw f22, %0" :: "m"(local_warp_results[tid*8+6])); + asm volatile ("fsw f23, %0" :: "m"(local_warp_results[tid*8+7])); + + /* + col = ((threadgroup % 4) // 2) * 8 + row = (threadgroup * 8) % 16 + row += (threadgroup // 4) * 4 + offsets = [(0, 0), (0, 1), (2, 0), (2, 1), (0, 4), (0, 5), (2, 4), (2, 5)] + offset = offsets[register-16] + row += offset[0] + col += offset[1] + thread_offsets = [(0, 0), (1, 0), (0, 2), (1, 2)] + thread_offset = thread_offsets[thread % 4] + row += thread_offset[0] + col += thread_offset[1] + return (row, col) + */ + + int local_col = ((tg % 4) / 2) * 8; + int local_row = (tg * 8) % 16; + local_row += (tg / 4) * 4; + + // int row_offsets[8] = {0, 0, 2, 2, 0, 0, 2, 2}; + // int col_offsets[8] = {0, 1, 0, 1, 4, 5, 4, 5}; + + // int thread_row_offsets[4] = {0, 1, 0, 1}; + // int thread_col_offsets[4] = {0, 0, 2, 2}; + int thread_row_offset = (tid % 4) % 2; + int thread_col_offset = ((tid % 4) / 2) * 2; + + float *global_offset_C = C + (threadblock_id_y * BM * 2 + warp_y * BM) * dim_n + threadblock_id_x * BN * 2 + warp_x * BM; + for (int i = 0; i < 8; i += 1) { + int row_offset = ((i / 2) % 2) * 2; + int col_offset = (i / 4) * 4 + i % 2; + + int adjusted_local_row = local_row + thread_row_offset + row_offset; + int adjusted_local_col = local_col + thread_col_offset + col_offset; + + float v = local_warp_results[tid*8+i]; + global_offset_C[adjusted_local_row * dim_n + adjusted_local_col] = v; + } +} + +void threadblock_barrier(unsigned int tid_in_threadblock, unsigned int barrier_id, unsigned int count) { + vx_fence(); + vx_barrier(barrier_id, count); +} + +void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, + const uint32_t tid_in_threadblock, + const uint32_t threadblock_dim_x, + const uint32_t threadblock_dim_y, + const uint32_t threadblock_id_x, + const uint32_t threadblock_id_y, + const uint32_t threadblock_id, + float *sharedmem_per_threadblock) { + const float *A = (const float *)arg->addr_a; + const float *B = (const float *)arg->addr_b; + float *C = (float *)arg->addr_c; + + const uint32_t dim_m = arg->dim_m; + const uint32_t dim_n = arg->dim_n; + const uint32_t dim_k = arg->dim_k; + + // FIXME: Output block size is assumed to be square, i.e. BM == BN + // const uint32_t BM = threadblock_dim_y; + // const uint32_t BN = threadblock_dim_y; + // const uint32_t BK = threadblock_dim_x; + // constexpr uint32_t BM = 8; + // constexpr uint32_t BN = 8; + // constexpr uint32_t BK = 2; + + const uint32_t warp_in_threadblock = tid_in_threadblock / 32; + const uint32_t tid_in_warp = tid_in_threadblock % 32; + const uint32_t warp_x = warp_in_threadblock % 2; + const uint32_t warp_y = warp_in_threadblock / 2; + + const uint32_t global_a_row = threadblock_dim_y * threadblock_id_y; + + // 32 * 8 block of A, being loaded by 4 warps + const uint32_t local_a_row = warp_y * BM + warp_x * (BM / 2) + (tid_in_warp / BK); + const uint32_t local_a_col = tid_in_warp % BK; + + // 8 * 32 block of B, being loaded by 4 warps + // do a fat coalesced load + const uint32_t global_b_col = threadblock_dim_x * threadblock_id_x; + const uint32_t local_b_row = warp_in_threadblock; + const uint32_t local_b_col = tid_in_warp; + + + volatile float *local_a = sharedmem_per_threadblock; + const size_t local_a_elems = (threadblock_dim_y * BK); + volatile float *local_b = sharedmem_per_threadblock + local_a_elems; + const size_t local_b_elems = (threadblock_dim_x * BK); + volatile float *local_warp_results = local_b + local_b_elems + (warp_in_threadblock * BM * BN); + + // clear out C + initialize_C(); + + for (uint32_t k = 0; k < dim_k; k += BK) { + // Data move from GMEM to SMEM + // + // Make sure global offset values for A and B are contiguous between + // neighboring threads to ensure GMEM coalescing. (not possible for A here, but for B it's doable) + + // coalesced load from global memory -> unit-stride store into shared memory + uint32_t global_a_offset = + dim_k * (global_a_row + local_a_row) + (k + local_a_col); + uint32_t shared_a_offset = + BK * local_a_row + local_a_col; + + local_a[shared_a_offset] = A[global_a_offset]; + + global_a_offset += dim_k * (BM / 4); + shared_a_offset += BK * (BM / 4); + + local_a[shared_a_offset] = A[global_a_offset]; + + uint32_t global_b_offset = + dim_n * (k + local_b_row) + (global_b_col + local_b_col); + uint32_t shared_b_offset = + (BN * 2) * (local_b_row) + local_b_col; + + local_b[shared_b_offset] = B[global_b_offset]; + + global_b_offset += dim_n * (BK / 2); + shared_b_offset += (BN * 2) * (BK / 2); + + local_b[shared_b_offset] = B[global_b_offset]; + + // want all 4 warps to reach barrier before moving on (just use barrier 0 for now) + threadblock_barrier(tid_in_threadblock, 0, 4); + + // perform wmma + vx_wmma_load(local_a, local_b, warp_x, warp_y, tid_in_warp); + vx_wmma(); + + // same as above + threadblock_barrier(tid_in_threadblock, 0, 4); + } + + write_results( + local_warp_results, + tid_in_warp, + warp_x, + warp_y, + dim_m, + dim_n, + C, + threadblock_id_x, + threadblock_id_y + ); +} + +void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { + // @perf: All threads are running these compute whose result is mostly same + // across the threadblock + const int NT = 32; // vx_num_threads(); + const int NW = 4; // vx_num_warps(); + const uint32_t threads_per_threadblock = NT * NW; + + // matches 4 warp capacity + const uint32_t threadblock_dim_x = 2 * BN; + const uint32_t threadblock_dim_y = 2 * BM; + const int threadblock_id = task_id / threads_per_threadblock; + const int tid_in_threadblock = task_id % threads_per_threadblock; + + const uint32_t dim_m = arg->dim_m; + const uint32_t dim_n = arg->dim_n; + const uint32_t dim_n_in_blocks = dim_n / threadblock_dim_x; + const int threadblock_id_x = threadblock_id % dim_n_in_blocks; + const int threadblock_id_y = threadblock_id / dim_n_in_blocks; + + // "static" shared memory allocation. This would determine threadblock + // occupancy of a single cluster + // only 1 threadblock running at a time, so this is ok + float *sharedmem_per_threadblock = + (float *)DEV_SMEM_START_ADDR; // + (2 * BM * BK) + (2 * BN * BK) * threadblock_id; + + thread_block_gemm(arg, tid_in_threadblock, threadblock_dim_x, + threadblock_dim_y, threadblock_id_x, threadblock_id_y, + threadblock_id, sharedmem_per_threadblock); +} + +int main() { + kernel_arg_t *arg = (kernel_arg_t *)KERNEL_ARG_DEV_MEM_ADDR; + int NT = vx_num_threads(); + + // TODO: add support for edge-case (m, n not divisible by 16) + const uint32_t grid_size = arg->dim_m * arg->dim_n * NT / (BM * BN); + + // for now, simplifying assumption of just 1 core + // vx_spawn_tasks_contiguous first runs warps 1 through NW, then NW+1 through 2*NW, etc. + // we can thus treat 1 through NW as a single threadblock for the purposes of the optimization. + vx_spawn_tasks_contiguous(grid_size, (vx_spawn_tasks_cb)kernel_body, arg); + return 0; +} \ No newline at end of file diff --git a/tests/regression/sgemm_tcore/main.cpp b/tests/regression/sgemm_tcore/main.cpp new file mode 100644 index 00000000..5ae65809 --- /dev/null +++ b/tests/regression/sgemm_tcore/main.cpp @@ -0,0 +1,270 @@ +#include +#include +#include +#include +#include +#include +#include "common.h" + +#define RT_CHECK(_expr) \ + do { \ + int _ret = _expr; \ + if (0 == _ret) \ + break; \ + printf("Error: '%s' returned %d!\n", #_expr, (int)_ret); \ + cleanup(); \ + exit(-1); \ + } while (false) + +/////////////////////////////////////////////////////////////////////////////// + +const char* kernel_file = "kernel.bin"; +uint32_t count = 0; + +std::vector src_a_data; +std::vector src_b_data; +std::vector ref_data; + +vx_device_h device = nullptr; +std::vector staging_buf; +kernel_arg_t kernel_arg = {}; + +static void show_usage() { + std::cout << "Vortex Test." << std::endl; + std::cout << "Usage: [-k: kernel] [-n words] [-h: help]" << std::endl; +} + +static void parse_args(int argc, char **argv) { + int c; + while ((c = getopt(argc, argv, "n:k:h?")) != -1) { + switch (c) { + case 'n': + count = atoi(optarg); + break; + case 'k': + kernel_file = optarg; + break; + case 'h': + case '?': { + show_usage(); + exit(0); + } break; + default: + show_usage(); + exit(-1); + } + } +} + +void cleanup() { + if (device) { + vx_mem_free(device, kernel_arg.addr_a); + vx_mem_free(device, kernel_arg.addr_b); + vx_mem_free(device, kernel_arg.addr_c); + vx_dev_close(device); + } +} + +void generate_source_matrix(uint32_t dim_m, uint32_t dim_n, uint32_t dim_k) { + src_a_data.resize(dim_m * dim_k); + src_b_data.resize(dim_k * dim_n); + + for (uint32_t i = 0; i < src_a_data.size(); ++i) { + src_a_data[i] = static_cast(i); + std::cout << "A: " << i << ": value=" << src_a_data[i] << std::endl; + } + for (uint32_t i = 0; i < src_b_data.size(); ++i) { + src_b_data[i] = static_cast(i); + std::cout << "B: " << i << ": value=" << src_b_data[i] << std::endl; + } +} + +void generate_reference_matmul(uint32_t dim_m, uint32_t dim_n, uint32_t dim_k) { + ref_data.resize(dim_m * dim_n); + + for (uint32_t i = 0; i < dim_m; ++i) { + for (uint32_t j = 0; j < dim_n; ++j) { + float ref = 0.0f; + for (uint32_t k = 0; k < dim_k; ++k) { + ref += src_a_data[dim_k * i + k] * src_b_data[dim_n * k + j]; + } + ref_data.at(dim_n * i + j) = ref; + } + } +} + +int run_test(const kernel_arg_t& kernel_arg, + uint32_t buf_size, + uint32_t dim_m, uint32_t dim_n) { + // start device + std::cout << "start device" << std::endl; + RT_CHECK(vx_start(device)); + + // wait for completion + std::cout << "wait for completion" << std::endl; + RT_CHECK(vx_ready_wait(device, VX_MAX_TIMEOUT)); + + // download destination buffer + std::cout << "download destination buffer" << std::endl; + RT_CHECK(vx_copy_from_dev(device, staging_buf.data(), kernel_arg.addr_c, buf_size)); + + // verify result + std::cout << "verify result" << std::endl; + { + int errors = 0; + auto buf_ptr = (float*)staging_buf.data(); + for (uint32_t i = 0; i < dim_m * dim_n; ++i) { + float ref = ref_data.at(i); + float cur = buf_ptr[i]; + if (std::abs((cur - ref) / ref) > 1e-6) { + std::cout << "error at result #" << std::dec << i + << std::hex << ": actual=" << cur << ", expected=" << ref << std::endl; + ++errors; + } + } + if (errors != 0) { + std::cout << "Found " << std::dec << errors << " errors!" << std::endl; + std::cout << "FAILED!" << std::endl; + return 1; + } + } + + return 0; +} + +int main(int argc, char *argv[]) { + // parse command arguments + parse_args(argc, argv); + + if (count == 0) { + count = 1; + } + + std::srand(50); + + // open device connection + std::cout << "open device connection" << std::endl; + RT_CHECK(vx_dev_open(&device)); + + // FIXME: hardcoded + uint32_t dim_m = 64; + uint32_t dim_n = 64; + uint32_t dim_k = 64; + + generate_source_matrix(dim_m, dim_n, dim_k); + generate_reference_matmul(dim_m, dim_n, dim_k); + + uint32_t src_a_buf_size = src_a_data.size() * sizeof(src_a_data[0]); + uint32_t src_b_buf_size = src_b_data.size() * sizeof(src_b_data[0]); + uint32_t dst_buf_size = ref_data.size() * sizeof(src_a_data[0]); + + std::cout << "buffer size: " << dst_buf_size << " bytes" << std::endl; + + // upload program + std::cout << "upload program" << std::endl; + RT_CHECK(vx_upload_kernel_file(device, kernel_file)); + + // allocate device memory + std::cout << "allocate device memory" << std::endl; + RT_CHECK(vx_mem_alloc(device, src_a_buf_size, VX_MEM_TYPE_GLOBAL, &kernel_arg.addr_a)); + RT_CHECK(vx_mem_alloc(device, src_b_buf_size, VX_MEM_TYPE_GLOBAL, &kernel_arg.addr_b)); + RT_CHECK(vx_mem_alloc(device, dst_buf_size, VX_MEM_TYPE_GLOBAL, &kernel_arg.addr_c)); + + kernel_arg.dim_m = dim_m; + kernel_arg.dim_n = dim_n; + kernel_arg.dim_k = dim_k; + + std::cout << "dev_addr_a=0x" << std::hex << kernel_arg.addr_a << std::endl; + std::cout << "dev_addr_b=0x" << std::hex << kernel_arg.addr_b << std::endl; + std::cout << "dev_addr_c=0x" << std::hex << kernel_arg.addr_c << std::endl; + + // allocate staging buffer + { + std::cout << "allocate staging buffer" << std::endl; + uint32_t staging_buf_size = std::max( + src_a_buf_size, + std::max( + src_b_buf_size, + std::max(dst_buf_size, sizeof(kernel_arg_t)))); + staging_buf.resize(staging_buf_size); + } + + // upload kernel argument + { + std::cout << "upload kernel argument" << std::endl; + auto buf_ptr = staging_buf.data(); + memcpy(buf_ptr, &kernel_arg, sizeof(kernel_arg_t)); + RT_CHECK(vx_copy_to_dev(device, KERNEL_ARG_DEV_MEM_ADDR, staging_buf.data(), sizeof(kernel_arg_t))); + + std::cout << "uploading argument buffer to device, device mem address=" + << std::hex << KERNEL_ARG_DEV_MEM_ADDR << ", size=" << std::dec + << sizeof(kernel_arg_t) << " bytes\n"; + std::ofstream file("args.bin", std::ios::binary | std::ios::out); + if (!file) { + std::cerr << "error: failed to open args.bin for writing\n"; + exit(EXIT_FAILURE); + } + file.write(reinterpret_cast(staging_buf.data()), + sizeof(kernel_arg_t)); + file.close(); + } + + // upload source buffer + { + { + auto buf_ptr = staging_buf.data(); + memcpy(buf_ptr, src_a_data.data(), src_a_data.size() * sizeof(float)); + RT_CHECK(vx_copy_to_dev(device, kernel_arg.addr_a, staging_buf.data(), + src_a_buf_size)); + + std::cout << "uploading source A matrix to device, device mem address=" + << std::hex << kernel_arg.addr_a << ", size=" << std::dec + << src_a_buf_size << " bytes\n"; + std::ofstream file("input.a.bin", std::ios::binary | std::ios::out); + if (!file) { + std::cerr << "error: failed to open args.bin for writing\n"; + exit(EXIT_FAILURE); + } + file.write(reinterpret_cast(buf_ptr), src_a_buf_size); + file.close(); + } + { + auto buf_ptr = staging_buf.data(); + memcpy(buf_ptr, src_b_data.data(), src_b_data.size() * sizeof(float)); + RT_CHECK(vx_copy_to_dev(device, kernel_arg.addr_b, staging_buf.data(), + src_b_buf_size)); + + std::cout << "uploading source B matrix to device, device mem address=" + << std::hex << kernel_arg.addr_b << ", size=" << std::dec + << src_b_buf_size << " bytes\n"; + std::ofstream file("input.b.bin", std::ios::binary | std::ios::out); + if (!file) { + std::cerr << "error: failed to open args.bin for writing\n"; + exit(EXIT_FAILURE); + } + file.write(reinterpret_cast(buf_ptr), src_b_buf_size); + file.close(); + } + } + + // clear destination buffer + { + std::cout << "clear destination buffer" << std::endl; + auto buf_ptr = (int32_t*)staging_buf.data(); + for (uint32_t i = 0; i < ref_data.size(); ++i) { + buf_ptr[i] = 0xdeadbeef; + } + RT_CHECK(vx_copy_to_dev(device, kernel_arg.addr_c, staging_buf.data(), dst_buf_size)); + } + + // run tests + std::cout << "run tests" << std::endl; + RT_CHECK(run_test(kernel_arg, dst_buf_size, kernel_arg.dim_m, kernel_arg.dim_n)); + std::cout << "PASSED!" << std::endl; + + // cleanup + std::cout << "cleanup" << std::endl; + cleanup(); + + return 0; +} \ No newline at end of file