tensor: Add dispatch unit to narrow to BLOCK_SIZE=1

This commit is contained in:
Hansung Kim
2024-05-15 15:34:26 -07:00
parent 9f9ec10960
commit 1a1094b2bb

View File

@@ -1,7 +1,7 @@
`ifdef EXT_T_ENABLE
`include "VX_fpu_define.vh"
module VX_tensor_core #(
module VX_tensor_core import VX_gpu_pkg::*; #(
) (
input clk,
@@ -10,15 +10,54 @@ module VX_tensor_core #(
VX_dispatch_if.slave dispatch_if [`ISSUE_WIDTH],
VX_commit_if.master commit_if [`ISSUE_WIDTH]
);
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
localparam BLOCK_SIZE = 1;
localparam NUM_LANES = `NUM_THREADS;
// localparam PARTIAL_BW = (BLOCK_SIZE != `ISSUE_WIDTH) || (NUM_LANES != `NUM_THREADS);
localparam PARTIAL_BW = 1;
VX_execute_if #(
.NUM_LANES (NUM_LANES)
) execute_if[BLOCK_SIZE]();
`RESET_RELAY (dispatch_reset, reset);
VX_dispatch_unit #(
.BLOCK_SIZE (BLOCK_SIZE),
.NUM_LANES (NUM_LANES),
.OUT_REG (PARTIAL_BW ? 1 : 0)
) dispatch_unit (
.clk (clk),
.reset (dispatch_reset),
.dispatch_if(dispatch_if),
.execute_if (execute_if)
);
VX_commit_if #(
.NUM_LANES (NUM_LANES)
) commit_block_if[BLOCK_SIZE]();
`RESET_RELAY (commit_reset, reset);
VX_gather_unit #(
.BLOCK_SIZE (BLOCK_SIZE),
.NUM_LANES (NUM_LANES),
.OUT_REG (PARTIAL_BW ? 3 : 0) // FIXME: why 3?
) gather_unit (
.clk (clk),
.reset (commit_reset),
.commit_in_if (commit_block_if),
.commit_out_if (commit_if)
);
for (genvar block_idx = 0; block_idx < BLOCK_SIZE; ++block_idx) begin
VX_tensor_core_warp #(
.ISW(i)
.ISW(1) // FIXME: not block_idx
) tensor_core (
.clk(clk),
.reset(reset),
.dispatch_if(dispatch_if[i]),
.commit_if(commit_if[i])
.execute_if(execute_if[block_idx]),
.commit_if(commit_block_if[block_idx])
);
end
@@ -30,7 +69,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
input clk,
input reset,
VX_dispatch_if.slave dispatch_if,
VX_execute_if.slave execute_if,
VX_commit_if.master commit_if
);
localparam NUM_OCTETS = (`NUM_THREADS / 8);
@@ -39,14 +78,15 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
// FIXME: not sure this is the right logic. just filling in what works
localparam LANE_OFFSET_THREADGROUP = (4 * NUM_OCTETS);
wire [1:0] step = 2'(dispatch_if.data.op_type);
wire [1:0] step = 2'(execute_if.data.op_type);
logic [NUM_OCTETS-1:0] octet_results_valid;
logic [NUM_OCTETS-1:0] octet_results_ready;
logic [NUM_OCTETS-1:0] octet_operands_ready;
// FIXME: should be NUM_LANES?
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0;
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1;
assign dispatch_if.ready = &octet_operands_ready;
assign execute_if.ready = &octet_operands_ready;
`ifdef EXT_T_ENABLE
for (genvar i = 0; i < NUM_OCTETS; ++i) begin
@@ -55,13 +95,13 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
`endif
// lane-to-octet mapping; see figure 13 of the paper
wire [7:0][31:0] octet_A = {
dispatch_if.data.rs1_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], dispatch_if.data.rs1_data[4*i +: 4]
execute_if.data.rs1_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], execute_if.data.rs1_data[4*i +: 4]
};
wire [7:0][31:0] octet_B = {
dispatch_if.data.rs2_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], dispatch_if.data.rs2_data[4*i +: 4]
execute_if.data.rs2_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], execute_if.data.rs2_data[4*i +: 4]
};
wire [7:0][31:0] octet_C = {
dispatch_if.data.rs3_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], dispatch_if.data.rs3_data[4*i +: 4]
execute_if.data.rs3_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], execute_if.data.rs3_data[4*i +: 4]
};
logic [3:0][3:0][31:0] octet_D;
@@ -77,7 +117,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
.A_in(octet_A),
.B_in(octet_B),
.C_in(octet_C),
.operands_valid(dispatch_if.valid),
.operands_valid(execute_if.valid),
.operands_ready(octet_operands_ready[i]),
.step(step),
@@ -126,18 +166,18 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
localparam DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS;
wire dispatch_if_fire = dispatch_if.valid && dispatch_if.ready;
wire execute_if_fire = execute_if.valid && execute_if.ready;
wire commit_if_fire = commit_if.valid && commit_if.ready;
wire [DATAW-1:0] dispatch_if_data_enq = {
dispatch_if.data.uuid,
wis_to_wid(dispatch_if.data.wis, ISW),
dispatch_if.data.tmask,
dispatch_if.data.PC,
dispatch_if.data.wb,
dispatch_if.data.rd
wire [DATAW-1:0] execute_if_data_enq = {
execute_if.data.uuid,
execute_if.data.wid,
execute_if.data.tmask,
execute_if.data.PC,
execute_if.data.wb,
execute_if.data.rd
};
wire [DATAW-1:0] dispatch_if_data_deq;
wire [DATAW-1:0] execute_if_data_deq;
// this is probably a little oversized
VX_fifo_queue #(
@@ -146,10 +186,10 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
) pending_uops (
.clk(clk),
.reset(reset),
.push(dispatch_if_fire),
.push(execute_if_fire),
.pop(commit_if_fire),
.data_in(dispatch_if_data_enq),
.data_out(dispatch_if_data_deq),
.data_in(execute_if_data_enq),
.data_out(execute_if_data_deq),
`UNUSED_PIN(empty),
`UNUSED_PIN(alm_empty),
`UNUSED_PIN(full), // should be impossible to overflow
@@ -163,7 +203,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
localparam COMMIT_DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS + (`NUM_THREADS * `XLEN) + 1 + 1 + 1;
wire [COMMIT_DATAW-1:0] commit_if_data = {
dispatch_if_data_deq, /* uuid ~ rd */
execute_if_data_deq, /* uuid ~ rd */
subcommit == 1'b0 ? wb_data_0 : wb_data_1, /* data */
1'b0, /* pid */
1'b1, /* sop */
@@ -227,6 +267,10 @@ module VX_tensor_octet #(
// note that not all lanes participate at every step
case (step)
2'b00: begin
// Two A_in segments correspond to two 2x2 subtiles of A read
// by two threadgroups: [0:2,0:2] and [4:6,0:2] in Step 0 of
// Figure 10(b). B_in OTOH is shared by two threadgroups.
// Note k-dimension is shrunk from 4 to 2.
A_half = { A_in[5:4], A_in[1:0] };
B_half = B_in[3:0];
end