Add WU architecture kernel cases
This commit is contained in:
26
kernels/wu_arch/Makefile
Normal file
26
kernels/wu_arch/Makefile
Normal file
@@ -0,0 +1,26 @@
|
||||
PROJECT = wu_arch
|
||||
|
||||
VX_SRCS = kernel.cpp
|
||||
|
||||
OPTS ?= -n1
|
||||
|
||||
WU_VARIANT_DUMPS = \
|
||||
kernel.radiance.barriers.dump
|
||||
|
||||
all: kernel.radiance.dump $(WU_VARIANT_DUMPS)
|
||||
|
||||
include ../common.mk
|
||||
|
||||
kernel.radiance.barriers.dump: kernel.radiance.barriers.elf
|
||||
$(VX_DP) -D $< > $@
|
||||
|
||||
kernel.radiance.barriers.elf: $(VX_SRCS) $(VX_INCLUDES) $(BINFILES)
|
||||
$(VX_CXX) $(VX_CFLAGS) $(VX_SRCS) $(VX_LDFLAGS) -DRADIANCE -DWU_RUN_DOMAIN_BARRIERS -o $@
|
||||
$(OBJCOPY) --set-section-flags .operand.a=$(OBJCOPY_FLAGS) $@
|
||||
$(OBJCOPY) --set-section-flags .operand.b=$(OBJCOPY_FLAGS) $@
|
||||
$(OBJCOPY) --set-section-flags .operand.c=$(OBJCOPY_FLAGS) $@
|
||||
$(OBJCOPY) --set-section-flags .args=$(OBJCOPY_FLAGS) $@
|
||||
$(OBJCOPY) --update-section .operand.a=input.a.bin $@ || true
|
||||
$(OBJCOPY) --update-section .operand.b=input.b.bin $@ || true
|
||||
$(OBJCOPY) --update-section .operand.c=input.c.bin $@ || true
|
||||
$(OBJCOPY) --update-section .args=args.bin $@ || true
|
||||
1
kernels/wu_arch/args.bin
Normal file
1
kernels/wu_arch/args.bin
Normal file
@@ -0,0 +1 @@
|
||||
0
|
||||
1
kernels/wu_arch/input.a.bin
Normal file
1
kernels/wu_arch/input.a.bin
Normal file
@@ -0,0 +1 @@
|
||||
0
|
||||
1
kernels/wu_arch/input.b.bin
Normal file
1
kernels/wu_arch/input.b.bin
Normal file
@@ -0,0 +1 @@
|
||||
0
|
||||
1
kernels/wu_arch/input.c.bin
Normal file
1
kernels/wu_arch/input.c.bin
Normal file
@@ -0,0 +1 @@
|
||||
0
|
||||
173
kernels/wu_arch/kernel.cpp
Normal file
173
kernels/wu_arch/kernel.cpp
Normal file
@@ -0,0 +1,173 @@
|
||||
#include <stdint.h>
|
||||
#include <vx_intrinsics.h>
|
||||
|
||||
#define DEV_SMEM_START_ADDR 0xff000000u
|
||||
#define MAX_WARPS 8
|
||||
#define MINIMAL_INIT_WORDS 4
|
||||
#define WU_SCALAR_SPIN 32
|
||||
#define WU_TENSOR_SPIN 32
|
||||
#define WU_WAIT_SPIN 8192
|
||||
#define WU_STATUS_DONE 0x600du
|
||||
#define WU_STATUS_SCALAR_BASE 0x5100u
|
||||
#define WU_STATUS_TENSOR_BASE 0x7100u
|
||||
#define WU_BARRIER_SCALAR 0u
|
||||
#define WU_BARRIER_MASKED 1u
|
||||
#define WU_BARRIER_TENSOR 2u
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_status[MAX_WARPS] __attribute__((aligned(32)));
|
||||
volatile uint32_t g_scalar_seen[MAX_WARPS] __attribute__((aligned(32)));
|
||||
volatile uint32_t g_tensor_seen[MAX_WARPS] __attribute__((aligned(32)));
|
||||
volatile uint32_t g_spin_sink[MAX_WARPS] __attribute__((aligned(32)));
|
||||
extern volatile uint64_t tohost;
|
||||
}
|
||||
|
||||
extern "C" void vx_perf_dump() {}
|
||||
|
||||
static inline void wu_report_tohost(uint32_t exit_code) {
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
tohost = (static_cast<uint64_t>(exit_code) << 1) | 1u;
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main();
|
||||
|
||||
extern "C" void __attribute__((naked, section(".init"), used)) _start() {
|
||||
asm volatile(
|
||||
".option push\n\t"
|
||||
".option norelax\n\t"
|
||||
"la gp, __global_pointer\n\t"
|
||||
".option pop\n\t"
|
||||
"csrr t0, %[csr_core]\n\t"
|
||||
"bnez t0, 2f\n\t"
|
||||
"li sp, %[stack_base]\n\t"
|
||||
"call wu_main\n\t"
|
||||
"mv gp, a0\n\t"
|
||||
"2:\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"1: j 1b\n\t"
|
||||
:
|
||||
: [csr_core] "i"(VX_CSR_CORE_ID),
|
||||
[stack_base] "i"(STACK_BASE_ADDR),
|
||||
[custom0] "i"(RISCV_CUSTOM0)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" void scalar_worker() {
|
||||
const uint32_t wid = static_cast<uint32_t>(vx_warp_id());
|
||||
const uint32_t tid = static_cast<uint32_t>(vx_thread_id());
|
||||
volatile uint32_t mix = wid + 1u;
|
||||
|
||||
#ifdef WU_RUN_DOMAIN_BARRIERS
|
||||
vx_barrier_scalar(WU_BARRIER_SCALAR, NUM_SCALAR_WARPS);
|
||||
vx_barrier_mask(WU_BARRIER_MASKED, vx_scalar_warp_mask());
|
||||
#endif
|
||||
|
||||
for (uint32_t i = 0; i < WU_SCALAR_SPIN; ++i)
|
||||
mix = (mix << 1) ^ (i + wid);
|
||||
|
||||
if (tid == 0 && wid < MAX_WARPS) {
|
||||
g_spin_sink[wid] = mix;
|
||||
g_scalar_seen[wid] = WU_STATUS_SCALAR_BASE | wid;
|
||||
}
|
||||
|
||||
vx_tmc_zero();
|
||||
while (1) {}
|
||||
}
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_worker() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
#ifdef WU_RUN_DOMAIN_BARRIERS
|
||||
"li x1, (%[bar_tensor] | (%[domain_tensor] << %[domain_shift]))\n\t"
|
||||
"li x2, %[num_tensor]\n\t"
|
||||
".insn r %[custom0], 4, 0, x0, x1, x2\n\t"
|
||||
#endif
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"li x7, %[tensor_spin]\n\t"
|
||||
"1:\n\t"
|
||||
"addi x7, x7, -1\n\t"
|
||||
"bnez x7, 1b\n\t"
|
||||
"slli x5, x5, 2\n\t"
|
||||
"la x6, g_tensor_seen\n\t"
|
||||
"add x6, x6, x5\n\t"
|
||||
"li x7, %[tensor_base]\n\t"
|
||||
"srli x5, x5, 2\n\t"
|
||||
"or x7, x7, x5\n\t"
|
||||
"sw x7, 0(x6)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"2: j 2b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID),
|
||||
[custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[bar_tensor] "i"(WU_BARRIER_TENSOR),
|
||||
[domain_tensor] "i"(VX_BARRIER_DOMAIN_TENSOR),
|
||||
[domain_shift] "i"(VX_BARRIER_DOMAIN_SHIFT),
|
||||
[num_tensor] "i"(NUM_TENSOR_WARPS),
|
||||
[tensor_spin] "i"(WU_TENSOR_SPIN),
|
||||
[tensor_base] "i"(WU_STATUS_TENSOR_BASE)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
static void init_state() {
|
||||
g_status[0] = 0;
|
||||
for (uint32_t i = 0; i < MAX_WARPS; ++i) {
|
||||
g_scalar_seen[i] = 0;
|
||||
g_tensor_seen[i] = 0;
|
||||
}
|
||||
|
||||
volatile uint32_t *smem =
|
||||
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
|
||||
for (uint32_t i = 0; i < MINIMAL_INIT_WORDS; ++i)
|
||||
smem[i] = 0x100u + i;
|
||||
}
|
||||
|
||||
static int wait_for_wu_completion() {
|
||||
for (uint32_t spin = 0; spin < WU_WAIT_SPIN; ++spin) {
|
||||
uint32_t done = 1;
|
||||
for (uint32_t wid = 0; wid < NUM_SCALAR_WARPS; ++wid)
|
||||
done &= (g_scalar_seen[wid] == (WU_STATUS_SCALAR_BASE | wid));
|
||||
for (uint32_t wid = NUM_SCALAR_WARPS; wid < NUM_WARPS; ++wid)
|
||||
done &= (g_tensor_seen[wid] == (WU_STATUS_TENSOR_BASE | wid));
|
||||
if (done)
|
||||
return 0;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0)
|
||||
return 0;
|
||||
if (vx_thread_id() != 0)
|
||||
return 0;
|
||||
|
||||
init_state();
|
||||
|
||||
const uint32_t other_scalar_warps = vx_scalar_warp_mask() & ~1u;
|
||||
if (other_scalar_warps != 0)
|
||||
vx_spawn_scalar(other_scalar_warps, scalar_worker);
|
||||
|
||||
vx_spawn_tensor(vx_tensor_warp_mask(), tensor_worker);
|
||||
|
||||
#ifdef WU_RUN_DOMAIN_BARRIERS
|
||||
vx_barrier_scalar(WU_BARRIER_SCALAR, NUM_SCALAR_WARPS);
|
||||
vx_barrier_mask(WU_BARRIER_MASKED, vx_scalar_warp_mask());
|
||||
#endif
|
||||
|
||||
volatile uint32_t mix = 1;
|
||||
for (uint32_t i = 0; i < WU_SCALAR_SPIN; ++i)
|
||||
mix = (mix << 1) ^ (i + 3u);
|
||||
g_spin_sink[0] = mix;
|
||||
g_scalar_seen[0] = WU_STATUS_SCALAR_BASE;
|
||||
|
||||
if (wait_for_wu_completion() != 0) {
|
||||
g_status[0] = 0xe001u;
|
||||
wu_report_tohost(1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
g_status[0] = WU_STATUS_DONE;
|
||||
wu_report_tohost(0);
|
||||
return 0;
|
||||
}
|
||||
9
kernels/wu_arch_hgemm/Makefile
Normal file
9
kernels/wu_arch_hgemm/Makefile
Normal file
@@ -0,0 +1,9 @@
|
||||
PROJECT = wu_arch_hgemm
|
||||
|
||||
VX_SRCS = kernel.cpp
|
||||
OPTS ?= -n1
|
||||
|
||||
include ../common.mk
|
||||
|
||||
args.bin input.a.bin input.b.bin input.c.bin: ../wu_arch_cases/zero.bin
|
||||
cp $< $@
|
||||
8
kernels/wu_arch_hgemm/README.md
Normal file
8
kernels/wu_arch_hgemm/README.md
Normal file
@@ -0,0 +1,8 @@
|
||||
# wu_arch_hgemm
|
||||
|
||||
Tensor-warp HGEMM smoke test for the Wu split scalar/tensor warp configuration.
|
||||
|
||||
Scalar warp 0 initializes the shared-memory B operand, spawns only the tensor
|
||||
warp mask, waits for tensor warps `NUM_SCALAR_WARPS..NUM_WARPS-1`, and reports
|
||||
completion through `tohost`. Tensor warps execute the Blackwell custom HGEMM
|
||||
instruction sequence and then stop themselves.
|
||||
1
kernels/wu_arch_hgemm/args.bin
Normal file
1
kernels/wu_arch_hgemm/args.bin
Normal file
@@ -0,0 +1 @@
|
||||
0
|
||||
1
kernels/wu_arch_hgemm/input.a.bin
Normal file
1
kernels/wu_arch_hgemm/input.a.bin
Normal file
@@ -0,0 +1 @@
|
||||
0
|
||||
1
kernels/wu_arch_hgemm/input.b.bin
Normal file
1
kernels/wu_arch_hgemm/input.b.bin
Normal file
@@ -0,0 +1 @@
|
||||
0
|
||||
1
kernels/wu_arch_hgemm/input.c.bin
Normal file
1
kernels/wu_arch_hgemm/input.c.bin
Normal file
@@ -0,0 +1 @@
|
||||
0
|
||||
87
kernels/wu_arch_hgemm/kernel.cpp
Normal file
87
kernels/wu_arch_hgemm/kernel.cpp
Normal file
@@ -0,0 +1,87 @@
|
||||
#include "../wu_arch_cases/common_wu_min.h"
|
||||
|
||||
#define DEV_SMEM_START_ADDR 0xff000000u
|
||||
#define WU_CASE_TENSOR_HGEMM_BASE 0x7500u
|
||||
|
||||
#define BW_REP2(x) x, x
|
||||
#define BW_REP4(x) BW_REP2(x), BW_REP2(x)
|
||||
#define BW_REP8(x) BW_REP4(x), BW_REP4(x)
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_hgemm_a_row[8] __attribute__((aligned(32))) = {
|
||||
BW_REP8(0x3c003c00u)};
|
||||
volatile uint32_t g_hgemm_b_row[8] __attribute__((aligned(32))) = {
|
||||
BW_REP8(0x40004000u)};
|
||||
volatile uint32_t g_hgemm_c_row[8] __attribute__((aligned(32))) = {
|
||||
BW_REP8(0x3f800000u)};
|
||||
}
|
||||
|
||||
#undef BW_REP2
|
||||
#undef BW_REP4
|
||||
#undef BW_REP8
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_hgemm_worker() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x1, x5, 11\n\t"
|
||||
"addi x2, x1, 1024\n\t"
|
||||
"la x6, g_hgemm_a_row\n\t"
|
||||
"la x3, g_hgemm_c_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"1:\n\t"
|
||||
"add x4, x1, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x6\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 32\n\t"
|
||||
"li x4, 1024\n\t"
|
||||
"blt x7, x4, 1b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[hgemm_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"2: j 2b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID),
|
||||
[custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[smem_base] "i"(DEV_SMEM_START_ADDR),
|
||||
[hgemm_base] "i"(WU_CASE_TENSOR_HGEMM_BASE)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (!wu_is_leader()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
wu_case_reset();
|
||||
|
||||
volatile uint32_t *smem_b =
|
||||
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
|
||||
for (uint32_t frag = 0; frag < 32u; ++frag) {
|
||||
const uint32_t row = frag * 8u;
|
||||
for (uint32_t i = 0; i < 8u; ++i) {
|
||||
smem_b[row + i] = g_hgemm_b_row[i];
|
||||
}
|
||||
}
|
||||
|
||||
vx_spawn_tensor(vx_tensor_warp_mask(), tensor_hgemm_worker);
|
||||
|
||||
if (wu_wait_seen_range(NUM_SCALAR_WARPS, NUM_WARPS,
|
||||
WU_CASE_TENSOR_HGEMM_BASE) != 0) {
|
||||
wu_case_fail(0x09u);
|
||||
return 1;
|
||||
}
|
||||
|
||||
wu_case_pass();
|
||||
return 0;
|
||||
}
|
||||
Reference in New Issue
Block a user