From f5a9ca5bf31fc4ddc70a81b5c7a5e6d8bc751697 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 29 May 2024 14:47:25 -0700 Subject: [PATCH] tensor: Enqueue both insts in pair to issue queue Otherwise the first-in-pair instructions can run ahead, latching their inputs for the next pair before the second-in-pair insts finish compute on the current one. Might introduce more frontend stalls, need more experimenting --- hw/rtl/core/VX_tensor_core.sv | 27 +++++++++++++++++---------- hw/rtl/fpu/VX_tensor_dpu.sv | 2 +- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/hw/rtl/core/VX_tensor_core.sv b/hw/rtl/core/VX_tensor_core.sv index 5f32f504..2fc54fc5 100644 --- a/hw/rtl/core/VX_tensor_core.sv +++ b/hw/rtl/core/VX_tensor_core.sv @@ -219,7 +219,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( VX_fifo_queue #( .DATAW(DATAW), - .DEPTH(4 /* FIXME: arbitrary */) + .DEPTH(8 /* FIXME: arbitrary */) ) pending_uops ( .clk(clk), .reset(reset), @@ -335,7 +335,8 @@ module VX_tensor_octet #( assign operands_ready = inbuf_ready_in; assign operands_valid_buf = !inbuf_empty; - wire inbuf_enq = operands_ready && operands_valid && operands_last_in_pair; + // wire inbuf_enq = operands_ready && operands_valid && operands_last_in_pair; + wire inbuf_enq = operands_ready && operands_valid; wire inbuf_deq = operands_valid_buf && operands_ready_buf; // the 'issue queue' for the dpu. @@ -350,7 +351,7 @@ module VX_tensor_octet #( VX_fifo_queue #( .DATAW ($bits(A_in) + $bits(B_in) + $bits(C_in) + $bits(operands_wid) + $bits(operands_step) + $bits(operands_last_in_pair)), - .DEPTH (4 /* FIXME: arbitrary */) + .DEPTH (8 /* FIXME: arbitrary */) ) input_buffer ( .clk (clk), .reset (reset), @@ -365,6 +366,9 @@ module VX_tensor_octet #( `UNUSED_PIN(size) ); + // FIXME: this shouldn't be necessary + `RUNTIME_ASSERT(reset || !inbuf_full, ("dpu issue queue is full!")) + typedef struct { logic [3:0][31:0] A_half; logic [3:0][31:0] B_half; @@ -411,8 +415,8 @@ module VX_tensor_octet #( assign halves_buf = get_operand_half(operands_step_buf, A_in_buf, B_in_buf, C_in_buf); wire do_hmma = operands_ready_buf && operands_valid_buf && operands_last_in_pair_buf; - wire operands_first_in_pair_fire = operands_ready && operands_valid && (!operands_last_in_pair); - // wire operands_first_in_pair_fire = operands_ready && operands_valid; + // wire operands_first_in_pair_fire = operands_ready && operands_valid && (!operands_last_in_pair); + wire operands_first_in_pair_fire = operands_ready_buf && operands_valid_buf && (!operands_last_in_pair_buf); always @(*) begin A_buffer_n = A_buffer; @@ -421,10 +425,10 @@ module VX_tensor_octet #( substeps_n = substeps; if (operands_first_in_pair_fire) begin - substeps_n[operands_wid] = 1'b1; // ready for hmma - A_buffer_n[operands_wid] = halves.A_half; - B_buffer_n[operands_wid] = halves.B_half; - C_buffer_n[operands_wid] = halves.C_half; + substeps_n[operands_wid_buf] = 1'b1; // ready for hmma + A_buffer_n[operands_wid_buf] = halves_buf.A_half; + B_buffer_n[operands_wid_buf] = halves_buf.B_half; + C_buffer_n[operands_wid_buf] = halves_buf.C_half; end if (do_hmma) begin substeps_n[operands_wid_buf] = 1'b0; // finished hmma, ready for next operand @@ -521,7 +525,7 @@ module VX_tensor_octet #( // TODO: This is probably oversized. VX_fifo_queue #( .DATAW ($bits(D_wid) + $bits(D_out)), - .DEPTH (4 /* FIXME: arbitrary */) + .DEPTH (8 /* FIXME: arbitrary */) ) output_buffer ( .clk (clk), .reset (reset), @@ -536,6 +540,9 @@ module VX_tensor_octet #( `UNUSED_PIN(size) ); + // FIXME: this shouldn't be necessary + `RUNTIME_ASSERT(reset || !outbuf_full, ("dpu result queue is full!")) + `ifdef PERF_ENABLE logic [`PERF_CTR_BITS-1:0] perf_tensor_dpu_total; diff --git a/hw/rtl/fpu/VX_tensor_dpu.sv b/hw/rtl/fpu/VX_tensor_dpu.sv index 33529370..90c2c7ed 100644 --- a/hw/rtl/fpu/VX_tensor_dpu.sv +++ b/hw/rtl/fpu/VX_tensor_dpu.sv @@ -95,7 +95,7 @@ module VX_tensor_dpu #( VX_shift_register #( .DATAW (1 + $bits(wid)/* + $bits(D_tile)*/), // .DEPTH (`LATENCY_HMMA), - .DEPTH (2), + .DEPTH (4), .RESETW (1) ) shift_reg ( .clk (clk),