diff --git a/hw/rtl/core/VX_tensor_core.sv b/hw/rtl/core/VX_tensor_core.sv index 71ed8538..14d8175b 100644 --- a/hw/rtl/core/VX_tensor_core.sv +++ b/hw/rtl/core/VX_tensor_core.sv @@ -1,7 +1,7 @@ `ifdef EXT_T_ENABLE `include "VX_fpu_define.vh" -module VX_tensor_core #( +module VX_tensor_core import VX_gpu_pkg::*; #( ) ( input clk, @@ -10,15 +10,54 @@ module VX_tensor_core #( VX_dispatch_if.slave dispatch_if [`ISSUE_WIDTH], VX_commit_if.master commit_if [`ISSUE_WIDTH] ); - for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin + localparam BLOCK_SIZE = 1; + localparam NUM_LANES = `NUM_THREADS; + // localparam PARTIAL_BW = (BLOCK_SIZE != `ISSUE_WIDTH) || (NUM_LANES != `NUM_THREADS); + localparam PARTIAL_BW = 1; + + VX_execute_if #( + .NUM_LANES (NUM_LANES) + ) execute_if[BLOCK_SIZE](); + + `RESET_RELAY (dispatch_reset, reset); + + VX_dispatch_unit #( + .BLOCK_SIZE (BLOCK_SIZE), + .NUM_LANES (NUM_LANES), + .OUT_REG (PARTIAL_BW ? 1 : 0) + ) dispatch_unit ( + .clk (clk), + .reset (dispatch_reset), + .dispatch_if(dispatch_if), + .execute_if (execute_if) + ); + + VX_commit_if #( + .NUM_LANES (NUM_LANES) + ) commit_block_if[BLOCK_SIZE](); + + `RESET_RELAY (commit_reset, reset); + + VX_gather_unit #( + .BLOCK_SIZE (BLOCK_SIZE), + .NUM_LANES (NUM_LANES), + .OUT_REG (PARTIAL_BW ? 3 : 0) // FIXME: why 3? + ) gather_unit ( + .clk (clk), + .reset (commit_reset), + .commit_in_if (commit_block_if), + .commit_out_if (commit_if) + ); + + for (genvar block_idx = 0; block_idx < BLOCK_SIZE; ++block_idx) begin VX_tensor_core_warp #( - .ISW(i) + .ISW(1) // FIXME: not block_idx ) tensor_core ( .clk(clk), .reset(reset), - .dispatch_if(dispatch_if[i]), - .commit_if(commit_if[i]) + .execute_if(execute_if[block_idx]), + .commit_if(commit_block_if[block_idx]) ); end @@ -30,7 +69,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( input clk, input reset, - VX_dispatch_if.slave dispatch_if, + VX_execute_if.slave execute_if, VX_commit_if.master commit_if ); localparam NUM_OCTETS = (`NUM_THREADS / 8); @@ -39,14 +78,15 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( // FIXME: not sure this is the right logic. just filling in what works localparam LANE_OFFSET_THREADGROUP = (4 * NUM_OCTETS); - wire [1:0] step = 2'(dispatch_if.data.op_type); + wire [1:0] step = 2'(execute_if.data.op_type); logic [NUM_OCTETS-1:0] octet_results_valid; logic [NUM_OCTETS-1:0] octet_results_ready; logic [NUM_OCTETS-1:0] octet_operands_ready; + // FIXME: should be NUM_LANES? logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0; logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1; - assign dispatch_if.ready = &octet_operands_ready; + assign execute_if.ready = &octet_operands_ready; `ifdef EXT_T_ENABLE for (genvar i = 0; i < NUM_OCTETS; ++i) begin @@ -55,13 +95,13 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( `endif // lane-to-octet mapping; see figure 13 of the paper wire [7:0][31:0] octet_A = { - dispatch_if.data.rs1_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], dispatch_if.data.rs1_data[4*i +: 4] + execute_if.data.rs1_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], execute_if.data.rs1_data[4*i +: 4] }; wire [7:0][31:0] octet_B = { - dispatch_if.data.rs2_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], dispatch_if.data.rs2_data[4*i +: 4] + execute_if.data.rs2_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], execute_if.data.rs2_data[4*i +: 4] }; wire [7:0][31:0] octet_C = { - dispatch_if.data.rs3_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], dispatch_if.data.rs3_data[4*i +: 4] + execute_if.data.rs3_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], execute_if.data.rs3_data[4*i +: 4] }; logic [3:0][3:0][31:0] octet_D; @@ -77,7 +117,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( .A_in(octet_A), .B_in(octet_B), .C_in(octet_C), - .operands_valid(dispatch_if.valid), + .operands_valid(execute_if.valid), .operands_ready(octet_operands_ready[i]), .step(step), @@ -126,18 +166,18 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( localparam DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS; - wire dispatch_if_fire = dispatch_if.valid && dispatch_if.ready; + wire execute_if_fire = execute_if.valid && execute_if.ready; wire commit_if_fire = commit_if.valid && commit_if.ready; - wire [DATAW-1:0] dispatch_if_data_enq = { - dispatch_if.data.uuid, - wis_to_wid(dispatch_if.data.wis, ISW), - dispatch_if.data.tmask, - dispatch_if.data.PC, - dispatch_if.data.wb, - dispatch_if.data.rd + wire [DATAW-1:0] execute_if_data_enq = { + execute_if.data.uuid, + execute_if.data.wid, + execute_if.data.tmask, + execute_if.data.PC, + execute_if.data.wb, + execute_if.data.rd }; - wire [DATAW-1:0] dispatch_if_data_deq; + wire [DATAW-1:0] execute_if_data_deq; // this is probably a little oversized VX_fifo_queue #( @@ -146,10 +186,10 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( ) pending_uops ( .clk(clk), .reset(reset), - .push(dispatch_if_fire), + .push(execute_if_fire), .pop(commit_if_fire), - .data_in(dispatch_if_data_enq), - .data_out(dispatch_if_data_deq), + .data_in(execute_if_data_enq), + .data_out(execute_if_data_deq), `UNUSED_PIN(empty), `UNUSED_PIN(alm_empty), `UNUSED_PIN(full), // should be impossible to overflow @@ -163,7 +203,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( localparam COMMIT_DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS + (`NUM_THREADS * `XLEN) + 1 + 1 + 1; wire [COMMIT_DATAW-1:0] commit_if_data = { - dispatch_if_data_deq, /* uuid ~ rd */ + execute_if_data_deq, /* uuid ~ rd */ subcommit == 1'b0 ? wb_data_0 : wb_data_1, /* data */ 1'b0, /* pid */ 1'b1, /* sop */ @@ -227,6 +267,10 @@ module VX_tensor_octet #( // note that not all lanes participate at every step case (step) 2'b00: begin + // Two A_in segments correspond to two 2x2 subtiles of A read + // by two threadgroups: [0:2,0:2] and [4:6,0:2] in Step 0 of + // Figure 10(b). B_in OTOH is shared by two threadgroups. + // Note k-dimension is shrunk from 4 to 2. A_half = { A_in[5:4], A_in[1:0] }; B_half = B_in[3:0]; end