diff --git a/kernel/include/vx_spawn.h b/kernel/include/vx_spawn.h index 321e3f83..06a85af7 100644 --- a/kernel/include/vx_spawn.h +++ b/kernel/include/vx_spawn.h @@ -50,6 +50,7 @@ void vx_wspawn_wait(); void vx_spawn_kernel(context_t * ctx, vx_spawn_kernel_cb callback, void * arg); void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback, void * arg); +void vx_spawn_tasks_cluster(int num_tasks, vx_spawn_tasks_cb callback, void * arg); void vx_serial(vx_serial_cb callback, void * arg); diff --git a/kernel/src/vx_spawn.c b/kernel/src/vx_spawn.c index c57e55f2..04b58253 100644 --- a/kernel/src/vx_spawn.c +++ b/kernel/src/vx_spawn.c @@ -140,6 +140,87 @@ static void __attribute__ ((noinline)) spawn_tasks_all_cb() { vx_tmc_zero(); } +void vx_spawn_tasks_cluster(int num_tasks, vx_spawn_tasks_cb callback, void *arg) { + // device specs + int NC = vx_num_cores(); + int NW = vx_num_warps(); + int NT = vx_num_threads(); + + // current core id + int core_id = vx_core_id(); + if (core_id >= NUM_CORES_MAX) + return; + + // Distribute threads equally across as many cores as possible, even if they + // don't fill up NW*NT in a single core. This makes sure the warps get evenly + // distributed in a single cluster + // + // TODO: Try to contain in a single cluster if possible? + int num_active_cores = (num_tasks > NT) ? (num_tasks / NT) : 1; + num_active_cores = MIN(num_active_cores, NC); + if (core_id >= num_active_cores) + return; // terminate extra cores + + int tasks_per_core = num_tasks / num_active_cores; + int tasks_per_core_last = tasks_per_core; + if (core_id == (num_active_cores - 1)) { + int rem = num_tasks % num_active_cores; + tasks_per_core_last += rem; // last core also executes remaining tasks + } + + int num_full_warps = tasks_per_core_last / NT; + int rem_threads_in_last_warp = tasks_per_core_last % NT; + // sequential iterations + int num_full_waves = 1; + int rem_warps_in_last_wave = 0; + if (num_full_warps >= NW) { + // this division will result in the same value for both the last core and + // the rest + num_full_waves = num_full_warps / NW; + rem_warps_in_last_wave = num_full_warps % NW; + } + + int cluster_id = core_id / CORES_PER_CLUSTER; + const int tasks_per_cluster = tasks_per_core * CORES_PER_CLUSTER; + const int offset = cluster_id * tasks_per_cluster; + wspawn_tasks_args_t wspawn_args = {callback, arg, offset, num_full_waves, rem_warps_in_last_wave}; + g_wspawn_args[core_id] = &wspawn_args; + + if (num_full_warps >= 1) { + // execute callback on other warps + int nw = MIN(num_full_warps, NW); + vx_wspawn(nw, spawn_tasks_all_cb); + + // activate all threads + vx_tmc(-1); + + // call stub routine + spawn_tasks_cluster_all_stub(); + + // back to single-threaded + vx_tmc_one(); + + // wait for spawn warps to terminate + vx_wspawn_wait(); + } + + if (rem_threads_in_last_warp != 0) { + // adjust offset + wspawn_args.offset += (tasks_per_core_last - rem_threads_in_last_warp); + + // activate remaining threads + int tmask = (1 << rem_threads_in_last_warp) - 1; + vx_tmc(tmask); + + // call stub routine + // FIXME: unimplemented for cluster! + spawn_tasks_rem_stub(); + + // back to single-threaded + vx_tmc_one(); + } +} + void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) { // device specs int NC = vx_num_cores();