Stage half-operands per warp

An easy solution to handle multiple concurrent warp operations by
staging half-operands in their own per-warp register.  This might
increase area requirement by quite a bit.

TODO: Commit is not being handled correctly yet
This commit is contained in:
Hansung Kim
2024-05-25 19:08:17 -07:00
parent 45d86b26a2
commit 8775458a8f

View File

@@ -83,6 +83,8 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
localparam LANE_OFFSET_THREADGROUP = (4 * NUM_OCTETS);
wire [1:0] step = 2'(execute_if.data.op_type);
wire operands_last_in_pair = (execute_if.data.op_mod == `INST_MOD_BITS'(1));
logic [NUM_OCTETS-1:0] octet_results_valid;
logic [NUM_OCTETS-1:0] octet_results_ready;
logic [NUM_OCTETS-1:0] octet_operands_ready;
@@ -111,6 +113,8 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
logic [3:0][3:0][31:0] octet_D;
logic result_valid;
logic result_ready;
// op_mod is reused to indicate instruction's id in pair
VX_tensor_octet #(
.ISW(ISW),
.OCTET(i)
@@ -122,6 +126,8 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
.B_in(octet_B),
.C_in(octet_C),
.operands_valid(execute_if.valid),
.operands_wid(execute_if.data.wid),
.operands_last_in_pair(operands_last_in_pair),
.operands_ready(octet_operands_ready[i]),
.step(step),
@@ -245,11 +251,14 @@ module VX_tensor_octet #(
input clk,
input reset,
input [7:0][31:0] A_in,
input [7:0][31:0] B_in,
input [7:0][31:0] C_in,
input operands_valid, // we have to backpressure due to there potentially being contention over commit
output operands_ready,
input [7:0][31:0] A_in,
input [7:0][31:0] B_in,
input [7:0][31:0] C_in,
input operands_valid,
input [`NW_WIDTH-1:0] operands_wid,
input operands_last_in_pair,
// we have to backpressure due to there potentially being contention over commit
output operands_ready,
input [1:0] step,
@@ -258,9 +267,9 @@ module VX_tensor_octet #(
input result_ready
);
// 512 bits/octet * 4 octets per warp
logic [3:0][31:0] A_buffer, A_buffer_n;
logic [3:0][31:0] B_buffer, B_buffer_n;
logic [7:0][31:0] C_buffer, C_buffer_n;
logic [`NUM_WARPS-1:0][3:0][31:0] A_buffer, A_buffer_n;
logic [`NUM_WARPS-1:0][3:0][31:0] B_buffer, B_buffer_n;
logic [`NUM_WARPS-1:0][7:0][31:0] C_buffer, C_buffer_n;
// half the inputs are buffered, half are not (instead coming straight
// from operand bus) unlike the real tensor core.
@@ -268,6 +277,10 @@ module VX_tensor_octet #(
logic [3:0][31:0] A_half;
logic [3:0][31:0] B_half;
logic [7:0][31:0] C_half;
logic [`NUM_WARPS-1:0] substeps;
logic [`NUM_WARPS-1:0] substeps_n;
always @(*) begin
// note that not all lanes participate at every step
case (step)
@@ -296,18 +309,29 @@ module VX_tensor_octet #(
end
logic substep;
wire substep_n = (operands_ready && operands_valid) ? ~substep : substep;
wire operands_fire = operands_ready && operands_valid;
wire substep_n = operands_fire && operands_last_in_pair;
always @(*) begin
A_buffer_n = A_buffer;
B_buffer_n = B_buffer;
C_buffer_n = C_buffer;
substeps_n = substeps;
if (substep == 1'b0) begin
A_buffer_n = A_half;
B_buffer_n = B_half;
C_buffer_n = C_half;
if (operands_fire) begin
substeps_n[operands_wid] = ~substeps[operands_wid];
if (!operands_last_in_pair) begin
A_buffer_n[operands_wid] = A_half;
B_buffer_n[operands_wid] = B_half;
C_buffer_n[operands_wid] = C_half;
end
end
// if (operands_fire && (substep == 1'b0)) begin
// A_buffer_n[operands_wid] = A_half;
// B_buffer_n[operands_wid] = B_half;
// C_buffer_n[operands_wid] = C_half;
// end
end
always @(posedge clk) begin
@@ -315,13 +339,17 @@ module VX_tensor_octet #(
A_buffer <= '0;
B_buffer <= '0;
C_buffer <= '0;
substep <= '0;
substeps <= '0;
end
else begin
A_buffer <= A_buffer_n;
B_buffer <= B_buffer_n;
C_buffer <= C_buffer_n;
substep <= substep_n;
substeps <= substeps_n;
end
end
@@ -330,39 +358,38 @@ module VX_tensor_octet #(
// wire stall = result_valid && ~result_ready;
// backpressure from commit
wire stall = ~outbuf_ready_in;
assign operands_ready = ~stall;
// assign operands_ready = ~stall;
// TODO: Below line is to only allow 1 warp to occupy the octet at a time;
// currently, dpu is fully-pipelined and allows concurrency between
// multiple warps. This seems to be not a problem though given that the
// RF operand read takes >=2 cycles, which should be the end-to-end
// latency of the DPU anyways
// assign operands_ready = hmma_ready && ~stall;
assign operands_ready = hmma_ready && ~stall;
// A is 4x2 fp32 matrix
wire [3:0][1:0][31:0] A_tile = {
{ A_half[3], A_buffer[3] },
{ A_half[2], A_buffer[2] },
{ A_half[1], A_buffer[1] },
{ A_half[0], A_buffer[0] }
{ A_half[3], A_buffer[operands_wid][3] },
{ A_half[2], A_buffer[operands_wid][2] },
{ A_half[1], A_buffer[operands_wid][1] },
{ A_half[0], A_buffer[operands_wid][0] }
};
// B is 2x4 fp32 matrix
wire [1:0][3:0][31:0] B_tile = {
B_half, B_buffer
B_half, B_buffer[operands_wid]
};
// C is 4x4 fp32 matrix
logic [3:0][3:0][31:0] C_tile;
logic [3:0][3:0][31:0] D_tile;
always @(*) begin
C_tile = {
C_half[7], C_buffer[7], C_half[5], C_buffer[5],
C_half[6], C_buffer[6], C_half[4], C_buffer[4],
C_half[3], C_buffer[3], C_half[1], C_buffer[1],
C_half[2], C_buffer[2], C_half[0], C_buffer[0]
};
C_tile[3] = { C_half[7], C_buffer[operands_wid][7], C_half[5], C_buffer[operands_wid][5] };
C_tile[2] = { C_half[6], C_buffer[operands_wid][6], C_half[4], C_buffer[operands_wid][4] };
C_tile[1] = { C_half[3], C_buffer[operands_wid][3], C_half[1], C_buffer[operands_wid][1] };
C_tile[0] = { C_half[2], C_buffer[operands_wid][2], C_half[0], C_buffer[operands_wid][0] };
end
wire do_hmma = (substep == 1'b1 && operands_valid && operands_ready);
// wire do_hmma = operands_fire && (substeps[operands_wid] == 1'b1);
wire do_hmma = operands_fire && operands_last_in_pair;
wire dpu_valid;
// this does (m,n,k)=(4,4,2) matmul, modeling compute of a single octet