tensor: Enable scaling NUM_THREADS by octets

todo: lane-to-octet mapping is arbitrary atm
This commit is contained in:
Hansung Kim
2024-05-08 11:26:09 -07:00
parent d624b3e50a
commit 9f9ec10960
2 changed files with 70 additions and 17 deletions

View File

@@ -10,8 +10,6 @@ module VX_tensor_core #(
VX_dispatch_if.slave dispatch_if [`ISSUE_WIDTH],
VX_commit_if.master commit_if [`ISSUE_WIDTH]
);
`STATIC_ASSERT(`NUM_THREADS == 32, ("tensor core requires # of threads in a warp to be 32 (try running w/ CONFIGS=\"-DNUM_THREADS=32\")"));
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
VX_tensor_core_warp #(
.ISW(i)
@@ -35,29 +33,35 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
VX_dispatch_if.slave dispatch_if,
VX_commit_if.master commit_if
);
localparam NUM_OCTETS = (`NUM_THREADS / 8);
// offet in the lane numbers that get mapped to the two threadgroups in an
// octet. E.g. two tgs map lane 0-3 and lane 16-19 -> 16
// FIXME: not sure this is the right logic. just filling in what works
localparam LANE_OFFSET_THREADGROUP = (4 * NUM_OCTETS);
wire [1:0] step = 2'(dispatch_if.data.op_type);
logic [3:0] octet_results_valid;
logic [3:0] octet_results_ready;
logic [3:0] octet_operands_ready;
logic [NUM_OCTETS-1:0] octet_results_valid;
logic [NUM_OCTETS-1:0] octet_results_ready;
logic [NUM_OCTETS-1:0] octet_operands_ready;
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0;
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1;
assign dispatch_if.ready = &octet_operands_ready;
`ifdef EXT_T_ENABLE
for (genvar i = 0; i < 4/*octets*/; ++i) begin
for (genvar i = 0; i < NUM_OCTETS; ++i) begin
`else
for (genvar i = 0; i < 0; ++i) begin
`endif
// lane-to-octet mapping; see figure 13 of the paper
wire [7:0][31:0] octet_A = {
dispatch_if.data.rs1_data[16+4*i +: 4], dispatch_if.data.rs1_data[4*i +: 4]
dispatch_if.data.rs1_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], dispatch_if.data.rs1_data[4*i +: 4]
};
wire [7:0][31:0] octet_B = {
dispatch_if.data.rs2_data[16+4*i +: 4], dispatch_if.data.rs2_data[4*i +: 4]
dispatch_if.data.rs2_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], dispatch_if.data.rs2_data[4*i +: 4]
};
wire [7:0][31:0] octet_C = {
dispatch_if.data.rs3_data[16+4*i +: 4], dispatch_if.data.rs3_data[4*i +: 4]
dispatch_if.data.rs3_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], dispatch_if.data.rs3_data[4*i +: 4]
};
logic [3:0][3:0][31:0] octet_D;
@@ -100,15 +104,15 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
assign wb_data_1[4*i+2] = octet_D[0][3];
assign wb_data_1[4*i+3] = octet_D[1][3];
assign wb_data_0[4*i+16+0] = octet_D[2][0];
assign wb_data_0[4*i+16+1] = octet_D[3][0];
assign wb_data_0[4*i+16+2] = octet_D[2][2];
assign wb_data_0[4*i+16+3] = octet_D[3][2];
assign wb_data_0[4*i+LANE_OFFSET_THREADGROUP+0] = octet_D[2][0];
assign wb_data_0[4*i+LANE_OFFSET_THREADGROUP+1] = octet_D[3][0];
assign wb_data_0[4*i+LANE_OFFSET_THREADGROUP+2] = octet_D[2][2];
assign wb_data_0[4*i+LANE_OFFSET_THREADGROUP+3] = octet_D[3][2];
assign wb_data_1[4*i+16+0] = octet_D[2][1];
assign wb_data_1[4*i+16+1] = octet_D[3][1];
assign wb_data_1[4*i+16+2] = octet_D[2][3];
assign wb_data_1[4*i+16+3] = octet_D[3][3];
assign wb_data_1[4*i+LANE_OFFSET_THREADGROUP+0] = octet_D[2][1];
assign wb_data_1[4*i+LANE_OFFSET_THREADGROUP+1] = octet_D[3][1];
assign wb_data_1[4*i+LANE_OFFSET_THREADGROUP+2] = octet_D[2][3];
assign wb_data_1[4*i+LANE_OFFSET_THREADGROUP+3] = octet_D[3][3];
end
/* commit_if.data_t parts that we need to keep around:

View File

@@ -0,0 +1,49 @@
// uop metadata (sequencing, next state), execution metadata (EX_TYPE, OP_TYPE, OP_MOD), wb, use pc, use imm, pc, imm, rd, rs1, rs2, rs3
HMMA_SET0_STEP0_0: begin
uop = {NEXT, HMMA_SET0_STEP0_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(16), `FREG(0), `FREG(8), `FREG(16)};
end
HMMA_SET0_STEP0_1: begin
uop = {NEXT, HMMA_SET0_STEP1_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(17), `FREG(1), `FREG(9), `FREG(17)};
end
HMMA_SET0_STEP1_0: begin
uop = {NEXT, HMMA_SET0_STEP1_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(18), `FREG(0), `FREG(8), `FREG(18)};
end
HMMA_SET0_STEP1_1: begin
uop = {NEXT, HMMA_SET0_STEP2_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(19), `FREG(1), `FREG(9), `FREG(19)};
end
HMMA_SET0_STEP2_0: begin
uop = {NEXT, HMMA_SET0_STEP2_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(20), `FREG(0), `FREG(8), `FREG(20)};
end
HMMA_SET0_STEP2_1: begin
uop = {NEXT, HMMA_SET0_STEP3_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(21), `FREG(1), `FREG(9), `FREG(21)};
end
HMMA_SET0_STEP3_0: begin
uop = {NEXT, HMMA_SET0_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(0), `FREG(8), `FREG(22)};
end
HMMA_SET0_STEP3_1: begin
uop = {NEXT, HMMA_SET1_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(23), `FREG(1), `FREG(9), `FREG(23)};
end
HMMA_SET1_STEP0_0: begin
uop = {NEXT, HMMA_SET1_STEP0_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(16), `FREG(2), `FREG(10), `FREG(16)};
end
HMMA_SET1_STEP0_1: begin
uop = {NEXT, HMMA_SET1_STEP1_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(17), `FREG(3), `FREG(11), `FREG(17)};
end
HMMA_SET1_STEP1_0: begin
uop = {NEXT, HMMA_SET1_STEP1_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(18), `FREG(2), `FREG(10), `FREG(18)};
end
HMMA_SET1_STEP1_1: begin
uop = {NEXT, HMMA_SET1_STEP2_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(19), `FREG(3), `FREG(11), `FREG(19)};
end
HMMA_SET1_STEP2_0: begin
uop = {NEXT, HMMA_SET1_STEP2_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(20), `FREG(2), `FREG(10), `FREG(20)};
end
HMMA_SET1_STEP2_1: begin
uop = {NEXT, HMMA_SET1_STEP3_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(21), `FREG(3), `FREG(11), `FREG(21)};
end
HMMA_SET1_STEP3_0: begin
uop = {NEXT, HMMA_SET1_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(2), `FREG(10), `FREG(22)};
end
HMMA_SET1_STEP3_1: begin
uop = {FINISH, HMMA_SET0_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b1, 32'b1, `FREG(23), `FREG(3), `FREG(11), `FREG(23)};
end