software updaet for new thread mask design
This commit is contained in:
@@ -48,20 +48,20 @@ module VX_gpu_unit #(
|
|||||||
|
|
||||||
// split
|
// split
|
||||||
|
|
||||||
wire [`NUM_THREADS-1:0] split_then_mask;
|
wire [`NUM_THREADS-1:0] split_then_tmask;
|
||||||
wire [`NUM_THREADS-1:0] split_else_mask;
|
wire [`NUM_THREADS-1:0] split_else_tmask;
|
||||||
|
|
||||||
for (genvar i = 0; i < `NUM_THREADS; i++) begin
|
for (genvar i = 0; i < `NUM_THREADS; i++) begin
|
||||||
wire taken = gpu_req_if.rs1_data[i][gpu_req_if.tid];
|
wire taken = gpu_req_if.rs1_data[i][0];
|
||||||
assign split_then_mask[i] = gpu_req_if.tmask[i] & taken;
|
assign split_then_tmask[i] = gpu_req_if.tmask[i] & taken;
|
||||||
assign split_else_mask[i] = gpu_req_if.tmask[i] & ~taken;
|
assign split_else_tmask[i] = gpu_req_if.tmask[i] & ~taken;
|
||||||
end
|
end
|
||||||
|
|
||||||
assign split.valid = is_split;
|
assign split.valid = is_split;
|
||||||
assign split.diverged = (| split_then_mask) && (| split_else_mask);
|
assign split.diverged = (| split_then_tmask) && (| split_else_tmask);
|
||||||
assign split.then_mask = split_then_mask;
|
assign split.then_tmask = split_then_tmask;
|
||||||
assign split.else_mask = split_else_mask;
|
assign split.else_tmask = split_else_tmask;
|
||||||
assign split.pc = gpu_req_if.next_PC;
|
assign split.pc = gpu_req_if.next_PC;
|
||||||
|
|
||||||
// barrier
|
// barrier
|
||||||
|
|
||||||
|
|||||||
@@ -41,8 +41,8 @@ typedef struct packed {
|
|||||||
typedef struct packed {
|
typedef struct packed {
|
||||||
logic valid;
|
logic valid;
|
||||||
logic diverged;
|
logic diverged;
|
||||||
logic [`NUM_THREADS-1:0] then_mask;
|
logic [`NUM_THREADS-1:0] then_tmask;
|
||||||
logic [`NUM_THREADS-1:0] else_mask;
|
logic [`NUM_THREADS-1:0] else_tmask;
|
||||||
logic [31:0] pc;
|
logic [31:0] pc;
|
||||||
} gpu_split_t;
|
} gpu_split_t;
|
||||||
|
|
||||||
|
|||||||
@@ -49,20 +49,22 @@ module VX_warp_sched #(
|
|||||||
|
|
||||||
wire ifetch_rsp_fire = ifetch_rsp_if.valid && ifetch_rsp_if.ready;
|
wire ifetch_rsp_fire = ifetch_rsp_if.valid && ifetch_rsp_if.ready;
|
||||||
|
|
||||||
|
wire tmc_active = (warp_ctl_if.tmc.tmask != 0);
|
||||||
|
|
||||||
always @(*) begin
|
always @(*) begin
|
||||||
active_warps_n = active_warps;
|
active_warps_n = active_warps;
|
||||||
if (warp_ctl_if.valid && warp_ctl_if.wspawn.valid) begin
|
if (warp_ctl_if.valid && warp_ctl_if.wspawn.valid) begin
|
||||||
active_warps_n = warp_ctl_if.wspawn.wmask;
|
active_warps_n = warp_ctl_if.wspawn.wmask;
|
||||||
end
|
end
|
||||||
if (warp_ctl_if.valid && warp_ctl_if.tmc.valid) begin
|
if (warp_ctl_if.valid && warp_ctl_if.tmc.valid) begin
|
||||||
active_warps_n[warp_ctl_if.wid] = (warp_ctl_if.tmc.tmask != 0);
|
active_warps_n[warp_ctl_if.wid] = tmc_active;
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
always @(*) begin
|
always @(*) begin
|
||||||
schedule_table_n = schedule_table;
|
schedule_table_n = schedule_table;
|
||||||
if (warp_ctl_if.valid && warp_ctl_if.tmc.valid) begin
|
if (warp_ctl_if.valid && warp_ctl_if.tmc.valid) begin
|
||||||
schedule_table_n[warp_ctl_if.wid] = (warp_ctl_if.tmc.tmask != 0);
|
schedule_table_n[warp_ctl_if.wid] = tmc_active;
|
||||||
end
|
end
|
||||||
if (warp_scheduled) begin // remove scheduled warp (round-robin)
|
if (warp_scheduled) begin // remove scheduled warp (round-robin)
|
||||||
schedule_table_n[scheduled_warp] = 0;
|
schedule_table_n[scheduled_warp] = 0;
|
||||||
@@ -104,12 +106,12 @@ module VX_warp_sched #(
|
|||||||
barrier_stall_mask[warp_ctl_if.barrier.id][warp_ctl_if.wid] <= 1;
|
barrier_stall_mask[warp_ctl_if.barrier.id][warp_ctl_if.wid] <= 1;
|
||||||
end
|
end
|
||||||
end else if (warp_ctl_if.valid && warp_ctl_if.tmc.valid) begin
|
end else if (warp_ctl_if.valid && warp_ctl_if.tmc.valid) begin
|
||||||
thread_masks[warp_ctl_if.wid] <= warp_ctl_if.tmc.tmask;
|
thread_masks[warp_ctl_if.wid] <= warp_ctl_if.tmc.tmask;
|
||||||
stalled_warps[warp_ctl_if.wid] <= 0;
|
stalled_warps[warp_ctl_if.wid] <= 0;
|
||||||
end else if (warp_ctl_if.valid && warp_ctl_if.split.valid) begin
|
end else if (warp_ctl_if.valid && warp_ctl_if.split.valid) begin
|
||||||
stalled_warps[warp_ctl_if.wid] <= 0;
|
stalled_warps[warp_ctl_if.wid] <= 0;
|
||||||
if (warp_ctl_if.split.diverged) begin
|
if (warp_ctl_if.split.diverged) begin
|
||||||
thread_masks[warp_ctl_if.wid] <= warp_ctl_if.split.then_mask;
|
thread_masks[warp_ctl_if.wid] <= warp_ctl_if.split.then_tmask;
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -179,6 +181,8 @@ module VX_warp_sched #(
|
|||||||
|
|
||||||
wire [(1+32+`NUM_THREADS-1):0] ipdom [`NUM_WARPS-1:0];
|
wire [(1+32+`NUM_THREADS-1):0] ipdom [`NUM_WARPS-1:0];
|
||||||
|
|
||||||
|
wire [`NUM_THREADS-1:0] curr_tmask = thread_masks[warp_ctl_if.wid];
|
||||||
|
|
||||||
for (genvar i = 0; i < `NUM_WARPS; i++) begin
|
for (genvar i = 0; i < `NUM_WARPS; i++) begin
|
||||||
wire push = warp_ctl_if.valid
|
wire push = warp_ctl_if.valid
|
||||||
&& warp_ctl_if.split.valid
|
&& warp_ctl_if.split.valid
|
||||||
@@ -186,9 +190,9 @@ module VX_warp_sched #(
|
|||||||
|
|
||||||
wire pop = join_if.valid && (i == join_if.wid);
|
wire pop = join_if.valid && (i == join_if.wid);
|
||||||
|
|
||||||
wire [`NUM_THREADS-1:0] else_mask = warp_ctl_if.split.diverged ? warp_ctl_if.split.else_mask : thread_masks[warp_ctl_if.wid];
|
wire [`NUM_THREADS-1:0] else_tmask = warp_ctl_if.split.diverged ? warp_ctl_if.split.else_tmask : curr_tmask;
|
||||||
wire [(1+32+`NUM_THREADS-1):0] q_end = {1'b0, 32'b0, thread_masks[warp_ctl_if.wid]};
|
wire [(1+32+`NUM_THREADS-1):0] q_end = {1'b0, 32'b0, curr_tmask};
|
||||||
wire [(1+32+`NUM_THREADS-1):0] q_else = {1'b1, warp_ctl_if.split.pc, else_mask};
|
wire [(1+32+`NUM_THREADS-1):0] q_else = {1'b1, warp_ctl_if.split.pc, else_tmask};
|
||||||
|
|
||||||
VX_ipdom_stack #(
|
VX_ipdom_stack #(
|
||||||
.WIDTH (1+32+`NUM_THREADS),
|
.WIDTH (1+32+`NUM_THREADS),
|
||||||
|
|||||||
@@ -43,7 +43,8 @@ inline int fast_log2(int x) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static void spawn_tasks_callback() {
|
static void spawn_tasks_callback() {
|
||||||
vx_tmc(vx_num_threads());
|
// activate all threads
|
||||||
|
vx_tmc(-1);
|
||||||
|
|
||||||
int core_id = vx_core_id();
|
int core_id = vx_core_id();
|
||||||
int wid = vx_warp_id();
|
int wid = vx_warp_id();
|
||||||
@@ -60,11 +61,13 @@ static void spawn_tasks_callback() {
|
|||||||
(p_wspawn_args->callback)(task_id, p_wspawn_args->arg);
|
(p_wspawn_args->callback)(task_id, p_wspawn_args->arg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// set warp0 to single-threaded and stop other warps
|
||||||
vx_tmc(0 == wid);
|
vx_tmc(0 == wid);
|
||||||
}
|
}
|
||||||
|
|
||||||
void spawn_remaining_tasks_callback(int nthreads) {
|
void spawn_remaining_tasks_callback(int thread_mask) {
|
||||||
vx_tmc(nthreads);
|
// activate threads
|
||||||
|
vx_tmc(thread_mask);
|
||||||
|
|
||||||
int core_id = vx_core_id();
|
int core_id = vx_core_id();
|
||||||
int tid = vx_thread_gid();
|
int tid = vx_thread_gid();
|
||||||
@@ -74,6 +77,7 @@ void spawn_remaining_tasks_callback(int nthreads) {
|
|||||||
int task_id = p_wspawn_args->offset + tid;
|
int task_id = p_wspawn_args->offset + tid;
|
||||||
(p_wspawn_args->callback)(task_id, p_wspawn_args->arg);
|
(p_wspawn_args->callback)(task_id, p_wspawn_args->arg);
|
||||||
|
|
||||||
|
// back to single-threaded
|
||||||
vx_tmc(1);
|
vx_tmc(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -132,7 +136,8 @@ void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
|
|||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
static void spawn_kernel_callback() {
|
static void spawn_kernel_callback() {
|
||||||
vx_tmc(vx_num_threads());
|
// activate all threads
|
||||||
|
vx_tmc(-1);
|
||||||
|
|
||||||
int core_id = vx_core_id();
|
int core_id = vx_core_id();
|
||||||
int wid = vx_warp_id();
|
int wid = vx_warp_id();
|
||||||
@@ -162,11 +167,13 @@ static void spawn_kernel_callback() {
|
|||||||
(p_wspawn_args->callback)(p_wspawn_args->arg, p_wspawn_args->ctx, gid0, gid1, gid2);
|
(p_wspawn_args->callback)(p_wspawn_args->arg, p_wspawn_args->ctx, gid0, gid1, gid2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// set warp0 to single-threaded and stop other warps
|
||||||
vx_tmc(0 == wid);
|
vx_tmc(0 == wid);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void spawn_kernel_remaining_callback(int nthreads) {
|
static void spawn_kernel_remaining_callback(int thread_mask) {
|
||||||
vx_tmc(nthreads);
|
// activate threads
|
||||||
|
vx_tmc(thread_mask);
|
||||||
|
|
||||||
int core_id = vx_core_id();
|
int core_id = vx_core_id();
|
||||||
int tid = vx_thread_gid();
|
int tid = vx_thread_gid();
|
||||||
@@ -190,6 +197,7 @@ static void spawn_kernel_remaining_callback(int nthreads) {
|
|||||||
|
|
||||||
(p_wspawn_args->callback)(p_wspawn_args->arg, p_wspawn_args->ctx, gid0, gid1, gid2);
|
(p_wspawn_args->callback)(p_wspawn_args->arg, p_wspawn_args->ctx, gid0, gid1, gid2);
|
||||||
|
|
||||||
|
// back to single-threaded
|
||||||
vx_tmc(1);
|
vx_tmc(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ label_exit_next:
|
|||||||
.global vx_set_sp
|
.global vx_set_sp
|
||||||
vx_set_sp:
|
vx_set_sp:
|
||||||
# activate all threads
|
# activate all threads
|
||||||
csrr a0, CSR_NT # get num threads
|
li a0, -1
|
||||||
.insn s 0x6b, 0, x0, 0(a0) # tmc a0
|
.insn s 0x6b, 0, x0, 0(a0) # tmc a0
|
||||||
|
|
||||||
# set global pointer register
|
# set global pointer register
|
||||||
|
|||||||
@@ -267,6 +267,9 @@ Word Core::get_csr(Addr addr, int tid, int wid) {
|
|||||||
} else if (addr == CSR_GCID) {
|
} else if (addr == CSR_GCID) {
|
||||||
// Processor coreID
|
// Processor coreID
|
||||||
return id_;
|
return id_;
|
||||||
|
} else if (addr == CSR_TMASK) {
|
||||||
|
// Processor coreID
|
||||||
|
return warps_.at(wid)->getTmask();
|
||||||
} else if (addr == CSR_NT) {
|
} else if (addr == CSR_NT) {
|
||||||
// Number of threads per warp
|
// Number of threads per warp
|
||||||
return arch_.num_threads();
|
return arch_.num_threads();
|
||||||
|
|||||||
@@ -817,10 +817,9 @@ void Warp::execute(const Instr &instr, Pipeline *pipeline) {
|
|||||||
switch (func3) {
|
switch (func3) {
|
||||||
case 0: {
|
case 0: {
|
||||||
// TMC
|
// TMC
|
||||||
int active_threads = std::min<int>(rsdata[0], num_threads);
|
|
||||||
tmask_.reset();
|
tmask_.reset();
|
||||||
for (int i = 0; i < active_threads; ++i) {
|
for (size_t i = 0; i < tmask_.size(); ++i) {
|
||||||
tmask_[i] = true;
|
tmask_[i] = rsdata[0] & (1 << i);
|
||||||
}
|
}
|
||||||
active_ = tmask_.any();
|
active_ = tmask_.any();
|
||||||
pipeline->stall_warp = true;
|
pipeline->stall_warp = true;
|
||||||
|
|||||||
@@ -74,6 +74,12 @@ public:
|
|||||||
active_ = tmask_.any();
|
active_ = tmask_.any();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Word getTmask() const {
|
||||||
|
if (active_)
|
||||||
|
return tmask_.to_ulong();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
Word getIRegValue(int reg) const {
|
Word getIRegValue(int reg) const {
|
||||||
return iRegFile_[0][reg];
|
return iRegFile_[0][reg];
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Binary file not shown.
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
Binary file not shown.
File diff suppressed because it is too large
Load Diff
Binary file not shown.
File diff suppressed because it is too large
Load Diff
Binary file not shown.
File diff suppressed because it is too large
Load Diff
Binary file not shown.
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user