adding predicate instruction

This commit is contained in:
Blaise Tine
2021-08-31 03:23:59 -04:00
parent 6caf674163
commit c162ce526f
7 changed files with 57 additions and 32 deletions

View File

@@ -349,7 +349,7 @@ module VX_decode #(
ex_type = `EX_GPU; ex_type = `EX_GPU;
case (func3) case (func3)
3'h0: begin 3'h0: begin
op_type = `OP_BITS'(`GPU_TMC); op_type = `OP_BITS'(rs2 ? `GPU_PRED : `GPU_TMC);
is_wstall = 1; is_wstall = 1;
`USED_IREG (rs1); `USED_IREG (rs1);
end end

View File

@@ -185,7 +185,7 @@
`define GPU_SPLIT 3'h2 `define GPU_SPLIT 3'h2
`define GPU_JOIN 3'h3 `define GPU_JOIN 3'h3
`define GPU_BAR 3'h4 `define GPU_BAR 3'h4
`define GPU_OTHER 3'h7 `define GPU_PRED 3'h5
`define GPU_BITS 3 `define GPU_BITS 3
`define GPU_OP(x) x[`GPU_BITS-1:0] `define GPU_OP(x) x[`GPU_BITS-1:0]

View File

@@ -19,6 +19,7 @@ module VX_gpu_unit #(
`UNUSED_PARAM (CORE_ID) `UNUSED_PARAM (CORE_ID)
`UNUSED_VAR (clk) `UNUSED_VAR (clk)
`UNUSED_VAR (reset) `UNUSED_VAR (reset)
`UNUSED_VAR (gpu_req_if.op_mod)
gpu_tmc_t tmc; gpu_tmc_t tmc;
gpu_wspawn_t wspawn; gpu_wspawn_t wspawn;
@@ -29,11 +30,18 @@ module VX_gpu_unit #(
wire is_tmc = (gpu_req_if.op_type == `GPU_TMC); wire is_tmc = (gpu_req_if.op_type == `GPU_TMC);
wire is_split = (gpu_req_if.op_type == `GPU_SPLIT); wire is_split = (gpu_req_if.op_type == `GPU_SPLIT);
wire is_bar = (gpu_req_if.op_type == `GPU_BAR); wire is_bar = (gpu_req_if.op_type == `GPU_BAR);
wire is_pred = (gpu_req_if.op_type == `GPU_PRED);
// tmc // tmc
assign tmc.valid = is_tmc; wire [`NUM_THREADS-1:0] pred_cond;
assign tmc.tmask = `NUM_THREADS'(gpu_req_if.rs1_data[gpu_req_if.tid]); for (genvar i = 0; i < `NUM_THREADS; i++) begin
assign pred_cond[i] = gpu_req_if.tmask[i] && gpu_req_if.rs1_data[i][0];
end
wire [`NUM_THREADS-1:0] pred = (pred_cond != 0) ? pred_cond : gpu_req_if.tmask;
assign tmc.valid = is_tmc || is_pred;
assign tmc.tmask = is_pred ? pred : `NUM_THREADS'(gpu_req_if.rs1_data[gpu_req_if.tid]);
// wspawn // wspawn

View File

@@ -128,6 +128,7 @@ task print_ex_op (
`GPU_SPLIT: dpi_trace("SPLIT"); `GPU_SPLIT: dpi_trace("SPLIT");
`GPU_JOIN: dpi_trace("JOIN"); `GPU_JOIN: dpi_trace("JOIN");
`GPU_BAR: dpi_trace("BAR"); `GPU_BAR: dpi_trace("BAR");
`GPU_BAR: dpi_trace("PRED");
default: dpi_trace("?"); default: dpi_trace("?");
endcase endcase
end end

View File

@@ -74,13 +74,15 @@ module VX_warp_sched #(
active_warps[0] <= '1; active_warps[0] <= '1;
thread_masks[0] <= '1; thread_masks[0] <= '1;
end else begin end else begin
if (warp_ctl_if.valid && warp_ctl_if.wspawn.valid) begin if (warp_ctl_if.valid) begin
if (warp_ctl_if.wspawn.valid) begin
use_wspawn <= warp_ctl_if.wspawn.wmask & (~`NUM_WARPS'(1)); use_wspawn <= warp_ctl_if.wspawn.wmask & (~`NUM_WARPS'(1));
wspawn_pc <= warp_ctl_if.wspawn.pc; wspawn_pc <= warp_ctl_if.wspawn.pc;
end else begin
stalled_warps[warp_ctl_if.wid] <= 0;
end end
if (warp_ctl_if.valid && warp_ctl_if.barrier.valid) begin if (warp_ctl_if.barrier.valid) begin
stalled_warps[warp_ctl_if.wid] <= 0;
if (reached_barrier_limit) begin if (reached_barrier_limit) begin
barrier_masks[warp_ctl_if.barrier.id] <= 0; barrier_masks[warp_ctl_if.barrier.id] <= 0;
end else begin end else begin
@@ -88,18 +90,16 @@ module VX_warp_sched #(
end end
end end
if (warp_ctl_if.valid && warp_ctl_if.tmc.valid) begin if (warp_ctl_if.tmc.valid) begin
thread_masks[warp_ctl_if.wid] <= warp_ctl_if.tmc.tmask; thread_masks[warp_ctl_if.wid] <= warp_ctl_if.tmc.tmask;
stalled_warps[warp_ctl_if.wid] <= 0;
end end
if (warp_ctl_if.valid && warp_ctl_if.split.valid) begin if (warp_ctl_if.split.valid) begin
stalled_warps[warp_ctl_if.wid] <= 0;
if (warp_ctl_if.split.diverged) begin if (warp_ctl_if.split.diverged) begin
thread_masks[warp_ctl_if.wid] <= warp_ctl_if.split.then_tmask; thread_masks[warp_ctl_if.wid] <= warp_ctl_if.split.then_tmask;
end end
end end
end
// Branch // Branch
if (branch_ctl_if.valid) begin if (branch_ctl_if.valid) begin

View File

@@ -53,8 +53,13 @@ extern "C" {
}) })
// Set thread mask // Set thread mask
inline void vx_tmc(unsigned num_threads) { inline void vx_tmc(unsigned mask) {
asm volatile (".insn s 0x6b, 0, x0, 0(%0)" :: "r"(num_threads)); asm volatile (".insn s 0x6b, 0, x0, 0(%0)" :: "r"(mask));
}
// Set thread predicate
inline void vx_pred(unsigned condition) {
asm volatile (".insn s 0x6b, 0, x1, 0(%0)" :: "r"(condition));
} }
typedef void (*vx_wspawn_pfn)(); typedef void (*vx_wspawn_pfn)();

View File

@@ -817,10 +817,21 @@ void Warp::execute(const Instr &instr, Pipeline *pipeline) {
switch (func3) { switch (func3) {
case 0: { case 0: {
// TMC // TMC
if (rsrc1) {
// predicate mode
ThreadMask pred;
for (int i = 0; i < num_threads; ++i) {
pred[i] = tmask_[i] ? (iRegFile_[i][rsrc0] != 0) : 0;
}
if (pred.any()) {
tmask_ &= pred;
}
} else {
tmask_.reset(); tmask_.reset();
for (int i = 0; i < num_threads; ++i) { for (int i = 0; i < num_threads; ++i) {
tmask_[i] = rsdata[0] & (1 << i); tmask_[i] = rsdata[0] & (1 << i);
} }
}
D(3, "*** TMC " << tmask_); D(3, "*** TMC " << tmask_);
active_ = tmask_.any(); active_ = tmask_.any();
pipeline->stall_warp = true; pipeline->stall_warp = true;