From a2ea27b2b522bcd3e45e18d8d67a201fb71aa204 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Tue, 27 Feb 2024 15:46:02 -0800 Subject: [PATCH] vx_spawn: Add spawn_tasks_contiguous_all_stub Spawns tasks in a way that the threads in a warp see contiguous thread_id, unlike the original variant where each thread were allocated a range of thread_id that spans the number of batches. E.g. in a 4-thread config, instead of mapping IDs (0,2,4,6)->(1,3,5,7), map (0,1,2,3)->(4,5,6,7). TODO remaining logic not implemented. --- kernel/src/vx_spawn.c | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/kernel/src/vx_spawn.c b/kernel/src/vx_spawn.c index fd8258e1..eb0bdb90 100644 --- a/kernel/src/vx_spawn.c +++ b/kernel/src/vx_spawn.c @@ -74,6 +74,27 @@ static void __attribute__ ((noinline)) spawn_tasks_all_stub() { } } +static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_stub() { + int NT = vx_num_threads(); + int NW = vx_num_warps(); + int cid = vx_core_id(); + int wid = vx_warp_id(); + int tid = vx_thread_id(); + + wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)g_wspawn_args[cid]; + + // FIXME: handle RW + int waves = p_wspawn_args->NWs; + int offset = p_wspawn_args->offset + (NT * wid + tid); + + vx_spawn_tasks_cb callback = p_wspawn_args->callback; + void* arg = p_wspawn_args->arg; + for (int wave_id = 0; wave_id < waves; ++wave_id) { + int task_id = offset + (wave_id * NT * NW); + callback(task_id, arg); + } +} + static void __attribute__ ((noinline)) spawn_tasks_rem_stub() { int cid = vx_core_id(); int tid = vx_thread_id(); @@ -88,7 +109,8 @@ static void __attribute__ ((noinline)) spawn_tasks_all_cb() { vx_tmc(-1); // call stub routine - spawn_tasks_all_stub(); + // spawn_tasks_all_stub(); + spawn_tasks_contiguous_all_stub(); // disable warp vx_tmc_zero(); @@ -141,7 +163,7 @@ void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) { vx_tmc(-1); // call stub routine - spawn_tasks_all_stub(); + spawn_tasks_contiguous_all_stub(); // back to single-threaded vx_tmc_one();