From 08d7721e1163636f705ad869ecb781364431e32c Mon Sep 17 00:00:00 2001 From: joshua Date: Thu, 28 Mar 2024 03:00:15 -0700 Subject: [PATCH] annoying swizzling problems --- hw/rtl/core/VX_tensor_core.sv | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/hw/rtl/core/VX_tensor_core.sv b/hw/rtl/core/VX_tensor_core.sv index 28369140..1055351d 100644 --- a/hw/rtl/core/VX_tensor_core.sv +++ b/hw/rtl/core/VX_tensor_core.sv @@ -42,13 +42,13 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( for (genvar i = 0; i < 4; ++i) begin wire [7:0][31:0] octet_A = { - dispatch_if.data.rs1_data[4*i +: 4], dispatch_if.data.rs1_data[16+4*i +: 4] + dispatch_if.data.rs1_data[16+4*i +: 4], dispatch_if.data.rs1_data[4*i +: 4] }; wire [7:0][31:0] octet_B = { - dispatch_if.data.rs2_data[4*i +: 4], dispatch_if.data.rs2_data[16+4*i +: 4] + dispatch_if.data.rs2_data[16+4*i +: 4], dispatch_if.data.rs2_data[4*i +: 4] }; wire [7:0][31:0] octet_C = { - dispatch_if.data.rs3_data[4*i +: 4], dispatch_if.data.rs3_data[16+4*i +: 4] + dispatch_if.data.rs3_data[16+4*i +: 4], dispatch_if.data.rs3_data[4*i +: 4] }; logic [3:0][3:0][31:0] octet_D; @@ -125,7 +125,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( // this is probably a little oversized VX_fifo_queue #( .DATAW(DATAW), - .DEPTH(8) + .DEPTH(16) ) pending_uops ( .clk(clk), .reset(reset), @@ -207,19 +207,19 @@ module VX_tensor_octet #( always @(*) begin case (step) 2'b00: begin - A_half = { A_in[1:0], A_in[5:4] }; + A_half = { A_in[5:4], A_in[1:0] }; B_half = B_in[3:0]; end 2'b01: begin - A_half = { A_in[3:2], A_in[7:6] }; + A_half = { A_in[7:6], A_in[3:2] }; B_half = B_in[3:0]; end 2'b10: begin - A_half = { A_in[1:0], A_in[5:4] }; + A_half = { A_in[5:4], A_in[1:0] }; B_half = B_in[7:4]; end 2'b11: begin - A_half = { A_in[3:2], A_in[7:6] }; + A_half = { A_in[7:6], A_in[3:2] }; B_half = B_in[7:4]; end endcase @@ -261,22 +261,22 @@ module VX_tensor_octet #( assign operands_ready = ~stall; wire [3:0][1:0][31:0] A_tile = { - { A_buffer[0], A_half[0] }, - { A_buffer[1], A_half[1] }, - { A_buffer[2], A_half[2] }, - { A_buffer[3], A_half[3] } + { A_half[3], A_buffer[3] }, + { A_half[2], A_buffer[2] }, + { A_half[1], A_buffer[1] }, + { A_half[0], A_buffer[0] } }; wire [1:0][3:0][31:0] B_tile = { - B_buffer, B_half + B_half, B_buffer }; logic [3:0][3:0][31:0] C_tile; always @(*) begin C_tile = { - C_buffer[0], C_half[0], C_buffer[1], C_half[1], - C_buffer[2], C_half[2], C_buffer[3], C_half[3], - C_buffer[4], C_half[4], C_buffer[5], C_half[5], - C_buffer[6], C_half[6], C_buffer[7], C_half[7] + C_half[7], C_buffer[7], C_half[5], C_buffer[5], + C_half[6], C_buffer[6], C_half[4], C_buffer[4], + C_half[3], C_buffer[3], C_half[1], C_buffer[1], + C_half[2], C_buffer[2], C_half[0], C_buffer[0] }; end