vx_spawn: Add spawn_tasks_contiguous_all_stub

Spawns tasks in a way that the threads in a warp see contiguous
thread_id, unlike the original variant where each thread were allocated
a range of thread_id that spans the number of batches.

E.g. in a 4-thread config, instead of mapping IDs (0,2,4,6)->(1,3,5,7),
map (0,1,2,3)->(4,5,6,7).

TODO remaining logic not implemented.
This commit is contained in:
Hansung Kim
2024-02-27 15:46:02 -08:00
parent 2b1b5fe537
commit a2ea27b2b5

View File

@@ -74,6 +74,27 @@ static void __attribute__ ((noinline)) spawn_tasks_all_stub() {
}
}
static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_stub() {
int NT = vx_num_threads();
int NW = vx_num_warps();
int cid = vx_core_id();
int wid = vx_warp_id();
int tid = vx_thread_id();
wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)g_wspawn_args[cid];
// FIXME: handle RW
int waves = p_wspawn_args->NWs;
int offset = p_wspawn_args->offset + (NT * wid + tid);
vx_spawn_tasks_cb callback = p_wspawn_args->callback;
void* arg = p_wspawn_args->arg;
for (int wave_id = 0; wave_id < waves; ++wave_id) {
int task_id = offset + (wave_id * NT * NW);
callback(task_id, arg);
}
}
static void __attribute__ ((noinline)) spawn_tasks_rem_stub() {
int cid = vx_core_id();
int tid = vx_thread_id();
@@ -88,7 +109,8 @@ static void __attribute__ ((noinline)) spawn_tasks_all_cb() {
vx_tmc(-1);
// call stub routine
spawn_tasks_all_stub();
// spawn_tasks_all_stub();
spawn_tasks_contiguous_all_stub();
// disable warp
vx_tmc_zero();
@@ -141,7 +163,7 @@ void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
vx_tmc(-1);
// call stub routine
spawn_tasks_all_stub();
spawn_tasks_contiguous_all_stub();
// back to single-threaded
vx_tmc_one();