tensor: Enqueue both insts in pair to issue queue

Otherwise the first-in-pair instructions can run ahead, latching their
inputs for the next pair before the second-in-pair insts finish compute
on the current one.  Might introduce more frontend stalls, need more
experimenting
This commit is contained in:
Hansung Kim
2024-05-29 14:47:25 -07:00
parent e9df173745
commit f5a9ca5bf3
2 changed files with 18 additions and 11 deletions

View File

@@ -219,7 +219,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
VX_fifo_queue #(
.DATAW(DATAW),
.DEPTH(4 /* FIXME: arbitrary */)
.DEPTH(8 /* FIXME: arbitrary */)
) pending_uops (
.clk(clk),
.reset(reset),
@@ -335,7 +335,8 @@ module VX_tensor_octet #(
assign operands_ready = inbuf_ready_in;
assign operands_valid_buf = !inbuf_empty;
wire inbuf_enq = operands_ready && operands_valid && operands_last_in_pair;
// wire inbuf_enq = operands_ready && operands_valid && operands_last_in_pair;
wire inbuf_enq = operands_ready && operands_valid;
wire inbuf_deq = operands_valid_buf && operands_ready_buf;
// the 'issue queue' for the dpu.
@@ -350,7 +351,7 @@ module VX_tensor_octet #(
VX_fifo_queue #(
.DATAW ($bits(A_in) + $bits(B_in) + $bits(C_in) +
$bits(operands_wid) + $bits(operands_step) + $bits(operands_last_in_pair)),
.DEPTH (4 /* FIXME: arbitrary */)
.DEPTH (8 /* FIXME: arbitrary */)
) input_buffer (
.clk (clk),
.reset (reset),
@@ -365,6 +366,9 @@ module VX_tensor_octet #(
`UNUSED_PIN(size)
);
// FIXME: this shouldn't be necessary
`RUNTIME_ASSERT(reset || !inbuf_full, ("dpu issue queue is full!"))
typedef struct {
logic [3:0][31:0] A_half;
logic [3:0][31:0] B_half;
@@ -411,8 +415,8 @@ module VX_tensor_octet #(
assign halves_buf = get_operand_half(operands_step_buf, A_in_buf, B_in_buf, C_in_buf);
wire do_hmma = operands_ready_buf && operands_valid_buf && operands_last_in_pair_buf;
wire operands_first_in_pair_fire = operands_ready && operands_valid && (!operands_last_in_pair);
// wire operands_first_in_pair_fire = operands_ready && operands_valid;
// wire operands_first_in_pair_fire = operands_ready && operands_valid && (!operands_last_in_pair);
wire operands_first_in_pair_fire = operands_ready_buf && operands_valid_buf && (!operands_last_in_pair_buf);
always @(*) begin
A_buffer_n = A_buffer;
@@ -421,10 +425,10 @@ module VX_tensor_octet #(
substeps_n = substeps;
if (operands_first_in_pair_fire) begin
substeps_n[operands_wid] = 1'b1; // ready for hmma
A_buffer_n[operands_wid] = halves.A_half;
B_buffer_n[operands_wid] = halves.B_half;
C_buffer_n[operands_wid] = halves.C_half;
substeps_n[operands_wid_buf] = 1'b1; // ready for hmma
A_buffer_n[operands_wid_buf] = halves_buf.A_half;
B_buffer_n[operands_wid_buf] = halves_buf.B_half;
C_buffer_n[operands_wid_buf] = halves_buf.C_half;
end
if (do_hmma) begin
substeps_n[operands_wid_buf] = 1'b0; // finished hmma, ready for next operand
@@ -521,7 +525,7 @@ module VX_tensor_octet #(
// TODO: This is probably oversized.
VX_fifo_queue #(
.DATAW ($bits(D_wid) + $bits(D_out)),
.DEPTH (4 /* FIXME: arbitrary */)
.DEPTH (8 /* FIXME: arbitrary */)
) output_buffer (
.clk (clk),
.reset (reset),
@@ -536,6 +540,9 @@ module VX_tensor_octet #(
`UNUSED_PIN(size)
);
// FIXME: this shouldn't be necessary
`RUNTIME_ASSERT(reset || !outbuf_full, ("dpu result queue is full!"))
`ifdef PERF_ENABLE
logic [`PERF_CTR_BITS-1:0] perf_tensor_dpu_total;

View File

@@ -95,7 +95,7 @@ module VX_tensor_dpu #(
VX_shift_register #(
.DATAW (1 + $bits(wid)/* + $bits(D_tile)*/),
// .DEPTH (`LATENCY_HMMA),
.DEPTH (2),
.DEPTH (4),
.RESETW (1)
) shift_reg (
.clk (clk),