From 4cac1adf7d2b93e6b02f1bc7c1e7914af5d00944 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Mon, 7 Oct 2024 17:10:59 -0700 Subject: [PATCH] Add dummy code for decoupled Hopper tensor core Define EXT_T_HOPPER that, when EXT_T_ENABLE is defined, distinguishes whether to instantiate core-coupled Volta-style or decoupled Hopper-style Tensor Core. --- hw/rtl/VX_config.vh | 7 +- hw/rtl/core/VX_decode.sv | 34 ++++++--- hw/rtl/core/VX_tensor_core.sv | 18 +++-- hw/rtl/core/VX_tensor_hopper_core.sv | 102 +++++++++++++++++++++++++++ 4 files changed, 146 insertions(+), 15 deletions(-) create mode 100644 hw/rtl/core/VX_tensor_hopper_core.sv diff --git a/hw/rtl/VX_config.vh b/hw/rtl/VX_config.vh index a9ff2742..f309a84a 100644 --- a/hw/rtl/VX_config.vh +++ b/hw/rtl/VX_config.vh @@ -40,8 +40,13 @@ `define EXT_F_ENABLE `endif +// core-coupled tensor core `ifndef EXT_T_DISABLE `define EXT_T_ENABLE +// decoupled Hopper-style tensor core +`ifndef EXT_T_HOPPER +`define EXT_T_HOPPER +`endif `endif `ifndef XLEN_32 @@ -83,7 +88,7 @@ `endif `ifndef NUM_CORES -`define NUM_CORES 8 +`define NUM_CORES 4 `endif `ifndef NUM_WARPS diff --git a/hw/rtl/core/VX_decode.sv b/hw/rtl/core/VX_decode.sv index cf21d72f..62fdde76 100644 --- a/hw/rtl/core/VX_decode.sv +++ b/hw/rtl/core/VX_decode.sv @@ -542,16 +542,30 @@ module VX_decode #( endcase end `ifdef EXT_T_ENABLE - `INST_EXT4: begin - ex_type = `EX_TENSOR; - op_type = `INST_TENSOR_HMMA; - // tensor core macroop is encoded as r-type - use_rd = 1; - `USED_IREG (rd); - `USED_IREG (rs1); - `USED_IREG (rs2); - `USED_IREG (rs3); - end + `ifdef EXT_T_HOPPER + // TODO + `INST_EXT4: begin + ex_type = `EX_TENSOR; + op_type = `INST_TENSOR_HMMA; + // tensor core macroop is encoded as r-type + use_rd = 1; + `USED_IREG (rd); + `USED_IREG (rs1); + `USED_IREG (rs2); + `USED_IREG (rs3); + end + `else + `INST_EXT4: begin + ex_type = `EX_TENSOR; + op_type = `INST_TENSOR_HMMA; + // tensor core macroop is encoded as r-type + use_rd = 1; + `USED_IREG (rd); + `USED_IREG (rs1); + `USED_IREG (rs2); + `USED_IREG (rs3); + end + `endif `endif default:; endcase diff --git a/hw/rtl/core/VX_tensor_core.sv b/hw/rtl/core/VX_tensor_core.sv index 730d7855..802af43d 100644 --- a/hw/rtl/core/VX_tensor_core.sv +++ b/hw/rtl/core/VX_tensor_core.sv @@ -12,7 +12,7 @@ module VX_tensor_core import VX_gpu_pkg::*; #( ); localparam BLOCK_SIZE = 1; localparam NUM_LANES = `NUM_THREADS; - // FIXME: @perf: PARTIAL_BW==1 increases power instantiating + // @perf: PARTIAL_BW==1 increases power instantiating // stream_buffers for ISSUE_WIDTH times localparam PARTIAL_BW = (BLOCK_SIZE != `ISSUE_WIDTH) || (NUM_LANES != `NUM_THREADS); @@ -51,16 +51,27 @@ module VX_tensor_core import VX_gpu_pkg::*; #( ); for (genvar block_idx = 0; block_idx < BLOCK_SIZE; ++block_idx) begin - VX_tensor_core_block #( +`ifdef EXT_T_HOPPER + VX_tensor_hopper_core_block #( .ISW(1), // FIXME: not block_idx .FP16(FP16) + ) tensor_hopper_core ( + .clk(clk), + .reset(reset), + .execute_if(execute_if[block_idx]), + .commit_if(commit_block_if[block_idx]) + ); +`else + VX_tensor_core_block #( + .ISW(1), // FIXME: use block_idx + .FP16(FP16) ) tensor_core ( .clk(clk), .reset(reset), - .execute_if(execute_if[block_idx]), .commit_if(commit_block_if[block_idx]) ); +`endif end endmodule @@ -275,7 +286,6 @@ module VX_tensor_core_block import VX_gpu_pkg::*; #( localparam COMMIT_DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS + (`NUM_THREADS * `XLEN) + 1 + 1 + 1; wire [COMMIT_DATAW-1:0] commit_if_data = { execute_if_data_deq[wb_wid], /* uuid ~ rd */ - // execute_if_data_deq, /* uuid ~ rd */ subcommit == 1'b0 ? wb_data_0 : wb_data_1, /* data */ 1'b0, /* pid */ 1'b1, /* sop */ diff --git a/hw/rtl/core/VX_tensor_hopper_core.sv b/hw/rtl/core/VX_tensor_hopper_core.sv new file mode 100644 index 00000000..c79a7994 --- /dev/null +++ b/hw/rtl/core/VX_tensor_hopper_core.sv @@ -0,0 +1,102 @@ +`ifdef EXT_T_ENABLE +`include "VX_fpu_define.vh" + +module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #( + parameter ISW, + parameter FP16 +) ( + input clk, + input reset, + + VX_execute_if.slave execute_if, + VX_commit_if.master commit_if +); + localparam METADATA_QUEUE_DEPTH = 2; // FIXME: arbitrary + + /* commit_if.data_t parts that we need to keep around: + - uuid + - wid + - tmask + - PC + - wb + - rd + */ + + localparam DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS; + + wire operand_enq_fire = execute_if.valid && execute_if.ready; + wire commit_if_fire = commit_if.valid && commit_if.ready; + wire [DATAW-1:0] execute_if_data_enq = { + execute_if.data.uuid, + execute_if.data.wid, + execute_if.data.tmask, + execute_if.data.PC, + execute_if.data.wb, + execute_if.data.rd + // pid/sop/eop set later + }; + + wire [`NUM_WARPS-1:0][DATAW-1:0] execute_if_data_deq; + + wire [`NUM_WARPS-1:0] metadata_queue_fulls; + wire [`NUM_WARPS-1:0] metadata_queue_emptys; + // OR not AND, we don't want any warp full + wire metadata_queue_full = |(metadata_queue_fulls); + assign execute_if.ready = !metadata_queue_full; + + `RUNTIME_ASSERT((!execute_if.valid || execute_if.data.wid == `NW_WIDTH'(0)), + ("runtime error: WGMMA execute not supported for warps other than 0!")) + + for (genvar i = 0; i < `NUM_WARPS; i++) begin + // Metadata queue for commit_if. This simply copies execute_if's + // metadata and pops them in conjunction with commit fire. + // + // This has to be separated per-warp, as otherwise requests from + // multiple warps can be enqueued interleaved, which makes it hard to + // ensure two consecutive dequeues are associated with the same warp for + // commit. (FIXME: this is not strictly necessary though.) + + wire enq = operand_enq_fire && (execute_if.data.wid == `NW_WIDTH'(i)); + // FIXME: commit only warp 0 + wire deq = commit_if_fire && (`NW_WIDTH'(i) == `NW_WIDTH'(0)); + + VX_fifo_queue #( + .DATAW(DATAW), + .DEPTH(METADATA_QUEUE_DEPTH) + ) pending_uops ( + .clk(clk), + .reset(reset), + .push(enq), + .pop(deq), + .data_in(execute_if_data_enq), + .data_out(execute_if_data_deq[i]), + .empty(metadata_queue_emptys[i]), + `UNUSED_PIN(alm_empty), + .full(metadata_queue_fulls[i]), + `UNUSED_PIN(alm_full), + `UNUSED_PIN(size) + ); + end + + // this shouldn't really happen unless there's a big contention over + // the commit stage + `RUNTIME_ASSERT(!(!reset && metadata_queue_full), ("tensor core uop queue is full!")) + + // FIXME: only checks warp 0 for commit! + assign commit_if.valid = ~metadata_queue_emptys[0/*FIXME*/]; + + wire [`NUM_THREADS-1:0][`XLEN-1:0] wb_data = '0; + + localparam COMMIT_DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS + (`NUM_THREADS * `XLEN) + 1 + 1 + 1; + wire [COMMIT_DATAW-1:0] commit_if_data = { + execute_if_data_deq[0/*FIXME*/], /* uuid ~ rd */ + wb_data, /* data */ + 1'b0, /* pid */ + 1'b1, /* sop */ + 1'b1 /* eop */ + }; + + assign commit_if.data = commit_if_data; +endmodule + +`endif