diff --git a/hw/rtl/core/VX_tensor_core.sv b/hw/rtl/core/VX_tensor_core.sv index 1f7a95db..cad70b97 100644 --- a/hw/rtl/core/VX_tensor_core.sv +++ b/hw/rtl/core/VX_tensor_core.sv @@ -55,7 +55,7 @@ module VX_tensor_core import VX_gpu_pkg::*; #( VX_tensor_hopper_core_block #( .ISW(1), // FIXME: not block_idx .FP16(FP16) - ) tensor_hopper_core ( + ) tensor_hopper_core_block ( .clk(clk), .reset(reset), .execute_if(execute_if[block_idx]), diff --git a/hw/rtl/core/VX_tensor_hopper_core.sv b/hw/rtl/core/VX_tensor_hopper_core.sv index dc763d48..8abe463e 100644 --- a/hw/rtl/core/VX_tensor_hopper_core.sv +++ b/hw/rtl/core/VX_tensor_hopper_core.sv @@ -12,7 +12,7 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #( VX_commit_if.master commit_if ); localparam NUM_LANES = `NUM_THREADS; - localparam METADATA_QUEUE_DEPTH = 16; // FIXME: arbitrary + localparam METADATA_QUEUE_DEPTH = 2; // FIXME: arbitrary /* commit_if.data_t parts that we need to keep around: - uuid @@ -22,29 +22,23 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #( - wb - rd */ - localparam DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS; - - wire operand_enq_fire = execute_if.valid && execute_if.ready; - wire commit_if_fire = commit_if.valid && commit_if.ready; - wire [`NUM_WARPS-1:0][`UUID_WIDTH-1:0] execute_if_data_uuid; wire [`NUM_WARPS-1:0][`NW_WIDTH-1:0] execute_if_data_wid; wire [`NUM_WARPS-1:0][NUM_LANES-1:0] execute_if_data_tmask; wire [`NUM_WARPS-1:0][`XLEN-1:0] execute_if_data_PC; wire [`NUM_WARPS-1:0] execute_if_data_wb; wire [`NUM_WARPS-1:0][`NR_BITS-1:0] execute_if_data_rd; - logic [DATAW-1:0] execute_if_data_new_rd; wire [`NUM_WARPS-1:0] metadata_queue_fulls; wire [`NUM_WARPS-1:0] metadata_queue_emptys; - // OR not AND, we don't want any warp full + // OR not AND; we don't want any warp to be full wire metadata_queue_full = |(metadata_queue_fulls); assign execute_if.ready = !metadata_queue_full; `RUNTIME_ASSERT((!execute_if.valid || execute_if.data.wid == `NW_WIDTH'(0)), ("runtime error: WGMMA execute not supported for warps other than 0!")) - logic metadata_deq; + wire metadata_deq; for (genvar i = 0; i < `NUM_WARPS; i++) begin // Metadata queue for commit_if. This simply copies execute_if's @@ -55,10 +49,12 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #( // ensure two consecutive dequeues are associated with the same warp for // commit. (FIXME: this is not strictly necessary though.) + wire operand_enq_fire = execute_if.valid && execute_if.ready; wire enq = operand_enq_fire && (execute_if.data.wid == `NW_WIDTH'(i)); // FIXME: commit only warp 0 - wire deq = metadata_deq && commit_if.ready && (`NW_WIDTH'(i) == `NW_WIDTH'(0)); + wire deq = metadata_deq && (`NW_WIDTH'(i) == `NW_WIDTH'(0)); + localparam DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS; VX_fifo_queue #( .DATAW(DATAW), .DEPTH(METADATA_QUEUE_DEPTH) @@ -85,28 +81,84 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #( // the commit stage `RUNTIME_ASSERT(!(!reset && metadata_queue_full), ("tensor core uop queue is full!")) + wire initiate_ready; // FIXME: unused + wire writeback_valid; + wire writeback_last; + + wire metadata_valid = ~metadata_queue_emptys[0/*FIXME*/]; + // dequeue metadata at the last writeback + assign metadata_deq = metadata_valid && writeback_valid && writeback_last; + + VX_tensor_hopper_core #( + ) tensor_hopper_core ( + .clk(clk), + .reset(reset), + + .initiate_valid(metadata_valid), + .initiate_wid(`NW_WIDTH'(0)/*FIXME*/), + .initiate_ready(initiate_ready), + + .writeback_valid(writeback_valid), + `UNUSED_PIN(writeback_wid), + .writeback_last(writeback_last), + .writeback_ready(commit_if.ready) + ); + + wire [`NUM_THREADS-1:0][`XLEN-1:0] wb_data = '0; + + assign commit_if.valid = writeback_valid; + assign commit_if.data.uuid = execute_if_data_uuid[0]; + assign commit_if.data.wid = execute_if_data_wid[0]; + assign commit_if.data.tmask = execute_if_data_tmask[0]; + assign commit_if.data.PC = execute_if_data_PC[0]; + assign commit_if.data.wb = writeback_last; + // custom rd + assign commit_if.data.rd = (`NR_BITS'(`NUM_IREGS) + `NR_BITS'(4'd3/*FIXME*/)); + assign commit_if.data.data = wb_data; + assign commit_if.data.tensor = writeback_last; + assign commit_if.data.pid = 1'b0; + assign commit_if.data.sop = 1'b1; + // eop is deliberately set so that we don't underflow the pending_instr + // buffer in VX_schedule. An instruction is considered committed only + // when the eop bit is set to one (see VX_commit). + assign commit_if.data.eop = writeback_last; +endmodule + + +// TODO: replace this with a Chisel module +module VX_tensor_hopper_core #( +) ( + input clk, + input reset, + + input initiate_valid, + input [`NW_WIDTH-1:0] initiate_wid, + output initiate_ready, + + output writeback_valid, + output [`NW_WIDTH-1:0] writeback_wid, + // indicates if this is the last writeback for the given wid, in which + // case the original HGMMA instruction should be signalled retired + output writeback_last, + input writeback_ready +); // dummy FSM that generates commits - logic [1:0] state, state_n; localparam STATE_IDLE = 4'd0; + logic [1:0] state, state_n; + + assign initiate_ready = (state == STATE_IDLE); always @(*) begin state_n = state; - metadata_deq = 1'b0; // when incremented to 1, count up until wrap-around to 0 if (state != STATE_IDLE) begin state_n = state + 1'd1; - end else begin - // kick-off from idle when execute valid - // FIXME: only checks warp 0 for commit! - if (~metadata_queue_emptys[0/*FIXME*/]) begin - state_n = 4'd1; - end end - // dequeue metadata when wrapping around - if ((state != STATE_IDLE) && (state_n == STATE_IDLE)) begin - metadata_deq = 1'b1; + // kick-off + if (initiate_valid && initiate_ready) begin + state_n = 4'd1; end end @@ -118,23 +170,10 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #( end end - // assign commit_if.valid = metadata_deq; - assign commit_if.valid = (state != STATE_IDLE); + assign writeback_valid = (state != STATE_IDLE); + assign writeback_wid = '0; // TODO + assign writeback_last = (state == 4'd15); - wire [`NUM_THREADS-1:0][`XLEN-1:0] wb_data = '0; - - assign commit_if.data.uuid = execute_if_data_uuid[0]; - assign commit_if.data.wid = execute_if_data_wid[0]; - assign commit_if.data.tmask = execute_if_data_tmask[0]; - assign commit_if.data.PC = execute_if_data_PC[0]; - assign commit_if.data.wb = (state == 2'b11); - // custom rd - assign commit_if.data.rd = (`NR_BITS'(`NUM_IREGS) + `NR_BITS'(state)); - assign commit_if.data.data = wb_data; - assign commit_if.data.tensor = (state == 2'b11); - assign commit_if.data.pid = 1'b0; - assign commit_if.data.sop = 1'b1; - assign commit_if.data.eop = (state == 2'b11); endmodule `endif