tensor: doc
This commit is contained in:
@@ -43,7 +43,8 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
||||
|
||||
assign dispatch_if.ready = &octet_operands_ready;
|
||||
|
||||
for (genvar i = 0; i < 4; ++i) begin
|
||||
for (genvar i = 0; i < 4/*octets*/; ++i) begin
|
||||
// lane-to-octet mapping; see figure 13 of the paper
|
||||
wire [7:0][31:0] octet_A = {
|
||||
dispatch_if.data.rs1_data[16+4*i +: 4], dispatch_if.data.rs1_data[4*i +: 4]
|
||||
};
|
||||
@@ -81,6 +82,9 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
||||
assign octet_results_valid[i] = result_valid;
|
||||
assign result_ready = octet_results_ready[i];
|
||||
|
||||
// each octet produces 4x4 output partial sum, but the 8 lanes mapped
|
||||
// to the octet can only do 8 fp32 writeback at a time; so we need to
|
||||
// split writeback over two cycles
|
||||
assign wb_data_0[4*i+0] = octet_D[0][0];
|
||||
assign wb_data_0[4*i+1] = octet_D[1][0];
|
||||
assign wb_data_0[4*i+2] = octet_D[0][2];
|
||||
@@ -150,11 +154,11 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
||||
|
||||
localparam COMMIT_DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS + (`NUM_THREADS * `XLEN) + 1 + 1 + 1;
|
||||
wire [COMMIT_DATAW-1:0] commit_if_data = {
|
||||
dispatch_if_data_deq,
|
||||
subcommit == 1'b0 ? wb_data_0 : wb_data_1,
|
||||
1'b0,
|
||||
1'b1,
|
||||
1'b1
|
||||
dispatch_if_data_deq, /* uuid ~ rd */
|
||||
subcommit == 1'b0 ? wb_data_0 : wb_data_1, /* data */
|
||||
1'b0, /* pid */
|
||||
1'b1, /* sop */
|
||||
1'b1 /* eop */
|
||||
};
|
||||
|
||||
assign commit_if.data = commit_if_data;
|
||||
@@ -204,8 +208,11 @@ module VX_tensor_octet #(
|
||||
logic [3:0][31:0] B_buffer, B_buffer_n;
|
||||
logic [7:0][31:0] C_buffer, C_buffer_n;
|
||||
|
||||
// half the inputs are buffered, half are not (instead coming straight from operand bus)
|
||||
// unlike the real tensor core, the banks are only 32 bit rather than 64 bit
|
||||
// half the inputs are buffered, half are not (instead coming straight
|
||||
// from operand bus) unlike the real tensor core.
|
||||
// the banks are only 32 bit rather than 64 bit (a pair of fp32 regs).
|
||||
// since A and B are supplied by 4 lanes each, we get 4 fp32's at a time
|
||||
// (8 for C).
|
||||
logic [3:0][31:0] A_half;
|
||||
logic [3:0][31:0] B_half;
|
||||
logic [7:0][31:0] C_half;
|
||||
@@ -265,15 +272,18 @@ module VX_tensor_octet #(
|
||||
wire stall = result_valid && ~result_ready;
|
||||
assign operands_ready = ~stall;
|
||||
|
||||
// A is 4x2 fp32 matrix
|
||||
wire [3:0][1:0][31:0] A_tile = {
|
||||
{ A_half[3], A_buffer[3] },
|
||||
{ A_half[2], A_buffer[2] },
|
||||
{ A_half[1], A_buffer[1] },
|
||||
{ A_half[0], A_buffer[0] }
|
||||
};
|
||||
// B is 2x4 fp32 matrix
|
||||
wire [1:0][3:0][31:0] B_tile = {
|
||||
B_half, B_buffer
|
||||
};
|
||||
// C is 4x4 fp32 matrix
|
||||
logic [3:0][3:0][31:0] C_tile;
|
||||
|
||||
always @(*) begin
|
||||
@@ -286,6 +296,8 @@ module VX_tensor_octet #(
|
||||
end
|
||||
|
||||
wire do_hmma = (substep == 1'b1 && operands_valid && operands_ready);
|
||||
|
||||
// this does (m,n,k)=(4,4,2) matmul, modeling compute of a single octet
|
||||
VX_tensor_dpu #(
|
||||
.ISW(ISW),
|
||||
.OCTET(OCTET)
|
||||
|
||||
Reference in New Issue
Block a user