From 20faf87b80931052b984204b66c39b6dd40e4b81 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Mon, 19 Aug 2024 16:42:02 -0700 Subject: [PATCH] tensor: Rename halves_buf to reduce confusion --- hw/rtl/core/VX_tensor_core.sv | 38 ++++++++++++++++++----------------- hw/rtl/fpu/VX_tensor_dpu.sv | 2 +- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/hw/rtl/core/VX_tensor_core.sv b/hw/rtl/core/VX_tensor_core.sv index d52150ac..e9976085 100644 --- a/hw/rtl/core/VX_tensor_core.sv +++ b/hw/rtl/core/VX_tensor_core.sv @@ -425,9 +425,7 @@ module VX_tensor_octet #( endfunction half_t halves; - half_t halves_buf; - assign halves = get_operand_half(operands_step, A_in, B_in, C_in); - assign halves_buf = get_operand_half(operands_step_buf, A_in_buf, B_in_buf, C_in_buf); + assign halves = get_operand_half(operands_step_buf, A_in_buf, B_in_buf, C_in_buf); wire do_hmma = operands_ready_buf && operands_valid_buf && operands_last_in_pair_buf; // wire operands_first_in_pair_fire = operands_ready && operands_valid && (!operands_last_in_pair); @@ -454,9 +452,9 @@ module VX_tensor_octet #( if (operands_first_in_pair_fire) begin // NOTE: substeps is only used for debugging substeps_n[operands_wid_buf] = 1'b1; // ready for hmma - A_buffer_n[operands_wid_buf] = halves_buf.A_half; - B_buffer_n[operands_wid_buf] = halves_buf.B_half; - C_buffer_n[operands_wid_buf] = halves_buf.C_half; + A_buffer_n[operands_wid_buf] = halves.A_half; + B_buffer_n[operands_wid_buf] = halves.B_half; + C_buffer_n[operands_wid_buf] = halves.C_half; end if (do_hmma) begin substeps_n[operands_wid_buf] = 1'b0; // finished hmma, ready for next operand @@ -478,28 +476,32 @@ module VX_tensor_octet #( assign operands_ready_buf = hmma_ready; // all *_tiles below are row-major - // A is a 4x2 fp32 matrix + // A is a 4x2 fp32 matrix; row 0-2 for one threadgroup, row 4-6 for the + // other. The two columns (along k) are shared between the threadgroups. + // Buffered data are combined with the current data along the K dimension. + // See figure 10(b). wire [3:0][1:0][31:0] A_tile = { - { halves_buf.A_half[3], A_buffer[operands_wid_buf][3] }, - { halves_buf.A_half[2], A_buffer[operands_wid_buf][2] }, - { halves_buf.A_half[1], A_buffer[operands_wid_buf][1] }, - { halves_buf.A_half[0], A_buffer[operands_wid_buf][0] } + { halves.A_half[3], A_buffer[operands_wid_buf][3] }, + { halves.A_half[2], A_buffer[operands_wid_buf][2] }, + { halves.A_half[1], A_buffer[operands_wid_buf][1] }, + { halves.A_half[0], A_buffer[operands_wid_buf][0] } }; - // B is a 2x4 fp32 matrix + // B is a 2x4 fp32 matrix, shared between the two threadgroups wire [1:0][3:0][31:0] B_tile = { - halves_buf.B_half, + halves.B_half, B_buffer[operands_wid_buf] }; - // C is a 4x4 fp32 matrix + // C is a 4x4 fp32 matrix; row 0-2 for one threadgroup, row 4-6 for the + // other logic [3:0][3:0][31:0] C_tile; wire [3:0][3:0][31:0] D_tile; wire [`NW_WIDTH-1:0] D_wid_dpu; always @(*) begin - C_tile[3] = { halves_buf.C_half[7], C_buffer[operands_wid_buf][7], halves_buf.C_half[5], C_buffer[operands_wid_buf][5] }; - C_tile[2] = { halves_buf.C_half[6], C_buffer[operands_wid_buf][6], halves_buf.C_half[4], C_buffer[operands_wid_buf][4] }; - C_tile[1] = { halves_buf.C_half[3], C_buffer[operands_wid_buf][3], halves_buf.C_half[1], C_buffer[operands_wid_buf][1] }; - C_tile[0] = { halves_buf.C_half[2], C_buffer[operands_wid_buf][2], halves_buf.C_half[0], C_buffer[operands_wid_buf][0] }; + C_tile[3] = { halves.C_half[7], C_buffer[operands_wid_buf][7], halves.C_half[5], C_buffer[operands_wid_buf][5] }; + C_tile[2] = { halves.C_half[6], C_buffer[operands_wid_buf][6], halves.C_half[4], C_buffer[operands_wid_buf][4] }; + C_tile[1] = { halves.C_half[3], C_buffer[operands_wid_buf][3], halves.C_half[1], C_buffer[operands_wid_buf][1] }; + C_tile[0] = { halves.C_half[2], C_buffer[operands_wid_buf][2], halves.C_half[0], C_buffer[operands_wid_buf][0] }; end wire dpu_valid; diff --git a/hw/rtl/fpu/VX_tensor_dpu.sv b/hw/rtl/fpu/VX_tensor_dpu.sv index 0504f457..0d8059ba 100644 --- a/hw/rtl/fpu/VX_tensor_dpu.sv +++ b/hw/rtl/fpu/VX_tensor_dpu.sv @@ -163,7 +163,7 @@ endmodule // does (m,n,k) = (2,4,2) matmul compute over 2 cycles. // see Figure 10(b) of the paper. module VX_tensor_threadgroup #( - parameter HALF_PRECISION = 1 + parameter HALF_PRECISION = 0 ) ( input clk, input reset,