From b254281295cecf7f90c2d6061d4a8dab69b86bff Mon Sep 17 00:00:00 2001 From: joshua Date: Thu, 21 Mar 2024 01:29:38 -0700 Subject: [PATCH] initial tcore impl --- hw/rtl/VX_config.vh | 2 +- hw/rtl/core/VX_tensor_core.sv | 289 +++++++++++++++++++++++++++++++- hw/rtl/core/VX_uop_sequencer.sv | 14 +- hw/rtl/fpu/VX_tensor_dpu.sv | 4 +- hw/rtl/fpu/VX_tensor_tb.sv | 2 + 5 files changed, 303 insertions(+), 8 deletions(-) diff --git a/hw/rtl/VX_config.vh b/hw/rtl/VX_config.vh index e8bb56fc..d741da8d 100644 --- a/hw/rtl/VX_config.vh +++ b/hw/rtl/VX_config.vh @@ -391,7 +391,7 @@ // Tensor Core Latency `ifndef LATENCY_HMMA -`define LATENCY_HMMA 4 +`define LATENCY_HMMA 8 `endif // Icache Configurable Knobs ////////////////////////////////////////////////// diff --git a/hw/rtl/core/VX_tensor_core.sv b/hw/rtl/core/VX_tensor_core.sv index c31f3f9f..a9419c66 100644 --- a/hw/rtl/core/VX_tensor_core.sv +++ b/hw/rtl/core/VX_tensor_core.sv @@ -10,6 +10,291 @@ module VX_tensor_core #( VX_commit_if.master commit_if [`ISSUE_WIDTH] ); `STATIC_ASSERT(`NUM_THREADS == 32, ("tensor core requires # of threads in a warp to be 32")); - `UNUSED_VAR(clk); - `UNUSED_VAR(reset); + + for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin + VX_tensor_core_warp #( + .ISW(i) + ) tensor_core ( + .clk(clk), + .reset(reset), + + .dispatch_if(dispatch_if[i]), + .commit_if(commit_if[i]) + ); + end + +endmodule + +module VX_tensor_core_warp import VX_gpu_pkg::*; #( + parameter ISW +) ( + input clk, + input reset, + + VX_dispatch_if.slave dispatch_if, + VX_commit_if.master commit_if +); + logic [1:0] step = 2'(dispatch_if.data.op_type); + logic [3:0] octet_results_valid; + logic [3:0] octet_results_ready; + logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0; + logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1; + + for (genvar i = 0; i < 4; ++i) begin + logic [7:0][31:0] octet_A = { + dispatch_if.data.rs1_data[4*i +: 4], dispatch_if.data.rs1_data[16+4*i +: 4] + }; + logic [7:0][31:0] octet_B = { + dispatch_if.data.rs2_data[4*i +: 4], dispatch_if.data.rs2_data[16+4*i +: 4] + }; + logic [7:0][31:0] octet_C = { + dispatch_if.data.rs3_data[4*i +: 4], dispatch_if.data.rs3_data[16+4*i +: 4] + }; + + logic [3:0][3:0][31:0] octet_D; + logic result_valid; + logic result_ready; + VX_tensor_octet #( + + ) octet ( + .clk(clk), + .reset(reset), + + .A_in(octet_A), + .B_in(octet_B), + .C_in(octet_C), + .operands_valid(dispatch_if.valid), + .operands_ready(dispatch_if.ready), + + .step(step), + + .D_out(octet_D), + .result_valid(result_valid), + .result_ready(result_ready) + ); + + // these should always be in lockstep + assign octet_results_valid[i] = result_valid; + assign result_ready = octet_results_ready[i]; + + assign wb_data_0[4*i+0] = octet_D[0][0]; + assign wb_data_0[4*i+1] = octet_D[1][0]; + assign wb_data_0[4*i+2] = octet_D[0][2]; + assign wb_data_0[4*i+3] = octet_D[1][2]; + + assign wb_data_1[4*i+0] = octet_D[0][1]; + assign wb_data_1[4*i+1] = octet_D[1][1]; + assign wb_data_1[4*i+2] = octet_D[0][3]; + assign wb_data_1[4*i+3] = octet_D[1][3]; + + assign wb_data_0[4*i+16+0] = octet_D[2][0]; + assign wb_data_0[4*i+16+1] = octet_D[3][0]; + assign wb_data_0[4*i+16+2] = octet_D[2][2]; + assign wb_data_0[4*i+16+3] = octet_D[3][2]; + + assign wb_data_1[4*i+16+0] = octet_D[2][1]; + assign wb_data_1[4*i+16+1] = octet_D[3][1]; + assign wb_data_1[4*i+16+2] = octet_D[2][3]; + assign wb_data_1[4*i+16+3] = octet_D[3][3]; + end + + /* commit_if.data_t parts that we need to keep around: + - uuid + - wid + - tmask + - PC + - wb + - rd + */ + + localparam DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS; + + wire dispatch_if_fire = dispatch_if.valid && dispatch_if.ready; + wire commit_if_fire = commit_if.valid && commit_if.ready; + wire [DATAW-1:0] dispatch_if_data_enq = { + dispatch_if.data.uuid, + wis_to_wid(dispatch_if.data.wis, ISW), + dispatch_if.data.tmask, + dispatch_if.data.PC, + dispatch_if.data.wb, + dispatch_if.data.rd + }; + + wire [DATAW-1:0] dispatch_if_data_deq; + + // this is probably a little oversized + VX_fifo_queue #( + .DATAW(DATAW), + .DEPTH(8) + ) pending_uops ( + .clk(clk), + .reset(reset), + .push(dispatch_if_fire), + .pop(commit_if_fire), + .data_in(dispatch_if_data_enq), + .data_out(dispatch_if_data_deq), + `UNUSED_PIN(empty), + `UNUSED_PIN(alm_empty), + `UNUSED_PIN(full), // should be impossible to overflow + `UNUSED_PIN(alm_full), + `UNUSED_PIN(size) + ); + + logic subcommit, subcommit_n; + logic all_valid = (& octet_results_valid); + assign commit_if.valid = all_valid; + + localparam COMMIT_DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS + (`NUM_THREADS * `XLEN) + 1 + 1 + 1; + logic [COMMIT_DATAW-1:0] commit_if_data = { + dispatch_if_data_deq, + subcommit == 1'b0 ? wb_data_0 : wb_data_1, + 1'b0, + 1'b1, + 1'b1 + }; + + assign commit_if.data = commit_if_data; + + always @(*) begin + subcommit_n = commit_if_fire ? ~subcommit : subcommit; + if (commit_if_fire && subcommit == 1'b1) begin + octet_results_ready = '1; + end + else begin + octet_results_ready = '0; + end + end + + always @(posedge clk) begin + if (reset) begin + subcommit <= '0; + end + else begin + subcommit <= subcommit_n; + end + end + +endmodule + +module VX_tensor_octet #( + +) ( + input clk, + input reset, + + input [7:0][31:0] A_in, + input [7:0][31:0] B_in, + input [7:0][31:0] C_in, + input operands_valid, // we have to backpressure due to there potentially being contention over commit + output operands_ready, + + input [1:0] step, + + output [3:0][3:0][31:0] D_out, + output result_valid, + input result_ready +); + // 512 bits/octet * 4 octets per warp + logic [3:0][31:0] A_buffer, A_buffer_n; + logic [3:0][31:0] B_buffer, B_buffer_n; + logic [7:0][31:0] C_buffer, C_buffer_n; + + // half the inputs are buffered, half are not (instead coming straight from operand bus) + // unlike the real tensor core, the banks are only 32 bit rather than 64 bit + logic [3:0][31:0] A_half; + logic [3:0][31:0] B_half; + logic [7:0][31:0] C_half; + always @(*) begin + case (step) + 2'b00: begin + A_half = { A_in[1:0], A_in[5:4] }; + B_half = B_in[3:0]; + end + 2'b01: begin + A_half = { A_in[3:2], A_in[7:6] }; + B_half = B_in[3:0]; + end + 2'b10: begin + A_half = { A_in[1:0], A_in[5:4] }; + B_half = B_in[7:4]; + end + 2'b11: begin + A_half = { A_in[3:2], A_in[7:6] }; + B_half = B_in[7:4]; + end + endcase + C_half = C_in; + end + + logic substep; + logic substep_n = (operands_ready && operands_valid) ? ~substep : substep; + + always @(*) begin + A_buffer_n = A_buffer; + B_buffer_n = B_buffer; + C_buffer_n = C_buffer; + + if (substep == 1'b0) begin + A_buffer_n = A_half; + B_buffer_n = B_half; + C_buffer_n = C_half; + end + end + + always @(posedge clk) begin + if (reset) begin + A_buffer <= '0; + B_buffer <= '0; + C_buffer <= '0; + substep <= '0; + end + else begin + A_buffer <= A_buffer_n; + B_buffer <= B_buffer_n; + C_buffer <= C_buffer_n; + substep <= substep_n; + end + end + + + wire stall = result_valid && ~result_ready; + assign operands_ready = ~stall; + + logic [3:0][1:0][31:0] A_tile = { + { A_buffer[0], A_half[0] }, + { A_buffer[1], A_half[1] }, + { A_buffer[2], A_half[2] }, + { A_buffer[3], A_half[3] } + }; + logic [1:0][3:0][31:0] B_tile = { + B_buffer, B_half + }; + logic [3:0][3:0][31:0] C_tile; + + always @(*) begin + C_tile = { + C_buffer[0], C_half[0], C_buffer[1], C_half[1], + C_buffer[2], C_half[2], C_buffer[3], C_half[3], + C_buffer[4], C_half[4], C_buffer[5], C_half[5], + C_buffer[6], C_half[6], C_buffer[7], C_half[7] + }; + end + + wire do_hmma = (substep == 1'b1 && operands_valid && operands_ready); + VX_tensor_dpu #( + + ) dpu ( + .clk(clk), + .reset(reset), + + .stall(stall), + + .valid_in(do_hmma), + .A_tile(A_tile), + .B_tile(B_tile), + .C_tile(C_tile), + + .valid_out(result_valid), + .D_tile(D_out) + ); endmodule diff --git a/hw/rtl/core/VX_uop_sequencer.sv b/hw/rtl/core/VX_uop_sequencer.sv index f18e473e..c57ea0ba 100644 --- a/hw/rtl/core/VX_uop_sequencer.sv +++ b/hw/rtl/core/VX_uop_sequencer.sv @@ -74,8 +74,8 @@ module VX_uop_sequencer import VX_gpu_pkg::*; ( NEXT, HMMA_SET0_STEP0_1, `EX_BITS'(`EX_TENSOR), - `INST_OP_BITS'(0), // denotes that the first half is being computed - `INST_MOD_BITS'(0), // field is unused for HMMA + `INST_OP_BITS'(0), // denotes that the first step is being computed + `INST_MOD_BITS'(0), // denotes that this is first substep (tensor core also tracks this) 1'b1, // write back 1'b0, // don't use PC 1'b0, // don't use immediate @@ -92,8 +92,8 @@ module VX_uop_sequencer import VX_gpu_pkg::*; ( FINISH, HMMA_SET0_STEP0_0, `EX_BITS'(`EX_TENSOR), - `INST_OP_BITS'(1), // denotes that the second half is being computed - `INST_MOD_BITS'(0), // field is unused for HMMA + `INST_OP_BITS'(0), // denotes that the first step is being computed + `INST_MOD_BITS'(1), // denotes that this is first substep (tensor core also tracks this) 1'b1, // write back 1'b0, // don't use PC 1'b0, // don't use immediate @@ -161,6 +161,12 @@ module VX_uop_sequencer import VX_gpu_pkg::*; ( assign ibuffer_if.data = use_uop ? ibuffer_output : uop_sequencer_if.data; always @(posedge clk) begin + + if (use_uop) begin + $display("unexpectedly used uop at %d", $time); + end + + if (reset) begin upc_r <= '0; use_uop_1d <= '0; diff --git a/hw/rtl/fpu/VX_tensor_dpu.sv b/hw/rtl/fpu/VX_tensor_dpu.sv index f9147c9d..8108790c 100644 --- a/hw/rtl/fpu/VX_tensor_dpu.sv +++ b/hw/rtl/fpu/VX_tensor_dpu.sv @@ -6,6 +6,8 @@ module VX_tensor_dpu #( input clk, input reset, + input stall, + input valid_in, input [3:0][1:0][31:0] A_tile, input [1:0][3:0][31:0] B_tile, @@ -28,7 +30,7 @@ module VX_tensor_dpu #( ) shift_reg ( .clk (clk), .reset (reset), - .enable (1'b1), + .enable (~stall), .data_in ({valid_in, result_hmma}), .data_out ({valid_out, D_tile}) ); diff --git a/hw/rtl/fpu/VX_tensor_tb.sv b/hw/rtl/fpu/VX_tensor_tb.sv index 9fa9fa41..4a6076b0 100644 --- a/hw/rtl/fpu/VX_tensor_tb.sv +++ b/hw/rtl/fpu/VX_tensor_tb.sv @@ -17,6 +17,8 @@ module VX_tensor_tb( .clk(clk), .reset(reset), + .stall(1'b0), + .valid_in(valid_in), .A_tile(A_tile), .B_tile(B_tile),