diff --git a/hw/rtl/core/VX_tensor_core.sv b/hw/rtl/core/VX_tensor_core.sv index 0612ca12..5f32f504 100644 --- a/hw/rtl/core/VX_tensor_core.sv +++ b/hw/rtl/core/VX_tensor_core.sv @@ -125,10 +125,9 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( .operands_valid(execute_if.valid), .operands_wid(execute_if.data.wid), .operands_last_in_pair(last_in_pair), + .operands_step(step), .operands_ready(octet_operands_ready[i]), - .step(step), - .D_out(octet_D), .D_wid(wb_wid), .result_valid(result_valid), @@ -186,18 +185,38 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( // pid/sop/eop set later }; + // wire [DATAW-1:0] execute_if_data_deq; + + // VX_fifo_queue #( + // .DATAW(DATAW), + // .DEPTH(4 /* FIXME: arbitrary */) + // ) pending_uops ( + // .clk(clk), + // .reset(reset), + // .push(execute_if_fire), + // .pop(commit_if_fire), + // .data_in(execute_if_data_enq), + // .data_out(execute_if_data_deq), + // `UNUSED_PIN(empty), + // `UNUSED_PIN(alm_empty), + // `UNUSED_PIN(full), // should be impossible to overflow + // `UNUSED_PIN(alm_full), + // `UNUSED_PIN(size) + // ); + wire [`NUM_WARPS-1:0][DATAW-1:0] execute_if_data_deq; for (genvar i = 0; i < `NUM_WARPS; i++) begin - wire enq = execute_if_fire && (execute_if.data.wid == `NW_WIDTH'(i)); - wire deq = commit_if_fire && ( wb_wid == `NW_WIDTH'(i)); - logic full; - // execute_if request queue. // This has to be separated per-warp, as otherwise requests from // multiple warps can be enqueued interleaved, which makes it hard to - // ensure two consecutive dequeues are associated to the same warp for + // ensure two consecutive dequeues are associated with the same warp for // commit. + + wire enq = execute_if_fire && (execute_if.data.wid == `NW_WIDTH'(i)); + wire deq = commit_if_fire && ( wb_wid == `NW_WIDTH'(i)); + wire full; + VX_fifo_queue #( .DATAW(DATAW), .DEPTH(4 /* FIXME: arbitrary */) @@ -215,7 +234,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( `UNUSED_PIN(size) ); - `RUNTIME_ASSERT(!full, ("tensor core uop queue is full!")); + `RUNTIME_ASSERT(!(!reset && full), ("tensor core uop queue is full!")); end // unlike execute which can be interleaved between warps, commit is @@ -229,6 +248,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( localparam COMMIT_DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS + (`NUM_THREADS * `XLEN) + 1 + 1 + 1; wire [COMMIT_DATAW-1:0] commit_if_data = { execute_if_data_deq[wb_wid], /* uuid ~ rd */ + // execute_if_data_deq, /* uuid ~ rd */ subcommit == 1'b0 ? wb_data_0 : wb_data_1, /* data */ 1'b0, /* pid */ 1'b1, /* sop */ @@ -271,11 +291,10 @@ module VX_tensor_octet #( input operands_valid, input [`NW_WIDTH-1:0] operands_wid, input operands_last_in_pair, + input [1:0] operands_step, // we have to backpressure due to there potentially being contention over commit output operands_ready, - input [1:0] step, - output [3:0][3:0][31:0] D_out, output [`NW_WIDTH-1:0] D_wid, output result_valid, @@ -292,11 +311,73 @@ module VX_tensor_octet #( logic [3:0][31:0] A_half; logic [3:0][31:0] B_half; logic [7:0][31:0] C_half; + logic [3:0][31:0] A_half_buf; + logic [3:0][31:0] B_half_buf; + logic [7:0][31:0] C_half_buf; + logic [`NUM_WARPS-1:0] substeps; logic [`NUM_WARPS-1:0] substeps_n; - always @(*) begin + wire [7:0][31:0] A_in_buf; + wire [7:0][31:0] B_in_buf; + wire [7:0][31:0] C_in_buf; + wire operands_valid_buf; + wire operands_ready_buf; + wire [`NW_WIDTH-1:0] operands_wid_buf; + wire operands_last_in_pair_buf; + wire [1:0] operands_step_buf; + + wire inbuf_empty; + wire inbuf_full; + wire inbuf_ready_in; + assign inbuf_ready_in = !inbuf_full; + assign operands_ready = inbuf_ready_in; + assign operands_valid_buf = !inbuf_empty; + + wire inbuf_enq = operands_ready && operands_valid && operands_last_in_pair; + wire inbuf_deq = operands_valid_buf && operands_ready_buf; + + // the 'issue queue' for the dpu. + // This exists to decouple the input of the dot-product unit from + // execute_if.ready. execute_if can arrive intermittently according to + // the frontend's behavior, and since the dpu can also stall for a fixed + // initiation latency, we need to decouple the two to efficiently feed the + // dpu. + // This only applies to the last instruction in a pair, since the first + // instruction only acts to buffer the operands and can execute + // immediately without backpressure. So we don't enqueue them. + VX_fifo_queue #( + .DATAW ($bits(A_in) + $bits(B_in) + $bits(C_in) + + $bits(operands_wid) + $bits(operands_step) + $bits(operands_last_in_pair)), + .DEPTH (4 /* FIXME: arbitrary */) + ) input_buffer ( + .clk (clk), + .reset (reset), + .push (inbuf_enq), + .pop (inbuf_deq), + .data_in ({A_in, B_in, C_in, operands_wid, operands_step, operands_last_in_pair}), + .data_out ({A_in_buf, B_in_buf, C_in_buf, operands_wid_buf, operands_step_buf, operands_last_in_pair_buf}), + .empty (inbuf_empty), + `UNUSED_PIN(alm_empty), + .full (inbuf_full), + `UNUSED_PIN(alm_full), + `UNUSED_PIN(size) + ); + + typedef struct { + logic [3:0][31:0] A_half; + logic [3:0][31:0] B_half; + logic [7:0][31:0] C_half; + } half_t; + + function half_t get_operand_half( + input logic [1:0] step, + input logic [7:0][31:0] A_in, + input logic [7:0][31:0] B_in, + input logic [7:0][31:0] C_in + ); + half_t half; // note that not all lanes participate at every step case (step) 2'b00: begin @@ -304,28 +385,34 @@ module VX_tensor_octet #( // by two threadgroups: [0:2,0:2] and [4:6,0:2] in Step 0 of // Figure 10(b). B_in OTOH is shared by two threadgroups. // Note k-dimension is shrunk from 4 to 2. - A_half = { A_in[5:4], A_in[1:0] }; - B_half = B_in[3:0]; + half.A_half = { A_in[5:4], A_in[1:0] }; + half.B_half = B_in[3:0]; end 2'b01: begin - A_half = { A_in[7:6], A_in[3:2] }; - B_half = B_in[3:0]; + half.A_half = { A_in[7:6], A_in[3:2] }; + half.B_half = B_in[3:0]; end 2'b10: begin - A_half = { A_in[5:4], A_in[1:0] }; - B_half = B_in[7:4]; + half.A_half = { A_in[5:4], A_in[1:0] }; + half.B_half = B_in[7:4]; end 2'b11: begin - A_half = { A_in[7:6], A_in[3:2] }; - B_half = B_in[7:4]; + half.A_half = { A_in[7:6], A_in[3:2] }; + half.B_half = B_in[7:4]; end endcase - C_half = C_in; - end + half.C_half = C_in; + return half; + endfunction - logic substep; - wire operands_fire = operands_ready && operands_valid; - wire substep_n = operands_fire && operands_last_in_pair; + half_t halves; + half_t halves_buf; + assign halves = get_operand_half(operands_step, A_in, B_in, C_in); + assign halves_buf = get_operand_half(operands_step_buf, A_in_buf, B_in_buf, C_in_buf); + + wire do_hmma = operands_ready_buf && operands_valid_buf && operands_last_in_pair_buf; + wire operands_first_in_pair_fire = operands_ready && operands_valid && (!operands_last_in_pair); + // wire operands_first_in_pair_fire = operands_ready && operands_valid; always @(*) begin A_buffer_n = A_buffer; @@ -333,20 +420,15 @@ module VX_tensor_octet #( C_buffer_n = C_buffer; substeps_n = substeps; - if (operands_fire) begin - substeps_n[operands_wid] = ~substeps[operands_wid]; - if (!operands_last_in_pair) begin - A_buffer_n[operands_wid] = A_half; - B_buffer_n[operands_wid] = B_half; - C_buffer_n[operands_wid] = C_half; - end + if (operands_first_in_pair_fire) begin + substeps_n[operands_wid] = 1'b1; // ready for hmma + A_buffer_n[operands_wid] = halves.A_half; + B_buffer_n[operands_wid] = halves.B_half; + C_buffer_n[operands_wid] = halves.C_half; + end + if (do_hmma) begin + substeps_n[operands_wid_buf] = 1'b0; // finished hmma, ready for next operand end - - // if (operands_fire && (substep == 1'b0)) begin - // A_buffer_n[operands_wid] = A_half; - // B_buffer_n[operands_wid] = B_half; - // C_buffer_n[operands_wid] = C_half; - // end end always @(posedge clk) begin @@ -354,43 +436,39 @@ module VX_tensor_octet #( A_buffer <= '0; B_buffer <= '0; C_buffer <= '0; - - substep <= '0; substeps <= '0; end else begin A_buffer <= A_buffer_n; B_buffer <= B_buffer_n; C_buffer <= C_buffer_n; - - substep <= substep_n; substeps <= substeps_n; end end - wire hmma_ready; wire outbuf_ready_in; - // wire stall = result_valid && ~result_ready; // backpressure from commit wire stall = ~outbuf_ready_in; + wire hmma_ready; + // assign operands_ready = ~stall; // TODO: Below line is to only allow 1 warp to occupy the octet at a time; // currently, dpu is fully-pipelined and allows concurrency between // multiple warps. This seems to be not a problem though given that the // RF operand read takes >=2 cycles, which should be the end-to-end // latency of the DPU anyways - assign operands_ready = hmma_ready && ~stall; + assign operands_ready_buf = hmma_ready && ~stall; // A is 4x2 fp32 matrix wire [3:0][1:0][31:0] A_tile = { - { A_half[3], A_buffer[operands_wid][3] }, - { A_half[2], A_buffer[operands_wid][2] }, - { A_half[1], A_buffer[operands_wid][1] }, - { A_half[0], A_buffer[operands_wid][0] } + { halves_buf.A_half[3], A_buffer[operands_wid_buf][3] }, + { halves_buf.A_half[2], A_buffer[operands_wid_buf][2] }, + { halves_buf.A_half[1], A_buffer[operands_wid_buf][1] }, + { halves_buf.A_half[0], A_buffer[operands_wid_buf][0] } }; // B is 2x4 fp32 matrix wire [1:0][3:0][31:0] B_tile = { - B_half, B_buffer[operands_wid] + halves_buf.B_half, B_buffer[operands_wid_buf] }; // C is 4x4 fp32 matrix logic [3:0][3:0][31:0] C_tile; @@ -398,14 +476,12 @@ module VX_tensor_octet #( logic [`NW_WIDTH-1:0] D_wid_dpu; always @(*) begin - C_tile[3] = { C_half[7], C_buffer[operands_wid][7], C_half[5], C_buffer[operands_wid][5] }; - C_tile[2] = { C_half[6], C_buffer[operands_wid][6], C_half[4], C_buffer[operands_wid][4] }; - C_tile[1] = { C_half[3], C_buffer[operands_wid][3], C_half[1], C_buffer[operands_wid][1] }; - C_tile[0] = { C_half[2], C_buffer[operands_wid][2], C_half[0], C_buffer[operands_wid][0] }; + C_tile[3] = { halves_buf.C_half[7], C_buffer[operands_wid_buf][7], halves_buf.C_half[5], C_buffer[operands_wid_buf][5] }; + C_tile[2] = { halves_buf.C_half[6], C_buffer[operands_wid_buf][6], halves_buf.C_half[4], C_buffer[operands_wid_buf][4] }; + C_tile[1] = { halves_buf.C_half[3], C_buffer[operands_wid_buf][3], halves_buf.C_half[1], C_buffer[operands_wid_buf][1] }; + C_tile[0] = { halves_buf.C_half[2], C_buffer[operands_wid_buf][2], halves_buf.C_half[0], C_buffer[operands_wid_buf][0] }; end - // wire do_hmma = operands_fire && (substeps[operands_wid] == 1'b1); - wire do_hmma = operands_fire && operands_last_in_pair; wire dpu_valid; // this does (m,n,k)=(4,4,2) matmul, modeling compute of a single octet @@ -423,7 +499,7 @@ module VX_tensor_octet #( .A_tile(A_tile), .B_tile(B_tile), .C_tile(C_tile), - .wid(operands_wid), + .wid(operands_wid_buf), .valid_out(dpu_valid), .D_tile(D_tile), @@ -438,14 +514,14 @@ module VX_tensor_octet #( wire outbuf_enq = outbuf_ready_in && dpu_valid; wire outbuf_deq = result_valid && result_ready; - // buffer to stage the result tile for 2 cycles until commit/writeback is - // complete. This decouples the irregular dpu output traffic from the - // regular, every-2-cycle commit traffic and thereby ensures the commit - // pipeline is used more efficiently. + // buffer to stage the result D tile for 2 cycles until commit/writeback + // is complete. This decouples the irregular dpu output traffic from the + // regular, every-2-cycle commit traffic to ensure the commit pipeline is + // used more efficiently. // TODO: This is probably oversized. VX_fifo_queue #( .DATAW ($bits(D_wid) + $bits(D_out)), - .DEPTH (8 /* FIXME: arbitrary */) + .DEPTH (4 /* FIXME: arbitrary */) ) output_buffer ( .clk (clk), .reset (reset), diff --git a/hw/rtl/fpu/VX_tensor_dpu.sv b/hw/rtl/fpu/VX_tensor_dpu.sv index 7a3ee41d..7e96a296 100644 --- a/hw/rtl/fpu/VX_tensor_dpu.sv +++ b/hw/rtl/fpu/VX_tensor_dpu.sv @@ -51,7 +51,7 @@ module VX_tensor_dpu #( .clk (clk), .reset (reset), .enable (~stall), - .data_in ({valid_in, wid, result_hmma}), + .data_in ({valid_in && ready_in, wid, result_hmma}), .data_out ({valid_out, D_wid, D_tile}) ); endmodule