initial tcore impl
This commit is contained in:
@@ -391,7 +391,7 @@
|
||||
|
||||
// Tensor Core Latency
|
||||
`ifndef LATENCY_HMMA
|
||||
`define LATENCY_HMMA 4
|
||||
`define LATENCY_HMMA 8
|
||||
`endif
|
||||
|
||||
// Icache Configurable Knobs //////////////////////////////////////////////////
|
||||
|
||||
@@ -10,6 +10,291 @@ module VX_tensor_core #(
|
||||
VX_commit_if.master commit_if [`ISSUE_WIDTH]
|
||||
);
|
||||
`STATIC_ASSERT(`NUM_THREADS == 32, ("tensor core requires # of threads in a warp to be 32"));
|
||||
`UNUSED_VAR(clk);
|
||||
`UNUSED_VAR(reset);
|
||||
|
||||
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
|
||||
VX_tensor_core_warp #(
|
||||
.ISW(i)
|
||||
) tensor_core (
|
||||
.clk(clk),
|
||||
.reset(reset),
|
||||
|
||||
.dispatch_if(dispatch_if[i]),
|
||||
.commit_if(commit_if[i])
|
||||
);
|
||||
end
|
||||
|
||||
endmodule
|
||||
|
||||
module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
||||
parameter ISW
|
||||
) (
|
||||
input clk,
|
||||
input reset,
|
||||
|
||||
VX_dispatch_if.slave dispatch_if,
|
||||
VX_commit_if.master commit_if
|
||||
);
|
||||
logic [1:0] step = 2'(dispatch_if.data.op_type);
|
||||
logic [3:0] octet_results_valid;
|
||||
logic [3:0] octet_results_ready;
|
||||
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0;
|
||||
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1;
|
||||
|
||||
for (genvar i = 0; i < 4; ++i) begin
|
||||
logic [7:0][31:0] octet_A = {
|
||||
dispatch_if.data.rs1_data[4*i +: 4], dispatch_if.data.rs1_data[16+4*i +: 4]
|
||||
};
|
||||
logic [7:0][31:0] octet_B = {
|
||||
dispatch_if.data.rs2_data[4*i +: 4], dispatch_if.data.rs2_data[16+4*i +: 4]
|
||||
};
|
||||
logic [7:0][31:0] octet_C = {
|
||||
dispatch_if.data.rs3_data[4*i +: 4], dispatch_if.data.rs3_data[16+4*i +: 4]
|
||||
};
|
||||
|
||||
logic [3:0][3:0][31:0] octet_D;
|
||||
logic result_valid;
|
||||
logic result_ready;
|
||||
VX_tensor_octet #(
|
||||
|
||||
) octet (
|
||||
.clk(clk),
|
||||
.reset(reset),
|
||||
|
||||
.A_in(octet_A),
|
||||
.B_in(octet_B),
|
||||
.C_in(octet_C),
|
||||
.operands_valid(dispatch_if.valid),
|
||||
.operands_ready(dispatch_if.ready),
|
||||
|
||||
.step(step),
|
||||
|
||||
.D_out(octet_D),
|
||||
.result_valid(result_valid),
|
||||
.result_ready(result_ready)
|
||||
);
|
||||
|
||||
// these should always be in lockstep
|
||||
assign octet_results_valid[i] = result_valid;
|
||||
assign result_ready = octet_results_ready[i];
|
||||
|
||||
assign wb_data_0[4*i+0] = octet_D[0][0];
|
||||
assign wb_data_0[4*i+1] = octet_D[1][0];
|
||||
assign wb_data_0[4*i+2] = octet_D[0][2];
|
||||
assign wb_data_0[4*i+3] = octet_D[1][2];
|
||||
|
||||
assign wb_data_1[4*i+0] = octet_D[0][1];
|
||||
assign wb_data_1[4*i+1] = octet_D[1][1];
|
||||
assign wb_data_1[4*i+2] = octet_D[0][3];
|
||||
assign wb_data_1[4*i+3] = octet_D[1][3];
|
||||
|
||||
assign wb_data_0[4*i+16+0] = octet_D[2][0];
|
||||
assign wb_data_0[4*i+16+1] = octet_D[3][0];
|
||||
assign wb_data_0[4*i+16+2] = octet_D[2][2];
|
||||
assign wb_data_0[4*i+16+3] = octet_D[3][2];
|
||||
|
||||
assign wb_data_1[4*i+16+0] = octet_D[2][1];
|
||||
assign wb_data_1[4*i+16+1] = octet_D[3][1];
|
||||
assign wb_data_1[4*i+16+2] = octet_D[2][3];
|
||||
assign wb_data_1[4*i+16+3] = octet_D[3][3];
|
||||
end
|
||||
|
||||
/* commit_if.data_t parts that we need to keep around:
|
||||
- uuid
|
||||
- wid
|
||||
- tmask
|
||||
- PC
|
||||
- wb
|
||||
- rd
|
||||
*/
|
||||
|
||||
localparam DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS;
|
||||
|
||||
wire dispatch_if_fire = dispatch_if.valid && dispatch_if.ready;
|
||||
wire commit_if_fire = commit_if.valid && commit_if.ready;
|
||||
wire [DATAW-1:0] dispatch_if_data_enq = {
|
||||
dispatch_if.data.uuid,
|
||||
wis_to_wid(dispatch_if.data.wis, ISW),
|
||||
dispatch_if.data.tmask,
|
||||
dispatch_if.data.PC,
|
||||
dispatch_if.data.wb,
|
||||
dispatch_if.data.rd
|
||||
};
|
||||
|
||||
wire [DATAW-1:0] dispatch_if_data_deq;
|
||||
|
||||
// this is probably a little oversized
|
||||
VX_fifo_queue #(
|
||||
.DATAW(DATAW),
|
||||
.DEPTH(8)
|
||||
) pending_uops (
|
||||
.clk(clk),
|
||||
.reset(reset),
|
||||
.push(dispatch_if_fire),
|
||||
.pop(commit_if_fire),
|
||||
.data_in(dispatch_if_data_enq),
|
||||
.data_out(dispatch_if_data_deq),
|
||||
`UNUSED_PIN(empty),
|
||||
`UNUSED_PIN(alm_empty),
|
||||
`UNUSED_PIN(full), // should be impossible to overflow
|
||||
`UNUSED_PIN(alm_full),
|
||||
`UNUSED_PIN(size)
|
||||
);
|
||||
|
||||
logic subcommit, subcommit_n;
|
||||
logic all_valid = (& octet_results_valid);
|
||||
assign commit_if.valid = all_valid;
|
||||
|
||||
localparam COMMIT_DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS + (`NUM_THREADS * `XLEN) + 1 + 1 + 1;
|
||||
logic [COMMIT_DATAW-1:0] commit_if_data = {
|
||||
dispatch_if_data_deq,
|
||||
subcommit == 1'b0 ? wb_data_0 : wb_data_1,
|
||||
1'b0,
|
||||
1'b1,
|
||||
1'b1
|
||||
};
|
||||
|
||||
assign commit_if.data = commit_if_data;
|
||||
|
||||
always @(*) begin
|
||||
subcommit_n = commit_if_fire ? ~subcommit : subcommit;
|
||||
if (commit_if_fire && subcommit == 1'b1) begin
|
||||
octet_results_ready = '1;
|
||||
end
|
||||
else begin
|
||||
octet_results_ready = '0;
|
||||
end
|
||||
end
|
||||
|
||||
always @(posedge clk) begin
|
||||
if (reset) begin
|
||||
subcommit <= '0;
|
||||
end
|
||||
else begin
|
||||
subcommit <= subcommit_n;
|
||||
end
|
||||
end
|
||||
|
||||
endmodule
|
||||
|
||||
module VX_tensor_octet #(
|
||||
|
||||
) (
|
||||
input clk,
|
||||
input reset,
|
||||
|
||||
input [7:0][31:0] A_in,
|
||||
input [7:0][31:0] B_in,
|
||||
input [7:0][31:0] C_in,
|
||||
input operands_valid, // we have to backpressure due to there potentially being contention over commit
|
||||
output operands_ready,
|
||||
|
||||
input [1:0] step,
|
||||
|
||||
output [3:0][3:0][31:0] D_out,
|
||||
output result_valid,
|
||||
input result_ready
|
||||
);
|
||||
// 512 bits/octet * 4 octets per warp
|
||||
logic [3:0][31:0] A_buffer, A_buffer_n;
|
||||
logic [3:0][31:0] B_buffer, B_buffer_n;
|
||||
logic [7:0][31:0] C_buffer, C_buffer_n;
|
||||
|
||||
// half the inputs are buffered, half are not (instead coming straight from operand bus)
|
||||
// unlike the real tensor core, the banks are only 32 bit rather than 64 bit
|
||||
logic [3:0][31:0] A_half;
|
||||
logic [3:0][31:0] B_half;
|
||||
logic [7:0][31:0] C_half;
|
||||
always @(*) begin
|
||||
case (step)
|
||||
2'b00: begin
|
||||
A_half = { A_in[1:0], A_in[5:4] };
|
||||
B_half = B_in[3:0];
|
||||
end
|
||||
2'b01: begin
|
||||
A_half = { A_in[3:2], A_in[7:6] };
|
||||
B_half = B_in[3:0];
|
||||
end
|
||||
2'b10: begin
|
||||
A_half = { A_in[1:0], A_in[5:4] };
|
||||
B_half = B_in[7:4];
|
||||
end
|
||||
2'b11: begin
|
||||
A_half = { A_in[3:2], A_in[7:6] };
|
||||
B_half = B_in[7:4];
|
||||
end
|
||||
endcase
|
||||
C_half = C_in;
|
||||
end
|
||||
|
||||
logic substep;
|
||||
logic substep_n = (operands_ready && operands_valid) ? ~substep : substep;
|
||||
|
||||
always @(*) begin
|
||||
A_buffer_n = A_buffer;
|
||||
B_buffer_n = B_buffer;
|
||||
C_buffer_n = C_buffer;
|
||||
|
||||
if (substep == 1'b0) begin
|
||||
A_buffer_n = A_half;
|
||||
B_buffer_n = B_half;
|
||||
C_buffer_n = C_half;
|
||||
end
|
||||
end
|
||||
|
||||
always @(posedge clk) begin
|
||||
if (reset) begin
|
||||
A_buffer <= '0;
|
||||
B_buffer <= '0;
|
||||
C_buffer <= '0;
|
||||
substep <= '0;
|
||||
end
|
||||
else begin
|
||||
A_buffer <= A_buffer_n;
|
||||
B_buffer <= B_buffer_n;
|
||||
C_buffer <= C_buffer_n;
|
||||
substep <= substep_n;
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
wire stall = result_valid && ~result_ready;
|
||||
assign operands_ready = ~stall;
|
||||
|
||||
logic [3:0][1:0][31:0] A_tile = {
|
||||
{ A_buffer[0], A_half[0] },
|
||||
{ A_buffer[1], A_half[1] },
|
||||
{ A_buffer[2], A_half[2] },
|
||||
{ A_buffer[3], A_half[3] }
|
||||
};
|
||||
logic [1:0][3:0][31:0] B_tile = {
|
||||
B_buffer, B_half
|
||||
};
|
||||
logic [3:0][3:0][31:0] C_tile;
|
||||
|
||||
always @(*) begin
|
||||
C_tile = {
|
||||
C_buffer[0], C_half[0], C_buffer[1], C_half[1],
|
||||
C_buffer[2], C_half[2], C_buffer[3], C_half[3],
|
||||
C_buffer[4], C_half[4], C_buffer[5], C_half[5],
|
||||
C_buffer[6], C_half[6], C_buffer[7], C_half[7]
|
||||
};
|
||||
end
|
||||
|
||||
wire do_hmma = (substep == 1'b1 && operands_valid && operands_ready);
|
||||
VX_tensor_dpu #(
|
||||
|
||||
) dpu (
|
||||
.clk(clk),
|
||||
.reset(reset),
|
||||
|
||||
.stall(stall),
|
||||
|
||||
.valid_in(do_hmma),
|
||||
.A_tile(A_tile),
|
||||
.B_tile(B_tile),
|
||||
.C_tile(C_tile),
|
||||
|
||||
.valid_out(result_valid),
|
||||
.D_tile(D_out)
|
||||
);
|
||||
endmodule
|
||||
|
||||
@@ -74,8 +74,8 @@ module VX_uop_sequencer import VX_gpu_pkg::*; (
|
||||
NEXT,
|
||||
HMMA_SET0_STEP0_1,
|
||||
`EX_BITS'(`EX_TENSOR),
|
||||
`INST_OP_BITS'(0), // denotes that the first half is being computed
|
||||
`INST_MOD_BITS'(0), // field is unused for HMMA
|
||||
`INST_OP_BITS'(0), // denotes that the first step is being computed
|
||||
`INST_MOD_BITS'(0), // denotes that this is first substep (tensor core also tracks this)
|
||||
1'b1, // write back
|
||||
1'b0, // don't use PC
|
||||
1'b0, // don't use immediate
|
||||
@@ -92,8 +92,8 @@ module VX_uop_sequencer import VX_gpu_pkg::*; (
|
||||
FINISH,
|
||||
HMMA_SET0_STEP0_0,
|
||||
`EX_BITS'(`EX_TENSOR),
|
||||
`INST_OP_BITS'(1), // denotes that the second half is being computed
|
||||
`INST_MOD_BITS'(0), // field is unused for HMMA
|
||||
`INST_OP_BITS'(0), // denotes that the first step is being computed
|
||||
`INST_MOD_BITS'(1), // denotes that this is first substep (tensor core also tracks this)
|
||||
1'b1, // write back
|
||||
1'b0, // don't use PC
|
||||
1'b0, // don't use immediate
|
||||
@@ -161,6 +161,12 @@ module VX_uop_sequencer import VX_gpu_pkg::*; (
|
||||
assign ibuffer_if.data = use_uop ? ibuffer_output : uop_sequencer_if.data;
|
||||
|
||||
always @(posedge clk) begin
|
||||
|
||||
if (use_uop) begin
|
||||
$display("unexpectedly used uop at %d", $time);
|
||||
end
|
||||
|
||||
|
||||
if (reset) begin
|
||||
upc_r <= '0;
|
||||
use_uop_1d <= '0;
|
||||
|
||||
@@ -6,6 +6,8 @@ module VX_tensor_dpu #(
|
||||
input clk,
|
||||
input reset,
|
||||
|
||||
input stall,
|
||||
|
||||
input valid_in,
|
||||
input [3:0][1:0][31:0] A_tile,
|
||||
input [1:0][3:0][31:0] B_tile,
|
||||
@@ -28,7 +30,7 @@ module VX_tensor_dpu #(
|
||||
) shift_reg (
|
||||
.clk (clk),
|
||||
.reset (reset),
|
||||
.enable (1'b1),
|
||||
.enable (~stall),
|
||||
.data_in ({valid_in, result_hmma}),
|
||||
.data_out ({valid_out, D_tile})
|
||||
);
|
||||
|
||||
@@ -17,6 +17,8 @@ module VX_tensor_tb(
|
||||
.clk(clk),
|
||||
.reset(reset),
|
||||
|
||||
.stall(1'b0),
|
||||
|
||||
.valid_in(valid_in),
|
||||
.A_tile(A_tile),
|
||||
.B_tile(B_tile),
|
||||
|
||||
Reference in New Issue
Block a user