Update Wu architecture kernel implementations and runtime library
This commit is contained in:
@@ -84,15 +84,32 @@
|
||||
#endif
|
||||
|
||||
#ifndef NUM_CORES
|
||||
#define NUM_CORES 8
|
||||
#define NUM_CORES 1
|
||||
#endif
|
||||
|
||||
#ifndef NUM_WARPS
|
||||
#define NUM_WARPS 8
|
||||
#define NUM_WARPS 4
|
||||
#endif
|
||||
|
||||
#ifndef NUM_TENSOR_WARPS
|
||||
#define NUM_TENSOR_WARPS 2
|
||||
#endif
|
||||
|
||||
#define NUM_SCALAR_WARPS (NUM_WARPS - NUM_TENSOR_WARPS)
|
||||
|
||||
#define IS_SCALAR_WARP(wid) ((wid) < NUM_SCALAR_WARPS)
|
||||
#define IS_TENSOR_WARP(wid) ((wid) >= NUM_SCALAR_WARPS)
|
||||
|
||||
#ifndef TENSOR_NUM_GPRS
|
||||
#define TENSOR_NUM_GPRS 8
|
||||
#endif
|
||||
|
||||
#ifndef TENSOR_NUM_FPRS
|
||||
#define TENSOR_NUM_FPRS 8
|
||||
#endif
|
||||
|
||||
#ifndef NUM_THREADS
|
||||
#define NUM_THREADS 8
|
||||
#define NUM_THREADS 4
|
||||
#endif
|
||||
|
||||
#ifndef NUM_BARRIERS
|
||||
@@ -682,4 +699,3 @@
|
||||
#define IMPLEMENTATION_ID 0
|
||||
|
||||
#endif // VX_CONFIG_VH
|
||||
|
||||
|
||||
@@ -136,6 +136,19 @@ inline void vx_wspawn(unsigned num_warps, vx_wspawn_pfn func_ptr) {
|
||||
asm volatile (".insn r %0, 1, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(num_warps), "r"(func_ptr));
|
||||
}
|
||||
|
||||
// Spawn an explicit warp mask. The current warp bit is ignored by hardware.
|
||||
inline void vx_wspawn_mask(unsigned warp_mask, vx_wspawn_pfn func_ptr) {
|
||||
asm volatile (".insn r %0, 6, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(warp_mask), "r"(func_ptr));
|
||||
}
|
||||
|
||||
inline void vx_spawn_scalar(unsigned warp_mask, vx_wspawn_pfn func_ptr) {
|
||||
vx_wspawn_mask(warp_mask & ((1u << NUM_SCALAR_WARPS) - 1u), func_ptr);
|
||||
}
|
||||
|
||||
inline void vx_spawn_tensor(unsigned warp_mask, vx_wspawn_pfn func_ptr) {
|
||||
vx_wspawn_mask(warp_mask & (((1u << NUM_TENSOR_WARPS) - 1u) << NUM_SCALAR_WARPS), func_ptr);
|
||||
}
|
||||
|
||||
// Split on a predicate
|
||||
inline unsigned vx_split(unsigned predicate) {
|
||||
unsigned ret;
|
||||
@@ -151,7 +164,34 @@ inline void vx_join(unsigned stack_ptr) {
|
||||
// Warp Barrier
|
||||
__attribute__((convergent))
|
||||
inline void vx_barrier(unsigned barried_id, unsigned num_warps) {
|
||||
asm volatile (".insn r %0, 4, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(barried_id), "r"(num_warps));
|
||||
unsigned scalar_warps = (num_warps > NUM_SCALAR_WARPS) ? NUM_SCALAR_WARPS : num_warps;
|
||||
asm volatile (".insn r %0, 4, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(barried_id), "r"(scalar_warps));
|
||||
}
|
||||
|
||||
#define VX_BARRIER_DOMAIN_SHIFT 28
|
||||
#define VX_BARRIER_DOMAIN_ALL 0u
|
||||
#define VX_BARRIER_DOMAIN_SCALAR 1u
|
||||
#define VX_BARRIER_DOMAIN_TENSOR 2u
|
||||
|
||||
__attribute__((convergent))
|
||||
inline void vx_barrier_domain(unsigned barrier_id, unsigned num_warps, unsigned domain) {
|
||||
unsigned encoded_id = barrier_id | (domain << VX_BARRIER_DOMAIN_SHIFT);
|
||||
asm volatile (".insn r %0, 4, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(encoded_id), "r"(num_warps));
|
||||
}
|
||||
|
||||
__attribute__((convergent))
|
||||
inline void vx_barrier_scalar(unsigned barrier_id, unsigned num_warps) {
|
||||
vx_barrier_domain(barrier_id, num_warps, VX_BARRIER_DOMAIN_SCALAR);
|
||||
}
|
||||
|
||||
__attribute__((convergent))
|
||||
inline void vx_barrier_tensor(unsigned barrier_id, unsigned num_warps) {
|
||||
vx_barrier_domain(barrier_id, num_warps, VX_BARRIER_DOMAIN_TENSOR);
|
||||
}
|
||||
|
||||
__attribute__((convergent))
|
||||
inline void vx_barrier_mask(unsigned barrier_id, unsigned warp_mask) {
|
||||
asm volatile (".insn r %0, 7, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(barrier_id), "r"(warp_mask));
|
||||
}
|
||||
|
||||
// Return current thread identifier
|
||||
@@ -203,6 +243,22 @@ inline int vx_num_warps() {
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline int vx_num_scalar_warps() {
|
||||
return NUM_SCALAR_WARPS;
|
||||
}
|
||||
|
||||
inline int vx_num_tensor_warps() {
|
||||
return NUM_TENSOR_WARPS;
|
||||
}
|
||||
|
||||
inline unsigned vx_scalar_warp_mask() {
|
||||
return (1u << NUM_SCALAR_WARPS) - 1u;
|
||||
}
|
||||
|
||||
inline unsigned vx_tensor_warp_mask() {
|
||||
return ((1u << NUM_TENSOR_WARPS) - 1u) << NUM_SCALAR_WARPS;
|
||||
}
|
||||
|
||||
// Return the number of cores per cluster
|
||||
inline int vx_num_cores() {
|
||||
int ret;
|
||||
|
||||
Reference in New Issue
Block a user