tensor: Split execution module from pipeline logic
This commit is contained in:
@@ -55,7 +55,7 @@ module VX_tensor_core import VX_gpu_pkg::*; #(
|
|||||||
VX_tensor_hopper_core_block #(
|
VX_tensor_hopper_core_block #(
|
||||||
.ISW(1), // FIXME: not block_idx
|
.ISW(1), // FIXME: not block_idx
|
||||||
.FP16(FP16)
|
.FP16(FP16)
|
||||||
) tensor_hopper_core (
|
) tensor_hopper_core_block (
|
||||||
.clk(clk),
|
.clk(clk),
|
||||||
.reset(reset),
|
.reset(reset),
|
||||||
.execute_if(execute_if[block_idx]),
|
.execute_if(execute_if[block_idx]),
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #(
|
|||||||
VX_commit_if.master commit_if
|
VX_commit_if.master commit_if
|
||||||
);
|
);
|
||||||
localparam NUM_LANES = `NUM_THREADS;
|
localparam NUM_LANES = `NUM_THREADS;
|
||||||
localparam METADATA_QUEUE_DEPTH = 16; // FIXME: arbitrary
|
localparam METADATA_QUEUE_DEPTH = 2; // FIXME: arbitrary
|
||||||
|
|
||||||
/* commit_if.data_t parts that we need to keep around:
|
/* commit_if.data_t parts that we need to keep around:
|
||||||
- uuid
|
- uuid
|
||||||
@@ -22,29 +22,23 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #(
|
|||||||
- wb
|
- wb
|
||||||
- rd
|
- rd
|
||||||
*/
|
*/
|
||||||
localparam DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS;
|
|
||||||
|
|
||||||
wire operand_enq_fire = execute_if.valid && execute_if.ready;
|
|
||||||
wire commit_if_fire = commit_if.valid && commit_if.ready;
|
|
||||||
|
|
||||||
wire [`NUM_WARPS-1:0][`UUID_WIDTH-1:0] execute_if_data_uuid;
|
wire [`NUM_WARPS-1:0][`UUID_WIDTH-1:0] execute_if_data_uuid;
|
||||||
wire [`NUM_WARPS-1:0][`NW_WIDTH-1:0] execute_if_data_wid;
|
wire [`NUM_WARPS-1:0][`NW_WIDTH-1:0] execute_if_data_wid;
|
||||||
wire [`NUM_WARPS-1:0][NUM_LANES-1:0] execute_if_data_tmask;
|
wire [`NUM_WARPS-1:0][NUM_LANES-1:0] execute_if_data_tmask;
|
||||||
wire [`NUM_WARPS-1:0][`XLEN-1:0] execute_if_data_PC;
|
wire [`NUM_WARPS-1:0][`XLEN-1:0] execute_if_data_PC;
|
||||||
wire [`NUM_WARPS-1:0] execute_if_data_wb;
|
wire [`NUM_WARPS-1:0] execute_if_data_wb;
|
||||||
wire [`NUM_WARPS-1:0][`NR_BITS-1:0] execute_if_data_rd;
|
wire [`NUM_WARPS-1:0][`NR_BITS-1:0] execute_if_data_rd;
|
||||||
logic [DATAW-1:0] execute_if_data_new_rd;
|
|
||||||
|
|
||||||
wire [`NUM_WARPS-1:0] metadata_queue_fulls;
|
wire [`NUM_WARPS-1:0] metadata_queue_fulls;
|
||||||
wire [`NUM_WARPS-1:0] metadata_queue_emptys;
|
wire [`NUM_WARPS-1:0] metadata_queue_emptys;
|
||||||
// OR not AND, we don't want any warp full
|
// OR not AND; we don't want any warp to be full
|
||||||
wire metadata_queue_full = |(metadata_queue_fulls);
|
wire metadata_queue_full = |(metadata_queue_fulls);
|
||||||
assign execute_if.ready = !metadata_queue_full;
|
assign execute_if.ready = !metadata_queue_full;
|
||||||
|
|
||||||
`RUNTIME_ASSERT((!execute_if.valid || execute_if.data.wid == `NW_WIDTH'(0)),
|
`RUNTIME_ASSERT((!execute_if.valid || execute_if.data.wid == `NW_WIDTH'(0)),
|
||||||
("runtime error: WGMMA execute not supported for warps other than 0!"))
|
("runtime error: WGMMA execute not supported for warps other than 0!"))
|
||||||
|
|
||||||
logic metadata_deq;
|
wire metadata_deq;
|
||||||
|
|
||||||
for (genvar i = 0; i < `NUM_WARPS; i++) begin
|
for (genvar i = 0; i < `NUM_WARPS; i++) begin
|
||||||
// Metadata queue for commit_if. This simply copies execute_if's
|
// Metadata queue for commit_if. This simply copies execute_if's
|
||||||
@@ -55,10 +49,12 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #(
|
|||||||
// ensure two consecutive dequeues are associated with the same warp for
|
// ensure two consecutive dequeues are associated with the same warp for
|
||||||
// commit. (FIXME: this is not strictly necessary though.)
|
// commit. (FIXME: this is not strictly necessary though.)
|
||||||
|
|
||||||
|
wire operand_enq_fire = execute_if.valid && execute_if.ready;
|
||||||
wire enq = operand_enq_fire && (execute_if.data.wid == `NW_WIDTH'(i));
|
wire enq = operand_enq_fire && (execute_if.data.wid == `NW_WIDTH'(i));
|
||||||
// FIXME: commit only warp 0
|
// FIXME: commit only warp 0
|
||||||
wire deq = metadata_deq && commit_if.ready && (`NW_WIDTH'(i) == `NW_WIDTH'(0));
|
wire deq = metadata_deq && (`NW_WIDTH'(i) == `NW_WIDTH'(0));
|
||||||
|
|
||||||
|
localparam DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS;
|
||||||
VX_fifo_queue #(
|
VX_fifo_queue #(
|
||||||
.DATAW(DATAW),
|
.DATAW(DATAW),
|
||||||
.DEPTH(METADATA_QUEUE_DEPTH)
|
.DEPTH(METADATA_QUEUE_DEPTH)
|
||||||
@@ -85,28 +81,84 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #(
|
|||||||
// the commit stage
|
// the commit stage
|
||||||
`RUNTIME_ASSERT(!(!reset && metadata_queue_full), ("tensor core uop queue is full!"))
|
`RUNTIME_ASSERT(!(!reset && metadata_queue_full), ("tensor core uop queue is full!"))
|
||||||
|
|
||||||
|
wire initiate_ready; // FIXME: unused
|
||||||
|
wire writeback_valid;
|
||||||
|
wire writeback_last;
|
||||||
|
|
||||||
|
wire metadata_valid = ~metadata_queue_emptys[0/*FIXME*/];
|
||||||
|
// dequeue metadata at the last writeback
|
||||||
|
assign metadata_deq = metadata_valid && writeback_valid && writeback_last;
|
||||||
|
|
||||||
|
VX_tensor_hopper_core #(
|
||||||
|
) tensor_hopper_core (
|
||||||
|
.clk(clk),
|
||||||
|
.reset(reset),
|
||||||
|
|
||||||
|
.initiate_valid(metadata_valid),
|
||||||
|
.initiate_wid(`NW_WIDTH'(0)/*FIXME*/),
|
||||||
|
.initiate_ready(initiate_ready),
|
||||||
|
|
||||||
|
.writeback_valid(writeback_valid),
|
||||||
|
`UNUSED_PIN(writeback_wid),
|
||||||
|
.writeback_last(writeback_last),
|
||||||
|
.writeback_ready(commit_if.ready)
|
||||||
|
);
|
||||||
|
|
||||||
|
wire [`NUM_THREADS-1:0][`XLEN-1:0] wb_data = '0;
|
||||||
|
|
||||||
|
assign commit_if.valid = writeback_valid;
|
||||||
|
assign commit_if.data.uuid = execute_if_data_uuid[0];
|
||||||
|
assign commit_if.data.wid = execute_if_data_wid[0];
|
||||||
|
assign commit_if.data.tmask = execute_if_data_tmask[0];
|
||||||
|
assign commit_if.data.PC = execute_if_data_PC[0];
|
||||||
|
assign commit_if.data.wb = writeback_last;
|
||||||
|
// custom rd
|
||||||
|
assign commit_if.data.rd = (`NR_BITS'(`NUM_IREGS) + `NR_BITS'(4'd3/*FIXME*/));
|
||||||
|
assign commit_if.data.data = wb_data;
|
||||||
|
assign commit_if.data.tensor = writeback_last;
|
||||||
|
assign commit_if.data.pid = 1'b0;
|
||||||
|
assign commit_if.data.sop = 1'b1;
|
||||||
|
// eop is deliberately set so that we don't underflow the pending_instr
|
||||||
|
// buffer in VX_schedule. An instruction is considered committed only
|
||||||
|
// when the eop bit is set to one (see VX_commit).
|
||||||
|
assign commit_if.data.eop = writeback_last;
|
||||||
|
endmodule
|
||||||
|
|
||||||
|
|
||||||
|
// TODO: replace this with a Chisel module
|
||||||
|
module VX_tensor_hopper_core #(
|
||||||
|
) (
|
||||||
|
input clk,
|
||||||
|
input reset,
|
||||||
|
|
||||||
|
input initiate_valid,
|
||||||
|
input [`NW_WIDTH-1:0] initiate_wid,
|
||||||
|
output initiate_ready,
|
||||||
|
|
||||||
|
output writeback_valid,
|
||||||
|
output [`NW_WIDTH-1:0] writeback_wid,
|
||||||
|
// indicates if this is the last writeback for the given wid, in which
|
||||||
|
// case the original HGMMA instruction should be signalled retired
|
||||||
|
output writeback_last,
|
||||||
|
input writeback_ready
|
||||||
|
);
|
||||||
// dummy FSM that generates commits
|
// dummy FSM that generates commits
|
||||||
logic [1:0] state, state_n;
|
|
||||||
localparam STATE_IDLE = 4'd0;
|
localparam STATE_IDLE = 4'd0;
|
||||||
|
logic [1:0] state, state_n;
|
||||||
|
|
||||||
|
assign initiate_ready = (state == STATE_IDLE);
|
||||||
|
|
||||||
always @(*) begin
|
always @(*) begin
|
||||||
state_n = state;
|
state_n = state;
|
||||||
metadata_deq = 1'b0;
|
|
||||||
|
|
||||||
// when incremented to 1, count up until wrap-around to 0
|
// when incremented to 1, count up until wrap-around to 0
|
||||||
if (state != STATE_IDLE) begin
|
if (state != STATE_IDLE) begin
|
||||||
state_n = state + 1'd1;
|
state_n = state + 1'd1;
|
||||||
end else begin
|
|
||||||
// kick-off from idle when execute valid
|
|
||||||
// FIXME: only checks warp 0 for commit!
|
|
||||||
if (~metadata_queue_emptys[0/*FIXME*/]) begin
|
|
||||||
state_n = 4'd1;
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
// dequeue metadata when wrapping around
|
// kick-off
|
||||||
if ((state != STATE_IDLE) && (state_n == STATE_IDLE)) begin
|
if (initiate_valid && initiate_ready) begin
|
||||||
metadata_deq = 1'b1;
|
state_n = 4'd1;
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -118,23 +170,10 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #(
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
// assign commit_if.valid = metadata_deq;
|
assign writeback_valid = (state != STATE_IDLE);
|
||||||
assign commit_if.valid = (state != STATE_IDLE);
|
assign writeback_wid = '0; // TODO
|
||||||
|
assign writeback_last = (state == 4'd15);
|
||||||
|
|
||||||
wire [`NUM_THREADS-1:0][`XLEN-1:0] wb_data = '0;
|
|
||||||
|
|
||||||
assign commit_if.data.uuid = execute_if_data_uuid[0];
|
|
||||||
assign commit_if.data.wid = execute_if_data_wid[0];
|
|
||||||
assign commit_if.data.tmask = execute_if_data_tmask[0];
|
|
||||||
assign commit_if.data.PC = execute_if_data_PC[0];
|
|
||||||
assign commit_if.data.wb = (state == 2'b11);
|
|
||||||
// custom rd
|
|
||||||
assign commit_if.data.rd = (`NR_BITS'(`NUM_IREGS) + `NR_BITS'(state));
|
|
||||||
assign commit_if.data.data = wb_data;
|
|
||||||
assign commit_if.data.tensor = (state == 2'b11);
|
|
||||||
assign commit_if.data.pid = 1'b0;
|
|
||||||
assign commit_if.data.sop = 1'b1;
|
|
||||||
assign commit_if.data.eop = (state == 2'b11);
|
|
||||||
endmodule
|
endmodule
|
||||||
|
|
||||||
`endif
|
`endif
|
||||||
|
|||||||
Reference in New Issue
Block a user