tensor: Split execution module from pipeline logic
This commit is contained in:
@@ -55,7 +55,7 @@ module VX_tensor_core import VX_gpu_pkg::*; #(
|
||||
VX_tensor_hopper_core_block #(
|
||||
.ISW(1), // FIXME: not block_idx
|
||||
.FP16(FP16)
|
||||
) tensor_hopper_core (
|
||||
) tensor_hopper_core_block (
|
||||
.clk(clk),
|
||||
.reset(reset),
|
||||
.execute_if(execute_if[block_idx]),
|
||||
|
||||
@@ -12,7 +12,7 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #(
|
||||
VX_commit_if.master commit_if
|
||||
);
|
||||
localparam NUM_LANES = `NUM_THREADS;
|
||||
localparam METADATA_QUEUE_DEPTH = 16; // FIXME: arbitrary
|
||||
localparam METADATA_QUEUE_DEPTH = 2; // FIXME: arbitrary
|
||||
|
||||
/* commit_if.data_t parts that we need to keep around:
|
||||
- uuid
|
||||
@@ -22,29 +22,23 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #(
|
||||
- wb
|
||||
- rd
|
||||
*/
|
||||
localparam DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS;
|
||||
|
||||
wire operand_enq_fire = execute_if.valid && execute_if.ready;
|
||||
wire commit_if_fire = commit_if.valid && commit_if.ready;
|
||||
|
||||
wire [`NUM_WARPS-1:0][`UUID_WIDTH-1:0] execute_if_data_uuid;
|
||||
wire [`NUM_WARPS-1:0][`NW_WIDTH-1:0] execute_if_data_wid;
|
||||
wire [`NUM_WARPS-1:0][NUM_LANES-1:0] execute_if_data_tmask;
|
||||
wire [`NUM_WARPS-1:0][`XLEN-1:0] execute_if_data_PC;
|
||||
wire [`NUM_WARPS-1:0] execute_if_data_wb;
|
||||
wire [`NUM_WARPS-1:0][`NR_BITS-1:0] execute_if_data_rd;
|
||||
logic [DATAW-1:0] execute_if_data_new_rd;
|
||||
|
||||
wire [`NUM_WARPS-1:0] metadata_queue_fulls;
|
||||
wire [`NUM_WARPS-1:0] metadata_queue_emptys;
|
||||
// OR not AND, we don't want any warp full
|
||||
// OR not AND; we don't want any warp to be full
|
||||
wire metadata_queue_full = |(metadata_queue_fulls);
|
||||
assign execute_if.ready = !metadata_queue_full;
|
||||
|
||||
`RUNTIME_ASSERT((!execute_if.valid || execute_if.data.wid == `NW_WIDTH'(0)),
|
||||
("runtime error: WGMMA execute not supported for warps other than 0!"))
|
||||
|
||||
logic metadata_deq;
|
||||
wire metadata_deq;
|
||||
|
||||
for (genvar i = 0; i < `NUM_WARPS; i++) begin
|
||||
// Metadata queue for commit_if. This simply copies execute_if's
|
||||
@@ -55,10 +49,12 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #(
|
||||
// ensure two consecutive dequeues are associated with the same warp for
|
||||
// commit. (FIXME: this is not strictly necessary though.)
|
||||
|
||||
wire operand_enq_fire = execute_if.valid && execute_if.ready;
|
||||
wire enq = operand_enq_fire && (execute_if.data.wid == `NW_WIDTH'(i));
|
||||
// FIXME: commit only warp 0
|
||||
wire deq = metadata_deq && commit_if.ready && (`NW_WIDTH'(i) == `NW_WIDTH'(0));
|
||||
wire deq = metadata_deq && (`NW_WIDTH'(i) == `NW_WIDTH'(0));
|
||||
|
||||
localparam DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS;
|
||||
VX_fifo_queue #(
|
||||
.DATAW(DATAW),
|
||||
.DEPTH(METADATA_QUEUE_DEPTH)
|
||||
@@ -85,28 +81,84 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #(
|
||||
// the commit stage
|
||||
`RUNTIME_ASSERT(!(!reset && metadata_queue_full), ("tensor core uop queue is full!"))
|
||||
|
||||
wire initiate_ready; // FIXME: unused
|
||||
wire writeback_valid;
|
||||
wire writeback_last;
|
||||
|
||||
wire metadata_valid = ~metadata_queue_emptys[0/*FIXME*/];
|
||||
// dequeue metadata at the last writeback
|
||||
assign metadata_deq = metadata_valid && writeback_valid && writeback_last;
|
||||
|
||||
VX_tensor_hopper_core #(
|
||||
) tensor_hopper_core (
|
||||
.clk(clk),
|
||||
.reset(reset),
|
||||
|
||||
.initiate_valid(metadata_valid),
|
||||
.initiate_wid(`NW_WIDTH'(0)/*FIXME*/),
|
||||
.initiate_ready(initiate_ready),
|
||||
|
||||
.writeback_valid(writeback_valid),
|
||||
`UNUSED_PIN(writeback_wid),
|
||||
.writeback_last(writeback_last),
|
||||
.writeback_ready(commit_if.ready)
|
||||
);
|
||||
|
||||
wire [`NUM_THREADS-1:0][`XLEN-1:0] wb_data = '0;
|
||||
|
||||
assign commit_if.valid = writeback_valid;
|
||||
assign commit_if.data.uuid = execute_if_data_uuid[0];
|
||||
assign commit_if.data.wid = execute_if_data_wid[0];
|
||||
assign commit_if.data.tmask = execute_if_data_tmask[0];
|
||||
assign commit_if.data.PC = execute_if_data_PC[0];
|
||||
assign commit_if.data.wb = writeback_last;
|
||||
// custom rd
|
||||
assign commit_if.data.rd = (`NR_BITS'(`NUM_IREGS) + `NR_BITS'(4'd3/*FIXME*/));
|
||||
assign commit_if.data.data = wb_data;
|
||||
assign commit_if.data.tensor = writeback_last;
|
||||
assign commit_if.data.pid = 1'b0;
|
||||
assign commit_if.data.sop = 1'b1;
|
||||
// eop is deliberately set so that we don't underflow the pending_instr
|
||||
// buffer in VX_schedule. An instruction is considered committed only
|
||||
// when the eop bit is set to one (see VX_commit).
|
||||
assign commit_if.data.eop = writeback_last;
|
||||
endmodule
|
||||
|
||||
|
||||
// TODO: replace this with a Chisel module
|
||||
module VX_tensor_hopper_core #(
|
||||
) (
|
||||
input clk,
|
||||
input reset,
|
||||
|
||||
input initiate_valid,
|
||||
input [`NW_WIDTH-1:0] initiate_wid,
|
||||
output initiate_ready,
|
||||
|
||||
output writeback_valid,
|
||||
output [`NW_WIDTH-1:0] writeback_wid,
|
||||
// indicates if this is the last writeback for the given wid, in which
|
||||
// case the original HGMMA instruction should be signalled retired
|
||||
output writeback_last,
|
||||
input writeback_ready
|
||||
);
|
||||
// dummy FSM that generates commits
|
||||
logic [1:0] state, state_n;
|
||||
localparam STATE_IDLE = 4'd0;
|
||||
logic [1:0] state, state_n;
|
||||
|
||||
assign initiate_ready = (state == STATE_IDLE);
|
||||
|
||||
always @(*) begin
|
||||
state_n = state;
|
||||
metadata_deq = 1'b0;
|
||||
|
||||
// when incremented to 1, count up until wrap-around to 0
|
||||
if (state != STATE_IDLE) begin
|
||||
state_n = state + 1'd1;
|
||||
end else begin
|
||||
// kick-off from idle when execute valid
|
||||
// FIXME: only checks warp 0 for commit!
|
||||
if (~metadata_queue_emptys[0/*FIXME*/]) begin
|
||||
state_n = 4'd1;
|
||||
end
|
||||
end
|
||||
|
||||
// dequeue metadata when wrapping around
|
||||
if ((state != STATE_IDLE) && (state_n == STATE_IDLE)) begin
|
||||
metadata_deq = 1'b1;
|
||||
// kick-off
|
||||
if (initiate_valid && initiate_ready) begin
|
||||
state_n = 4'd1;
|
||||
end
|
||||
end
|
||||
|
||||
@@ -118,23 +170,10 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #(
|
||||
end
|
||||
end
|
||||
|
||||
// assign commit_if.valid = metadata_deq;
|
||||
assign commit_if.valid = (state != STATE_IDLE);
|
||||
assign writeback_valid = (state != STATE_IDLE);
|
||||
assign writeback_wid = '0; // TODO
|
||||
assign writeback_last = (state == 4'd15);
|
||||
|
||||
wire [`NUM_THREADS-1:0][`XLEN-1:0] wb_data = '0;
|
||||
|
||||
assign commit_if.data.uuid = execute_if_data_uuid[0];
|
||||
assign commit_if.data.wid = execute_if_data_wid[0];
|
||||
assign commit_if.data.tmask = execute_if_data_tmask[0];
|
||||
assign commit_if.data.PC = execute_if_data_PC[0];
|
||||
assign commit_if.data.wb = (state == 2'b11);
|
||||
// custom rd
|
||||
assign commit_if.data.rd = (`NR_BITS'(`NUM_IREGS) + `NR_BITS'(state));
|
||||
assign commit_if.data.data = wb_data;
|
||||
assign commit_if.data.tensor = (state == 2'b11);
|
||||
assign commit_if.data.pid = 1'b0;
|
||||
assign commit_if.data.sop = 1'b1;
|
||||
assign commit_if.data.eop = (state == 2'b11);
|
||||
endmodule
|
||||
|
||||
`endif
|
||||
|
||||
Reference in New Issue
Block a user