tensor: Fix out-of-sync enqueue to dpu and metadata queue

This commit is contained in:
Hansung Kim
2024-05-30 18:03:04 -07:00
parent 97f37b1c75
commit 0a032ab400
2 changed files with 37 additions and 74 deletions

View File

@@ -77,7 +77,8 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
// octet. E.g. two tgs map lane 0-3 and lane 16-19 -> 16
// FIXME: not sure this is the right logic. just filling in what works
localparam LANE_OFFSET_THREADGROUP = (4 * NUM_OCTETS);
localparam METADATA_QUEUE_DEPTH = 4;
// this is only a rule of thumb
localparam METADATA_QUEUE_DEPTH = `LATENCY_HMMA;
wire [1:0] step = 2'(execute_if.data.op_type);
wire last_in_pair = (execute_if.data.op_mod == `INST_MOD_BITS'(1));
@@ -89,7 +90,11 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0;
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1;
wire [`NW_WIDTH-1:0] wb_wid;
// valid signal synced between the functional units (octet) and the
// metadata queue
wire operands_valid_synced;
`ifdef EXT_T_ENABLE
for (genvar i = 0; i < NUM_OCTETS; ++i) begin
`else
@@ -121,7 +126,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
.A_in(octet_A),
.B_in(octet_B),
.C_in(octet_C),
.operands_valid(execute_if.valid),
.operands_valid(operands_valid_synced),
.operands_wid(execute_if.data.wid),
.operands_last_in_pair(last_in_pair),
.operands_step(step),
@@ -172,8 +177,10 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
localparam DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS;
wire execute_if_fire = execute_if.valid && execute_if.ready;
wire commit_if_fire = commit_if.valid && commit_if.ready;
wire commit_if_ready_override;
wire operand_enq_fire = operands_valid_synced && execute_if.ready;
wire commit_if_fire = commit_if.valid && commit_if_ready_override;
wire [DATAW-1:0] execute_if_data_enq = {
execute_if.data.uuid,
execute_if.data.wid,
@@ -184,31 +191,14 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
// pid/sop/eop set later
};
// wire [DATAW-1:0] execute_if_data_deq;
// VX_fifo_queue #(
// .DATAW(DATAW),
// .DEPTH(4 /* FIXME: arbitrary */)
// ) pending_uops (
// .clk(clk),
// .reset(reset),
// .push(execute_if_fire),
// .pop(commit_if_fire),
// .data_in(execute_if_data_enq),
// .data_out(execute_if_data_deq),
// `UNUSED_PIN(empty),
// `UNUSED_PIN(alm_empty),
// `UNUSED_PIN(full), // should be impossible to overflow
// `UNUSED_PIN(alm_full),
// `UNUSED_PIN(size)
// );
wire [`NUM_WARPS-1:0][DATAW-1:0] execute_if_data_deq;
wire [`NUM_WARPS-1:0] metadata_queue_fulls;
// OR not AND, we don't want any warp full
wire metadata_queue_full = |(metadata_queue_fulls);
// need to make sure both metadata and octet issue queues are in sync
assign operands_valid_synced = execute_if.valid && !metadata_queue_full;
assign execute_if.ready = &(octet_operands_ready) && !metadata_queue_full;
for (genvar i = 0; i < `NUM_WARPS; i++) begin
@@ -220,8 +210,8 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
// ensure two consecutive dequeues are associated with the same warp for
// commit. (FIXME: this is not strictly necessary though.)
wire enq = execute_if_fire && (execute_if.data.wid == `NW_WIDTH'(i));
wire deq = commit_if_fire && ( wb_wid == `NW_WIDTH'(i));
wire enq = operand_enq_fire && (execute_if.data.wid == `NW_WIDTH'(i));
wire deq = commit_if_fire && ( wb_wid == `NW_WIDTH'(i));
VX_fifo_queue #(
.DATAW(DATAW),
@@ -253,8 +243,8 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
wire all_valid = (& octet_results_valid);
// define this to inject artificial commit backpressure for debugging
`define INJECT_COMMIT_BACKPRESSURE
`ifndef INJECT_COMMIT_BACKPRESSURE
// `define TENSOR_INJECT_COMMIT_BACKPRESSURE
`ifndef TENSOR_INJECT_COMMIT_BACKPRESSURE
assign commit_if.valid = all_valid;
assign commit_if_ready_override = commit_if.ready;
`else
@@ -358,47 +348,6 @@ module VX_tensor_octet #(
wire operands_last_in_pair_buf;
wire [1:0] operands_step_buf;
// wire inbuf_empty;
// wire inbuf_full;
// wire inbuf_ready_in;
// assign inbuf_ready_in = !inbuf_full;
// assign operands_ready = inbuf_ready_in;
// assign operands_valid_buf = !inbuf_empty;
// // wire inbuf_enq = operands_ready && operands_valid && operands_last_in_pair;
// wire inbuf_enq = operands_ready && operands_valid;
// wire inbuf_deq = operands_valid_buf && operands_ready_buf;
// // the 'issue queue' for the dpu.
// // This exists to decouple the input of the dot-product unit from
// // execute_if.ready. execute_if can arrive intermittently according to
// // the frontend's behavior, and since the dpu can also stall for a fixed
// // initiation latency, we need to decouple the two to efficiently feed the
// // dpu.
// // This only applies to the last instruction in a pair, since the first
// // instruction only acts to buffer the operands and can execute
// // immediately without backpressure. So we don't enqueue them.
// VX_fifo_queue #(
// .DATAW ($bits(A_in) + $bits(B_in) + $bits(C_in) +
// $bits(operands_wid) + $bits(operands_step) + $bits(operands_last_in_pair)),
// .DEPTH (ISSUE_QUEUE_DEPTH)
// ) input_buffer (
// .clk (clk),
// .reset (reset),
// .push (inbuf_enq),
// .pop (inbuf_deq),
// .data_in ({A_in, B_in, C_in, operands_wid, operands_step, operands_last_in_pair}),
// .data_out ({A_in_buf, B_in_buf, C_in_buf, operands_wid_buf, operands_step_buf, operands_last_in_pair_buf}),
// .empty (inbuf_empty),
// `UNUSED_PIN(alm_empty),
// .full (inbuf_full),
// `UNUSED_PIN(alm_full),
// `UNUSED_PIN(size)
// );
// // FIXME: this shouldn't be necessary
// `RUNTIME_ASSERT(reset || !inbuf_full, ("dpu issue queue is full!"))
assign A_in_buf = A_in;
assign B_in_buf = B_in;
assign C_in_buf = C_in;
@@ -521,7 +470,8 @@ module VX_tensor_octet #(
// this does (m,n,k)=(4,4,2) matmul, modeling compute of a single octet
VX_tensor_dpu #(
.ISW(ISW),
.OCTET(OCTET)
.OCTET(OCTET),
.ISSUE_QUEUE_DEPTH(2)
) dpu (
.clk(clk),
.reset(reset),
@@ -556,7 +506,7 @@ module VX_tensor_octet #(
// TODO: This is probably oversized.
VX_fifo_queue #(
.DATAW ($bits(D_wid) + $bits(D_out)),
.DEPTH (`LATENCY_HMMA)
.DEPTH (2 /*`LATENCY_HMMA*/)
) output_buffer (
.clk (clk),
.reset (reset),

View File

@@ -3,7 +3,8 @@
module VX_tensor_dpu #(
parameter ISW,
parameter OCTET
parameter OCTET,
parameter ISSUE_QUEUE_DEPTH = `LATENCY_HMMA
) (
input clk,
input reset,
@@ -62,6 +63,7 @@ module VX_tensor_dpu #(
logic [1:0] threadgroup_readys;
// B_tile is shared across the two threadgroups; see Figure 13
VX_tensor_threadgroup #(
.ISSUE_QUEUE_DEPTH(ISSUE_QUEUE_DEPTH)
) threadgroup_0 (
.clk (clk),
.reset (reset),
@@ -75,6 +77,7 @@ module VX_tensor_dpu #(
.D_frag (D_tile[1:0])
);
VX_tensor_threadgroup #(
.ISSUE_QUEUE_DEPTH(ISSUE_QUEUE_DEPTH)
) threadgroup_1 (
.clk (clk),
.reset (reset),
@@ -99,7 +102,7 @@ module VX_tensor_dpu #(
// need to pass along warp id's to do multithreading
VX_fifo_queue #(
.DATAW ($bits(wid)),
.DEPTH (`LATENCY_HMMA + `LATENCY_HMMA)
.DEPTH (ISSUE_QUEUE_DEPTH)
) wid_queue (
.clk (clk),
.reset (reset),
@@ -121,6 +124,7 @@ endmodule
// does (m,n,k) = (2,4,2) matmul compute over 2 cycles.
// matches Figure 10(b) of the paper.
module VX_tensor_threadgroup #(
parameter ISSUE_QUEUE_DEPTH
) (
input clk,
input reset,
@@ -149,9 +153,18 @@ module VX_tensor_threadgroup #(
assign ready_in = !full;
assign valid_buf = !empty;
// 'Issue queue' for the FEDP units.
// This exists to decouple the execution of the dot-product unit from
// the operand arrival. Operands from execute_if can arrive
// intermittently according to the frontend's behavior, and since the dpu
// can also stall for a fixed initiation latency, we need to decouple the
// two to efficiently feed the dpu.
//
// TODO: better queue design possible; e.g. B_frag is shared by two
// threadgroups, so we need only 1 queue per octet for B
VX_fifo_queue #(
.DATAW ($bits(A_frag) + $bits(B_frag) + $bits(C_frag)),
.DEPTH (4)
.DEPTH (ISSUE_QUEUE_DEPTH)
) input_buffer (
.clk (clk),
.reset (reset),