tensor: Split execution module from pipeline logic

This commit is contained in:
Hansung Kim
2024-10-11 20:09:09 -07:00
parent f7f23e0c05
commit 2934b1bd94
2 changed files with 77 additions and 38 deletions

View File

@@ -55,7 +55,7 @@ module VX_tensor_core import VX_gpu_pkg::*; #(
VX_tensor_hopper_core_block #(
.ISW(1), // FIXME: not block_idx
.FP16(FP16)
) tensor_hopper_core (
) tensor_hopper_core_block (
.clk(clk),
.reset(reset),
.execute_if(execute_if[block_idx]),

View File

@@ -12,7 +12,7 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #(
VX_commit_if.master commit_if
);
localparam NUM_LANES = `NUM_THREADS;
localparam METADATA_QUEUE_DEPTH = 16; // FIXME: arbitrary
localparam METADATA_QUEUE_DEPTH = 2; // FIXME: arbitrary
/* commit_if.data_t parts that we need to keep around:
- uuid
@@ -22,29 +22,23 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #(
- wb
- rd
*/
localparam DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS;
wire operand_enq_fire = execute_if.valid && execute_if.ready;
wire commit_if_fire = commit_if.valid && commit_if.ready;
wire [`NUM_WARPS-1:0][`UUID_WIDTH-1:0] execute_if_data_uuid;
wire [`NUM_WARPS-1:0][`NW_WIDTH-1:0] execute_if_data_wid;
wire [`NUM_WARPS-1:0][NUM_LANES-1:0] execute_if_data_tmask;
wire [`NUM_WARPS-1:0][`XLEN-1:0] execute_if_data_PC;
wire [`NUM_WARPS-1:0] execute_if_data_wb;
wire [`NUM_WARPS-1:0][`NR_BITS-1:0] execute_if_data_rd;
logic [DATAW-1:0] execute_if_data_new_rd;
wire [`NUM_WARPS-1:0] metadata_queue_fulls;
wire [`NUM_WARPS-1:0] metadata_queue_emptys;
// OR not AND, we don't want any warp full
// OR not AND; we don't want any warp to be full
wire metadata_queue_full = |(metadata_queue_fulls);
assign execute_if.ready = !metadata_queue_full;
`RUNTIME_ASSERT((!execute_if.valid || execute_if.data.wid == `NW_WIDTH'(0)),
("runtime error: WGMMA execute not supported for warps other than 0!"))
logic metadata_deq;
wire metadata_deq;
for (genvar i = 0; i < `NUM_WARPS; i++) begin
// Metadata queue for commit_if. This simply copies execute_if's
@@ -55,10 +49,12 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #(
// ensure two consecutive dequeues are associated with the same warp for
// commit. (FIXME: this is not strictly necessary though.)
wire operand_enq_fire = execute_if.valid && execute_if.ready;
wire enq = operand_enq_fire && (execute_if.data.wid == `NW_WIDTH'(i));
// FIXME: commit only warp 0
wire deq = metadata_deq && commit_if.ready && (`NW_WIDTH'(i) == `NW_WIDTH'(0));
wire deq = metadata_deq && (`NW_WIDTH'(i) == `NW_WIDTH'(0));
localparam DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS;
VX_fifo_queue #(
.DATAW(DATAW),
.DEPTH(METADATA_QUEUE_DEPTH)
@@ -85,28 +81,84 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #(
// the commit stage
`RUNTIME_ASSERT(!(!reset && metadata_queue_full), ("tensor core uop queue is full!"))
wire initiate_ready; // FIXME: unused
wire writeback_valid;
wire writeback_last;
wire metadata_valid = ~metadata_queue_emptys[0/*FIXME*/];
// dequeue metadata at the last writeback
assign metadata_deq = metadata_valid && writeback_valid && writeback_last;
VX_tensor_hopper_core #(
) tensor_hopper_core (
.clk(clk),
.reset(reset),
.initiate_valid(metadata_valid),
.initiate_wid(`NW_WIDTH'(0)/*FIXME*/),
.initiate_ready(initiate_ready),
.writeback_valid(writeback_valid),
`UNUSED_PIN(writeback_wid),
.writeback_last(writeback_last),
.writeback_ready(commit_if.ready)
);
wire [`NUM_THREADS-1:0][`XLEN-1:0] wb_data = '0;
assign commit_if.valid = writeback_valid;
assign commit_if.data.uuid = execute_if_data_uuid[0];
assign commit_if.data.wid = execute_if_data_wid[0];
assign commit_if.data.tmask = execute_if_data_tmask[0];
assign commit_if.data.PC = execute_if_data_PC[0];
assign commit_if.data.wb = writeback_last;
// custom rd
assign commit_if.data.rd = (`NR_BITS'(`NUM_IREGS) + `NR_BITS'(4'd3/*FIXME*/));
assign commit_if.data.data = wb_data;
assign commit_if.data.tensor = writeback_last;
assign commit_if.data.pid = 1'b0;
assign commit_if.data.sop = 1'b1;
// eop is deliberately set so that we don't underflow the pending_instr
// buffer in VX_schedule. An instruction is considered committed only
// when the eop bit is set to one (see VX_commit).
assign commit_if.data.eop = writeback_last;
endmodule
// TODO: replace this with a Chisel module
module VX_tensor_hopper_core #(
) (
input clk,
input reset,
input initiate_valid,
input [`NW_WIDTH-1:0] initiate_wid,
output initiate_ready,
output writeback_valid,
output [`NW_WIDTH-1:0] writeback_wid,
// indicates if this is the last writeback for the given wid, in which
// case the original HGMMA instruction should be signalled retired
output writeback_last,
input writeback_ready
);
// dummy FSM that generates commits
logic [1:0] state, state_n;
localparam STATE_IDLE = 4'd0;
logic [1:0] state, state_n;
assign initiate_ready = (state == STATE_IDLE);
always @(*) begin
state_n = state;
metadata_deq = 1'b0;
// when incremented to 1, count up until wrap-around to 0
if (state != STATE_IDLE) begin
state_n = state + 1'd1;
end else begin
// kick-off from idle when execute valid
// FIXME: only checks warp 0 for commit!
if (~metadata_queue_emptys[0/*FIXME*/]) begin
state_n = 4'd1;
end
end
// dequeue metadata when wrapping around
if ((state != STATE_IDLE) && (state_n == STATE_IDLE)) begin
metadata_deq = 1'b1;
// kick-off
if (initiate_valid && initiate_ready) begin
state_n = 4'd1;
end
end
@@ -118,23 +170,10 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #(
end
end
// assign commit_if.valid = metadata_deq;
assign commit_if.valid = (state != STATE_IDLE);
assign writeback_valid = (state != STATE_IDLE);
assign writeback_wid = '0; // TODO
assign writeback_last = (state == 4'd15);
wire [`NUM_THREADS-1:0][`XLEN-1:0] wb_data = '0;
assign commit_if.data.uuid = execute_if_data_uuid[0];
assign commit_if.data.wid = execute_if_data_wid[0];
assign commit_if.data.tmask = execute_if_data_tmask[0];
assign commit_if.data.PC = execute_if_data_PC[0];
assign commit_if.data.wb = (state == 2'b11);
// custom rd
assign commit_if.data.rd = (`NR_BITS'(`NUM_IREGS) + `NR_BITS'(state));
assign commit_if.data.data = wb_data;
assign commit_if.data.tensor = (state == 2'b11);
assign commit_if.data.pid = 1'b0;
assign commit_if.data.sop = 1'b1;
assign commit_if.data.eop = (state == 2'b11);
endmodule
`endif