Implement WU architecture support

This commit is contained in:
2026-05-25 19:25:05 +08:00
parent 323ed7d7e9
commit 0ad87bde81
35 changed files with 3303 additions and 472 deletions

View File

@@ -14,7 +14,8 @@
`include "VX_define.vh"
module VX_schedule import VX_gpu_pkg::*; #(
parameter CORE_ID = 0
parameter CORE_ID = 0,
parameter NUM_BRANCHES = `NUM_ALU_BLOCKS
) (
input wire clk,
input wire reset,
@@ -28,12 +29,20 @@ module VX_schedule import VX_gpu_pkg::*; #(
// inputsdecode_if
VX_warp_ctl_if.slave warp_ctl_if,
VX_branch_ctl_if.slave branch_ctl_if [`NUM_ALU_BLOCKS],
VX_branch_ctl_if.slave branch_ctl_if [NUM_BRANCHES],
VX_decode_sched_if.slave decode_sched_if,
VX_commit_sched_if.slave commit_sched_if,
`ifdef EXT_T_ENABLE
input wire tensor_csr_unlock_valid,
input wire [`NW_WIDTH-1:0] tensor_csr_unlock_wid,
input wire tensor_tmc_valid,
input wire [`NW_WIDTH-1:0] tensor_tmc_wid,
input wire [`NUM_THREADS-1:0] tensor_tmc_tmask,
`endif
// outputs
VX_schedule_if.master schedule_if,
VX_schedule_if.master scalar_schedule_if,
VX_schedule_if.master tensor_schedule_if,
`ifdef GBAR_ENABLE
VX_gbar_bus_if.master gbar_bus_if,
`endif
@@ -50,11 +59,10 @@ module VX_schedule import VX_gpu_pkg::*; #(
reg [`NUM_WARPS-1:0][`NUM_THREADS-1:0] thread_masks, thread_masks_n;
reg [`NUM_WARPS-1:0][`XLEN-1:0] warp_pcs, warp_pcs_n;
wire [`NW_WIDTH-1:0] schedule_wid;
wire [`NUM_THREADS-1:0] schedule_tmask;
wire [`XLEN-1:0] schedule_pc;
wire schedule_valid;
wire schedule_ready;
wire scalar_schedule_fire = scalar_schedule_if.valid && scalar_schedule_if.ready;
wire tensor_schedule_fire = tensor_schedule_if.valid && tensor_schedule_if.ready;
wire schedule_fire_any = scalar_schedule_fire || tensor_schedule_fire;
wire [`NW_WIDTH-1:0] schedule_fire_wid = tensor_schedule_fire ? tensor_schedule_if.data.wid : scalar_schedule_if.data.wid;
// split/join
wire join_valid;
@@ -68,15 +76,14 @@ module VX_schedule import VX_gpu_pkg::*; #(
reg [`NUM_WARPS-1:0][`UUID_WIDTH-1:0] issued_instrs;
wire schedule_fire = schedule_valid && schedule_ready;
wire schedule_if_fire = schedule_if.valid && schedule_if.ready;
wire schedule_if_fire = schedule_fire_any;
// branch
wire [`NUM_ALU_BLOCKS-1:0] branch_valid;
wire [`NUM_ALU_BLOCKS-1:0][`NW_WIDTH-1:0] branch_wid;
wire [`NUM_ALU_BLOCKS-1:0] branch_taken;
wire [`NUM_ALU_BLOCKS-1:0][`XLEN-1:0] branch_dest;
for (genvar i = 0; i < `NUM_ALU_BLOCKS; ++i) begin
wire [NUM_BRANCHES-1:0] branch_valid;
wire [NUM_BRANCHES-1:0][`NW_WIDTH-1:0] branch_wid;
wire [NUM_BRANCHES-1:0] branch_taken;
wire [NUM_BRANCHES-1:0][`XLEN-1:0] branch_dest;
for (genvar i = 0; i < NUM_BRANCHES; ++i) begin
assign branch_valid[i] = branch_ctl_if[i].valid;
assign branch_wid[i] = branch_ctl_if[i].wid;
assign branch_taken[i] = branch_ctl_if[i].taken;
@@ -87,7 +94,13 @@ module VX_schedule import VX_gpu_pkg::*; #(
reg [`NUM_BARRIERS-1:0][`NUM_WARPS-1:0] barrier_masks, barrier_masks_n;
reg [`NUM_WARPS-1:0] barrier_stalls, barrier_stalls_n;
wire [`CLOG2(`NUM_WARPS+1)-1:0] active_barrier_count;
wire [`NUM_WARPS-1:0] curr_barrier_mask;
wire [`NUM_WARPS-1:0] curr_barrier_mask;
wire [`NUM_WARPS-1:0] curr_barrier_mask_with_self;
wire [`NUM_WARPS-1:0] scalar_warp_mask;
wire [`NUM_WARPS-1:0] tensor_warp_mask;
wire [`NUM_WARPS-1:0] barrier_domain_mask;
wire [`NUM_WARPS-1:0] barrier_arrived_mask;
wire [`CLOG2(`NUM_WARPS+1)-1:0] barrier_arrived_count;
`ifdef GBAR_ENABLE
reg [`NUM_WARPS-1:0] curr_barrier_mask_n;
reg gbar_req_valid;
@@ -95,8 +108,21 @@ module VX_schedule import VX_gpu_pkg::*; #(
reg [`NC_WIDTH-1:0] gbar_req_size_m1;
`endif
for (genvar i = 0; i < `NUM_WARPS; ++i) begin
assign scalar_warp_mask[i] = `IS_SCALAR_WARP(i);
assign tensor_warp_mask[i] = `IS_TENSOR_WARP(i);
end
assign curr_barrier_mask = barrier_masks[warp_ctl_if.barrier.id];
assign curr_barrier_mask_with_self = curr_barrier_mask | (`NUM_WARPS'(1) << warp_ctl_if.wid);
assign barrier_domain_mask =
(warp_ctl_if.barrier.domain == BARRIER_SCALAR) ? (active_warps & scalar_warp_mask) :
(warp_ctl_if.barrier.domain == BARRIER_TENSOR) ? (active_warps & tensor_warp_mask) :
(warp_ctl_if.barrier.domain == BARRIER_MASK) ? (active_warps & warp_ctl_if.barrier.mask) :
active_warps;
assign barrier_arrived_mask = curr_barrier_mask_with_self & barrier_domain_mask;
`POP_COUNT(active_barrier_count, curr_barrier_mask);
`POP_COUNT(barrier_arrived_count, barrier_arrived_mask);
`UNUSED_VAR (active_barrier_count)
always @(*) begin
@@ -152,9 +178,11 @@ module VX_schedule import VX_gpu_pkg::*; #(
`endif
if (warp_ctl_if.valid && warp_ctl_if.barrier.valid) begin
if (~warp_ctl_if.barrier.is_global
&& (active_barrier_count[`NW_WIDTH-1:0] == warp_ctl_if.barrier.size_m1[`NW_WIDTH-1:0])) begin
&& ((warp_ctl_if.barrier.domain == BARRIER_MASK)
? ((barrier_arrived_mask & warp_ctl_if.barrier.mask) == warp_ctl_if.barrier.mask)
: (barrier_arrived_count[`NW_WIDTH-1:0] == (warp_ctl_if.barrier.size_m1[`NW_WIDTH-1:0] + `NW_WIDTH'(1))))) begin
barrier_masks_n[warp_ctl_if.barrier.id] = '0;
barrier_stalls_n &= ~barrier_masks[warp_ctl_if.barrier.id];
barrier_stalls_n &= ~barrier_arrived_mask;
end else begin
barrier_masks_n[warp_ctl_if.barrier.id][warp_ctl_if.wid] = 1;
barrier_stalls_n[warp_ctl_if.wid] = 1;
@@ -186,7 +214,7 @@ module VX_schedule import VX_gpu_pkg::*; #(
`endif
// Branch handling
for (integer i = 0; i < `NUM_ALU_BLOCKS; ++i) begin
for (integer i = 0; i < NUM_BRANCHES; ++i) begin
if (branch_valid[i]) begin
if (branch_taken[i]) begin
warp_pcs_n[branch_wid[i]] = branch_dest[i];
@@ -205,14 +233,31 @@ module VX_schedule import VX_gpu_pkg::*; #(
stalled_warps_n[sched_csr_if.unlock_wid] = 0;
end
`ifdef EXT_T_ENABLE
// Tensor control handles a minimal CSR-read/TMC subset without
// reusing the scalar SFU.
if (tensor_csr_unlock_valid) begin
stalled_warps_n[tensor_csr_unlock_wid] = 0;
end
if (tensor_tmc_valid) begin
active_warps_n[tensor_tmc_wid] = (tensor_tmc_tmask != 0);
thread_masks_n[tensor_tmc_wid] = tensor_tmc_tmask;
stalled_warps_n[tensor_tmc_wid] = 0;
end
`endif
// stall the warp until decode stage
if (schedule_fire) begin
stalled_warps_n[schedule_wid] = 1;
if (schedule_fire_any) begin
stalled_warps_n[schedule_fire_wid] = 1;
end
// advance PC
if (schedule_if_fire) begin
warp_pcs_n[schedule_if.data.wid] = schedule_if.data.PC + 4;
if (scalar_schedule_fire) begin
warp_pcs_n[scalar_schedule_if.data.wid] = scalar_schedule_if.data.PC + 4;
end
if (tensor_schedule_fire) begin
warp_pcs_n[tensor_schedule_if.data.wid] = tensor_schedule_if.data.PC + 4;
end
end
@@ -251,9 +296,9 @@ module VX_schedule import VX_gpu_pkg::*; #(
`ifdef GBAR_CLUSTER_ENABLE
// engage cluster barrier as soon as the barrier count is
// fulfilled, instead of requiring all warps to be synchronized
&& (active_barrier_count[`NW_WIDTH-1:0] == warp_ctl_if.barrier.size_m1[`NW_WIDTH-1:0])) begin
&& (barrier_arrived_count[`NW_WIDTH-1:0] == (warp_ctl_if.barrier.size_m1[`NW_WIDTH-1:0] + `NW_WIDTH'(1)))) begin
`else
&& (curr_barrier_mask_n == active_warps)) begin
&& (barrier_arrived_mask == barrier_domain_mask)) begin
`endif
gbar_req_valid <= 1;
gbar_req_id <= warp_ctl_if.barrier.id;
@@ -264,8 +309,11 @@ module VX_schedule import VX_gpu_pkg::*; #(
end
`endif
if (schedule_if_fire) begin
issued_instrs[schedule_if.data.wid] <= issued_instrs[schedule_if.data.wid] + `UUID_WIDTH'(1);
if (scalar_schedule_fire) begin
issued_instrs[scalar_schedule_if.data.wid] <= issued_instrs[scalar_schedule_if.data.wid] + `UUID_WIDTH'(1);
end
if (tensor_schedule_fire) begin
issued_instrs[tensor_schedule_if.data.wid] <= issued_instrs[tensor_schedule_if.data.wid] + `UUID_WIDTH'(1);
end
if (busy) begin
@@ -309,15 +357,33 @@ module VX_schedule import VX_gpu_pkg::*; #(
// schedule the next ready warp
wire [`NUM_WARPS-1:0] ready_warps = active_warps & ~(stalled_warps | barrier_stalls);
wire [`NUM_WARPS-1:0] scalar_ready_warps = ready_warps & scalar_warp_mask;
wire [`NUM_WARPS-1:0] tensor_ready_warps = ready_warps & tensor_warp_mask;
wire [`NW_WIDTH-1:0] scalar_schedule_wid;
wire [`NW_WIDTH-1:0] tensor_schedule_wid;
wire scalar_schedule_valid;
wire tensor_schedule_valid;
wire scalar_schedule_ready;
wire tensor_schedule_ready;
VX_lzc_rr #(
.N (`NUM_WARPS)
) wid_select (
) scalar_wid_select (
.clk (clk),
.reset (reset),
.data_in (ready_warps),
.data_out (schedule_wid),
.valid_out (schedule_valid)
.data_in (scalar_ready_warps),
.data_out (scalar_schedule_wid),
.valid_out (scalar_schedule_valid)
);
VX_lzc_rr #(
.N (`NUM_WARPS)
) tensor_wid_select (
.clk (clk),
.reset (reset),
.data_in (tensor_ready_warps),
.data_out (tensor_schedule_wid),
.valid_out (tensor_schedule_valid)
);
wire [`NUM_WARPS-1:0][(`NUM_THREADS + `XLEN)-1:0] schedule_data;
@@ -325,47 +391,78 @@ module VX_schedule import VX_gpu_pkg::*; #(
assign schedule_data[i] = {thread_masks[i], warp_pcs[i]};
end
assign {schedule_tmask, schedule_pc} = {
schedule_data[schedule_wid][(`NUM_THREADS + `XLEN)-1:(`NUM_THREADS + `XLEN)-4],
schedule_data[schedule_wid][(`NUM_THREADS + `XLEN)-5:0]
};
`ifndef NDEBUG
localparam GNW_WIDTH = `LOG2UP(`NUM_CLUSTERS * `NUM_CORES * `NUM_WARPS);
reg [`UUID_WIDTH-1:0] instr_uuid;
wire [GNW_WIDTH-1:0] g_wid = (GNW_WIDTH'(CORE_ID) << `NW_BITS) + GNW_WIDTH'(schedule_wid);
`ifdef SV_DPI
always @(posedge clk) begin
if (reset) begin
instr_uuid <= `UUID_WIDTH'(dpi_uuid_gen(1, 0, 0));
end else if (schedule_fire) begin
instr_uuid <= `UUID_WIDTH'(dpi_uuid_gen(0, 32'(g_wid), 64'(schedule_pc)));
end
end
function automatic [`UUID_WIDTH-1:0] schedule_uuid (
input logic [`NW_WIDTH-1:0] wid,
input logic [`XLEN-1:0] pc
);
logic [GNW_WIDTH-1:0] g_wid;
begin
g_wid = (GNW_WIDTH'(CORE_ID) << `NW_BITS) + GNW_WIDTH'(wid);
schedule_uuid = `UUID_WIDTH'({g_wid, 16'(pc)});
end
endfunction
`else
wire [GNW_WIDTH+16-1:0] w_uuid = {g_wid, 16'(schedule_pc)};
always @(*) begin
instr_uuid = `UUID_WIDTH'(w_uuid);
end
`endif
`else
wire [`UUID_WIDTH-1:0] instr_uuid = '0;
function automatic [`UUID_WIDTH-1:0] schedule_uuid (
input logic [`NW_WIDTH-1:0] wid,
input logic [`XLEN-1:0] pc
);
begin
`UNUSED_VAR (wid)
`UNUSED_VAR (pc)
schedule_uuid = '0;
end
endfunction
`endif
VX_elastic_buffer #(
.DATAW (`NUM_THREADS + `XLEN + `NW_WIDTH)
) out_buf (
.DATAW (`NUM_THREADS + `XLEN + `NW_WIDTH),
.SIZE (0)
) scalar_out_buf (
.clk (clk),
.reset (reset),
.valid_in (schedule_valid),
.ready_in (schedule_ready),
.data_in ({schedule_tmask, schedule_pc, schedule_wid}),
.data_out ({schedule_if.data.tmask, schedule_if.data.PC, schedule_if.data.wid}),
.valid_out (schedule_if.valid),
.ready_out (schedule_if.ready)
.valid_in (!reset && scalar_schedule_valid),
.ready_in (scalar_schedule_ready),
.data_in ({schedule_data[scalar_schedule_wid], scalar_schedule_wid}),
.data_out ({scalar_schedule_if.data.tmask, scalar_schedule_if.data.PC, scalar_schedule_if.data.wid}),
.valid_out (scalar_schedule_if.valid),
.ready_out (scalar_schedule_if.ready)
);
assign schedule_if.data.uuid = instr_uuid;
VX_elastic_buffer #(
.DATAW (`NUM_THREADS + `XLEN + `NW_WIDTH),
.SIZE (0)
) tensor_out_buf (
.clk (clk),
.reset (reset),
.valid_in (!reset && tensor_schedule_valid),
.ready_in (tensor_schedule_ready),
.data_in ({schedule_data[tensor_schedule_wid], tensor_schedule_wid}),
.data_out ({tensor_schedule_if.data.tmask, tensor_schedule_if.data.PC, tensor_schedule_if.data.wid}),
.valid_out (tensor_schedule_if.valid),
.ready_out (tensor_schedule_if.ready)
);
assign scalar_schedule_if.data.uuid = schedule_uuid(scalar_schedule_if.data.wid, scalar_schedule_if.data.PC);
assign tensor_schedule_if.data.uuid = schedule_uuid(tensor_schedule_if.data.wid, tensor_schedule_if.data.PC);
`RUNTIME_ASSERT(
!(scalar_schedule_fire && tensor_schedule_fire),
("%t: *** core%0d-schedule-two-domain-fire-with-single-fetch", $time, CORE_ID)
)
`RUNTIME_ASSERT(
!scalar_schedule_if.valid || `IS_SCALAR_WARP(scalar_schedule_if.data.wid),
("%t: *** core%0d-scalar-scheduler-issued-tensor-warp wid=%0d",
$time, CORE_ID, scalar_schedule_if.data.wid)
)
`RUNTIME_ASSERT(
!tensor_schedule_if.valid || `IS_TENSOR_WARP(tensor_schedule_if.data.wid),
("%t: *** core%0d-tensor-scheduler-issued-scalar-warp wid=%0d",
$time, CORE_ID, tensor_schedule_if.data.wid)
)
`RESET_RELAY (pending_instr_reset, reset);
@@ -377,8 +474,8 @@ module VX_schedule import VX_gpu_pkg::*; #(
) pending_instr(
.clk (clk),
.reset (pending_instr_reset),
.incr (schedule_if_fire),
.incr_wid (schedule_if.data.wid),
.incr (decode_sched_if.valid),
.incr_wid (decode_sched_if.wid),
.decr (commit_sched_if.committed),
.decr_wid (commit_sched_if.committed_wid),
.alm_empty_wid (sched_csr_if.alm_empty_wid),
@@ -413,13 +510,30 @@ module VX_schedule import VX_gpu_pkg::*; #(
end
`RUNTIME_ASSERT(timeout_ctr < `STALL_TIMEOUT, ("%t: *** core%0d-scheduler-timeout: stalled_warps=%b", $time, CORE_ID, stalled_warps));
`RUNTIME_ASSERT(
!(warp_ctl_if.valid && warp_ctl_if.barrier.valid) || barrier_domain_mask != '0,
("%t: *** core%0d-invalid-barrier-empty-domain: wid=%0d id=%0d domain=%0d active=%b mask=%b",
$time, CORE_ID, warp_ctl_if.wid, warp_ctl_if.barrier.id, warp_ctl_if.barrier.domain, active_warps, warp_ctl_if.barrier.mask)
)
`RUNTIME_ASSERT(
!(warp_ctl_if.valid && warp_ctl_if.barrier.valid) || barrier_domain_mask[warp_ctl_if.wid],
("%t: *** core%0d-invalid-barrier-wid-domain: wid=%0d id=%0d domain=%0d active=%b mask=%b",
$time, CORE_ID, warp_ctl_if.wid, warp_ctl_if.barrier.id, warp_ctl_if.barrier.domain, active_warps, warp_ctl_if.barrier.mask)
)
`ifdef PERF_ENABLE
reg [`PERF_CTR_BITS-1:0] perf_sched_idles;
reg [`PERF_CTR_BITS-1:0] perf_sched_stalls;
reg [`PERF_CTR_BITS-1:0] perf_sched_barrier_idles;
reg [`PERF_CTR_BITS-1:0] perf_scalar_sched_ready_cycles;
reg [`PERF_CTR_BITS-1:0] perf_tensor_sched_ready_cycles;
reg [`PERF_CTR_BITS-1:0] perf_scalar_sched_issued_cycles;
reg [`PERF_CTR_BITS-1:0] perf_tensor_sched_issued_cycles;
wire schedule_idle = ~schedule_valid;
wire schedule_stall = schedule_if.valid && ~schedule_if.ready;
wire schedule_idle = ~(scalar_schedule_if.valid || tensor_schedule_if.valid);
wire schedule_stall = (scalar_schedule_if.valid && ~scalar_schedule_if.ready)
|| (tensor_schedule_if.valid && ~tensor_schedule_if.ready);
wire [`CLOG2(`NUM_WARPS+1)-1:0] schedule_barrier_idle;
`POP_COUNT(schedule_barrier_idle, barrier_stalls);
@@ -427,17 +541,29 @@ module VX_schedule import VX_gpu_pkg::*; #(
if (reset) begin
perf_sched_idles <= '0;
perf_sched_barrier_idles <= '0;
perf_sched_stalls <= '0;
perf_sched_stalls <= '0;
perf_scalar_sched_ready_cycles <= '0;
perf_tensor_sched_ready_cycles <= '0;
perf_scalar_sched_issued_cycles <= '0;
perf_tensor_sched_issued_cycles <= '0;
end else begin
perf_sched_idles <= perf_sched_idles + `PERF_CTR_BITS'(schedule_idle);
perf_sched_barrier_idles <= perf_sched_barrier_idles + `PERF_CTR_BITS'(schedule_barrier_idle);
perf_sched_stalls <= perf_sched_stalls + `PERF_CTR_BITS'(schedule_stall);
perf_scalar_sched_ready_cycles <= perf_scalar_sched_ready_cycles + `PERF_CTR_BITS'(scalar_schedule_valid);
perf_tensor_sched_ready_cycles <= perf_tensor_sched_ready_cycles + `PERF_CTR_BITS'(tensor_schedule_valid);
perf_scalar_sched_issued_cycles <= perf_scalar_sched_issued_cycles + `PERF_CTR_BITS'(scalar_schedule_fire);
perf_tensor_sched_issued_cycles <= perf_tensor_sched_issued_cycles + `PERF_CTR_BITS'(tensor_schedule_fire);
end
end
assign perf_schedule_if.sched_idles = perf_sched_idles;
assign perf_schedule_if.sched_barrier_idles = perf_sched_barrier_idles;
assign perf_schedule_if.sched_stalls = perf_sched_stalls;
assign perf_schedule_if.sched_stalls = perf_sched_stalls;
assign perf_schedule_if.scalar_sched_ready_cycles = perf_scalar_sched_ready_cycles;
assign perf_schedule_if.tensor_sched_ready_cycles = perf_tensor_sched_ready_cycles;
assign perf_schedule_if.scalar_sched_issued_cycles = perf_scalar_sched_issued_cycles;
assign perf_schedule_if.tensor_sched_issued_cycles = perf_tensor_sched_issued_cycles;
`endif
endmodule