diff --git a/hw/rtl/VX_warp_sched.v b/hw/rtl/VX_warp_sched.v index e8bc6c82..bf8c979f 100644 --- a/hw/rtl/VX_warp_sched.v +++ b/hw/rtl/VX_warp_sched.v @@ -1,66 +1,65 @@ `include "VX_define.vh" module VX_warp_sched ( - input wire clk, // Clock - input wire reset, - input wire stall, + input wire clk, // Clock + input wire reset, + input wire stall, // Wspawn - input wire wspawn, - input wire[31:0] wsapwn_pc, - input wire[`NUM_WARPS-1:0] wspawn_new_active, + input wire wspawn, + input wire[31:0] wsapwn_pc, + input wire[`NUM_WARPS-1:0] wspawn_new_active, // CTM - input wire ctm, - input wire[`NUM_THREADS-1:0] ctm_mask, - input wire[`NW_BITS-1:0] ctm_warp_num, + input wire ctm, + input wire[`NUM_THREADS-1:0] ctm_mask, + input wire[`NW_BITS-1:0] ctm_warp_num, // WHALT - input wire whalt, - input wire[`NW_BITS-1:0] whalt_warp_num, + input wire whalt, + input wire[`NW_BITS-1:0] whalt_warp_num, - input wire is_barrier, + input wire is_barrier, `DEBUG_BEGIN - input wire[31:0] barrier_id, + input wire[31:0] barrier_id, `DEBUG_END - input wire[$clog2(`NUM_WARPS):0] num_warps, - input wire[`NW_BITS-1:0] barrier_warp_num, + input wire[$clog2(`NUM_WARPS):0] num_warps, + input wire[`NW_BITS-1:0] barrier_warp_num, // WSTALL - input wire wstall, - input wire[`NW_BITS-1:0] wstall_warp_num, + input wire wstall, + input wire[`NW_BITS-1:0] wstall_warp_num, // Split - input wire is_split, - input wire dont_split, - input wire[`NUM_THREADS-1:0] split_new_mask, - input wire[`NUM_THREADS-1:0] split_later_mask, - input wire[31:0] split_save_pc, - input wire[`NW_BITS-1:0] split_warp_num, + input wire is_split, + input wire dont_split, + input wire[`NUM_THREADS-1:0] split_new_mask, + input wire[`NUM_THREADS-1:0] split_later_mask, + input wire[31:0] split_save_pc, + input wire[`NW_BITS-1:0] split_warp_num, // Join - input wire is_join, - input wire[`NW_BITS-1:0] join_warp_num, + input wire is_join, + input wire[`NW_BITS-1:0] join_warp_num, // JAL - input wire jal, - input wire[31:0] jal_dest, - input wire[`NW_BITS-1:0] jal_warp_num, + input wire jal, + input wire[31:0] jal_dest, + input wire[`NW_BITS-1:0] jal_warp_num, // Branch - input wire branch_valid, - input wire branch_dir, - input wire[31:0] branch_dest, - input wire[`NW_BITS-1:0] branch_warp_num, + input wire branch_valid, + input wire branch_dir, + input wire[31:0] branch_dest, + input wire[`NW_BITS-1:0] branch_warp_num, - output wire[`NUM_THREADS-1:0] thread_mask, - output wire[`NW_BITS-1:0] warp_num, - output wire[31:0] warp_pc, - output wire ebreak, - output wire scheduled_warp, - - input wire[`NW_BITS-1:0] icache_stage_wid, - input wire[`NUM_THREADS-1:0] icache_stage_valids + output wire[`NUM_THREADS-1:0] thread_mask, + output wire[`NW_BITS-1:0] warp_num, + output wire[31:0] warp_pc, + output wire ebreak, + output wire scheduled_warp, + input wire[`NW_BITS-1:0] icache_stage_wid, + input wire[`NUM_THREADS-1:0] icache_stage_valids ); wire update_use_wspawn; wire update_visible_active; @@ -226,16 +225,21 @@ module VX_warp_sched ( end end - VX_countones #(.N(`NUM_WARPS)) barrier_count( + VX_countones #( + .N(`NUM_WARPS) + ) barrier_count ( .valids(curr_barrier_mask), .count (curr_barrier_count) - ); + ); - wire[$clog2(`NUM_WARPS):0] count_visible_active; - VX_countones #(.N(`NUM_WARPS)) num_visible( + wire [$clog2(`NUM_WARPS):0] count_visible_active; + + VX_countones #( + .N(`NUM_WARPS) + ) num_visible ( .valids(visible_active), .count (count_visible_active) - ); + ); // assign curr_barrier_count = $countones(curr_barrier_mask); @@ -254,17 +258,13 @@ module VX_warp_sched ( // end // end - assign update_visible_active = (count_visible_active < 1) && !(stall || wstall_this_cycle || hazard || is_join); wire[(1+32+`NUM_THREADS-1):0] q1 = {1'b1, 32'b0 , thread_masks[split_warp_num]}; wire[(1+32+`NUM_THREADS-1):0] q2 = {1'b0, split_save_pc , split_later_mask}; - assign {join_fall, join_pc, join_tm} = d[join_warp_num]; - - genvar curr_warp; generate for (curr_warp = 0; curr_warp < `NUM_WARPS; curr_warp = curr_warp + 1) begin : stacks @@ -273,7 +273,11 @@ module VX_warp_sched ( wire push = (is_split && !dont_split) && correct_warp_s; wire pop = is_join && correct_warp_j; - VX_generic_stack #(.WIDTH(1+32+`NUM_THREADS), .DEPTH($clog2(`NUM_THREADS)+1)) ipdom_stack( + + VX_generic_stack #( + .WIDTH(1+32+`NUM_THREADS), + .DEPTH($clog2(`NUM_THREADS)+1) + ) ipdom_stack( .clk (clk), .reset(reset), .push (push), @@ -308,11 +312,12 @@ module VX_warp_sched ( assign new_pc = warp_pc + 4; - assign use_active = (count_visible_active < 1) ? (warp_active & (~warp_stalled) & (~total_barrier_stall) & (~warp_lock)) : visible_active; // Choosing a warp to schedule - VX_priority_encoder choose_schedule( + VX_priority_encoder #( + .N(`NUM_WARPS) + ) choose_schedule ( .valids(use_active), .index (warp_to_schedule), .found (schedule) diff --git a/hw/rtl/libs/VX_priority_encoder.v b/hw/rtl/libs/VX_priority_encoder.v index 17b9d679..4e859f74 100644 --- a/hw/rtl/libs/VX_priority_encoder.v +++ b/hw/rtl/libs/VX_priority_encoder.v @@ -1,21 +1,28 @@ `include "VX_define.vh" -module VX_priority_encoder ( - input wire[`NUM_WARPS-1:0] valids, - output reg[`NW_BITS-1:0] index, - output reg found +module VX_priority_encoder #( + parameter N +) ( + input wire [N-1:0] valids, + output wire [`LOG2UP(N)-1:0] index, + output wire found ); + reg [`LOG2UP(N)-1:0] index_r; + reg found_r; integer i; always @(*) begin - index = 0; - found = 0; + index_r = 0; + found_r = 0; for (i = `NUM_WARPS-1; i >= 0; i = i - 1) begin if (valids[i]) begin - index = i[`NW_BITS-1:0]; - found = 1; + index_r = i[`NW_BITS-1:0]; + found_r = 1; end end end + + assign index = index_r; + assign found = found_r; endmodule \ No newline at end of file