Merge branch 'rtl' of https://github.com/hansungk/vortex-private into rtl
This commit is contained in:
@@ -347,7 +347,7 @@ void dpi_fmax(bool enable, int dst_fmt, int64_t a, int64_t b, int64_t* result, s
|
||||
|
||||
// A is M * K, B is K * M, C is M * M, D is M * M
|
||||
#define M 4
|
||||
#define K 2
|
||||
#define K 2 // FIXME: 4x4x1 / cycle / octet!
|
||||
|
||||
// all row major
|
||||
float c_A_tile[M][K];
|
||||
@@ -358,6 +358,15 @@ float c_D_tile[M][M];
|
||||
// code assumes that svBitVecVal is basically a uint32_t
|
||||
static_assert(sizeof(svBitVecVal) == 4);
|
||||
|
||||
void clear_float_array(float* c_tile, int rows, int cols) {
|
||||
for (int i = 0; i < rows; i += 1) {
|
||||
for (int j = 0; j < cols; j += 1) {
|
||||
int index = i * cols + j;
|
||||
c_tile[index] = 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void fill_float_array(const svBitVecVal* sv_tile, float* c_tile, int rows, int cols) {
|
||||
|
||||
for (int i = 0; i < rows; i += 1) {
|
||||
@@ -396,6 +405,11 @@ void dpi_hmma(bool enable, const svBitVecVal* A_tile, const svBitVecVal* B_tile,
|
||||
if (!enable) {
|
||||
return;
|
||||
}
|
||||
clear_float_array(&c_A_tile[0][0], M, K);
|
||||
clear_float_array(&c_B_tile[0][0], K, M);
|
||||
clear_float_array(&c_C_tile[0][0], M, M);
|
||||
clear_float_array(&c_D_tile[0][0], M, M);
|
||||
|
||||
// std::cout << "A: " << std::endl;
|
||||
fill_float_array(A_tile, &c_A_tile[0][0], M, K);
|
||||
// std::cout << "B: " << std::endl;
|
||||
@@ -551,7 +565,7 @@ void dpi_print_results(int wid, int octet, const svBitVecVal* A_tile, const svBi
|
||||
}
|
||||
|
||||
steps[wid] += 1;
|
||||
if (steps[wid] % 64 == 0) {
|
||||
if (steps[wid] % 32 == 0) {
|
||||
steps[wid] = 0;
|
||||
std::cout << "warp " << wid << " finished wmma\n";
|
||||
std::cout << "A tile" << "\n";
|
||||
|
||||
@@ -391,7 +391,7 @@
|
||||
|
||||
// Tensor Core Latency
|
||||
`ifndef LATENCY_HMMA
|
||||
`define LATENCY_HMMA 8
|
||||
`define LATENCY_HMMA 4
|
||||
`endif
|
||||
|
||||
// Icache Configurable Knobs //////////////////////////////////////////////////
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
`ifndef VX_PLATFORM_VH
|
||||
`define VX_PLATFORM_VH
|
||||
|
||||
// synthesis only
|
||||
// enable synthesizable build if SIMULATION not explicitly defined
|
||||
`ifndef SIMULATION
|
||||
`define SYNTHESIS
|
||||
`define NDEBUG
|
||||
|
||||
@@ -42,7 +42,7 @@ module VX_alu_unit #(
|
||||
|
||||
`RESET_RELAY (dispatch_reset, reset);
|
||||
|
||||
VX_dispatch_unit #(
|
||||
VX_dispatch_unit_sane #(
|
||||
.BLOCK_SIZE (BLOCK_SIZE),
|
||||
.NUM_LANES (NUM_LANES),
|
||||
.OUT_REG (PARTIAL_BW ? 1 : 0)
|
||||
|
||||
@@ -545,6 +545,12 @@ module VX_decode #(
|
||||
`INST_EXT4: begin
|
||||
ex_type = `EX_TENSOR;
|
||||
op_type = `INST_TENSOR_HMMA;
|
||||
// tensor core macroop is encoded as r-type
|
||||
use_rd = 1;
|
||||
`USED_IREG (rd);
|
||||
`USED_IREG (rs1);
|
||||
`USED_IREG (rs2);
|
||||
`USED_IREG (rs3);
|
||||
end
|
||||
`endif
|
||||
default:;
|
||||
|
||||
274
hw/rtl/core/VX_dispatch_unit_sane.sv
Normal file
274
hw/rtl/core/VX_dispatch_unit_sane.sv
Normal file
@@ -0,0 +1,274 @@
|
||||
// Copyright © 2019-2023
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
`include "VX_define.vh"
|
||||
|
||||
module VX_dispatch_unit_sane import VX_gpu_pkg::*; #(
|
||||
parameter BLOCK_SIZE = 1,
|
||||
parameter NUM_LANES = 1,
|
||||
parameter OUT_REG = 0,
|
||||
parameter MAX_FANOUT = `MAX_FANOUT
|
||||
) (
|
||||
input wire clk,
|
||||
input wire reset,
|
||||
|
||||
// inputs
|
||||
VX_dispatch_if.slave dispatch_if [`ISSUE_WIDTH],
|
||||
|
||||
// outputs
|
||||
VX_execute_if.master execute_if [BLOCK_SIZE]
|
||||
|
||||
);
|
||||
`STATIC_ASSERT ((`NUM_THREADS == NUM_LANES * (`NUM_THREADS / NUM_LANES)), ("invalid parameter"))
|
||||
localparam BLOCK_SIZE_W = `LOG2UP(BLOCK_SIZE);
|
||||
localparam NUM_PACKETS = `NUM_THREADS / NUM_LANES;
|
||||
localparam PID_BITS = `CLOG2(NUM_PACKETS);
|
||||
localparam PID_WIDTH = `UP(PID_BITS);
|
||||
localparam BATCH_COUNT = `ISSUE_WIDTH / BLOCK_SIZE;
|
||||
localparam BATCH_COUNT_W= `LOG2UP(BATCH_COUNT);
|
||||
localparam ISSUE_W = `LOG2UP(`ISSUE_WIDTH);
|
||||
localparam IN_DATAW = `UUID_WIDTH + ISSUE_WIS_W + `NUM_THREADS + `INST_OP_BITS + `INST_MOD_BITS + 1 + 1 + 1 + `XLEN + `XLEN + `NR_BITS + `NT_WIDTH + (3 * `NUM_THREADS * `XLEN);
|
||||
localparam OUT_DATAW = `UUID_WIDTH + `NW_WIDTH + NUM_LANES + `INST_OP_BITS + `INST_MOD_BITS + 1 + 1 + 1 + `XLEN + `XLEN + `NR_BITS + `NT_WIDTH + (3 * NUM_LANES * `XLEN) + PID_WIDTH + 1 + 1;
|
||||
localparam FANOUT_ENABLE= (`NUM_THREADS > (MAX_FANOUT + MAX_FANOUT/2));
|
||||
|
||||
localparam DATA_TMASK_OFF = IN_DATAW - (`UUID_WIDTH + ISSUE_WIS_W + `NUM_THREADS);
|
||||
localparam DATA_REGS_OFF = 0;
|
||||
|
||||
wire [`ISSUE_WIDTH-1:0] dispatch_valid;
|
||||
wire [`ISSUE_WIDTH-1:0][IN_DATAW-1:0] dispatch_data;
|
||||
wire [`ISSUE_WIDTH-1:0] dispatch_ready;
|
||||
|
||||
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
|
||||
assign dispatch_valid[i] = dispatch_if[i].valid;
|
||||
assign dispatch_data[i] = dispatch_if[i].data;
|
||||
assign dispatch_if[i].ready = dispatch_ready[i];
|
||||
end
|
||||
|
||||
wire [BLOCK_SIZE-1:0][ISSUE_W-1:0] issue_indices;
|
||||
wire [BLOCK_SIZE-1:0] block_ready;
|
||||
wire [BLOCK_SIZE-1:0][NUM_LANES-1:0] block_tmask;
|
||||
wire [BLOCK_SIZE-1:0][2:0][NUM_LANES-1:0][`XLEN-1:0] block_regs;
|
||||
wire [BLOCK_SIZE-1:0][PID_WIDTH-1:0] block_pid;
|
||||
wire [BLOCK_SIZE-1:0] block_sop;
|
||||
wire [BLOCK_SIZE-1:0] block_eop;
|
||||
wire [BLOCK_SIZE-1:0] block_done;
|
||||
|
||||
wire batch_done = (& block_done);
|
||||
|
||||
logic [BATCH_COUNT_W-1:0] batch_idx;
|
||||
// if (BATCH_COUNT != 1) begin
|
||||
// always @(posedge clk) begin
|
||||
// if (reset) begin
|
||||
// batch_idx <= '0;
|
||||
// end else begin
|
||||
// batch_idx <= batch_idx + BATCH_COUNT_W'(batch_done);
|
||||
// end
|
||||
// end
|
||||
// end else begin
|
||||
// assign batch_idx = 0;
|
||||
// `UNUSED_VAR(batch_done)
|
||||
// end
|
||||
|
||||
// group dispatch_valid by blocks
|
||||
wire [BATCH_COUNT-1:0] batch_valids;
|
||||
for (genvar i = 0; i < BATCH_COUNT; ++i) begin
|
||||
assign batch_valids[i] = |(dispatch_valid[(BLOCK_SIZE * i) +: BLOCK_SIZE]);
|
||||
end
|
||||
|
||||
// elect the leftmost-valid batch for the dispatch
|
||||
wire dispatch_any_valid;
|
||||
VX_lzc_rr #(
|
||||
.N (BATCH_COUNT)
|
||||
) batch_select (
|
||||
.clk (clk),
|
||||
.reset (reset),
|
||||
.data_in (batch_valids),
|
||||
.data_out (batch_idx),
|
||||
.valid_out (dispatch_any_valid)
|
||||
);
|
||||
|
||||
for (genvar block_idx = 0; block_idx < BLOCK_SIZE; ++block_idx) begin
|
||||
|
||||
wire [ISSUE_W-1:0] issue_idx = ISSUE_W'(batch_idx * BLOCK_SIZE) + ISSUE_W'(block_idx);
|
||||
assign issue_indices[block_idx] = issue_idx;
|
||||
|
||||
wire valid_p, ready_p;
|
||||
|
||||
if (`NUM_THREADS != NUM_LANES) begin
|
||||
reg [NUM_PACKETS-1:0] sent_mask_p;
|
||||
wire [PID_WIDTH-1:0] start_p_n, start_p, end_p;
|
||||
wire dispatch_valid_r;
|
||||
reg is_first_p;
|
||||
|
||||
wire fire_p = valid_p && ready_p;
|
||||
|
||||
wire is_last_p = (start_p == end_p);
|
||||
|
||||
wire fire_eop = fire_p && is_last_p;
|
||||
|
||||
always @(posedge clk) begin
|
||||
if (reset) begin
|
||||
sent_mask_p <= '0;
|
||||
is_first_p <= 1;
|
||||
end else begin
|
||||
if ((BATCH_COUNT != 1) ? batch_done : fire_eop) begin
|
||||
sent_mask_p <= '0;
|
||||
is_first_p <= 1;
|
||||
end else if (fire_p) begin
|
||||
sent_mask_p[start_p] <= 1;
|
||||
is_first_p <= 0;
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
wire [NUM_PACKETS-1:0][NUM_LANES-1:0] per_packet_tmask;
|
||||
wire [NUM_PACKETS-1:0][2:0][NUM_LANES-1:0][`XLEN-1:0] per_packet_regs;
|
||||
|
||||
wire [`NUM_THREADS-1:0] dispatch_tmask = dispatch_data[issue_idx][DATA_TMASK_OFF +: `NUM_THREADS];
|
||||
wire [`NUM_THREADS-1:0][`XLEN-1:0] dispatch_rs1_data = dispatch_data[issue_idx][DATA_REGS_OFF + 2 * `NUM_THREADS * `XLEN +: `NUM_THREADS * `XLEN];
|
||||
wire [`NUM_THREADS-1:0][`XLEN-1:0] dispatch_rs2_data = dispatch_data[issue_idx][DATA_REGS_OFF + 1 * `NUM_THREADS * `XLEN +: `NUM_THREADS * `XLEN];
|
||||
wire [`NUM_THREADS-1:0][`XLEN-1:0] dispatch_rs3_data = dispatch_data[issue_idx][DATA_REGS_OFF + 0 * `NUM_THREADS * `XLEN +: `NUM_THREADS * `XLEN];
|
||||
|
||||
for (genvar i = 0; i < NUM_PACKETS; ++i) begin
|
||||
for (genvar j = 0; j < NUM_LANES; ++j) begin
|
||||
localparam k = i * NUM_LANES + j;
|
||||
assign per_packet_tmask[i][j] = dispatch_tmask[k];
|
||||
assign per_packet_regs[i][0][j] = dispatch_rs1_data[k];
|
||||
assign per_packet_regs[i][1][j] = dispatch_rs2_data[k];
|
||||
assign per_packet_regs[i][2][j] = dispatch_rs3_data[k];
|
||||
end
|
||||
end
|
||||
|
||||
wire [NUM_PACKETS-1:0] packet_valids;
|
||||
wire [NUM_PACKETS-1:0][PID_WIDTH-1:0] packet_ids;
|
||||
|
||||
for (genvar i = 0; i < NUM_PACKETS; ++i) begin
|
||||
assign packet_valids[i] = (| per_packet_tmask[i]);
|
||||
assign packet_ids[i] = PID_WIDTH'(i);
|
||||
end
|
||||
|
||||
VX_find_first #(
|
||||
.N (NUM_PACKETS),
|
||||
.DATAW (PID_WIDTH),
|
||||
.REVERSE (0)
|
||||
) find_first (
|
||||
.valid_in (packet_valids & ~sent_mask_p),
|
||||
.data_in (packet_ids),
|
||||
.data_out (start_p_n),
|
||||
`UNUSED_PIN (valid_out)
|
||||
);
|
||||
|
||||
VX_find_first #(
|
||||
.N (NUM_PACKETS),
|
||||
.DATAW (PID_WIDTH),
|
||||
.REVERSE (1)
|
||||
) find_last (
|
||||
.valid_in (packet_valids),
|
||||
.data_in (packet_ids),
|
||||
.data_out (end_p),
|
||||
`UNUSED_PIN (valid_out)
|
||||
);
|
||||
|
||||
VX_pipe_register #(
|
||||
.DATAW (1 + PID_WIDTH),
|
||||
.RESETW (1),
|
||||
.DEPTH (FANOUT_ENABLE ? 1 : 0)
|
||||
) pipe_reg (
|
||||
.clk (clk),
|
||||
.reset (reset || fire_p), // should flush on fire
|
||||
.enable (1'b1),
|
||||
.data_in ({dispatch_valid[issue_idx], start_p_n}),
|
||||
.data_out ({dispatch_valid_r, start_p})
|
||||
);
|
||||
|
||||
wire [NUM_LANES-1:0] tmask_p = per_packet_tmask[start_p];
|
||||
wire [2:0][NUM_LANES-1:0][`XLEN-1:0] regs_p = per_packet_regs[start_p];
|
||||
|
||||
wire block_enable = (BATCH_COUNT == 1 || ~(& sent_mask_p));
|
||||
|
||||
assign valid_p = dispatch_valid_r && block_enable;
|
||||
assign block_tmask[block_idx] = tmask_p;
|
||||
assign block_regs[block_idx] = regs_p;
|
||||
assign block_pid[block_idx] = start_p;
|
||||
assign block_sop[block_idx] = is_first_p;
|
||||
assign block_eop[block_idx] = is_last_p;
|
||||
if (FANOUT_ENABLE) begin
|
||||
assign block_ready[block_idx] = dispatch_valid_r && ready_p && block_enable;
|
||||
end else begin
|
||||
assign block_ready[block_idx] = ready_p && block_enable;
|
||||
end
|
||||
assign block_done[block_idx] = ~dispatch_valid[issue_idx] || fire_eop;
|
||||
end else begin
|
||||
assign valid_p = dispatch_valid[issue_idx];
|
||||
assign block_tmask[block_idx] = dispatch_data[issue_idx][DATA_TMASK_OFF +: `NUM_THREADS];
|
||||
assign block_regs[block_idx][0] = dispatch_data[issue_idx][DATA_REGS_OFF + 2 * `NUM_THREADS * `XLEN +: `NUM_THREADS * `XLEN];
|
||||
assign block_regs[block_idx][1] = dispatch_data[issue_idx][DATA_REGS_OFF + 1 * `NUM_THREADS * `XLEN +: `NUM_THREADS * `XLEN];
|
||||
assign block_regs[block_idx][2] = dispatch_data[issue_idx][DATA_REGS_OFF + 0 * `NUM_THREADS * `XLEN +: `NUM_THREADS * `XLEN];
|
||||
assign block_pid[block_idx] = '0;
|
||||
assign block_sop[block_idx] = 1'b1;
|
||||
assign block_eop[block_idx] = 1'b1;
|
||||
assign block_ready[block_idx] = ready_p;
|
||||
assign block_done[block_idx] = ~valid_p || ready_p;
|
||||
end
|
||||
|
||||
wire [ISSUE_ISW_W-1:0] isw;
|
||||
if (BATCH_COUNT != 1) begin
|
||||
if (BLOCK_SIZE != 1) begin
|
||||
assign isw = {batch_idx, BLOCK_SIZE_W'(block_idx)};
|
||||
end else begin
|
||||
assign isw = batch_idx;
|
||||
end
|
||||
end else begin
|
||||
assign isw = block_idx;
|
||||
end
|
||||
|
||||
`RESET_RELAY(buf_out_reset, reset);
|
||||
|
||||
wire [`NW_WIDTH-1:0] block_wid = wis_to_wid(dispatch_data[issue_idx][DATA_TMASK_OFF+`NUM_THREADS +: ISSUE_WIS_W], isw);
|
||||
|
||||
VX_elastic_buffer #(
|
||||
.DATAW (OUT_DATAW),
|
||||
.SIZE (`OUT_REG_TO_EB_SIZE(OUT_REG)),
|
||||
.OUT_REG (`OUT_REG_TO_EB_REG(OUT_REG))
|
||||
) buf_out (
|
||||
.clk (clk),
|
||||
.reset (buf_out_reset),
|
||||
.valid_in (valid_p),
|
||||
.ready_in (ready_p),
|
||||
.data_in ({
|
||||
dispatch_data[issue_idx][IN_DATAW-1 : DATA_TMASK_OFF+`NUM_THREADS+ISSUE_WIS_W],
|
||||
block_wid,
|
||||
block_tmask[block_idx],
|
||||
dispatch_data[issue_idx][DATA_TMASK_OFF-1 : DATA_REGS_OFF + 3 * `NUM_THREADS * `XLEN],
|
||||
block_regs[block_idx][0],
|
||||
block_regs[block_idx][1],
|
||||
block_regs[block_idx][2],
|
||||
block_pid[block_idx],
|
||||
block_sop[block_idx],
|
||||
block_eop[block_idx]}),
|
||||
.data_out (execute_if[block_idx].data),
|
||||
.valid_out (execute_if[block_idx].valid),
|
||||
.ready_out (execute_if[block_idx].ready)
|
||||
);
|
||||
end
|
||||
|
||||
reg [`ISSUE_WIDTH-1:0] ready_in;
|
||||
always @(*) begin
|
||||
ready_in = 0;
|
||||
for (integer i = 0; i < BLOCK_SIZE; ++i) begin
|
||||
ready_in[issue_indices[i]] = block_ready[i] && block_eop[i];
|
||||
end
|
||||
end
|
||||
assign dispatch_ready = ready_in;
|
||||
|
||||
endmodule
|
||||
@@ -39,7 +39,7 @@ module VX_fpu_unit import VX_fpu_pkg::*; #(
|
||||
|
||||
`RESET_RELAY (dispatch_reset, reset);
|
||||
|
||||
VX_dispatch_unit #(
|
||||
VX_dispatch_unit_sane #(
|
||||
.BLOCK_SIZE (BLOCK_SIZE),
|
||||
.NUM_LANES (NUM_LANES),
|
||||
.OUT_REG (PARTIAL_BW ? 1 : 0)
|
||||
|
||||
@@ -49,7 +49,7 @@ module VX_lsu_unit import VX_gpu_pkg::*; #(
|
||||
|
||||
`RESET_RELAY (dispatch_reset, reset);
|
||||
|
||||
VX_dispatch_unit #(
|
||||
VX_dispatch_unit_sane #(
|
||||
.BLOCK_SIZE (BLOCK_SIZE),
|
||||
.NUM_LANES (NUM_LANES),
|
||||
.OUT_REG (1)
|
||||
@@ -596,6 +596,31 @@ module VX_lsu_unit import VX_gpu_pkg::*; #(
|
||||
.commit_out_if (commit_if)
|
||||
);
|
||||
|
||||
`ifdef PERF_ENABLE
|
||||
wire [`CLOG2(NUM_LANES+1)-1:0] perf_rsp_tmask_valids_per_cycle;
|
||||
wire [`CLOG2(NUM_LANES+1)-1:0] perf_rsp_tmask_total_per_cycle;
|
||||
reg [`PERF_CTR_BITS-1:0] perf_rsp_tmask_valids;
|
||||
reg [`PERF_CTR_BITS-1:0] perf_rsp_tmask_total;
|
||||
reg [`PERF_CTR_BITS-1:0] perf_rsp_fires;
|
||||
|
||||
`POP_COUNT(perf_rsp_tmask_valids_per_cycle, rsp_tmask);
|
||||
assign perf_rsp_tmask_total_per_cycle = NUM_LANES;
|
||||
|
||||
always @(posedge clk) begin
|
||||
if (reset) begin
|
||||
perf_rsp_tmask_valids <= '0;
|
||||
perf_rsp_tmask_total <= '0;
|
||||
perf_rsp_fires <= '0;
|
||||
end else begin
|
||||
if (mem_rsp_fire) begin
|
||||
perf_rsp_tmask_valids <= perf_rsp_tmask_valids + perf_rsp_tmask_valids_per_cycle;
|
||||
perf_rsp_tmask_total <= perf_rsp_tmask_total + perf_rsp_tmask_total_per_cycle;
|
||||
perf_rsp_fires <= perf_rsp_fires + 1'b1;
|
||||
end
|
||||
end
|
||||
end
|
||||
`endif
|
||||
|
||||
`ifdef DBG_SCOPE_LSU
|
||||
if (CORE_ID == 0) begin
|
||||
`ifdef SCOPE
|
||||
|
||||
@@ -66,6 +66,7 @@ module VX_smem_unit import VX_gpu_pkg::*; #(
|
||||
.req_valid (smem_req_valid),
|
||||
.req_rw (smem_req_rw),
|
||||
.req_byteen (smem_req_byteen),
|
||||
// FIXME: synthesis complains undriven when USE_EXTERNAL_SMEM
|
||||
.req_addr (smem_req_addr),
|
||||
.req_data (smem_req_data),
|
||||
.req_tag (smem_req_tag),
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
`ifdef EXT_T_ENABLE
|
||||
`include "VX_fpu_define.vh"
|
||||
|
||||
module VX_tensor_core #(
|
||||
module VX_tensor_core import VX_gpu_pkg::*; #(
|
||||
|
||||
) (
|
||||
input clk,
|
||||
@@ -10,17 +10,54 @@ module VX_tensor_core #(
|
||||
VX_dispatch_if.slave dispatch_if [`ISSUE_WIDTH],
|
||||
VX_commit_if.master commit_if [`ISSUE_WIDTH]
|
||||
);
|
||||
`STATIC_ASSERT(`NUM_THREADS == 32, ("tensor core requires # of threads in a warp to be 32 (try running w/ CONFIGS=\"-DNUM_THREADS=32\")"));
|
||||
|
||||
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
|
||||
localparam BLOCK_SIZE = 1;
|
||||
localparam NUM_LANES = `NUM_THREADS;
|
||||
// localparam PARTIAL_BW = (BLOCK_SIZE != `ISSUE_WIDTH) || (NUM_LANES != `NUM_THREADS);
|
||||
localparam PARTIAL_BW = 1;
|
||||
|
||||
VX_execute_if #(
|
||||
.NUM_LANES (NUM_LANES)
|
||||
) execute_if[BLOCK_SIZE]();
|
||||
|
||||
`RESET_RELAY (dispatch_reset, reset);
|
||||
|
||||
VX_dispatch_unit_sane #(
|
||||
.BLOCK_SIZE (BLOCK_SIZE),
|
||||
.NUM_LANES (NUM_LANES),
|
||||
.OUT_REG (PARTIAL_BW ? 1 : 0)
|
||||
) dispatch_unit (
|
||||
.clk (clk),
|
||||
.reset (dispatch_reset),
|
||||
.dispatch_if(dispatch_if),
|
||||
.execute_if (execute_if)
|
||||
);
|
||||
|
||||
VX_commit_if #(
|
||||
.NUM_LANES (NUM_LANES)
|
||||
) commit_block_if[BLOCK_SIZE]();
|
||||
|
||||
`RESET_RELAY (commit_reset, reset);
|
||||
|
||||
VX_gather_unit #(
|
||||
.BLOCK_SIZE (BLOCK_SIZE),
|
||||
.NUM_LANES (NUM_LANES),
|
||||
.OUT_REG (PARTIAL_BW ? 3 : 0) // FIXME: why 3?
|
||||
) gather_unit (
|
||||
.clk (clk),
|
||||
.reset (commit_reset),
|
||||
.commit_in_if (commit_block_if),
|
||||
.commit_out_if (commit_if)
|
||||
);
|
||||
|
||||
for (genvar block_idx = 0; block_idx < BLOCK_SIZE; ++block_idx) begin
|
||||
VX_tensor_core_warp #(
|
||||
.ISW(i)
|
||||
.ISW(1) // FIXME: not block_idx
|
||||
) tensor_core (
|
||||
.clk(clk),
|
||||
.reset(reset),
|
||||
|
||||
.dispatch_if(dispatch_if[i]),
|
||||
.commit_if(commit_if[i])
|
||||
.execute_if(execute_if[block_idx]),
|
||||
.commit_if(commit_block_if[block_idx])
|
||||
);
|
||||
end
|
||||
|
||||
@@ -32,37 +69,53 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
||||
input clk,
|
||||
input reset,
|
||||
|
||||
VX_dispatch_if.slave dispatch_if,
|
||||
VX_execute_if.slave execute_if,
|
||||
VX_commit_if.master commit_if
|
||||
);
|
||||
wire [1:0] step = 2'(dispatch_if.data.op_type);
|
||||
logic [3:0] octet_results_valid;
|
||||
logic [3:0] octet_results_ready;
|
||||
logic [3:0] octet_operands_ready;
|
||||
localparam NUM_OCTETS = (`NUM_THREADS / 8);
|
||||
// offet in the lane numbers that get mapped to the two threadgroups in an
|
||||
// octet. E.g. two tgs map lane 0-3 and lane 16-19 -> 16
|
||||
// FIXME: not sure this is the right logic. just filling in what works
|
||||
localparam LANE_OFFSET_THREADGROUP = (4 * NUM_OCTETS);
|
||||
// this is only a rule of thumb
|
||||
localparam METADATA_QUEUE_DEPTH = 2 * `LATENCY_HMMA;
|
||||
|
||||
wire [1:0] step = 2'(execute_if.data.op_type);
|
||||
// op_mod is reused to indicate instruction's id in pair
|
||||
wire last_in_pair = (execute_if.data.op_mod == `INST_MOD_BITS'(1));
|
||||
|
||||
logic [NUM_OCTETS-1:0] octet_results_valid;
|
||||
logic [NUM_OCTETS-1:0] octet_results_ready;
|
||||
logic [NUM_OCTETS-1:0] octet_operands_ready;
|
||||
// FIXME: should be NUM_LANES?
|
||||
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0;
|
||||
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1;
|
||||
|
||||
assign dispatch_if.ready = &octet_operands_ready;
|
||||
wire [`NW_WIDTH-1:0] wb_wid;
|
||||
|
||||
// valid signal synced between the functional units (octet) and the
|
||||
// metadata queue
|
||||
wire operands_valid_synced;
|
||||
|
||||
`ifdef EXT_T_ENABLE
|
||||
for (genvar i = 0; i < 4/*octets*/; ++i) begin
|
||||
for (genvar i = 0; i < NUM_OCTETS; ++i) begin
|
||||
`else
|
||||
for (genvar i = 0; i < 0; ++i) begin
|
||||
`endif
|
||||
// lane-to-octet mapping; see figure 13 of the paper
|
||||
wire [7:0][31:0] octet_A = {
|
||||
dispatch_if.data.rs1_data[16+4*i +: 4], dispatch_if.data.rs1_data[4*i +: 4]
|
||||
execute_if.data.rs1_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], execute_if.data.rs1_data[4*i +: 4]
|
||||
};
|
||||
wire [7:0][31:0] octet_B = {
|
||||
dispatch_if.data.rs2_data[16+4*i +: 4], dispatch_if.data.rs2_data[4*i +: 4]
|
||||
execute_if.data.rs2_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], execute_if.data.rs2_data[4*i +: 4]
|
||||
};
|
||||
wire [7:0][31:0] octet_C = {
|
||||
dispatch_if.data.rs3_data[16+4*i +: 4], dispatch_if.data.rs3_data[4*i +: 4]
|
||||
execute_if.data.rs3_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], execute_if.data.rs3_data[4*i +: 4]
|
||||
};
|
||||
|
||||
logic [3:0][3:0][31:0] octet_D;
|
||||
logic result_valid;
|
||||
logic result_ready;
|
||||
|
||||
VX_tensor_octet #(
|
||||
.ISW(ISW),
|
||||
.OCTET(i)
|
||||
@@ -73,12 +126,14 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
||||
.A_in(octet_A),
|
||||
.B_in(octet_B),
|
||||
.C_in(octet_C),
|
||||
.operands_valid(dispatch_if.valid),
|
||||
.operands_valid(operands_valid_synced),
|
||||
.operands_wid(execute_if.data.wid),
|
||||
.operands_last_in_pair(last_in_pair),
|
||||
.operands_step(step),
|
||||
.operands_ready(octet_operands_ready[i]),
|
||||
|
||||
.step(step),
|
||||
|
||||
.D_out(octet_D),
|
||||
.D_wid(wb_wid),
|
||||
.result_valid(result_valid),
|
||||
.result_ready(result_ready)
|
||||
);
|
||||
@@ -100,15 +155,15 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
||||
assign wb_data_1[4*i+2] = octet_D[0][3];
|
||||
assign wb_data_1[4*i+3] = octet_D[1][3];
|
||||
|
||||
assign wb_data_0[4*i+16+0] = octet_D[2][0];
|
||||
assign wb_data_0[4*i+16+1] = octet_D[3][0];
|
||||
assign wb_data_0[4*i+16+2] = octet_D[2][2];
|
||||
assign wb_data_0[4*i+16+3] = octet_D[3][2];
|
||||
assign wb_data_0[4*i+LANE_OFFSET_THREADGROUP+0] = octet_D[2][0];
|
||||
assign wb_data_0[4*i+LANE_OFFSET_THREADGROUP+1] = octet_D[3][0];
|
||||
assign wb_data_0[4*i+LANE_OFFSET_THREADGROUP+2] = octet_D[2][2];
|
||||
assign wb_data_0[4*i+LANE_OFFSET_THREADGROUP+3] = octet_D[3][2];
|
||||
|
||||
assign wb_data_1[4*i+16+0] = octet_D[2][1];
|
||||
assign wb_data_1[4*i+16+1] = octet_D[3][1];
|
||||
assign wb_data_1[4*i+16+2] = octet_D[2][3];
|
||||
assign wb_data_1[4*i+16+3] = octet_D[3][3];
|
||||
assign wb_data_1[4*i+LANE_OFFSET_THREADGROUP+0] = octet_D[2][1];
|
||||
assign wb_data_1[4*i+LANE_OFFSET_THREADGROUP+1] = octet_D[3][1];
|
||||
assign wb_data_1[4*i+LANE_OFFSET_THREADGROUP+2] = octet_D[2][3];
|
||||
assign wb_data_1[4*i+LANE_OFFSET_THREADGROUP+3] = octet_D[3][3];
|
||||
end
|
||||
|
||||
/* commit_if.data_t parts that we need to keep around:
|
||||
@@ -122,44 +177,95 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
||||
|
||||
localparam DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS;
|
||||
|
||||
wire dispatch_if_fire = dispatch_if.valid && dispatch_if.ready;
|
||||
wire commit_if_fire = commit_if.valid && commit_if.ready;
|
||||
wire [DATAW-1:0] dispatch_if_data_enq = {
|
||||
dispatch_if.data.uuid,
|
||||
wis_to_wid(dispatch_if.data.wis, ISW),
|
||||
dispatch_if.data.tmask,
|
||||
dispatch_if.data.PC,
|
||||
dispatch_if.data.wb,
|
||||
dispatch_if.data.rd
|
||||
wire operand_enq_fire = operands_valid_synced && execute_if.ready;
|
||||
wire commit_if_ready_override;
|
||||
wire commit_if_fire = commit_if.valid && commit_if_ready_override;
|
||||
wire [DATAW-1:0] execute_if_data_enq = {
|
||||
execute_if.data.uuid,
|
||||
execute_if.data.wid,
|
||||
execute_if.data.tmask,
|
||||
execute_if.data.PC,
|
||||
execute_if.data.wb,
|
||||
execute_if.data.rd
|
||||
// pid/sop/eop set later
|
||||
};
|
||||
|
||||
wire [DATAW-1:0] dispatch_if_data_deq;
|
||||
wire [`NUM_WARPS-1:0][DATAW-1:0] execute_if_data_deq;
|
||||
|
||||
// this is probably a little oversized
|
||||
VX_fifo_queue #(
|
||||
.DATAW(DATAW),
|
||||
.DEPTH(16)
|
||||
) pending_uops (
|
||||
.clk(clk),
|
||||
.reset(reset),
|
||||
.push(dispatch_if_fire),
|
||||
.pop(commit_if_fire),
|
||||
.data_in(dispatch_if_data_enq),
|
||||
.data_out(dispatch_if_data_deq),
|
||||
`UNUSED_PIN(empty),
|
||||
`UNUSED_PIN(alm_empty),
|
||||
`UNUSED_PIN(full), // should be impossible to overflow
|
||||
`UNUSED_PIN(alm_full),
|
||||
`UNUSED_PIN(size)
|
||||
);
|
||||
wire [`NUM_WARPS-1:0] metadata_queue_fulls;
|
||||
// OR not AND, we don't want any warp full
|
||||
wire metadata_queue_full = |(metadata_queue_fulls);
|
||||
|
||||
// need to make sure both metadata and octet issue queues are in sync
|
||||
assign operands_valid_synced = execute_if.valid && !metadata_queue_full;
|
||||
assign execute_if.ready = &(octet_operands_ready) && !metadata_queue_full;
|
||||
|
||||
for (genvar i = 0; i < `NUM_WARPS; i++) begin
|
||||
// Metadata queue for commit_if. This simply copies execute_if's
|
||||
// metadata and pops them in conjunction with commit fire.
|
||||
//
|
||||
// This has to be separated per-warp, as otherwise requests from
|
||||
// multiple warps can be enqueued interleaved, which makes it hard to
|
||||
// ensure two consecutive dequeues are associated with the same warp for
|
||||
// commit. (FIXME: this is not strictly necessary though.)
|
||||
|
||||
wire enq = operand_enq_fire && (execute_if.data.wid == `NW_WIDTH'(i));
|
||||
wire deq = commit_if_fire && ( wb_wid == `NW_WIDTH'(i));
|
||||
|
||||
VX_fifo_queue #(
|
||||
.DATAW(DATAW),
|
||||
.DEPTH(METADATA_QUEUE_DEPTH)
|
||||
) pending_uops (
|
||||
.clk(clk),
|
||||
.reset(reset),
|
||||
.push(enq),
|
||||
.pop(deq),
|
||||
.data_in(execute_if_data_enq),
|
||||
.data_out(execute_if_data_deq[i]),
|
||||
`UNUSED_PIN(empty),
|
||||
`UNUSED_PIN(alm_empty),
|
||||
.full(metadata_queue_fulls[i]),
|
||||
`UNUSED_PIN(alm_full),
|
||||
`UNUSED_PIN(size)
|
||||
);
|
||||
end
|
||||
|
||||
// this shouldn't really happen unless there's a big contention over
|
||||
// the commit stage
|
||||
`RUNTIME_ASSERT(!(!reset && metadata_queue_full), ("tensor core uop queue is full!"))
|
||||
|
||||
// unlike execute which can be interleaved between warps, commit is
|
||||
// serialized and completed one-warp-by-warp, therefore we only need to
|
||||
// keep one subcommit state bit unlike for `substeps`
|
||||
logic subcommit, subcommit_n;
|
||||
|
||||
wire all_valid = (& octet_results_valid);
|
||||
|
||||
// define this to inject artificial commit backpressure for debugging
|
||||
// `define TENSOR_INJECT_COMMIT_BACKPRESSURE
|
||||
`ifndef TENSOR_INJECT_COMMIT_BACKPRESSURE
|
||||
assign commit_if.valid = all_valid;
|
||||
assign commit_if_ready_override = commit_if.ready;
|
||||
`else
|
||||
logic [1:0] counter;
|
||||
always @(posedge clk) begin
|
||||
if (reset) begin
|
||||
counter <= '0;
|
||||
end else begin
|
||||
if (all_valid) begin
|
||||
counter <= counter + 1'b1;
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
assign commit_if.valid = all_valid && (counter == 2'b0);
|
||||
assign commit_if_ready_override = commit_if.ready && (counter == 2'b0);
|
||||
`endif
|
||||
|
||||
localparam COMMIT_DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS + (`NUM_THREADS * `XLEN) + 1 + 1 + 1;
|
||||
wire [COMMIT_DATAW-1:0] commit_if_data = {
|
||||
dispatch_if_data_deq, /* uuid ~ rd */
|
||||
execute_if_data_deq[wb_wid], /* uuid ~ rd */
|
||||
// execute_if_data_deq, /* uuid ~ rd */
|
||||
subcommit == 1'b0 ? wb_data_0 : wb_data_1, /* data */
|
||||
1'b0, /* pid */
|
||||
1'b1, /* sop */
|
||||
@@ -196,22 +302,25 @@ module VX_tensor_octet #(
|
||||
input clk,
|
||||
input reset,
|
||||
|
||||
input [7:0][31:0] A_in,
|
||||
input [7:0][31:0] B_in,
|
||||
input [7:0][31:0] C_in,
|
||||
input operands_valid, // we have to backpressure due to there potentially being contention over commit
|
||||
output operands_ready,
|
||||
|
||||
input [1:0] step,
|
||||
input [7:0][31:0] A_in,
|
||||
input [7:0][31:0] B_in,
|
||||
input [7:0][31:0] C_in,
|
||||
input operands_valid,
|
||||
input [`NW_WIDTH-1:0] operands_wid,
|
||||
input operands_last_in_pair,
|
||||
input [1:0] operands_step,
|
||||
// we have to backpressure due to there potentially being contention over commit
|
||||
output operands_ready,
|
||||
|
||||
output [3:0][3:0][31:0] D_out,
|
||||
output [`NW_WIDTH-1:0] D_wid,
|
||||
output result_valid,
|
||||
input result_ready
|
||||
);
|
||||
// 512 bits/octet * 4 octets per warp
|
||||
logic [3:0][31:0] A_buffer, A_buffer_n;
|
||||
logic [3:0][31:0] B_buffer, B_buffer_n;
|
||||
logic [7:0][31:0] C_buffer, C_buffer_n;
|
||||
logic [`NUM_WARPS-1:0][3:0][31:0] A_buffer, A_buffer_n;
|
||||
logic [`NUM_WARPS-1:0][3:0][31:0] B_buffer, B_buffer_n;
|
||||
logic [`NUM_WARPS-1:0][7:0][31:0] C_buffer, C_buffer_n;
|
||||
|
||||
// half the inputs are buffered, half are not (instead coming straight
|
||||
// from operand bus) unlike the real tensor core.
|
||||
@@ -219,41 +328,95 @@ module VX_tensor_octet #(
|
||||
logic [3:0][31:0] A_half;
|
||||
logic [3:0][31:0] B_half;
|
||||
logic [7:0][31:0] C_half;
|
||||
always @(*) begin
|
||||
logic [3:0][31:0] A_half_buf;
|
||||
logic [3:0][31:0] B_half_buf;
|
||||
logic [7:0][31:0] C_half_buf;
|
||||
|
||||
|
||||
logic [`NUM_WARPS-1:0] substeps;
|
||||
logic [`NUM_WARPS-1:0] substeps_n;
|
||||
|
||||
wire [7:0][31:0] A_in_buf;
|
||||
wire [7:0][31:0] B_in_buf;
|
||||
wire [7:0][31:0] C_in_buf;
|
||||
wire operands_valid_buf;
|
||||
wire operands_ready_buf;
|
||||
wire [`NW_WIDTH-1:0] operands_wid_buf;
|
||||
wire operands_last_in_pair_buf;
|
||||
wire [1:0] operands_step_buf;
|
||||
|
||||
assign A_in_buf = A_in;
|
||||
assign B_in_buf = B_in;
|
||||
assign C_in_buf = C_in;
|
||||
assign operands_step_buf = operands_step;
|
||||
assign operands_wid_buf = operands_wid;
|
||||
assign operands_last_in_pair_buf = operands_last_in_pair;
|
||||
assign operands_valid_buf = operands_valid;
|
||||
assign operands_ready = operands_ready_buf;
|
||||
|
||||
typedef struct {
|
||||
logic [3:0][31:0] A_half;
|
||||
logic [3:0][31:0] B_half;
|
||||
logic [7:0][31:0] C_half;
|
||||
} half_t;
|
||||
|
||||
function half_t get_operand_half(
|
||||
input logic [1:0] step,
|
||||
input logic [7:0][31:0] A_in,
|
||||
input logic [7:0][31:0] B_in,
|
||||
input logic [7:0][31:0] C_in
|
||||
);
|
||||
half_t half;
|
||||
// note that not all lanes participate at every step
|
||||
case (step)
|
||||
2'b00: begin
|
||||
A_half = { A_in[5:4], A_in[1:0] };
|
||||
B_half = B_in[3:0];
|
||||
// Two A_in segments correspond to two 2x2 subtiles of A read
|
||||
// by two threadgroups: [0:2,0:2] and [4:6,0:2] in Step 0 of
|
||||
// Figure 10(b). B_in OTOH is shared by two threadgroups.
|
||||
// Note k-dimension is shrunk from 4 to 2.
|
||||
half.A_half = { A_in[5:4], A_in[1:0] };
|
||||
half.B_half = B_in[3:0];
|
||||
end
|
||||
2'b01: begin
|
||||
A_half = { A_in[7:6], A_in[3:2] };
|
||||
B_half = B_in[3:0];
|
||||
half.A_half = { A_in[7:6], A_in[3:2] };
|
||||
half.B_half = B_in[3:0];
|
||||
end
|
||||
2'b10: begin
|
||||
A_half = { A_in[5:4], A_in[1:0] };
|
||||
B_half = B_in[7:4];
|
||||
half.A_half = { A_in[5:4], A_in[1:0] };
|
||||
half.B_half = B_in[7:4];
|
||||
end
|
||||
2'b11: begin
|
||||
A_half = { A_in[7:6], A_in[3:2] };
|
||||
B_half = B_in[7:4];
|
||||
half.A_half = { A_in[7:6], A_in[3:2] };
|
||||
half.B_half = B_in[7:4];
|
||||
end
|
||||
endcase
|
||||
C_half = C_in;
|
||||
end
|
||||
half.C_half = C_in;
|
||||
return half;
|
||||
endfunction
|
||||
|
||||
logic substep;
|
||||
wire substep_n = (operands_ready && operands_valid) ? ~substep : substep;
|
||||
half_t halves;
|
||||
half_t halves_buf;
|
||||
assign halves = get_operand_half(operands_step, A_in, B_in, C_in);
|
||||
assign halves_buf = get_operand_half(operands_step_buf, A_in_buf, B_in_buf, C_in_buf);
|
||||
|
||||
wire do_hmma = operands_ready_buf && operands_valid_buf && operands_last_in_pair_buf;
|
||||
// wire operands_first_in_pair_fire = operands_ready && operands_valid && (!operands_last_in_pair);
|
||||
wire operands_first_in_pair_fire = operands_ready_buf && operands_valid_buf && (!operands_last_in_pair_buf);
|
||||
|
||||
always @(*) begin
|
||||
A_buffer_n = A_buffer;
|
||||
B_buffer_n = B_buffer;
|
||||
C_buffer_n = C_buffer;
|
||||
substeps_n = substeps;
|
||||
|
||||
if (substep == 1'b0) begin
|
||||
A_buffer_n = A_half;
|
||||
B_buffer_n = B_half;
|
||||
C_buffer_n = C_half;
|
||||
if (operands_first_in_pair_fire) begin
|
||||
substeps_n[operands_wid_buf] = 1'b1; // ready for hmma
|
||||
A_buffer_n[operands_wid_buf] = halves_buf.A_half;
|
||||
B_buffer_n[operands_wid_buf] = halves_buf.B_half;
|
||||
C_buffer_n[operands_wid_buf] = halves_buf.C_half;
|
||||
end
|
||||
if (do_hmma) begin
|
||||
substeps_n[operands_wid_buf] = 1'b0; // finished hmma, ready for next operand
|
||||
end
|
||||
end
|
||||
|
||||
@@ -262,61 +425,113 @@ module VX_tensor_octet #(
|
||||
A_buffer <= '0;
|
||||
B_buffer <= '0;
|
||||
C_buffer <= '0;
|
||||
substep <= '0;
|
||||
substeps <= '0;
|
||||
end
|
||||
else begin
|
||||
A_buffer <= A_buffer_n;
|
||||
B_buffer <= B_buffer_n;
|
||||
C_buffer <= C_buffer_n;
|
||||
substep <= substep_n;
|
||||
substeps <= substeps_n;
|
||||
end
|
||||
end
|
||||
|
||||
wire stall = result_valid && ~result_ready;
|
||||
assign operands_ready = ~stall;
|
||||
wire outbuf_ready_in;
|
||||
wire hmma_ready;
|
||||
assign operands_ready_buf = hmma_ready;
|
||||
|
||||
// A is 4x2 fp32 matrix
|
||||
wire [3:0][1:0][31:0] A_tile = {
|
||||
{ A_half[3], A_buffer[3] },
|
||||
{ A_half[2], A_buffer[2] },
|
||||
{ A_half[1], A_buffer[1] },
|
||||
{ A_half[0], A_buffer[0] }
|
||||
{ halves_buf.A_half[3], A_buffer[operands_wid_buf][3] },
|
||||
{ halves_buf.A_half[2], A_buffer[operands_wid_buf][2] },
|
||||
{ halves_buf.A_half[1], A_buffer[operands_wid_buf][1] },
|
||||
{ halves_buf.A_half[0], A_buffer[operands_wid_buf][0] }
|
||||
};
|
||||
// B is 2x4 fp32 matrix
|
||||
wire [1:0][3:0][31:0] B_tile = {
|
||||
B_half, B_buffer
|
||||
halves_buf.B_half, B_buffer[operands_wid_buf]
|
||||
};
|
||||
// C is 4x4 fp32 matrix
|
||||
logic [3:0][3:0][31:0] C_tile;
|
||||
logic [3:0][3:0][31:0] D_tile;
|
||||
logic [`NW_WIDTH-1:0] D_wid_dpu;
|
||||
|
||||
always @(*) begin
|
||||
C_tile = {
|
||||
C_half[7], C_buffer[7], C_half[5], C_buffer[5],
|
||||
C_half[6], C_buffer[6], C_half[4], C_buffer[4],
|
||||
C_half[3], C_buffer[3], C_half[1], C_buffer[1],
|
||||
C_half[2], C_buffer[2], C_half[0], C_buffer[0]
|
||||
};
|
||||
C_tile[3] = { halves_buf.C_half[7], C_buffer[operands_wid_buf][7], halves_buf.C_half[5], C_buffer[operands_wid_buf][5] };
|
||||
C_tile[2] = { halves_buf.C_half[6], C_buffer[operands_wid_buf][6], halves_buf.C_half[4], C_buffer[operands_wid_buf][4] };
|
||||
C_tile[1] = { halves_buf.C_half[3], C_buffer[operands_wid_buf][3], halves_buf.C_half[1], C_buffer[operands_wid_buf][1] };
|
||||
C_tile[0] = { halves_buf.C_half[2], C_buffer[operands_wid_buf][2], halves_buf.C_half[0], C_buffer[operands_wid_buf][0] };
|
||||
end
|
||||
|
||||
wire do_hmma = (substep == 1'b1 && operands_valid && operands_ready);
|
||||
wire dpu_valid;
|
||||
|
||||
// this does (m,n,k)=(4,4,2) matmul, modeling compute of a single octet
|
||||
VX_tensor_dpu #(
|
||||
.ISW(ISW),
|
||||
.OCTET(OCTET)
|
||||
.OCTET(OCTET),
|
||||
.ISSUE_QUEUE_DEPTH(4 /*@perf: arbtirary*/)
|
||||
) dpu (
|
||||
.clk(clk),
|
||||
.reset(reset),
|
||||
|
||||
.stall(stall),
|
||||
|
||||
.valid_in(do_hmma),
|
||||
.ready_in(hmma_ready),
|
||||
.A_tile(A_tile),
|
||||
.B_tile(B_tile),
|
||||
.C_tile(C_tile),
|
||||
.wid(operands_wid_buf),
|
||||
|
||||
.valid_out(result_valid),
|
||||
.D_tile(D_out)
|
||||
.valid_out(dpu_valid),
|
||||
.ready_out(outbuf_ready_in),
|
||||
.D_tile(D_tile),
|
||||
.D_wid(D_wid_dpu)
|
||||
);
|
||||
|
||||
wire outbuf_empty;
|
||||
wire outbuf_full;
|
||||
// backpressure from commit
|
||||
assign outbuf_ready_in = ~outbuf_full;
|
||||
assign result_valid = ~outbuf_empty;
|
||||
|
||||
wire outbuf_enq = outbuf_ready_in && dpu_valid;
|
||||
wire outbuf_deq = result_valid && result_ready;
|
||||
|
||||
// buffer to stage the result D tile for 2 cycles until commit/writeback
|
||||
// is complete. This decouples the irregular dpu output traffic from the
|
||||
// regular, every-2-cycle commit traffic to ensure the commit pipeline is
|
||||
// used more efficiently.
|
||||
// FIXME: unnecessary?
|
||||
VX_fifo_queue #(
|
||||
.DATAW ($bits(D_wid) + $bits(D_out)),
|
||||
.DEPTH (2 /* arbitrary */)
|
||||
) output_buffer (
|
||||
.clk (clk),
|
||||
.reset (reset),
|
||||
.push (outbuf_enq),
|
||||
.pop (outbuf_deq),
|
||||
.data_in ({D_wid_dpu, D_tile}),
|
||||
.data_out ({D_wid, D_out}),
|
||||
.empty (outbuf_empty),
|
||||
`UNUSED_PIN(alm_empty),
|
||||
.full (outbuf_full), // should be impossible to overflow
|
||||
`UNUSED_PIN(alm_full),
|
||||
`UNUSED_PIN(size)
|
||||
);
|
||||
|
||||
// FIXME: this shouldn't be necessary
|
||||
`RUNTIME_ASSERT(reset || !outbuf_full, ("dpu result queue is full!"))
|
||||
|
||||
`ifdef PERF_ENABLE
|
||||
logic [`PERF_CTR_BITS-1:0] perf_tensor_dpu_total;
|
||||
|
||||
always @(posedge clk) begin
|
||||
if (reset) begin
|
||||
perf_tensor_dpu_total <= '0;
|
||||
end else begin
|
||||
if (do_hmma) begin
|
||||
perf_tensor_dpu_total <= perf_tensor_dpu_total + 2'd2;
|
||||
end
|
||||
end
|
||||
end
|
||||
`endif
|
||||
endmodule
|
||||
`endif
|
||||
|
||||
49
hw/rtl/core/VX_tensor_ucode_8lanes.vh
Normal file
49
hw/rtl/core/VX_tensor_ucode_8lanes.vh
Normal file
@@ -0,0 +1,49 @@
|
||||
// uop metadata (sequencing, next state), execution metadata (EX_TYPE, OP_TYPE, OP_MOD), wb, use pc, use imm, pc, imm, rd, rs1, rs2, rs3
|
||||
HMMA_SET0_STEP0_0: begin
|
||||
uop = {NEXT, HMMA_SET0_STEP0_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(16), `FREG(0), `FREG(8), `FREG(16)};
|
||||
end
|
||||
HMMA_SET0_STEP0_1: begin
|
||||
uop = {NEXT, HMMA_SET0_STEP1_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(17), `FREG(1), `FREG(9), `FREG(17)};
|
||||
end
|
||||
HMMA_SET0_STEP1_0: begin
|
||||
uop = {NEXT, HMMA_SET0_STEP1_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(18), `FREG(0), `FREG(8), `FREG(18)};
|
||||
end
|
||||
HMMA_SET0_STEP1_1: begin
|
||||
uop = {NEXT, HMMA_SET0_STEP2_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(19), `FREG(1), `FREG(9), `FREG(19)};
|
||||
end
|
||||
HMMA_SET0_STEP2_0: begin
|
||||
uop = {NEXT, HMMA_SET0_STEP2_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(20), `FREG(0), `FREG(8), `FREG(20)};
|
||||
end
|
||||
HMMA_SET0_STEP2_1: begin
|
||||
uop = {NEXT, HMMA_SET0_STEP3_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(21), `FREG(1), `FREG(9), `FREG(21)};
|
||||
end
|
||||
HMMA_SET0_STEP3_0: begin
|
||||
uop = {NEXT, HMMA_SET0_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(0), `FREG(8), `FREG(22)};
|
||||
end
|
||||
HMMA_SET0_STEP3_1: begin
|
||||
uop = {NEXT, HMMA_SET1_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(23), `FREG(1), `FREG(9), `FREG(23)};
|
||||
end
|
||||
HMMA_SET1_STEP0_0: begin
|
||||
uop = {NEXT, HMMA_SET1_STEP0_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(16), `FREG(2), `FREG(10), `FREG(16)};
|
||||
end
|
||||
HMMA_SET1_STEP0_1: begin
|
||||
uop = {NEXT, HMMA_SET1_STEP1_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(17), `FREG(3), `FREG(11), `FREG(17)};
|
||||
end
|
||||
HMMA_SET1_STEP1_0: begin
|
||||
uop = {NEXT, HMMA_SET1_STEP1_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(18), `FREG(2), `FREG(10), `FREG(18)};
|
||||
end
|
||||
HMMA_SET1_STEP1_1: begin
|
||||
uop = {NEXT, HMMA_SET1_STEP2_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(19), `FREG(3), `FREG(11), `FREG(19)};
|
||||
end
|
||||
HMMA_SET1_STEP2_0: begin
|
||||
uop = {NEXT, HMMA_SET1_STEP2_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(20), `FREG(2), `FREG(10), `FREG(20)};
|
||||
end
|
||||
HMMA_SET1_STEP2_1: begin
|
||||
uop = {NEXT, HMMA_SET1_STEP3_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(21), `FREG(3), `FREG(11), `FREG(21)};
|
||||
end
|
||||
HMMA_SET1_STEP3_0: begin
|
||||
uop = {NEXT, HMMA_SET1_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(2), `FREG(10), `FREG(22)};
|
||||
end
|
||||
HMMA_SET1_STEP3_1: begin
|
||||
uop = {FINISH, HMMA_SET0_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b1, 32'b1, `FREG(23), `FREG(3), `FREG(11), `FREG(23)};
|
||||
end
|
||||
@@ -14,10 +14,9 @@ module VX_uop_sequencer import VX_gpu_pkg::*; (
|
||||
localparam UOP_TABLE_SIZE = 64;
|
||||
localparam UPC_BITS = `CLOG2(UOP_TABLE_SIZE);
|
||||
|
||||
localparam NEXT = 2'b00;
|
||||
localparam FINISH = 2'b01;
|
||||
|
||||
localparam UBR_BITS = 2;
|
||||
localparam NEXT = UBR_BITS'(2'b00);
|
||||
localparam FINISH = UBR_BITS'(2'b01);
|
||||
|
||||
// uop metadata (sequencing, next state), execution metadata (EX_TYPE, OP_TYPE, OP_MOD), wb, use pc, use imm, pc, imm, rd, rs1, rs2, rs3
|
||||
localparam UOP_TABLE_WIDTH = UBR_BITS + UPC_BITS + `EX_BITS + `INST_OP_BITS + `INST_MOD_BITS + 1 + 1 + 1 + `XLEN + `XLEN + (`NR_BITS * 4);
|
||||
@@ -122,7 +121,17 @@ module VX_uop_sequencer import VX_gpu_pkg::*; (
|
||||
// passthrough when !use_uop
|
||||
assign ibuffer_if.valid = use_uop ? 1'b1 : uop_sequencer_if.valid;
|
||||
assign uop_sequencer_if.ready = use_uop ? (uop_fire && ubr == FINISH) : ibuffer_if.ready;
|
||||
assign ibuffer_if.data = use_uop ? ibuffer_output : uop_sequencer_if.data;
|
||||
|
||||
always @(*) begin
|
||||
ibuffer_if.data = use_uop ? ibuffer_output : uop_sequencer_if.data;
|
||||
|
||||
if (uop_sequencer_if.valid && use_uop &&
|
||||
uop_sequencer_if.data.rd == `NR_BITS'(1)) begin
|
||||
// a little sketchy? but shouldn't create any loop
|
||||
ibuffer_if.data.rd = ibuffer_if.data.rd + `NR_BITS'(8); // FIXME: 8 is hardcoded
|
||||
ibuffer_if.data.rs3 = ibuffer_if.data.rs3 + `NR_BITS'(8);
|
||||
end
|
||||
end
|
||||
|
||||
always @(posedge clk) begin
|
||||
if (uop_start) begin
|
||||
|
||||
@@ -3,44 +3,274 @@
|
||||
|
||||
module VX_tensor_dpu #(
|
||||
parameter ISW,
|
||||
parameter OCTET
|
||||
parameter OCTET,
|
||||
// @perf: has big impact on throughput. A rule of thumb is to set it to
|
||||
// the pipeline length of FEDPs in order to make sure there are enough
|
||||
// entries to fully saturate the pipeline, but this is still rough
|
||||
parameter ISSUE_QUEUE_DEPTH = `LATENCY_HMMA
|
||||
) (
|
||||
input clk,
|
||||
input reset,
|
||||
|
||||
input stall,
|
||||
|
||||
input valid_in,
|
||||
output ready_in,
|
||||
input [3:0][1:0][31:0] A_tile,
|
||||
input [1:0][3:0][31:0] B_tile,
|
||||
input [3:0][3:0][31:0] C_tile,
|
||||
input [`NW_WIDTH-1:0] wid,
|
||||
|
||||
output valid_out,
|
||||
output [3:0][3:0][31:0] D_tile
|
||||
input ready_out,
|
||||
output [3:0][3:0][31:0] D_tile,
|
||||
output [`NW_WIDTH-1:0] D_wid
|
||||
);
|
||||
logic [3:0][3:0][31:0] result_hmma;
|
||||
// logic [3:0][3:0][31:0] result_hmma;
|
||||
|
||||
// always @(*) begin
|
||||
// dpi_hmma(valid_in, A_tile, B_tile, C_tile, result_hmma);
|
||||
// end
|
||||
|
||||
// logic ready_reg;
|
||||
// always @(posedge clk) begin
|
||||
// if (reset) begin
|
||||
// ready_reg <= '1;
|
||||
// end else if (valid_in && ready_in) begin
|
||||
// ready_reg <= '0;
|
||||
// dpi_print_results(int'(ISW), int'(OCTET), A_tile, B_tile, C_tile, result_hmma);
|
||||
// end else if (valid_out && ready_out) begin
|
||||
// ready_reg <= '1;
|
||||
// end
|
||||
// end
|
||||
|
||||
// // fixed-latency queue
|
||||
// VX_shift_register #(
|
||||
// .DATAW (1 + $bits(wid)/* + $bits(D_tile)*/),
|
||||
// .DEPTH (`LATENCY_HMMA + 1),
|
||||
// .RESETW (1)
|
||||
// ) shift_reg (
|
||||
// .clk (clk),
|
||||
// .reset (reset),
|
||||
// .enable (ready_out),
|
||||
// .data_in ({valid_in && ready_in, wid /*, result_hmma*/}),
|
||||
// .data_out ({valid_out, D_wid/*, D_tile */})
|
||||
// );
|
||||
|
||||
// ready as soon as valid_out
|
||||
// assign ready_in = ready_reg || valid_out;
|
||||
|
||||
// fully pipelined; ready_in is coupled to ready_out by immediately
|
||||
// stalling
|
||||
// assign ready_in = ready_out;
|
||||
|
||||
logic synced_fire;
|
||||
assign synced_fire = valid_in && ready_in;
|
||||
|
||||
logic [1:0] threadgroup_valids;
|
||||
logic [1:0] threadgroup_readys;
|
||||
// B_tile is shared across the two threadgroups; see Figure 13
|
||||
VX_tensor_threadgroup #(
|
||||
.ISSUE_QUEUE_DEPTH(ISSUE_QUEUE_DEPTH)
|
||||
) threadgroup_0 (
|
||||
.clk (clk),
|
||||
.reset (reset),
|
||||
.valid_in (synced_fire),
|
||||
.ready_in (threadgroup_readys[0]),
|
||||
.stall (!ready_out),
|
||||
.A_frag (A_tile[1:0]),
|
||||
.B_frag (B_tile),
|
||||
.C_frag (C_tile[1:0]),
|
||||
.valid_out (threadgroup_valids[0]),
|
||||
.D_frag (D_tile[1:0])
|
||||
);
|
||||
VX_tensor_threadgroup #(
|
||||
.ISSUE_QUEUE_DEPTH(ISSUE_QUEUE_DEPTH)
|
||||
) threadgroup_1 (
|
||||
.clk (clk),
|
||||
.reset (reset),
|
||||
.valid_in (synced_fire),
|
||||
.ready_in (threadgroup_readys[1]),
|
||||
.stall (!ready_out),
|
||||
.A_frag (A_tile[3:2]),
|
||||
.B_frag (B_tile),
|
||||
.C_frag (C_tile[3:2]),
|
||||
.valid_out (threadgroup_valids[1]),
|
||||
.D_frag (D_tile[3:2])
|
||||
);
|
||||
|
||||
wire empty;
|
||||
wire full;
|
||||
wire enq = valid_in && ready_in;
|
||||
wire deq = valid_out && ready_out;
|
||||
|
||||
assign ready_in = &(threadgroup_readys) && !full;
|
||||
assign valid_out = &(threadgroup_valids);
|
||||
|
||||
// need to pass along warp id's to do multithreading
|
||||
VX_fifo_queue #(
|
||||
.DATAW ($bits(wid)),
|
||||
// @perf: seems to require deeper depth than the FEDP issue queues to
|
||||
// not cause stalls.
|
||||
.DEPTH (2 * ISSUE_QUEUE_DEPTH)
|
||||
) wid_queue (
|
||||
.clk (clk),
|
||||
.reset (reset),
|
||||
.push (enq),
|
||||
.pop (deq),
|
||||
.data_in (wid),
|
||||
.data_out (D_wid),
|
||||
.empty (empty),
|
||||
`UNUSED_PIN(alm_empty),
|
||||
.full (full),
|
||||
`UNUSED_PIN(alm_full),
|
||||
`UNUSED_PIN(size)
|
||||
);
|
||||
|
||||
`RUNTIME_ASSERT(reset || !(deq && empty),
|
||||
("dequeueing from empty warp id queue!"))
|
||||
endmodule
|
||||
|
||||
// does (m,n,k) = (2,4,2) matmul compute over 2 cycles.
|
||||
// matches Figure 10(b) of the paper.
|
||||
module VX_tensor_threadgroup #(
|
||||
parameter ISSUE_QUEUE_DEPTH
|
||||
) (
|
||||
input clk,
|
||||
input reset,
|
||||
|
||||
input valid_in,
|
||||
output ready_in,
|
||||
input stall,
|
||||
input [1:0][1:0][31:0] A_frag,
|
||||
input [1:0][3:0][31:0] B_frag,
|
||||
input [1:0][3:0][31:0] C_frag,
|
||||
|
||||
output valid_out,
|
||||
output [1:0][3:0][31:0] D_frag
|
||||
);
|
||||
wire [1:0][1:0][31:0] A_frag_buf;
|
||||
wire [1:0][3:0][31:0] B_frag_buf;
|
||||
wire [1:0][3:0][31:0] C_frag_buf;
|
||||
|
||||
wire valid_buf;
|
||||
wire ready_buf;
|
||||
|
||||
wire enq = valid_in && ready_in;
|
||||
wire deq = valid_buf && ready_buf;
|
||||
wire empty;
|
||||
wire full;
|
||||
assign ready_in = !full;
|
||||
assign valid_buf = !empty;
|
||||
|
||||
// 'Issue queue' for the FEDP units.
|
||||
// This exists to decouple the execution of the dot-product unit from
|
||||
// the operand arrival. Operands from execute_if can arrive
|
||||
// intermittently according to the frontend's behavior, and since the dpu
|
||||
// can also stall for a fixed initiation latency, we need to decouple the
|
||||
// two to efficiently feed the dpu.
|
||||
//
|
||||
// TODO: better queue design possible; e.g. B_frag is shared by two
|
||||
// threadgroups, so we need only 1 queue per octet for B
|
||||
VX_fifo_queue #(
|
||||
.DATAW ($bits(A_frag) + $bits(B_frag) + $bits(C_frag)),
|
||||
.DEPTH (ISSUE_QUEUE_DEPTH)
|
||||
) input_buffer (
|
||||
.clk (clk),
|
||||
.reset (reset),
|
||||
.push (enq),
|
||||
.pop (deq),
|
||||
.data_in ({A_frag, B_frag, C_frag}),
|
||||
.data_out ({A_frag_buf, B_frag_buf, C_frag_buf}),
|
||||
.empty (empty),
|
||||
`UNUSED_PIN(alm_empty),
|
||||
.full (full),
|
||||
`UNUSED_PIN(alm_full),
|
||||
`UNUSED_PIN(size)
|
||||
);
|
||||
|
||||
logic [3:0] fedp_valids;
|
||||
wire fedp_valid_out = &(fedp_valids);
|
||||
wire fedp_ready_out = !stall;
|
||||
wire fedp_fire_out = fedp_valid_out && fedp_ready_out;
|
||||
|
||||
wire fedp_valid_in = valid_buf;
|
||||
wire fedp_ready_in = fedp_ready_out; // coupled
|
||||
wire fedp_fire_in = fedp_valid_in && fedp_ready_in;
|
||||
|
||||
// 0: FEDP uses first half from input_buffer
|
||||
// 1: FEDP uses last half and pops input_buffer
|
||||
logic step_in;
|
||||
// 0: FEDP produces first half of D_frag
|
||||
// 1: FEDP produces last half of D_frag and asserts valid_out
|
||||
logic step_out;
|
||||
assign ready_buf = fedp_fire_in && (step_in == 1'b1);
|
||||
|
||||
// latch the first-half result of D_frag
|
||||
logic [3:0][31:0] D_reg, D_reg_n;
|
||||
wire [3:0][31:0] D_half;
|
||||
always @(*) begin
|
||||
dpi_hmma(valid_in, A_tile, B_tile, C_tile, result_hmma);
|
||||
D_reg_n = D_reg;
|
||||
if (fedp_fire_out) begin
|
||||
if (step_out == 1'b0) begin
|
||||
D_reg_n = D_half;
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
always @(posedge clk) begin
|
||||
if (~reset && valid_in) begin
|
||||
dpi_print_results(int'(ISW), int'(OCTET), A_tile, B_tile, C_tile, result_hmma);
|
||||
if (reset) begin
|
||||
step_in <= '0;
|
||||
step_out <= '0;
|
||||
|
||||
D_reg <= '0;
|
||||
end else begin
|
||||
if (fedp_fire_in) begin
|
||||
step_in <= ~step_in;
|
||||
end
|
||||
if (fedp_fire_out) begin
|
||||
step_out <= ~step_out;
|
||||
end
|
||||
|
||||
D_reg <= D_reg_n;
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
VX_shift_register #(
|
||||
.DATAW (1 + $bits(D_tile)),
|
||||
.DEPTH (`LATENCY_HMMA),
|
||||
.RESETW (1)
|
||||
) shift_reg (
|
||||
.clk (clk),
|
||||
.reset (reset),
|
||||
.enable (~stall),
|
||||
.data_in ({valid_in, result_hmma}),
|
||||
.data_out ({valid_out, D_tile})
|
||||
);
|
||||
// TODO: Instead of latching half-result and constructing a full D tile,
|
||||
// we should be able to send these half fragments down to commit stage
|
||||
// immediately, saving flop space
|
||||
assign D_frag[0][0] = D_reg[0];
|
||||
assign D_frag[0][2] = D_reg[1];
|
||||
assign D_frag[1][0] = D_reg[2];
|
||||
assign D_frag[1][2] = D_reg[3];
|
||||
assign D_frag[0][1] = D_half[0];
|
||||
assign D_frag[0][3] = D_half[1];
|
||||
assign D_frag[1][1] = D_half[2];
|
||||
assign D_frag[1][3] = D_half[3];
|
||||
|
||||
// 4 FEDPs per threadgroup
|
||||
for (genvar i = 0; i < 4; ++i) begin
|
||||
localparam int d_row = i / 2;
|
||||
localparam int d_col = (i % 2) * 2;
|
||||
// four-element dot product (FEDP) unit
|
||||
TensorDotProductUnit fedp (
|
||||
.clock (clk),
|
||||
.reset (reset),
|
||||
.io_in_valid (fedp_fire_in),
|
||||
.io_in_bits_a_0 (A_frag_buf[d_row][0]),
|
||||
.io_in_bits_a_1 (A_frag_buf[d_row][1]),
|
||||
.io_in_bits_a_2 (32'h0),
|
||||
.io_in_bits_a_3 (32'h0),
|
||||
.io_in_bits_b_0 (step_in == 1'b0 ? B_frag_buf[0][d_col] : B_frag_buf[0][d_col + 1]),
|
||||
.io_in_bits_b_1 (step_in == 1'b0 ? B_frag_buf[1][d_col] : B_frag_buf[1][d_col + 1]),
|
||||
.io_in_bits_b_2 (32'h0),
|
||||
.io_in_bits_b_3 (32'h0),
|
||||
.io_in_bits_c (step_in == 1'b0 ? C_frag_buf[d_row][d_col] : C_frag_buf[d_row][d_col + 1]),
|
||||
.io_stall (stall),
|
||||
.io_out_valid (fedp_valids[i]),
|
||||
.io_out_bits_data (D_half[i])
|
||||
);
|
||||
end
|
||||
|
||||
assign valid_out = fedp_valid_out && (step_out == 1'b1);
|
||||
endmodule
|
||||
|
||||
`endif
|
||||
|
||||
Reference in New Issue
Block a user