annoying swizzling problems

This commit is contained in:
joshua
2024-03-28 03:00:15 -07:00
parent e16584ddd9
commit 08d7721e11

View File

@@ -42,13 +42,13 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
for (genvar i = 0; i < 4; ++i) begin for (genvar i = 0; i < 4; ++i) begin
wire [7:0][31:0] octet_A = { wire [7:0][31:0] octet_A = {
dispatch_if.data.rs1_data[4*i +: 4], dispatch_if.data.rs1_data[16+4*i +: 4] dispatch_if.data.rs1_data[16+4*i +: 4], dispatch_if.data.rs1_data[4*i +: 4]
}; };
wire [7:0][31:0] octet_B = { wire [7:0][31:0] octet_B = {
dispatch_if.data.rs2_data[4*i +: 4], dispatch_if.data.rs2_data[16+4*i +: 4] dispatch_if.data.rs2_data[16+4*i +: 4], dispatch_if.data.rs2_data[4*i +: 4]
}; };
wire [7:0][31:0] octet_C = { wire [7:0][31:0] octet_C = {
dispatch_if.data.rs3_data[4*i +: 4], dispatch_if.data.rs3_data[16+4*i +: 4] dispatch_if.data.rs3_data[16+4*i +: 4], dispatch_if.data.rs3_data[4*i +: 4]
}; };
logic [3:0][3:0][31:0] octet_D; logic [3:0][3:0][31:0] octet_D;
@@ -125,7 +125,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
// this is probably a little oversized // this is probably a little oversized
VX_fifo_queue #( VX_fifo_queue #(
.DATAW(DATAW), .DATAW(DATAW),
.DEPTH(8) .DEPTH(16)
) pending_uops ( ) pending_uops (
.clk(clk), .clk(clk),
.reset(reset), .reset(reset),
@@ -207,19 +207,19 @@ module VX_tensor_octet #(
always @(*) begin always @(*) begin
case (step) case (step)
2'b00: begin 2'b00: begin
A_half = { A_in[1:0], A_in[5:4] }; A_half = { A_in[5:4], A_in[1:0] };
B_half = B_in[3:0]; B_half = B_in[3:0];
end end
2'b01: begin 2'b01: begin
A_half = { A_in[3:2], A_in[7:6] }; A_half = { A_in[7:6], A_in[3:2] };
B_half = B_in[3:0]; B_half = B_in[3:0];
end end
2'b10: begin 2'b10: begin
A_half = { A_in[1:0], A_in[5:4] }; A_half = { A_in[5:4], A_in[1:0] };
B_half = B_in[7:4]; B_half = B_in[7:4];
end end
2'b11: begin 2'b11: begin
A_half = { A_in[3:2], A_in[7:6] }; A_half = { A_in[7:6], A_in[3:2] };
B_half = B_in[7:4]; B_half = B_in[7:4];
end end
endcase endcase
@@ -261,22 +261,22 @@ module VX_tensor_octet #(
assign operands_ready = ~stall; assign operands_ready = ~stall;
wire [3:0][1:0][31:0] A_tile = { wire [3:0][1:0][31:0] A_tile = {
{ A_buffer[0], A_half[0] }, { A_half[3], A_buffer[3] },
{ A_buffer[1], A_half[1] }, { A_half[2], A_buffer[2] },
{ A_buffer[2], A_half[2] }, { A_half[1], A_buffer[1] },
{ A_buffer[3], A_half[3] } { A_half[0], A_buffer[0] }
}; };
wire [1:0][3:0][31:0] B_tile = { wire [1:0][3:0][31:0] B_tile = {
B_buffer, B_half B_half, B_buffer
}; };
logic [3:0][3:0][31:0] C_tile; logic [3:0][3:0][31:0] C_tile;
always @(*) begin always @(*) begin
C_tile = { C_tile = {
C_buffer[0], C_half[0], C_buffer[1], C_half[1], C_half[7], C_buffer[7], C_half[5], C_buffer[5],
C_buffer[2], C_half[2], C_buffer[3], C_half[3], C_half[6], C_buffer[6], C_half[4], C_buffer[4],
C_buffer[4], C_half[4], C_buffer[5], C_half[5], C_half[3], C_buffer[3], C_half[1], C_buffer[1],
C_buffer[6], C_half[6], C_buffer[7], C_half[7] C_half[2], C_buffer[2], C_half[0], C_buffer[0]
}; };
end end