diff --git a/hw/rtl/VX_decode.v b/hw/rtl/VX_decode.v index ee810959..d4253208 100644 --- a/hw/rtl/VX_decode.v +++ b/hw/rtl/VX_decode.v @@ -349,7 +349,7 @@ module VX_decode #( ex_type = `EX_GPU; case (func3) 3'h0: begin - op_type = `INST_OP_BITS'(`INST_GPU_TMC); + op_type = rs2[0] ? `INST_OP_BITS'(`INST_GPU_PRED) : `INST_OP_BITS'(`INST_GPU_TMC); is_wstall = 1; `USED_IREG (rs1); end diff --git a/hw/rtl/VX_define.vh b/hw/rtl/VX_define.vh index a0530688..46a6a406 100644 --- a/hw/rtl/VX_define.vh +++ b/hw/rtl/VX_define.vh @@ -185,7 +185,7 @@ `define INST_GPU_SPLIT 3'h2 `define INST_GPU_JOIN 3'h3 `define INST_GPU_BAR 3'h4 -`define INST_GPU_OTHER 3'h7 +`define INST_GPU_PRED 3'h5 `define INST_GPU_BITS 3 /////////////////////////////////////////////////////////////////////////////// diff --git a/hw/rtl/VX_gpu_unit.v b/hw/rtl/VX_gpu_unit.v index e63f8e1b..85a4db1b 100644 --- a/hw/rtl/VX_gpu_unit.v +++ b/hw/rtl/VX_gpu_unit.v @@ -29,11 +29,18 @@ module VX_gpu_unit #( wire is_tmc = (gpu_req_if.op_type == `INST_GPU_TMC); wire is_split = (gpu_req_if.op_type == `INST_GPU_SPLIT); wire is_bar = (gpu_req_if.op_type == `INST_GPU_BAR); + wire is_pred = (gpu_req_if.op_type == `INST_GPU_PRED); // tmc - assign tmc.valid = is_tmc; - assign tmc.tmask = `NUM_THREADS'(gpu_req_if.rs1_data[gpu_req_if.tid]); + wire [`NUM_THREADS-1:0] pred_cond; + for (genvar i = 0; i < `NUM_THREADS; i++) begin + assign pred_cond[i] = gpu_req_if.tmask[i] && gpu_req_if.rs1_data[i][0]; + end + wire [`NUM_THREADS-1:0] pred = (pred_cond != 0) ? pred_cond : gpu_req_if.tmask; + + assign tmc.valid = is_tmc || is_pred; + assign tmc.tmask = is_pred ? pred : `NUM_THREADS'(gpu_req_if.rs1_data[gpu_req_if.tid]); // wspawn diff --git a/hw/rtl/VX_print_instr.vh b/hw/rtl/VX_print_instr.vh index 24fc73a0..614a7d46 100644 --- a/hw/rtl/VX_print_instr.vh +++ b/hw/rtl/VX_print_instr.vh @@ -136,6 +136,7 @@ task print_ex_op ( `INST_GPU_SPLIT: dpi_trace("SPLIT"); `INST_GPU_JOIN: dpi_trace("JOIN"); `INST_GPU_BAR: dpi_trace("BAR"); + `INST_GPU_PRED: dpi_trace("PRED"); default: dpi_trace("?"); endcase end diff --git a/hw/rtl/VX_warp_sched.v b/hw/rtl/VX_warp_sched.v index 63bd7528..6c3d5e23 100644 --- a/hw/rtl/VX_warp_sched.v +++ b/hw/rtl/VX_warp_sched.v @@ -74,33 +74,33 @@ module VX_warp_sched #( active_warps[0] <= '1; thread_masks[0] <= '1; end else begin - if (warp_ctl_if.valid && warp_ctl_if.wspawn.valid) begin - use_wspawn <= warp_ctl_if.wspawn.wmask & (~`NUM_WARPS'(1)); - wspawn_pc <= warp_ctl_if.wspawn.pc; - end - - if (warp_ctl_if.valid && warp_ctl_if.barrier.valid) begin - stalled_warps[warp_ctl_if.wid] <= 0; - if (reached_barrier_limit) begin - barrier_masks[warp_ctl_if.barrier.id] <= 0; + if (warp_ctl_if.valid) begin + if (warp_ctl_if.wspawn.valid) begin + use_wspawn <= warp_ctl_if.wspawn.wmask & (~`NUM_WARPS'(1)); + wspawn_pc <= warp_ctl_if.wspawn.pc; end else begin - barrier_masks[warp_ctl_if.barrier.id][warp_ctl_if.wid] <= 1; + stalled_warps[warp_ctl_if.wid] <= 0; end - end - - if (warp_ctl_if.valid && warp_ctl_if.tmc.valid) begin - thread_masks[warp_ctl_if.wid] <= warp_ctl_if.tmc.tmask; - stalled_warps[warp_ctl_if.wid] <= 0; - end - - if (warp_ctl_if.valid && warp_ctl_if.split.valid) begin - stalled_warps[warp_ctl_if.wid] <= 0; - if (warp_ctl_if.split.diverged) begin - thread_masks[warp_ctl_if.wid] <= warp_ctl_if.split.then_tmask; + + if (warp_ctl_if.barrier.valid) begin + if (reached_barrier_limit) begin + barrier_masks[warp_ctl_if.barrier.id] <= 0; + end else begin + barrier_masks[warp_ctl_if.barrier.id][warp_ctl_if.wid] <= 1; + end + end + + if (warp_ctl_if.tmc.valid) begin + thread_masks[warp_ctl_if.wid] <= warp_ctl_if.tmc.tmask; + end + + if (warp_ctl_if.split.valid) begin + if (warp_ctl_if.split.diverged) begin + thread_masks[warp_ctl_if.wid] <= warp_ctl_if.split.then_tmask; + end end end - // Branch if (branch_ctl_if.valid) begin if (branch_ctl_if.taken) begin diff --git a/runtime/include/vx_intrinsics.h b/runtime/include/vx_intrinsics.h index 87a123ba..df07ccae 100644 --- a/runtime/include/vx_intrinsics.h +++ b/runtime/include/vx_intrinsics.h @@ -53,8 +53,13 @@ extern "C" { }) // Set thread mask -inline void vx_tmc(unsigned num_threads) { - asm volatile (".insn s 0x6b, 0, x0, 0(%0)" :: "r"(num_threads)); +inline void vx_tmc(unsigned mask) { + asm volatile (".insn s 0x6b, 0, x0, 0(%0)" :: "r"(mask)); +} + +// Set thread predicate +inline void vx_pred(unsigned condition) { + asm volatile (".insn s 0x6b, 0, x1, 0(%0)" :: "r"(condition)); } typedef void (*vx_wspawn_pfn)(); diff --git a/simX/execute.cpp b/simX/execute.cpp index f255f66e..f29ea150 100644 --- a/simX/execute.cpp +++ b/simX/execute.cpp @@ -816,10 +816,21 @@ void Warp::execute(const Instr &instr, Pipeline *pipeline) { case GPGPU: switch (func3) { case 0: { - // TMC - tmask_.reset(); - for (int i = 0; i < num_threads; ++i) { - tmask_[i] = rsdata[0] & (1 << i); + // TMC + if (rsrc1) { + // predicate mode + ThreadMask pred; + for (int i = 0; i < num_threads; ++i) { + pred[i] = tmask_[i] ? (iRegFile_[i][rsrc0] != 0) : 0; + } + if (pred.any()) { + tmask_ &= pred; + } + } else { + tmask_.reset(); + for (int i = 0; i < num_threads; ++i) { + tmask_[i] = rsdata[0] & (1 << i); + } } D(3, "*** TMC " << tmask_); active_ = tmask_.any();