tensor: Add dispatch unit to narrow to BLOCK_SIZE=1
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
`ifdef EXT_T_ENABLE
|
||||
`include "VX_fpu_define.vh"
|
||||
|
||||
module VX_tensor_core #(
|
||||
module VX_tensor_core import VX_gpu_pkg::*; #(
|
||||
|
||||
) (
|
||||
input clk,
|
||||
@@ -10,15 +10,54 @@ module VX_tensor_core #(
|
||||
VX_dispatch_if.slave dispatch_if [`ISSUE_WIDTH],
|
||||
VX_commit_if.master commit_if [`ISSUE_WIDTH]
|
||||
);
|
||||
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
|
||||
localparam BLOCK_SIZE = 1;
|
||||
localparam NUM_LANES = `NUM_THREADS;
|
||||
// localparam PARTIAL_BW = (BLOCK_SIZE != `ISSUE_WIDTH) || (NUM_LANES != `NUM_THREADS);
|
||||
localparam PARTIAL_BW = 1;
|
||||
|
||||
VX_execute_if #(
|
||||
.NUM_LANES (NUM_LANES)
|
||||
) execute_if[BLOCK_SIZE]();
|
||||
|
||||
`RESET_RELAY (dispatch_reset, reset);
|
||||
|
||||
VX_dispatch_unit #(
|
||||
.BLOCK_SIZE (BLOCK_SIZE),
|
||||
.NUM_LANES (NUM_LANES),
|
||||
.OUT_REG (PARTIAL_BW ? 1 : 0)
|
||||
) dispatch_unit (
|
||||
.clk (clk),
|
||||
.reset (dispatch_reset),
|
||||
.dispatch_if(dispatch_if),
|
||||
.execute_if (execute_if)
|
||||
);
|
||||
|
||||
VX_commit_if #(
|
||||
.NUM_LANES (NUM_LANES)
|
||||
) commit_block_if[BLOCK_SIZE]();
|
||||
|
||||
`RESET_RELAY (commit_reset, reset);
|
||||
|
||||
VX_gather_unit #(
|
||||
.BLOCK_SIZE (BLOCK_SIZE),
|
||||
.NUM_LANES (NUM_LANES),
|
||||
.OUT_REG (PARTIAL_BW ? 3 : 0) // FIXME: why 3?
|
||||
) gather_unit (
|
||||
.clk (clk),
|
||||
.reset (commit_reset),
|
||||
.commit_in_if (commit_block_if),
|
||||
.commit_out_if (commit_if)
|
||||
);
|
||||
|
||||
for (genvar block_idx = 0; block_idx < BLOCK_SIZE; ++block_idx) begin
|
||||
VX_tensor_core_warp #(
|
||||
.ISW(i)
|
||||
.ISW(1) // FIXME: not block_idx
|
||||
) tensor_core (
|
||||
.clk(clk),
|
||||
.reset(reset),
|
||||
|
||||
.dispatch_if(dispatch_if[i]),
|
||||
.commit_if(commit_if[i])
|
||||
.execute_if(execute_if[block_idx]),
|
||||
.commit_if(commit_block_if[block_idx])
|
||||
);
|
||||
end
|
||||
|
||||
@@ -30,7 +69,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
||||
input clk,
|
||||
input reset,
|
||||
|
||||
VX_dispatch_if.slave dispatch_if,
|
||||
VX_execute_if.slave execute_if,
|
||||
VX_commit_if.master commit_if
|
||||
);
|
||||
localparam NUM_OCTETS = (`NUM_THREADS / 8);
|
||||
@@ -39,14 +78,15 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
||||
// FIXME: not sure this is the right logic. just filling in what works
|
||||
localparam LANE_OFFSET_THREADGROUP = (4 * NUM_OCTETS);
|
||||
|
||||
wire [1:0] step = 2'(dispatch_if.data.op_type);
|
||||
wire [1:0] step = 2'(execute_if.data.op_type);
|
||||
logic [NUM_OCTETS-1:0] octet_results_valid;
|
||||
logic [NUM_OCTETS-1:0] octet_results_ready;
|
||||
logic [NUM_OCTETS-1:0] octet_operands_ready;
|
||||
// FIXME: should be NUM_LANES?
|
||||
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0;
|
||||
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1;
|
||||
|
||||
assign dispatch_if.ready = &octet_operands_ready;
|
||||
assign execute_if.ready = &octet_operands_ready;
|
||||
|
||||
`ifdef EXT_T_ENABLE
|
||||
for (genvar i = 0; i < NUM_OCTETS; ++i) begin
|
||||
@@ -55,13 +95,13 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
||||
`endif
|
||||
// lane-to-octet mapping; see figure 13 of the paper
|
||||
wire [7:0][31:0] octet_A = {
|
||||
dispatch_if.data.rs1_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], dispatch_if.data.rs1_data[4*i +: 4]
|
||||
execute_if.data.rs1_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], execute_if.data.rs1_data[4*i +: 4]
|
||||
};
|
||||
wire [7:0][31:0] octet_B = {
|
||||
dispatch_if.data.rs2_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], dispatch_if.data.rs2_data[4*i +: 4]
|
||||
execute_if.data.rs2_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], execute_if.data.rs2_data[4*i +: 4]
|
||||
};
|
||||
wire [7:0][31:0] octet_C = {
|
||||
dispatch_if.data.rs3_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], dispatch_if.data.rs3_data[4*i +: 4]
|
||||
execute_if.data.rs3_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], execute_if.data.rs3_data[4*i +: 4]
|
||||
};
|
||||
|
||||
logic [3:0][3:0][31:0] octet_D;
|
||||
@@ -77,7 +117,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
||||
.A_in(octet_A),
|
||||
.B_in(octet_B),
|
||||
.C_in(octet_C),
|
||||
.operands_valid(dispatch_if.valid),
|
||||
.operands_valid(execute_if.valid),
|
||||
.operands_ready(octet_operands_ready[i]),
|
||||
|
||||
.step(step),
|
||||
@@ -126,18 +166,18 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
||||
|
||||
localparam DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS;
|
||||
|
||||
wire dispatch_if_fire = dispatch_if.valid && dispatch_if.ready;
|
||||
wire execute_if_fire = execute_if.valid && execute_if.ready;
|
||||
wire commit_if_fire = commit_if.valid && commit_if.ready;
|
||||
wire [DATAW-1:0] dispatch_if_data_enq = {
|
||||
dispatch_if.data.uuid,
|
||||
wis_to_wid(dispatch_if.data.wis, ISW),
|
||||
dispatch_if.data.tmask,
|
||||
dispatch_if.data.PC,
|
||||
dispatch_if.data.wb,
|
||||
dispatch_if.data.rd
|
||||
wire [DATAW-1:0] execute_if_data_enq = {
|
||||
execute_if.data.uuid,
|
||||
execute_if.data.wid,
|
||||
execute_if.data.tmask,
|
||||
execute_if.data.PC,
|
||||
execute_if.data.wb,
|
||||
execute_if.data.rd
|
||||
};
|
||||
|
||||
wire [DATAW-1:0] dispatch_if_data_deq;
|
||||
wire [DATAW-1:0] execute_if_data_deq;
|
||||
|
||||
// this is probably a little oversized
|
||||
VX_fifo_queue #(
|
||||
@@ -146,10 +186,10 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
||||
) pending_uops (
|
||||
.clk(clk),
|
||||
.reset(reset),
|
||||
.push(dispatch_if_fire),
|
||||
.push(execute_if_fire),
|
||||
.pop(commit_if_fire),
|
||||
.data_in(dispatch_if_data_enq),
|
||||
.data_out(dispatch_if_data_deq),
|
||||
.data_in(execute_if_data_enq),
|
||||
.data_out(execute_if_data_deq),
|
||||
`UNUSED_PIN(empty),
|
||||
`UNUSED_PIN(alm_empty),
|
||||
`UNUSED_PIN(full), // should be impossible to overflow
|
||||
@@ -163,7 +203,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
||||
|
||||
localparam COMMIT_DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS + (`NUM_THREADS * `XLEN) + 1 + 1 + 1;
|
||||
wire [COMMIT_DATAW-1:0] commit_if_data = {
|
||||
dispatch_if_data_deq, /* uuid ~ rd */
|
||||
execute_if_data_deq, /* uuid ~ rd */
|
||||
subcommit == 1'b0 ? wb_data_0 : wb_data_1, /* data */
|
||||
1'b0, /* pid */
|
||||
1'b1, /* sop */
|
||||
@@ -227,6 +267,10 @@ module VX_tensor_octet #(
|
||||
// note that not all lanes participate at every step
|
||||
case (step)
|
||||
2'b00: begin
|
||||
// Two A_in segments correspond to two 2x2 subtiles of A read
|
||||
// by two threadgroups: [0:2,0:2] and [4:6,0:2] in Step 0 of
|
||||
// Figure 10(b). B_in OTOH is shared by two threadgroups.
|
||||
// Note k-dimension is shrunk from 4 to 2.
|
||||
A_half = { A_in[5:4], A_in[1:0] };
|
||||
B_half = B_in[3:0];
|
||||
end
|
||||
|
||||
Reference in New Issue
Block a user