initial tcore impl

This commit is contained in:
joshua
2024-03-21 01:29:38 -07:00
parent f9b4509936
commit b254281295
5 changed files with 303 additions and 8 deletions

View File

@@ -391,7 +391,7 @@
// Tensor Core Latency
`ifndef LATENCY_HMMA
`define LATENCY_HMMA 4
`define LATENCY_HMMA 8
`endif
// Icache Configurable Knobs //////////////////////////////////////////////////

View File

@@ -10,6 +10,291 @@ module VX_tensor_core #(
VX_commit_if.master commit_if [`ISSUE_WIDTH]
);
`STATIC_ASSERT(`NUM_THREADS == 32, ("tensor core requires # of threads in a warp to be 32"));
`UNUSED_VAR(clk);
`UNUSED_VAR(reset);
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
VX_tensor_core_warp #(
.ISW(i)
) tensor_core (
.clk(clk),
.reset(reset),
.dispatch_if(dispatch_if[i]),
.commit_if(commit_if[i])
);
end
endmodule
module VX_tensor_core_warp import VX_gpu_pkg::*; #(
parameter ISW
) (
input clk,
input reset,
VX_dispatch_if.slave dispatch_if,
VX_commit_if.master commit_if
);
logic [1:0] step = 2'(dispatch_if.data.op_type);
logic [3:0] octet_results_valid;
logic [3:0] octet_results_ready;
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0;
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1;
for (genvar i = 0; i < 4; ++i) begin
logic [7:0][31:0] octet_A = {
dispatch_if.data.rs1_data[4*i +: 4], dispatch_if.data.rs1_data[16+4*i +: 4]
};
logic [7:0][31:0] octet_B = {
dispatch_if.data.rs2_data[4*i +: 4], dispatch_if.data.rs2_data[16+4*i +: 4]
};
logic [7:0][31:0] octet_C = {
dispatch_if.data.rs3_data[4*i +: 4], dispatch_if.data.rs3_data[16+4*i +: 4]
};
logic [3:0][3:0][31:0] octet_D;
logic result_valid;
logic result_ready;
VX_tensor_octet #(
) octet (
.clk(clk),
.reset(reset),
.A_in(octet_A),
.B_in(octet_B),
.C_in(octet_C),
.operands_valid(dispatch_if.valid),
.operands_ready(dispatch_if.ready),
.step(step),
.D_out(octet_D),
.result_valid(result_valid),
.result_ready(result_ready)
);
// these should always be in lockstep
assign octet_results_valid[i] = result_valid;
assign result_ready = octet_results_ready[i];
assign wb_data_0[4*i+0] = octet_D[0][0];
assign wb_data_0[4*i+1] = octet_D[1][0];
assign wb_data_0[4*i+2] = octet_D[0][2];
assign wb_data_0[4*i+3] = octet_D[1][2];
assign wb_data_1[4*i+0] = octet_D[0][1];
assign wb_data_1[4*i+1] = octet_D[1][1];
assign wb_data_1[4*i+2] = octet_D[0][3];
assign wb_data_1[4*i+3] = octet_D[1][3];
assign wb_data_0[4*i+16+0] = octet_D[2][0];
assign wb_data_0[4*i+16+1] = octet_D[3][0];
assign wb_data_0[4*i+16+2] = octet_D[2][2];
assign wb_data_0[4*i+16+3] = octet_D[3][2];
assign wb_data_1[4*i+16+0] = octet_D[2][1];
assign wb_data_1[4*i+16+1] = octet_D[3][1];
assign wb_data_1[4*i+16+2] = octet_D[2][3];
assign wb_data_1[4*i+16+3] = octet_D[3][3];
end
/* commit_if.data_t parts that we need to keep around:
- uuid
- wid
- tmask
- PC
- wb
- rd
*/
localparam DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS;
wire dispatch_if_fire = dispatch_if.valid && dispatch_if.ready;
wire commit_if_fire = commit_if.valid && commit_if.ready;
wire [DATAW-1:0] dispatch_if_data_enq = {
dispatch_if.data.uuid,
wis_to_wid(dispatch_if.data.wis, ISW),
dispatch_if.data.tmask,
dispatch_if.data.PC,
dispatch_if.data.wb,
dispatch_if.data.rd
};
wire [DATAW-1:0] dispatch_if_data_deq;
// this is probably a little oversized
VX_fifo_queue #(
.DATAW(DATAW),
.DEPTH(8)
) pending_uops (
.clk(clk),
.reset(reset),
.push(dispatch_if_fire),
.pop(commit_if_fire),
.data_in(dispatch_if_data_enq),
.data_out(dispatch_if_data_deq),
`UNUSED_PIN(empty),
`UNUSED_PIN(alm_empty),
`UNUSED_PIN(full), // should be impossible to overflow
`UNUSED_PIN(alm_full),
`UNUSED_PIN(size)
);
logic subcommit, subcommit_n;
logic all_valid = (& octet_results_valid);
assign commit_if.valid = all_valid;
localparam COMMIT_DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS + (`NUM_THREADS * `XLEN) + 1 + 1 + 1;
logic [COMMIT_DATAW-1:0] commit_if_data = {
dispatch_if_data_deq,
subcommit == 1'b0 ? wb_data_0 : wb_data_1,
1'b0,
1'b1,
1'b1
};
assign commit_if.data = commit_if_data;
always @(*) begin
subcommit_n = commit_if_fire ? ~subcommit : subcommit;
if (commit_if_fire && subcommit == 1'b1) begin
octet_results_ready = '1;
end
else begin
octet_results_ready = '0;
end
end
always @(posedge clk) begin
if (reset) begin
subcommit <= '0;
end
else begin
subcommit <= subcommit_n;
end
end
endmodule
module VX_tensor_octet #(
) (
input clk,
input reset,
input [7:0][31:0] A_in,
input [7:0][31:0] B_in,
input [7:0][31:0] C_in,
input operands_valid, // we have to backpressure due to there potentially being contention over commit
output operands_ready,
input [1:0] step,
output [3:0][3:0][31:0] D_out,
output result_valid,
input result_ready
);
// 512 bits/octet * 4 octets per warp
logic [3:0][31:0] A_buffer, A_buffer_n;
logic [3:0][31:0] B_buffer, B_buffer_n;
logic [7:0][31:0] C_buffer, C_buffer_n;
// half the inputs are buffered, half are not (instead coming straight from operand bus)
// unlike the real tensor core, the banks are only 32 bit rather than 64 bit
logic [3:0][31:0] A_half;
logic [3:0][31:0] B_half;
logic [7:0][31:0] C_half;
always @(*) begin
case (step)
2'b00: begin
A_half = { A_in[1:0], A_in[5:4] };
B_half = B_in[3:0];
end
2'b01: begin
A_half = { A_in[3:2], A_in[7:6] };
B_half = B_in[3:0];
end
2'b10: begin
A_half = { A_in[1:0], A_in[5:4] };
B_half = B_in[7:4];
end
2'b11: begin
A_half = { A_in[3:2], A_in[7:6] };
B_half = B_in[7:4];
end
endcase
C_half = C_in;
end
logic substep;
logic substep_n = (operands_ready && operands_valid) ? ~substep : substep;
always @(*) begin
A_buffer_n = A_buffer;
B_buffer_n = B_buffer;
C_buffer_n = C_buffer;
if (substep == 1'b0) begin
A_buffer_n = A_half;
B_buffer_n = B_half;
C_buffer_n = C_half;
end
end
always @(posedge clk) begin
if (reset) begin
A_buffer <= '0;
B_buffer <= '0;
C_buffer <= '0;
substep <= '0;
end
else begin
A_buffer <= A_buffer_n;
B_buffer <= B_buffer_n;
C_buffer <= C_buffer_n;
substep <= substep_n;
end
end
wire stall = result_valid && ~result_ready;
assign operands_ready = ~stall;
logic [3:0][1:0][31:0] A_tile = {
{ A_buffer[0], A_half[0] },
{ A_buffer[1], A_half[1] },
{ A_buffer[2], A_half[2] },
{ A_buffer[3], A_half[3] }
};
logic [1:0][3:0][31:0] B_tile = {
B_buffer, B_half
};
logic [3:0][3:0][31:0] C_tile;
always @(*) begin
C_tile = {
C_buffer[0], C_half[0], C_buffer[1], C_half[1],
C_buffer[2], C_half[2], C_buffer[3], C_half[3],
C_buffer[4], C_half[4], C_buffer[5], C_half[5],
C_buffer[6], C_half[6], C_buffer[7], C_half[7]
};
end
wire do_hmma = (substep == 1'b1 && operands_valid && operands_ready);
VX_tensor_dpu #(
) dpu (
.clk(clk),
.reset(reset),
.stall(stall),
.valid_in(do_hmma),
.A_tile(A_tile),
.B_tile(B_tile),
.C_tile(C_tile),
.valid_out(result_valid),
.D_tile(D_out)
);
endmodule

View File

@@ -74,8 +74,8 @@ module VX_uop_sequencer import VX_gpu_pkg::*; (
NEXT,
HMMA_SET0_STEP0_1,
`EX_BITS'(`EX_TENSOR),
`INST_OP_BITS'(0), // denotes that the first half is being computed
`INST_MOD_BITS'(0), // field is unused for HMMA
`INST_OP_BITS'(0), // denotes that the first step is being computed
`INST_MOD_BITS'(0), // denotes that this is first substep (tensor core also tracks this)
1'b1, // write back
1'b0, // don't use PC
1'b0, // don't use immediate
@@ -92,8 +92,8 @@ module VX_uop_sequencer import VX_gpu_pkg::*; (
FINISH,
HMMA_SET0_STEP0_0,
`EX_BITS'(`EX_TENSOR),
`INST_OP_BITS'(1), // denotes that the second half is being computed
`INST_MOD_BITS'(0), // field is unused for HMMA
`INST_OP_BITS'(0), // denotes that the first step is being computed
`INST_MOD_BITS'(1), // denotes that this is first substep (tensor core also tracks this)
1'b1, // write back
1'b0, // don't use PC
1'b0, // don't use immediate
@@ -161,6 +161,12 @@ module VX_uop_sequencer import VX_gpu_pkg::*; (
assign ibuffer_if.data = use_uop ? ibuffer_output : uop_sequencer_if.data;
always @(posedge clk) begin
if (use_uop) begin
$display("unexpectedly used uop at %d", $time);
end
if (reset) begin
upc_r <= '0;
use_uop_1d <= '0;

View File

@@ -6,6 +6,8 @@ module VX_tensor_dpu #(
input clk,
input reset,
input stall,
input valid_in,
input [3:0][1:0][31:0] A_tile,
input [1:0][3:0][31:0] B_tile,
@@ -28,7 +30,7 @@ module VX_tensor_dpu #(
) shift_reg (
.clk (clk),
.reset (reset),
.enable (1'b1),
.enable (~stall),
.data_in ({valid_in, result_hmma}),
.data_out ({valid_out, D_tile})
);

View File

@@ -17,6 +17,8 @@ module VX_tensor_tb(
.clk(clk),
.reset(reset),
.stall(1'b0),
.valid_in(valid_in),
.A_tile(A_tile),
.B_tile(B_tile),