From fb626ee21ca65d2b57437f1a8148f297fce67f7e Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sun, 5 May 2024 18:35:52 -0700 Subject: [PATCH] tensor: doc --- hw/rtl/core/VX_tensor_core.sv | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/hw/rtl/core/VX_tensor_core.sv b/hw/rtl/core/VX_tensor_core.sv index f7f49fc7..a08498d1 100644 --- a/hw/rtl/core/VX_tensor_core.sv +++ b/hw/rtl/core/VX_tensor_core.sv @@ -43,7 +43,8 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( assign dispatch_if.ready = &octet_operands_ready; - for (genvar i = 0; i < 4; ++i) begin + for (genvar i = 0; i < 4/*octets*/; ++i) begin + // lane-to-octet mapping; see figure 13 of the paper wire [7:0][31:0] octet_A = { dispatch_if.data.rs1_data[16+4*i +: 4], dispatch_if.data.rs1_data[4*i +: 4] }; @@ -81,6 +82,9 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( assign octet_results_valid[i] = result_valid; assign result_ready = octet_results_ready[i]; + // each octet produces 4x4 output partial sum, but the 8 lanes mapped + // to the octet can only do 8 fp32 writeback at a time; so we need to + // split writeback over two cycles assign wb_data_0[4*i+0] = octet_D[0][0]; assign wb_data_0[4*i+1] = octet_D[1][0]; assign wb_data_0[4*i+2] = octet_D[0][2]; @@ -150,11 +154,11 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( localparam COMMIT_DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS + (`NUM_THREADS * `XLEN) + 1 + 1 + 1; wire [COMMIT_DATAW-1:0] commit_if_data = { - dispatch_if_data_deq, - subcommit == 1'b0 ? wb_data_0 : wb_data_1, - 1'b0, - 1'b1, - 1'b1 + dispatch_if_data_deq, /* uuid ~ rd */ + subcommit == 1'b0 ? wb_data_0 : wb_data_1, /* data */ + 1'b0, /* pid */ + 1'b1, /* sop */ + 1'b1 /* eop */ }; assign commit_if.data = commit_if_data; @@ -204,8 +208,11 @@ module VX_tensor_octet #( logic [3:0][31:0] B_buffer, B_buffer_n; logic [7:0][31:0] C_buffer, C_buffer_n; - // half the inputs are buffered, half are not (instead coming straight from operand bus) - // unlike the real tensor core, the banks are only 32 bit rather than 64 bit + // half the inputs are buffered, half are not (instead coming straight + // from operand bus) unlike the real tensor core. + // the banks are only 32 bit rather than 64 bit (a pair of fp32 regs). + // since A and B are supplied by 4 lanes each, we get 4 fp32's at a time + // (8 for C). logic [3:0][31:0] A_half; logic [3:0][31:0] B_half; logic [7:0][31:0] C_half; @@ -265,15 +272,18 @@ module VX_tensor_octet #( wire stall = result_valid && ~result_ready; assign operands_ready = ~stall; + // A is 4x2 fp32 matrix wire [3:0][1:0][31:0] A_tile = { { A_half[3], A_buffer[3] }, { A_half[2], A_buffer[2] }, { A_half[1], A_buffer[1] }, { A_half[0], A_buffer[0] } }; + // B is 2x4 fp32 matrix wire [1:0][3:0][31:0] B_tile = { B_half, B_buffer }; + // C is 4x4 fp32 matrix logic [3:0][3:0][31:0] C_tile; always @(*) begin @@ -286,6 +296,8 @@ module VX_tensor_octet #( end wire do_hmma = (substep == 1'b1 && operands_valid && operands_ready); + + // this does (m,n,k)=(4,4,2) matmul, modeling compute of a single octet VX_tensor_dpu #( .ISW(ISW), .OCTET(OCTET)