vx_spawn.c: Create separate vx_spawn_tasks_contiguous

This commit is contained in:
Hansung Kim
2024-03-27 15:38:52 -07:00
parent fa6adceb7e
commit 870846f20f
2 changed files with 104 additions and 2 deletions

View File

@@ -53,6 +53,7 @@ void vx_spawn_kernel(context_t * ctx, vx_spawn_kernel_cb callback, void * arg);
void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback, void * arg);
void vx_spawn_tasks_cluster(int num_tasks, vx_spawn_tasks_cb callback, void * arg);
void vx_spawn_tasks_contiguous(int num_tasks, vx_spawn_tasks_cb callback , void * arg);
void vx_serial(vx_serial_cb callback, void * arg);

View File

@@ -74,6 +74,26 @@ static void __attribute__ ((noinline)) spawn_tasks_all_stub() {
}
}
static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_stub() {
int NT = vx_num_threads();
int NW = vx_num_warps();
int cid = vx_core_id();
int wid = vx_warp_id();
int tid = vx_thread_id();
wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)g_wspawn_args[cid];
int waves = p_wspawn_args->NWs + (wid < p_wspawn_args->RWs);
int offset = p_wspawn_args->offset + (NT * wid + tid);
vx_spawn_tasks_cb callback = p_wspawn_args->callback;
void* arg = p_wspawn_args->arg;
for (int wave_id = 0; wave_id < waves; ++wave_id) {
int task_id = offset + (wave_id * NT * NW);
callback(task_id, arg);
}
}
static void __attribute__ ((noinline)) spawn_tasks_cluster_all_stub() {
int NT = vx_num_threads();
int NW = vx_num_warps();
@@ -87,7 +107,6 @@ static void __attribute__ ((noinline)) spawn_tasks_cluster_all_stub() {
wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)g_wspawn_args[cid];
// FIXME: handle RW
int waves = p_wspawn_args->NWs + (wid < p_wspawn_args->RWs);
int offset = p_wspawn_args->offset + (NT * wid_in_cluster + tid);
@@ -129,12 +148,22 @@ static void __attribute__ ((noinline)) spawn_tasks_cluster_rem_stub() {
(p_wspawn_args->callback)(task_id, p_wspawn_args->arg);
}
static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_cb() {
// activate all threads
vx_tmc(-1);
// call stub routine
spawn_tasks_contiguous_all_stub();
// disable warp
vx_tmc_zero();
}
static void __attribute__ ((noinline)) spawn_tasks_cluster_all_cb() {
// activate all threads
vx_tmc(-1);
// call stub routine
// spawn_tasks_all_stub();
spawn_tasks_cluster_all_stub();
// disable warp
@@ -243,6 +272,78 @@ void vx_spawn_tasks_cluster(int num_tasks, vx_spawn_tasks_cb callback, void *arg
}
}
void vx_spawn_tasks_contiguous(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
// device specs
int NC = vx_num_cores();
int NW = vx_num_warps();
int NT = vx_num_threads();
// current core id
int core_id = vx_core_id();
if (core_id >= NUM_CORES_MAX)
return;
// calculate necessary active cores
int WT = NW * NT;
int nC = (num_tasks > WT) ? (num_tasks / WT) : 1;
int nc = MIN(nC, NC);
if (core_id >= nc)
return; // terminate extra cores
// number of tasks per core
int tasks_per_core = num_tasks / nc;
int tasks_per_core_n1 = tasks_per_core;
if (core_id == (nc-1)) {
int rem = num_tasks - (nc * tasks_per_core);
tasks_per_core_n1 += rem; // last core also executes remaining tasks
}
// number of tasks per warp
int TW = tasks_per_core_n1 / NT; // occupied warps
int rT = tasks_per_core_n1 - TW * NT; // remaining threads
int fW = 1, rW = 0;
if (TW >= NW) {
fW = TW / NW; // full warps iterations
rW = TW - fW * NW; // remaining warps
}
wspawn_tasks_args_t wspawn_args = { callback, arg, core_id * tasks_per_core, fW, rW };
g_wspawn_args[core_id] = &wspawn_args;
if (TW >= 1) {
// execute callback on other warps
int nw = MIN(TW, NW);
vx_wspawn(nw, spawn_tasks_contiguous_all_cb);
// activate all threads
vx_tmc(-1);
// call stub routine
spawn_tasks_contiguous_all_stub();
// back to single-threaded
vx_tmc_one();
// wait for spawn warps to terminate
vx_wspawn_wait();
}
if (rT != 0) {
// adjust offset
wspawn_args.offset += (tasks_per_core_n1 - rT);
// activate remaining threads
int tmask = (1 << rT) - 1;
vx_tmc(tmask);
// call stub routine
spawn_tasks_rem_stub();
// back to single-threaded
vx_tmc_one();
}
}
void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
// device specs
int NC = vx_num_cores();