Do two-cycle compute with 1 FEDP per lane

This commit is contained in:
Hansung Kim
2024-05-29 22:01:03 -07:00
parent 35273b3d74
commit 73a2f5781e

View File

@@ -30,47 +30,43 @@ module VX_tensor_dpu #(
always @(posedge clk) begin
if (reset) begin
ready_reg <= '1;
end else if (valid_in) begin
end else if (valid_in && ready_in) begin
ready_reg <= '0;
dpi_print_results(int'(ISW), int'(OCTET), A_tile, B_tile, C_tile, result_hmma);
end else if (valid_out) begin
end else if (valid_out && ready_out) begin
ready_reg <= '1;
end
end
// ready as soon as valid_out
// assign ready_in = ready_reg || valid_out;
// assign ready_in = ready_reg;
// fully pipelined; ready_in is coupled to ready_out by immediately
// stalling
assign ready_in = ready_out;
// assign ready_in = ready_out;
// wire dpu_valid;
// wire [31:0] dpu_data;
// TensorDotProductUnit dpu_pipe (
// .clock (clk),
// .reset (reset),
// .io_in_valid (valid_in && ready_in),
// .io_in_bits_a_0 (32'h40000000),
// .io_in_bits_a_1 (32'h40000000),
// .io_in_bits_a_2 (32'h40000000),
// .io_in_bits_a_3 (32'h40000000),
// .io_in_bits_b_0 (32'h40000000),
// .io_in_bits_b_1 (32'h40000000),
// .io_in_bits_b_2 (32'h40000000),
// .io_in_bits_b_3 (32'h40000000),
// .io_in_bits_c (32'h3f800000),
// .io_out_valid (dpu_valid),
// .io_out_bits_data (dpu_data)
// // fixed-latency queue
// VX_shift_register #(
// .DATAW (1 + $bits(wid)/* + $bits(D_tile)*/),
// .DEPTH (`LATENCY_HMMA + 1),
// .RESETW (1)
// ) shift_reg (
// .clk (clk),
// .reset (reset),
// .enable (ready_out),
// .data_in ({valid_in && ready_in, wid /*, result_hmma*/}),
// .data_out ({valid_out, D_wid/*, D_tile */})
// );
logic [1:0] threadgroup_valids;
logic [1:0] threadgroup_readys;
// B_tile is shared across the two threadgroups; see Figure 13
VX_tensor_threadgroup #(
) threadgroup_0 (
.clk (clk),
.reset (reset),
.valid_in (valid_in),
.ready_in (threadgroup_readys[0]),
.stall (!ready_out),
.A_frag (A_tile[1:0]),
.B_frag (B_tile),
@@ -83,6 +79,7 @@ module VX_tensor_dpu #(
.clk (clk),
.reset (reset),
.valid_in (valid_in),
.ready_in (threadgroup_readys[1]),
.stall (!ready_out),
.A_frag (A_tile[3:2]),
.B_frag (B_tile),
@@ -91,21 +88,36 @@ module VX_tensor_dpu #(
.D_frag (D_tile[3:2])
);
// fixed-latency queue
VX_shift_register #(
.DATAW (1 + $bits(wid)/* + $bits(D_tile)*/),
.DEPTH (`LATENCY_HMMA),
.RESETW (1)
) shift_reg (
.clk (clk),
.reset (reset),
.enable (ready_out),
.data_in ({valid_in && ready_in, wid /*, result_hmma*/}),
.data_out ({valid_out, D_wid/*, D_tile */})
wire empty;
wire full;
wire enq = valid_in && ready_in;
wire deq = valid_out && ready_out;
assign ready_in = &(threadgroup_readys);
assign valid_out = &(threadgroup_valids);
// need to pass along warp id's to do multithreading
VX_fifo_queue #(
.DATAW ($bits(wid)),
.DEPTH (`LATENCY_HMMA + `LATENCY_HMMA)
) wid_queue (
.clk (clk),
.reset (reset),
.push (enq),
.pop (deq),
.data_in (wid),
.data_out (D_wid),
.empty (empty),
`UNUSED_PIN(alm_empty),
.full (full), // should be impossible to overflow
`UNUSED_PIN(alm_full),
`UNUSED_PIN(size)
);
`RUNTIME_ASSERT(reset || (&(threadgroup_valids) == valid_out),
("FEDP and metadata queue went out of sync!"))
`RUNTIME_ASSERT(reset || !full, ("dpu wid queue is full!"))
// `RUNTIME_ASSERT(reset || (&(threadgroup_valids) == valid_out),
// ("FEDP and metadata queue went out of sync!"))
endmodule
// does (m,n,k) = (2,4,2) matmul compute over 2 cycles.
@@ -116,6 +128,7 @@ module VX_tensor_threadgroup #(
input reset,
input valid_in,
output ready_in,
input stall,
input [1:0][1:0][31:0] A_frag,
input [1:0][3:0][31:0] B_frag,
@@ -123,35 +136,123 @@ module VX_tensor_threadgroup #(
output valid_out,
output [1:0][3:0][31:0] D_frag
);
wire [1:0][1:0][31:0] A_frag_buf;
wire [1:0][3:0][31:0] B_frag_buf;
wire [1:0][3:0][31:0] C_frag_buf;
wire valid_buf;
wire ready_buf;
wire enq = valid_in && ready_in;
wire deq = valid_buf && ready_buf;
wire empty;
wire full;
assign ready_in = !full;
assign valid_buf = !empty;
VX_fifo_queue #(
.DATAW ($bits(A_frag) + $bits(B_frag) + $bits(C_frag)),
.DEPTH (4)
) input_buffer (
.clk (clk),
.reset (reset),
.push (enq),
.pop (deq),
.data_in ({A_frag, B_frag, C_frag}),
.data_out ({A_frag_buf, B_frag_buf, C_frag_buf}),
.empty (empty),
`UNUSED_PIN(alm_empty),
.full (full),
`UNUSED_PIN(alm_full),
`UNUSED_PIN(size)
);
logic [3:0] fedp_valids;
wire fedp_valid_out = &(fedp_valids);
wire fedp_ready_out = !stall;
wire fedp_fire_out = fedp_valid_out && fedp_ready_out;
wire fedp_valid_in = valid_buf;
wire fedp_ready_in = fedp_ready_out; // coupled
wire fedp_fire_in = fedp_valid_in && fedp_ready_in;
// 0: FEDP uses first half from input_buffer
// 1: FEDP uses last half and pops input_buffer
logic step_in;
// 0: FEDP produces first half of D_frag
// 1: FEDP produces last half of D_frag and asserts valid_out
logic step_out;
assign ready_buf = fedp_fire_in && (step_in == 1'b1);
// FIXME shrink size
logic [1:0][3:0][31:0] D_reg, D_reg_n;
wire [3:0][31:0] D_half;
always @(*) begin
D_reg_n = D_reg;
if (fedp_fire_out) begin
if (step_out == 1'b0) begin
D_reg_n[0][0] = D_half[0];
D_reg_n[0][2] = D_half[1];
D_reg_n[1][0] = D_half[2];
D_reg_n[1][2] = D_half[3];
end
end
end
always @(posedge clk) begin
if (reset) begin
step_in <= '0;
step_out <= '0;
D_reg <= '0;
end else begin
if (fedp_fire_in) begin
step_in <= ~step_in;
end
if (fedp_fire_out) begin
step_out <= ~step_out;
end
D_reg <= D_reg_n;
end
end
assign D_frag[0][0] = D_reg[0][0];
assign D_frag[0][2] = D_reg[0][2];
assign D_frag[1][0] = D_reg[1][0];
assign D_frag[1][2] = D_reg[1][2];
assign D_frag[0][1] = D_half[0];
assign D_frag[0][3] = D_half[1];
assign D_frag[1][1] = D_half[2];
assign D_frag[1][3] = D_half[3];
// 4 FEDPs per threadgroup
// FIXME: experimenting with 8 FEDPs first
logic [1:0][3:0] valids;
for (genvar D_row = 0; D_row < 2; ++D_row) begin
for (genvar D_col = 0; D_col < 4; ++D_col) begin
for (genvar i = 0; i < 4; ++i) begin
localparam int d_row = i / 2;
localparam int d_col = (i % 2) * 2;
// four-element dot product (FEDP) unit
TensorDotProductUnit fedp (
.clock (clk),
.reset (reset),
.io_in_valid (valid_in),
.io_in_bits_a_0 (A_frag[D_row][0]),
.io_in_bits_a_1 (A_frag[D_row][1]),
.io_in_valid (fedp_fire_in),
.io_in_bits_a_0 (A_frag_buf[d_row][0]),
.io_in_bits_a_1 (A_frag_buf[d_row][1]),
.io_in_bits_a_2 (32'h0),
.io_in_bits_a_3 (32'h0),
.io_in_bits_b_0 (B_frag[0][D_col]),
.io_in_bits_b_1 (B_frag[1][D_col]),
.io_in_bits_b_0 (step_in == 1'b0 ? B_frag_buf[0][d_col] : B_frag_buf[0][d_col + 1]),
.io_in_bits_b_1 (step_in == 1'b0 ? B_frag_buf[1][d_col] : B_frag_buf[1][d_col + 1]),
.io_in_bits_b_2 (32'h0),
.io_in_bits_b_3 (32'h0),
.io_in_bits_c (C_frag[D_row][D_col]),
.io_in_bits_c (step_in == 1'b0 ? C_frag_buf[d_row][d_col] : C_frag_buf[d_row][d_col + 1]),
.io_stall (stall),
.io_out_valid (valids[D_row][D_col]),
.io_out_bits_data (D_frag[D_row][D_col])
.io_out_valid (fedp_valids[i]),
.io_out_bits_data (D_half[i])
);
end
end
assign valid_out = (&(valids[0])) && (&(valids[1]));
assign valid_out = fedp_valid_out && (step_out == 1'b1);
endmodule
`endif