tensor: Connect SMEM addr/rf IO

This commit is contained in:
Hansung Kim
2024-10-28 19:42:02 -07:00
parent 4376bd33a2
commit 8a66b5ed89

View File

@@ -25,13 +25,15 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #(
- wb
- rd
*/
wire [`UUID_WIDTH-1:0] execute_if_data_uuid;
wire [`NW_WIDTH-1:0] execute_if_data_wid;
wire [NUM_LANES-1:0] execute_if_data_tmask;
wire [`INST_ALU_BITS-1:0] execute_if_data_op_type;
wire [`XLEN-1:0] execute_if_data_PC;
wire execute_if_data_wb;
wire [`NR_BITS-1:0] execute_if_data_rd;
wire [`UUID_WIDTH-1:0] execute_if_data_uuid;
wire [`NW_WIDTH-1:0] execute_if_data_wid;
wire [NUM_LANES-1:0] execute_if_data_tmask;
wire [`INST_ALU_BITS-1:0] execute_if_data_op_type;
wire [`XLEN-1:0] execute_if_data_PC;
wire execute_if_data_wb;
wire [`NR_BITS-1:0] execute_if_data_rd;
wire [NUM_LANES-1:0][`XLEN-1:0] execute_if_data_rs1;
wire [NUM_LANES-1:0][`XLEN-1:0] execute_if_data_rs2;
wire metadata_queue_full;
wire metadata_queue_empty;
@@ -52,7 +54,8 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #(
wire enq = operand_enq_fire;
wire deq = metadata_deq;
localparam DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `INST_ALU_BITS + `XLEN + 1 + `NR_BITS;
localparam DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `INST_ALU_BITS + `XLEN + 1 +
`NR_BITS + (NUM_LANES * `XLEN) + (NUM_LANES * `XLEN);
VX_fifo_queue #(
.DATAW(DATAW),
.DEPTH(METADATA_QUEUE_DEPTH)
@@ -63,10 +66,12 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #(
.pop(deq),
.data_in({execute_if.data.uuid, execute_if.data.wid,
execute_if.data.tmask, execute_if.data.op_type, execute_if.data.PC,
execute_if.data.wb, execute_if.data.rd}),
execute_if.data.wb, execute_if.data.rd,
execute_if.data.rs1_data, execute_if.data.rs2_data}),
.data_out({execute_if_data_uuid, execute_if_data_wid,
execute_if_data_tmask, execute_if_data_op_type, execute_if_data_PC,
execute_if_data_wb, execute_if_data_rd}),
execute_if_data_wb, execute_if_data_rd,
execute_if_data_rs1, execute_if_data_rs2}),
.empty(metadata_queue_empty),
`UNUSED_PIN(alm_empty),
.full(metadata_queue_full),
@@ -94,6 +99,10 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #(
// commit
wire initiate_valid = metadata_valid && commit_if.ready && !hmma_wait;
wire [`NW_WIDTH-1:0] initiate_wid = execute_if_data_wid;
wire [`XLEN-1:0] initiate_addr_a = execute_if_data_rs1[0];
wire [`XLEN-1:0] initiate_addr_b = execute_if_data_rs2[0];
`RUNTIME_ASSERT(!metadata_valid || execute_if_data_tmask[0],
("tmask for HGMMA instruction is invalid"))
// we're recycling execute_if.op_type as operands_if.op_type which might
// have a different width; let's be safe
@@ -107,17 +116,17 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #(
// /*
// fake fsm driving tc rf port
reg [11:0] counter;
always @(posedge clk) begin
if (reset) begin
counter <= 12'd1;
end else begin
counter <= counter + 12'd1;
end
end
assign regfile_if.req_valid = (counter[3:0] != 4'd0);
assign regfile_if.req_data.wis = '0;
assign regfile_if.req_data.rs = counter[11:7];
// reg [11:0] counter;
// always @(posedge clk) begin
// if (reset) begin
// counter <= 12'd1;
// end else begin
// counter <= counter + 12'd1;
// end
// end
// assign regfile_if.req_valid = (counter[6:0] == 7'd0);
// assign regfile_if.req_data.wis = '0;
// assign regfile_if.req_data.rs = counter[11:7];
// */
TensorCoreDecoupled tensor_hopper_core (
@@ -127,6 +136,8 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #(
.io_initiate_ready(initiate_ready),
.io_initiate_valid(initiate_valid),
.io_initiate_bits_wid(initiate_wid),
.io_initiate_bits_addressA(initiate_addr_a),
.io_initiate_bits_addressB(initiate_addr_b),
.io_writeback_ready(writeback_ready),
.io_writeback_valid(writeback_valid),
@@ -150,6 +161,7 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #(
.io_respB_valid(smem_B_if.rsp_valid),
.io_respB_bits_source(smem_B_if.rsp_data.tag),
.io_respB_bits_data(smem_B_if.rsp_data.data),
.io_respC(regfile_if.rsp_data.data),
.io_reqA_ready(smem_A_if.req_ready),
.io_reqA_valid(smem_A_if.req_valid),
@@ -158,8 +170,15 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #(
.io_reqB_ready(smem_B_if.req_ready),
.io_reqB_valid(smem_B_if.req_valid),
.io_reqB_bits_source(smem_B_if.req_data.tag),
.io_reqB_bits_address(smem_B_if.req_data.addr)
.io_reqB_bits_address(smem_B_if.req_data.addr),
.io_reqC_valid(regfile_if.req_valid),
.io_reqC_bits(regfile_if.req_data.rs[4:0])
);
// add offset of 32 for fp regs
assign regfile_if.req_data.rs[5] = 1'b1;
assign regfile_if.req_data.wis = '0;
`STATIC_ASSERT((`ISSUE_WIDTH == `NUM_WARPS),
("static assertion failed: tensor_hopper_core assumes ISSUE_WIDTH == NUM_WARPS"))
// VX_tensor_hopper_core #(
// ) tensor_hopper_core (