tensor: Enable scaling NUM_THREADS by octets
todo: lane-to-octet mapping is arbitrary atm
This commit is contained in:
@@ -10,8 +10,6 @@ module VX_tensor_core #(
|
||||
VX_dispatch_if.slave dispatch_if [`ISSUE_WIDTH],
|
||||
VX_commit_if.master commit_if [`ISSUE_WIDTH]
|
||||
);
|
||||
`STATIC_ASSERT(`NUM_THREADS == 32, ("tensor core requires # of threads in a warp to be 32 (try running w/ CONFIGS=\"-DNUM_THREADS=32\")"));
|
||||
|
||||
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
|
||||
VX_tensor_core_warp #(
|
||||
.ISW(i)
|
||||
@@ -35,29 +33,35 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
||||
VX_dispatch_if.slave dispatch_if,
|
||||
VX_commit_if.master commit_if
|
||||
);
|
||||
localparam NUM_OCTETS = (`NUM_THREADS / 8);
|
||||
// offet in the lane numbers that get mapped to the two threadgroups in an
|
||||
// octet. E.g. two tgs map lane 0-3 and lane 16-19 -> 16
|
||||
// FIXME: not sure this is the right logic. just filling in what works
|
||||
localparam LANE_OFFSET_THREADGROUP = (4 * NUM_OCTETS);
|
||||
|
||||
wire [1:0] step = 2'(dispatch_if.data.op_type);
|
||||
logic [3:0] octet_results_valid;
|
||||
logic [3:0] octet_results_ready;
|
||||
logic [3:0] octet_operands_ready;
|
||||
logic [NUM_OCTETS-1:0] octet_results_valid;
|
||||
logic [NUM_OCTETS-1:0] octet_results_ready;
|
||||
logic [NUM_OCTETS-1:0] octet_operands_ready;
|
||||
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0;
|
||||
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1;
|
||||
|
||||
assign dispatch_if.ready = &octet_operands_ready;
|
||||
|
||||
`ifdef EXT_T_ENABLE
|
||||
for (genvar i = 0; i < 4/*octets*/; ++i) begin
|
||||
for (genvar i = 0; i < NUM_OCTETS; ++i) begin
|
||||
`else
|
||||
for (genvar i = 0; i < 0; ++i) begin
|
||||
`endif
|
||||
// lane-to-octet mapping; see figure 13 of the paper
|
||||
wire [7:0][31:0] octet_A = {
|
||||
dispatch_if.data.rs1_data[16+4*i +: 4], dispatch_if.data.rs1_data[4*i +: 4]
|
||||
dispatch_if.data.rs1_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], dispatch_if.data.rs1_data[4*i +: 4]
|
||||
};
|
||||
wire [7:0][31:0] octet_B = {
|
||||
dispatch_if.data.rs2_data[16+4*i +: 4], dispatch_if.data.rs2_data[4*i +: 4]
|
||||
dispatch_if.data.rs2_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], dispatch_if.data.rs2_data[4*i +: 4]
|
||||
};
|
||||
wire [7:0][31:0] octet_C = {
|
||||
dispatch_if.data.rs3_data[16+4*i +: 4], dispatch_if.data.rs3_data[4*i +: 4]
|
||||
dispatch_if.data.rs3_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], dispatch_if.data.rs3_data[4*i +: 4]
|
||||
};
|
||||
|
||||
logic [3:0][3:0][31:0] octet_D;
|
||||
@@ -100,15 +104,15 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
||||
assign wb_data_1[4*i+2] = octet_D[0][3];
|
||||
assign wb_data_1[4*i+3] = octet_D[1][3];
|
||||
|
||||
assign wb_data_0[4*i+16+0] = octet_D[2][0];
|
||||
assign wb_data_0[4*i+16+1] = octet_D[3][0];
|
||||
assign wb_data_0[4*i+16+2] = octet_D[2][2];
|
||||
assign wb_data_0[4*i+16+3] = octet_D[3][2];
|
||||
assign wb_data_0[4*i+LANE_OFFSET_THREADGROUP+0] = octet_D[2][0];
|
||||
assign wb_data_0[4*i+LANE_OFFSET_THREADGROUP+1] = octet_D[3][0];
|
||||
assign wb_data_0[4*i+LANE_OFFSET_THREADGROUP+2] = octet_D[2][2];
|
||||
assign wb_data_0[4*i+LANE_OFFSET_THREADGROUP+3] = octet_D[3][2];
|
||||
|
||||
assign wb_data_1[4*i+16+0] = octet_D[2][1];
|
||||
assign wb_data_1[4*i+16+1] = octet_D[3][1];
|
||||
assign wb_data_1[4*i+16+2] = octet_D[2][3];
|
||||
assign wb_data_1[4*i+16+3] = octet_D[3][3];
|
||||
assign wb_data_1[4*i+LANE_OFFSET_THREADGROUP+0] = octet_D[2][1];
|
||||
assign wb_data_1[4*i+LANE_OFFSET_THREADGROUP+1] = octet_D[3][1];
|
||||
assign wb_data_1[4*i+LANE_OFFSET_THREADGROUP+2] = octet_D[2][3];
|
||||
assign wb_data_1[4*i+LANE_OFFSET_THREADGROUP+3] = octet_D[3][3];
|
||||
end
|
||||
|
||||
/* commit_if.data_t parts that we need to keep around:
|
||||
|
||||
49
hw/rtl/core/VX_tensor_ucode_8lanes.vh
Normal file
49
hw/rtl/core/VX_tensor_ucode_8lanes.vh
Normal file
@@ -0,0 +1,49 @@
|
||||
// uop metadata (sequencing, next state), execution metadata (EX_TYPE, OP_TYPE, OP_MOD), wb, use pc, use imm, pc, imm, rd, rs1, rs2, rs3
|
||||
HMMA_SET0_STEP0_0: begin
|
||||
uop = {NEXT, HMMA_SET0_STEP0_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(16), `FREG(0), `FREG(8), `FREG(16)};
|
||||
end
|
||||
HMMA_SET0_STEP0_1: begin
|
||||
uop = {NEXT, HMMA_SET0_STEP1_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(17), `FREG(1), `FREG(9), `FREG(17)};
|
||||
end
|
||||
HMMA_SET0_STEP1_0: begin
|
||||
uop = {NEXT, HMMA_SET0_STEP1_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(18), `FREG(0), `FREG(8), `FREG(18)};
|
||||
end
|
||||
HMMA_SET0_STEP1_1: begin
|
||||
uop = {NEXT, HMMA_SET0_STEP2_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(19), `FREG(1), `FREG(9), `FREG(19)};
|
||||
end
|
||||
HMMA_SET0_STEP2_0: begin
|
||||
uop = {NEXT, HMMA_SET0_STEP2_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(20), `FREG(0), `FREG(8), `FREG(20)};
|
||||
end
|
||||
HMMA_SET0_STEP2_1: begin
|
||||
uop = {NEXT, HMMA_SET0_STEP3_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(21), `FREG(1), `FREG(9), `FREG(21)};
|
||||
end
|
||||
HMMA_SET0_STEP3_0: begin
|
||||
uop = {NEXT, HMMA_SET0_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(0), `FREG(8), `FREG(22)};
|
||||
end
|
||||
HMMA_SET0_STEP3_1: begin
|
||||
uop = {NEXT, HMMA_SET1_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(23), `FREG(1), `FREG(9), `FREG(23)};
|
||||
end
|
||||
HMMA_SET1_STEP0_0: begin
|
||||
uop = {NEXT, HMMA_SET1_STEP0_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(16), `FREG(2), `FREG(10), `FREG(16)};
|
||||
end
|
||||
HMMA_SET1_STEP0_1: begin
|
||||
uop = {NEXT, HMMA_SET1_STEP1_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(17), `FREG(3), `FREG(11), `FREG(17)};
|
||||
end
|
||||
HMMA_SET1_STEP1_0: begin
|
||||
uop = {NEXT, HMMA_SET1_STEP1_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(18), `FREG(2), `FREG(10), `FREG(18)};
|
||||
end
|
||||
HMMA_SET1_STEP1_1: begin
|
||||
uop = {NEXT, HMMA_SET1_STEP2_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(19), `FREG(3), `FREG(11), `FREG(19)};
|
||||
end
|
||||
HMMA_SET1_STEP2_0: begin
|
||||
uop = {NEXT, HMMA_SET1_STEP2_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(20), `FREG(2), `FREG(10), `FREG(20)};
|
||||
end
|
||||
HMMA_SET1_STEP2_1: begin
|
||||
uop = {NEXT, HMMA_SET1_STEP3_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(21), `FREG(3), `FREG(11), `FREG(21)};
|
||||
end
|
||||
HMMA_SET1_STEP3_0: begin
|
||||
uop = {NEXT, HMMA_SET1_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(2), `FREG(10), `FREG(22)};
|
||||
end
|
||||
HMMA_SET1_STEP3_1: begin
|
||||
uop = {FINISH, HMMA_SET0_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b1, 32'b1, `FREG(23), `FREG(3), `FREG(11), `FREG(23)};
|
||||
end
|
||||
Reference in New Issue
Block a user