tensor: doc

This commit is contained in:
Hansung Kim
2024-05-05 18:35:52 -07:00
parent 9ea291eea2
commit fb626ee21c

View File

@@ -43,7 +43,8 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
assign dispatch_if.ready = &octet_operands_ready;
for (genvar i = 0; i < 4; ++i) begin
for (genvar i = 0; i < 4/*octets*/; ++i) begin
// lane-to-octet mapping; see figure 13 of the paper
wire [7:0][31:0] octet_A = {
dispatch_if.data.rs1_data[16+4*i +: 4], dispatch_if.data.rs1_data[4*i +: 4]
};
@@ -81,6 +82,9 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
assign octet_results_valid[i] = result_valid;
assign result_ready = octet_results_ready[i];
// each octet produces 4x4 output partial sum, but the 8 lanes mapped
// to the octet can only do 8 fp32 writeback at a time; so we need to
// split writeback over two cycles
assign wb_data_0[4*i+0] = octet_D[0][0];
assign wb_data_0[4*i+1] = octet_D[1][0];
assign wb_data_0[4*i+2] = octet_D[0][2];
@@ -150,11 +154,11 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
localparam COMMIT_DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS + (`NUM_THREADS * `XLEN) + 1 + 1 + 1;
wire [COMMIT_DATAW-1:0] commit_if_data = {
dispatch_if_data_deq,
subcommit == 1'b0 ? wb_data_0 : wb_data_1,
1'b0,
1'b1,
1'b1
dispatch_if_data_deq, /* uuid ~ rd */
subcommit == 1'b0 ? wb_data_0 : wb_data_1, /* data */
1'b0, /* pid */
1'b1, /* sop */
1'b1 /* eop */
};
assign commit_if.data = commit_if_data;
@@ -204,8 +208,11 @@ module VX_tensor_octet #(
logic [3:0][31:0] B_buffer, B_buffer_n;
logic [7:0][31:0] C_buffer, C_buffer_n;
// half the inputs are buffered, half are not (instead coming straight from operand bus)
// unlike the real tensor core, the banks are only 32 bit rather than 64 bit
// half the inputs are buffered, half are not (instead coming straight
// from operand bus) unlike the real tensor core.
// the banks are only 32 bit rather than 64 bit (a pair of fp32 regs).
// since A and B are supplied by 4 lanes each, we get 4 fp32's at a time
// (8 for C).
logic [3:0][31:0] A_half;
logic [3:0][31:0] B_half;
logic [7:0][31:0] C_half;
@@ -265,15 +272,18 @@ module VX_tensor_octet #(
wire stall = result_valid && ~result_ready;
assign operands_ready = ~stall;
// A is 4x2 fp32 matrix
wire [3:0][1:0][31:0] A_tile = {
{ A_half[3], A_buffer[3] },
{ A_half[2], A_buffer[2] },
{ A_half[1], A_buffer[1] },
{ A_half[0], A_buffer[0] }
};
// B is 2x4 fp32 matrix
wire [1:0][3:0][31:0] B_tile = {
B_half, B_buffer
};
// C is 4x4 fp32 matrix
logic [3:0][3:0][31:0] C_tile;
always @(*) begin
@@ -286,6 +296,8 @@ module VX_tensor_octet #(
end
wire do_hmma = (substep == 1'b1 && operands_valid && operands_ready);
// this does (m,n,k)=(4,4,2) matmul, modeling compute of a single octet
VX_tensor_dpu #(
.ISW(ISW),
.OCTET(OCTET)