diff --git a/runtime/include/vx_print.h b/runtime/include/vx_print.h index c0ecd392..fde2cbe6 100644 --- a/runtime/include/vx_print.h +++ b/runtime/include/vx_print.h @@ -9,7 +9,10 @@ extern "C" { int vx_vprintf(const char* format, va_list va); int vx_printf(const char * format, ...); -int vx_putchar(int c); + +void vx_putchar(int c); +void vx_putint(int value, int base); +void vx_putfloat(float value, int precision); #ifdef __cplusplus } diff --git a/runtime/include/vx_spawn.h b/runtime/include/vx_spawn.h index 9d246e06..905f22a5 100644 --- a/runtime/include/vx_spawn.h +++ b/runtime/include/vx_spawn.h @@ -8,7 +8,7 @@ extern "C" { #endif -struct context_t { +typedef struct { uint32_t num_groups[3]; uint32_t global_offset[3]; uint32_t local_size[3]; @@ -16,11 +16,11 @@ struct context_t { uint32_t *printf_buffer_position; uint32_t printf_buffer_capacity; uint32_t work_dim; -}; +} context_t; typedef void (*vx_spawn_kernel_cb) ( const void * /* arg */, - const struct context_t * /* context */, + const context_t * /* context */, uint32_t /* group_x */, uint32_t /* group_y */, uint32_t /* group_z */ @@ -28,9 +28,9 @@ typedef void (*vx_spawn_kernel_cb) ( typedef void (*vx_spawn_tasks_cb)(int task_id, void *arg); -typedef void (*vx_serial_cb)(int task_id, void *arg); +typedef void (*vx_serial_cb)(void *arg); -void vx_spawn_kernel(struct context_t * ctx, vx_spawn_kernel_cb callback, void * arg); +void vx_spawn_kernel(context_t * ctx, vx_spawn_kernel_cb callback, void * arg); void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback, void * arg); diff --git a/runtime/src/vx_print.c b/runtime/src/vx_print.c index e3e93190..b43cdd4a 100644 --- a/runtime/src/vx_print.c +++ b/runtime/src/vx_print.c @@ -4,28 +4,37 @@ #include #include #include +#include #ifdef __cplusplus extern "C" { #endif -struct printf_arg_t { +typedef struct { const char* format; - va_list va; + va_list* va; int ret; -}; +} printf_arg_t; -static void __printf_callback(int task_id, void* arg) { - struct printf_arg_t* p_arg = (struct printf_arg_t*)(arg); - p_arg->ret = vprintf(p_arg->format, p_arg->va); +typedef struct { + int value; + int base; +} putint_arg_t; + +typedef struct { + float value; + int precision; +} putfloat_arg_t; + +static void __printf_cb(printf_arg_t* arg) { + arg->ret = vprintf(arg->format, *arg->va); } int vx_vprintf(const char* format, va_list va) { - // need to execute 'vprintf' single-threaded due to potential thread-data dependency - struct printf_arg_t arg; + printf_arg_t arg; arg.format = format; - arg.va = va; - vx_serial(__printf_callback, &arg); + arg.va = &va; + vx_serial(__printf_cb, &arg); return arg.ret; } @@ -38,6 +47,45 @@ int vx_printf(const char * format, ...) { return ret; } +static void __putint_cb(const putint_arg_t* arg) { + char tmp[33]; + float value = arg->value; + int base = arg->base; + itoa(value, tmp, base); + for (int i = 0; i < 33; ++i) { + int c = tmp[i]; + if (!c) break; + vx_putchar(c); + } +} + +void vx_putint(int value, int base) { + putint_arg_t arg; + arg.value = value; + arg.base = base; + vx_serial(__putint_cb, &arg); +} + +static void __putfloat_cb(const putfloat_arg_t* arg) { + float value = arg->value; + int precision = arg->precision; + int ipart = (int)value; + vx_putint(ipart, 10); + if (precision != 0) { + vx_putchar('.'); + float frac = value - (float)ipart; + float fscaled = frac * pow(10, precision); + vx_putint((int)fscaled, 10); + } +} + +void vx_putfloat(float value, int precision) { + putfloat_arg_t arg; + arg.value = value; + arg.precision = precision; + vx_serial(__putfloat_cb, &arg); +} + #ifdef __cplusplus } #endif \ No newline at end of file diff --git a/runtime/src/vx_spawn.S b/runtime/src/vx_spawn.S index cf9caa48..d2388ce9 100644 --- a/runtime/src/vx_spawn.S +++ b/runtime/src/vx_spawn.S @@ -1,3 +1,5 @@ +#include + .type vx_serial, @function .global vx_serial vx_serial: @@ -8,23 +10,22 @@ vx_serial: sw s2, 8(sp) sw s1, 4(sp) sw s0, 0(sp) - mv s4, a0 # callback - mv s3, a1 # arg - csrr s2, 0xfc0 # NT - csrr s1, 0xcc0 # tid - li s0, 0 # index + mv s4, a0 # s4 <- callback + mv s3, a1 # s3 <- arg + csrr s2, CSR_NT # s2 <- NT + csrr s1, CSR_WTID # s1 <- tid + li s0, 0 # s0 <- index label_loop: sub t0, s0, s1 - snez t0, t0 - .insn s 0x6b, 2, x0, 0(t0) # split t0 + seqz t1, t0 # (index != tid) + .insn s 0x6b, 2, x0, 0(t1) # split t0 bnez t0, label_join - mv a0, s0 # a0 <- index - mv a1, s3 # a1 <- arg - jalr s4 # callback(index, arg) + mv a0, s3 # a0 <- arg + jalr s4 # callback(arg) label_join: .insn s 0x6b, 3, x0, 0(x0) # join - addi s0, s0, 1 - blt s0, s2, label_loop + addi s0, s0, 1 # index++ + blt s0, s2, label_loop # loop back lw ra, 20(sp) lw s4, 16(sp) lw s3, 12(sp) diff --git a/runtime/src/vx_spawn.c b/runtime/src/vx_spawn.c index 94fd4a31..eb8be09a 100644 --- a/runtime/src/vx_spawn.c +++ b/runtime/src/vx_spawn.c @@ -20,7 +20,7 @@ typedef struct { } wspawn_tasks_args_t; typedef struct { - struct context_t * ctx; + context_t * ctx; vx_spawn_kernel_cb callback; void * arg; int offset; @@ -44,10 +44,7 @@ inline int fast_log2(int x) { return (*(int*)(&f)>>23) - 127; } -static void spawn_tasks_callback() { - // activate all threads - vx_tmc(-1); - +static void __attribute__ ((noinline)) spawn_tasks_all_stub() { int core_id = vx_core_id(); int wid = vx_warp_id(); int tid = vx_thread_id(); @@ -65,15 +62,9 @@ static void spawn_tasks_callback() { // wait for all warps to complete vx_barrier(0, p_wspawn_args->NW); - - // set warp0 to single-threaded and stop other warps - vx_tmc(0 == wid); } -void spawn_remaining_tasks_callback(int thread_mask) { - // activate threads - vx_tmc(thread_mask); - +static void __attribute__ ((noinline)) spawn_tasks_rem_stub() { int core_id = vx_core_id(); int tid = vx_thread_gid(); @@ -81,6 +72,26 @@ void spawn_remaining_tasks_callback(int thread_mask) { int task_id = p_wspawn_args->offset + tid; (p_wspawn_args->callback)(task_id, p_wspawn_args->arg); +} + +static void spawn_tasks_all_cb() { + // activate all threads + vx_tmc(-1); + + // call stub routine + spawn_tasks_all_stub(); + + // set warp0 to single-threaded and stop other warps + int wid = vx_warp_id(); + vx_tmc(0 == wid); +} + +static void spawn_tasks_rem_cb(int thread_mask) { + // activate threads + vx_tmc(thread_mask); + + // call stub routine + spawn_tasks_rem_stub(); // back to single-threaded vx_tmc(1); @@ -128,24 +139,21 @@ void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) { if (nW >= 1) { int nw = MIN(nW, NW); wspawn_args.NW = nw; - vx_wspawn(nw, spawn_tasks_callback); - spawn_tasks_callback(); + vx_wspawn(nw, spawn_tasks_all_cb); + spawn_tasks_all_cb(); } //-- if (rT != 0) { wspawn_args.offset = tasks_per_core0 - rT; int tmask = (1 << rT) - 1; - spawn_remaining_tasks_callback(tmask); + spawn_tasks_rem_cb(tmask); } } /////////////////////////////////////////////////////////////////////////////// -static void spawn_kernel_callback() { - // activate all threads - vx_tmc(-1); - +static void __attribute__ ((noinline)) spawn_kernel_all_stub() { int core_id = vx_core_id(); int wid = vx_warp_id(); int tid = vx_thread_id(); @@ -176,15 +184,9 @@ static void spawn_kernel_callback() { // wait for all warps to complete vx_barrier(0, p_wspawn_args->NW); - - // set warp0 to single-threaded and stop other warps - vx_tmc(0 == wid); } -static void spawn_kernel_remaining_callback(int thread_mask) { - // activate threads - vx_tmc(thread_mask); - +static void __attribute__ ((noinline)) spawn_kernel_rem_stub() { int core_id = vx_core_id(); int tid = vx_thread_gid(); @@ -206,12 +208,32 @@ static void spawn_kernel_remaining_callback(int thread_mask) { int gid2 = p_wspawn_args->ctx->global_offset[2] + k; (p_wspawn_args->callback)(p_wspawn_args->arg, p_wspawn_args->ctx, gid0, gid1, gid2); +} + +static void spawn_kernel_all_cb() { + // activate all threads + vx_tmc(-1); + + // call stub routine + spawn_kernel_all_stub(); + + // set warp0 to single-threaded and stop other warps + int wid = vx_warp_id(); + vx_tmc(0 == wid); +} + +static void spawn_kernel_rem_cb(int thread_mask) { + // activate threads + vx_tmc(thread_mask); + + // call stub routine + spawn_kernel_rem_stub(); // back to single-threaded vx_tmc(1); } -void vx_spawn_kernel(struct context_t * ctx, vx_spawn_kernel_cb callback, void * arg) { +void vx_spawn_kernel(context_t * ctx, vx_spawn_kernel_cb callback, void * arg) { // total number of WGs int X = ctx->num_groups[0]; int Y = ctx->num_groups[1]; @@ -268,15 +290,15 @@ void vx_spawn_kernel(struct context_t * ctx, vx_spawn_kernel_cb callback, void * if (nW >= 1) { int nw = MIN(nW, NW); wspawn_args.NW = nw; - vx_wspawn(nw, spawn_kernel_callback); - spawn_kernel_callback(); + vx_wspawn(nw, spawn_kernel_all_cb); + spawn_kernel_all_cb(); } //-- if (rT != 0) { wspawn_args.offset = wgs_per_core0 - rT; int tmask = (1 << rT) - 1; - spawn_kernel_remaining_callback(tmask); + spawn_kernel_rem_cb(tmask); } }