diff --git a/hw/rtl/VX_config.vh b/hw/rtl/VX_config.vh index 65d56e8a..5ef71794 100644 --- a/hw/rtl/VX_config.vh +++ b/hw/rtl/VX_config.vh @@ -391,7 +391,7 @@ // Tensor Core Latency `ifndef LATENCY_HMMA -`define LATENCY_HMMA 8 +`define LATENCY_HMMA 2 `endif // Icache Configurable Knobs ////////////////////////////////////////////////// diff --git a/hw/rtl/core/VX_tensor_core.sv b/hw/rtl/core/VX_tensor_core.sv index d1ee3b38..0612ca12 100644 --- a/hw/rtl/core/VX_tensor_core.sv +++ b/hw/rtl/core/VX_tensor_core.sv @@ -189,8 +189,8 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( wire [`NUM_WARPS-1:0][DATAW-1:0] execute_if_data_deq; for (genvar i = 0; i < `NUM_WARPS; i++) begin - wire enq = execute_if_fire && (execute_if.data.wid == i); - wire deq = commit_if_fire && (wb_wid == i); + wire enq = execute_if_fire && (execute_if.data.wid == `NW_WIDTH'(i)); + wire deq = commit_if_fire && ( wb_wid == `NW_WIDTH'(i)); logic full; // execute_if request queue. @@ -395,7 +395,7 @@ module VX_tensor_octet #( // C is 4x4 fp32 matrix logic [3:0][3:0][31:0] C_tile; logic [3:0][3:0][31:0] D_tile; - logic [`NW_WIDTH-1:0] D_warp_id; + logic [`NW_WIDTH-1:0] D_wid_dpu; always @(*) begin C_tile[3] = { C_half[7], C_buffer[operands_wid][7], C_half[5], C_buffer[operands_wid][5] }; @@ -423,27 +423,41 @@ module VX_tensor_octet #( .A_tile(A_tile), .B_tile(B_tile), .C_tile(C_tile), - .warp_id(operands_wid), + .wid(operands_wid), .valid_out(dpu_valid), .D_tile(D_tile), - .D_warp_id(D_warp_id) + .D_wid(D_wid_dpu) ); + wire outbuf_empty; + wire outbuf_full; + assign outbuf_ready_in = ~outbuf_full; + assign result_valid = ~outbuf_empty; + + wire outbuf_enq = outbuf_ready_in && dpu_valid; + wire outbuf_deq = result_valid && result_ready; + // buffer to stage the result tile for 2 cycles until commit/writeback is - // complete - VX_stream_buffer #( + // complete. This decouples the irregular dpu output traffic from the + // regular, every-2-cycle commit traffic and thereby ensures the commit + // pipeline is used more efficiently. + // TODO: This is probably oversized. + VX_fifo_queue #( .DATAW ($bits(D_wid) + $bits(D_out)), - .OUT_REG (1) // not sure this is necessary + .DEPTH (8 /* FIXME: arbitrary */) ) output_buffer ( - .clk (clk), + .clk (clk), .reset (reset), - .valid_in (dpu_valid), - .ready_in (outbuf_ready_in), - .data_in ({D_warp_id, D_tile}), + .push (outbuf_enq), + .pop (outbuf_deq), + .data_in ({D_wid_dpu, D_tile}), .data_out ({D_wid, D_out}), - .ready_out (result_ready), - .valid_out (result_valid) + .empty (outbuf_empty), + `UNUSED_PIN(alm_empty), + .full (outbuf_full), // should be impossible to overflow + `UNUSED_PIN(alm_full), + `UNUSED_PIN(size) ); `ifdef PERF_ENABLE diff --git a/hw/rtl/fpu/VX_tensor_dpu.sv b/hw/rtl/fpu/VX_tensor_dpu.sv index 1ffbb6d3..7a3ee41d 100644 --- a/hw/rtl/fpu/VX_tensor_dpu.sv +++ b/hw/rtl/fpu/VX_tensor_dpu.sv @@ -15,11 +15,11 @@ module VX_tensor_dpu #( input [3:0][1:0][31:0] A_tile, input [1:0][3:0][31:0] B_tile, input [3:0][3:0][31:0] C_tile, - input [`NW_WIDTH-1:0] warp_id, + input [`NW_WIDTH-1:0] wid, output valid_out, output [3:0][3:0][31:0] D_tile, - output [`NW_WIDTH-1:0] D_warp_id + output [`NW_WIDTH-1:0] D_wid ); logic [3:0][3:0][31:0] result_hmma; @@ -44,15 +44,15 @@ module VX_tensor_dpu #( // fixed-latency model VX_shift_register #( - .DATAW (1 + $bits(warp_id) + $bits(D_tile)), + .DATAW (1 + $bits(wid) + $bits(D_tile)), .DEPTH (`LATENCY_HMMA), .RESETW (1) ) shift_reg ( .clk (clk), .reset (reset), .enable (~stall), - .data_in ({valid_in, warp_id, result_hmma}), - .data_out ({valid_out, D_warp_id, D_tile}) + .data_in ({valid_in, wid, result_hmma}), + .data_out ({valid_out, D_wid, D_tile}) ); endmodule `endif