diff --git a/hw/rtl/VX_gpu_unit.v b/hw/rtl/VX_gpu_unit.v index 85a4db1b..4d8b9168 100644 --- a/hw/rtl/VX_gpu_unit.v +++ b/hw/rtl/VX_gpu_unit.v @@ -31,23 +31,30 @@ module VX_gpu_unit #( wire is_bar = (gpu_req_if.op_type == `INST_GPU_BAR); wire is_pred = (gpu_req_if.op_type == `INST_GPU_PRED); + wire [31:0] rs1_data = gpu_req_if.rs1_data[gpu_req_if.tid]; + + wire [`NUM_THREADS-1:0] taken_tmask; + wire [`NUM_THREADS-1:0] not_taken_tmask; + + for (genvar i = 0; i < `NUM_THREADS; i++) begin + wire taken = gpu_req_if.rs1_data[i][0]; + assign taken_tmask[i] = gpu_req_if.tmask[i] & taken; + assign not_taken_tmask[i] = gpu_req_if.tmask[i] & ~taken; + end + // tmc - wire [`NUM_THREADS-1:0] pred_cond; - for (genvar i = 0; i < `NUM_THREADS; i++) begin - assign pred_cond[i] = gpu_req_if.tmask[i] && gpu_req_if.rs1_data[i][0]; - end - wire [`NUM_THREADS-1:0] pred = (pred_cond != 0) ? pred_cond : gpu_req_if.tmask; + wire [`NUM_THREADS-1:0] pred_mask = (taken_tmask != 0) ? taken_tmask : gpu_req_if.tmask; assign tmc.valid = is_tmc || is_pred; - assign tmc.tmask = is_pred ? pred : `NUM_THREADS'(gpu_req_if.rs1_data[gpu_req_if.tid]); + assign tmc.tmask = is_pred ? pred_mask : rs1_data[`NUM_THREADS-1:0]; // wspawn wire [31:0] wspawn_pc = gpu_req_if.rs2_data; wire [`NUM_WARPS-1:0] wspawn_wmask; for (genvar i = 0; i < `NUM_WARPS; i++) begin - assign wspawn_wmask[i] = (i < gpu_req_if.rs1_data[gpu_req_if.tid]); + assign wspawn_wmask[i] = (i < rs1_data); end assign wspawn.valid = is_wspawn; assign wspawn.wmask = wspawn_wmask; @@ -55,25 +62,16 @@ module VX_gpu_unit #( // split - wire [`NUM_THREADS-1:0] split_then_tmask; - wire [`NUM_THREADS-1:0] split_else_tmask; - - for (genvar i = 0; i < `NUM_THREADS; i++) begin - wire taken = gpu_req_if.rs1_data[i][0]; - assign split_then_tmask[i] = gpu_req_if.tmask[i] & taken; - assign split_else_tmask[i] = gpu_req_if.tmask[i] & ~taken; - end - assign split.valid = is_split; - assign split.diverged = (| split_then_tmask) && (| split_else_tmask); - assign split.then_tmask = split_then_tmask; - assign split.else_tmask = split_else_tmask; + assign split.diverged = (| taken_tmask) && (| not_taken_tmask); + assign split.then_tmask = taken_tmask; + assign split.else_tmask = not_taken_tmask; assign split.pc = gpu_req_if.next_PC; // barrier assign barrier.valid = is_bar; - assign barrier.id = gpu_req_if.rs1_data[gpu_req_if.tid][`NB_BITS-1:0]; + assign barrier.id = rs1_data[`NB_BITS-1:0]; assign barrier.size_m1 = (`NW_BITS)'(gpu_req_if.rs2_data - 1); // output @@ -89,7 +87,7 @@ module VX_gpu_unit #( .enable (!stall), .data_in ({gpu_req_if.valid, gpu_req_if.wid, gpu_req_if.tmask, gpu_req_if.PC, gpu_req_if.rd, gpu_req_if.wb, tmc, wspawn, split, barrier}), .data_out ({gpu_commit_if.valid, gpu_commit_if.wid, gpu_commit_if.tmask, gpu_commit_if.PC, gpu_commit_if.rd, gpu_commit_if.wb, warp_ctl_if.tmc, warp_ctl_if.wspawn, warp_ctl_if.split, warp_ctl_if.barrier}) - ); + ); assign gpu_commit_if.eop = 1'b1; diff --git a/hw/rtl/VX_warp_sched.v b/hw/rtl/VX_warp_sched.v index 6c3d5e23..79eb629a 100644 --- a/hw/rtl/VX_warp_sched.v +++ b/hw/rtl/VX_warp_sched.v @@ -74,30 +74,29 @@ module VX_warp_sched #( active_warps[0] <= '1; thread_masks[0] <= '1; end else begin - if (warp_ctl_if.valid) begin - if (warp_ctl_if.wspawn.valid) begin - use_wspawn <= warp_ctl_if.wspawn.wmask & (~`NUM_WARPS'(1)); - wspawn_pc <= warp_ctl_if.wspawn.pc; - end else begin - stalled_warps[warp_ctl_if.wid] <= 0; - end + if (warp_ctl_if.valid && warp_ctl_if.wspawn.valid) begin + use_wspawn <= warp_ctl_if.wspawn.wmask & (~`NUM_WARPS'(1)); + wspawn_pc <= warp_ctl_if.wspawn.pc; + end - if (warp_ctl_if.barrier.valid) begin - if (reached_barrier_limit) begin - barrier_masks[warp_ctl_if.barrier.id] <= 0; - end else begin - barrier_masks[warp_ctl_if.barrier.id][warp_ctl_if.wid] <= 1; - end + if (warp_ctl_if.valid && warp_ctl_if.barrier.valid) begin + stalled_warps[warp_ctl_if.wid] <= 0; + if (reached_barrier_limit) begin + barrier_masks[warp_ctl_if.barrier.id] <= 0; + end else begin + barrier_masks[warp_ctl_if.barrier.id][warp_ctl_if.wid] <= 1; end - - if (warp_ctl_if.tmc.valid) begin - thread_masks[warp_ctl_if.wid] <= warp_ctl_if.tmc.tmask; - end - - if (warp_ctl_if.split.valid) begin - if (warp_ctl_if.split.diverged) begin - thread_masks[warp_ctl_if.wid] <= warp_ctl_if.split.then_tmask; - end + end + + if (warp_ctl_if.valid && warp_ctl_if.tmc.valid) begin + thread_masks[warp_ctl_if.wid] <= warp_ctl_if.tmc.tmask; + stalled_warps[warp_ctl_if.wid] <= 0; + end + + if (warp_ctl_if.valid && warp_ctl_if.split.valid) begin + stalled_warps[warp_ctl_if.wid] <= 0; + if (warp_ctl_if.split.diverged) begin + thread_masks[warp_ctl_if.wid] <= warp_ctl_if.split.then_tmask; end end diff --git a/runtime/include/vx_intrinsics.h b/runtime/include/vx_intrinsics.h index df07ccae..a1869318 100644 --- a/runtime/include/vx_intrinsics.h +++ b/runtime/include/vx_intrinsics.h @@ -53,8 +53,8 @@ extern "C" { }) // Set thread mask -inline void vx_tmc(unsigned mask) { - asm volatile (".insn s 0x6b, 0, x0, 0(%0)" :: "r"(mask)); +inline void vx_tmc(unsigned thread_mask) { + asm volatile (".insn s 0x6b, 0, x0, 0(%0)" :: "r"(thread_mask)); } // Set thread predicate