Merge branch 'kernels' into tensor_core
This commit is contained in:
@@ -51,10 +51,10 @@ $(PROJECT).dump: $(PROJECT).a
|
||||
%.S.o: src/%.S
|
||||
$(CC) $(CFLAGS) -c $< -o $@
|
||||
|
||||
%.cpp.o: src/%.cpp
|
||||
%.cpp.o: src/%.cpp include/vx_spawn.h
|
||||
$(CXX) $(CFLAGS) -c $< -o $@
|
||||
|
||||
%.c.o: src/%.c
|
||||
%.c.o: src/%.c include/vx_spawn.h
|
||||
$(CC) $(CFLAGS) -c $< -o $@
|
||||
|
||||
$(PROJECT).a: $(OBJS)
|
||||
|
||||
164
kernel/include/gemmini_mmio.h
Normal file
164
kernel/include/gemmini_mmio.h
Normal file
@@ -0,0 +1,164 @@
|
||||
#ifndef GEMMINI_MMIO_H
|
||||
#define GEMMINI_MMIO_H
|
||||
#ifndef GEMMINI_PARAMS_H
|
||||
#error INCLUDE GEMMINI.H FIRST
|
||||
#endif
|
||||
|
||||
#define SMEM_BASE 0xff000000
|
||||
#define SMEM_SIZE 0x4000
|
||||
#define SMEM_MASK (SMEM_SIZE - 1)
|
||||
#define SMEM_ADDR_END 0xff008000
|
||||
|
||||
#define SPAD_BASE 0x0
|
||||
#define SPAD_ROW_SIZE (DIM * sizeof(elem_t))
|
||||
#define SPAD_NUM_ROWS (SMEM_SIZE / SPAD_ROW_SIZE)
|
||||
#define SPAD_MASK (SPAD_NUM_ROWS - 1)
|
||||
|
||||
#define PRINT_BUF ((char *) (SMEM_ADDR_END))
|
||||
#define GEMMINI_RS1_ADDR 0xff007010
|
||||
#define GEMMINI_RS2_ADDR 0xff007018
|
||||
#define GEMMINI_INST_ADDR 0xff007000
|
||||
#define GEMMINI_BUSY_ADDR 0xff007020
|
||||
|
||||
#define SMEM_TO_SPAD(smem_addr) (SPAD_BASE + ((smem_addr) & SMEM_MASK) / SPAD_ROW_SIZE)
|
||||
#define SPAD_TO_SMEM(spad_addr) (SMEM_BASE + ((spad_addr) & SPAD_MASK) * SPAD_ROW_SIZE)
|
||||
|
||||
// convert normal matrix i,j into tiled smem offset
|
||||
// top_in_tiles = i / DIM
|
||||
// left_in_tiles = j / DIM
|
||||
// num_tiles_before_current = top_in_tiles * (J / DIM) + left_in_tiles
|
||||
// smem_addr = num_tiles_before_current * DIM * DIM + (i % DIM) * DIM + (j % DIM)
|
||||
#define SMEM_MAT_OFFSET(i, j, J) \
|
||||
(((i) / DIM * (J) / DIM + (j) / DIM) * DIM * DIM + ((i) % DIM) * DIM + ((j) % DIM))
|
||||
|
||||
// #define fence() { for (int i = 0; i < 10; i++) *((volatile uint32_t *) (0xFFFF0000)) = 0xdeadbeef; }
|
||||
#undef gemmini_fence
|
||||
#define gemmini_fence() { while (*((volatile uint32_t *) GEMMINI_BUSY_ADDR)) asm volatile ("nop"); }
|
||||
|
||||
#undef ROCC_INSTRUCTION_RS1_RS2
|
||||
#define ROCC_INSTRUCTION_RS1_RS2(x, rs1, rs2, funct) { \
|
||||
/* printf("function %d\n", funct); */ \
|
||||
*((volatile uint64_t *) GEMMINI_RS1_ADDR) = (rs1); \
|
||||
*((volatile uint64_t *) GEMMINI_RS2_ADDR) = (rs2); \
|
||||
/* *((volatile uint32_t*) GEMMINI_RS2_ADDR) = (uint32_t) ((uint64_t) (rs2) & 0xFFFFFFFFULL); */ \
|
||||
/* *((volatile uint32_t*) (GEMMINI_RS2_ADDR + 4)) = (uint32_t) ((uint64_t) (rs2) >> 32); */ \
|
||||
/* gemmini_fence(); */ \
|
||||
*((volatile uint32_t*) GEMMINI_INST_ADDR) = (0x7B) | (0 << 7) | (3 << 12) | (1 << 15) | (2 << 20) | ((funct) << 25); \
|
||||
/* sprintf((char *) PRINT_BUF, "%llx %llx %d\n", rs1, rs2, funct); */ \
|
||||
}
|
||||
|
||||
#define sp_tiled_matmul_full_spad_ws(A_sp_addr_start, B_sp_addr_start, D_sp_addr_start, C_dst_sp_addr_start,\
|
||||
I, J, K, pad_I, pad_J, pad_K, a_transpose, b_transpose, full_C, low_D, acc, act, skips) \
|
||||
gemmini_loop_ws_spad(I, J, K, pad_I, pad_J, pad_K, A_sp_addr_start, (B_sp_addr_start) + (K) * (J) * DIM, NULL, \
|
||||
C_dst_sp_addr_start, a_transpose, b_transpose, full_C, low_D, acc, act, 0, 0, false, skips)
|
||||
|
||||
/* inline static void sp_tiled_matmul_full_spad_ws(const uint32_t A_sp_addr_start, const uint32_t B_sp_addr_start,
|
||||
const uint32_t D_sp_addr_start, const uint32_t C_dst_sp_addr_start,
|
||||
size_t I, size_t J, size_t K, size_t pad_I, size_t pad_J, size_t pad_K,
|
||||
bool a_transpose, bool b_transpose,
|
||||
bool full_C, bool low_D, bool acc,
|
||||
int act, int skip_mvout) {
|
||||
|
||||
gemmini_loop_ws_spad(I, J, K, pad_I, pad_J, pad_K,
|
||||
A_sp_addr_start, B_sp_addr_start + K * J * DIM, NULL, C_dst_sp_addr_start,
|
||||
a_transpose, b_transpose,
|
||||
full_C, low_D, acc,
|
||||
act, 0, 0, false, skip_mvout); */
|
||||
/*
|
||||
return;
|
||||
|
||||
|
||||
// const uint32_t A_sp_addr_start = 0;
|
||||
// const uint32_t B_sp_addr_start = BANK_NUM * BANK_ROWS - K * J * DIM;
|
||||
// const uint32_t D_sp_addr_start = 1 << (ADDR_LEN-1);
|
||||
const uint32_t C_sp_addr_start = 2 << (ADDR_LEN-2) | (full_C << (ADDR_LEN-3));
|
||||
// const int D_blocks = low_D ? (J <= MAX_BLOCK_LEN ? J : MAX_BLOCK_LEN) :
|
||||
// (J <= MAX_BLOCK_LEN_ACC ? J : MAX_BLOCK_LEN_ACC);
|
||||
const int C_blocks = 1; //full_C ? 1 : (J <= MAX_BLOCK_LEN ? J : MAX_BLOCK_LEN);
|
||||
// const size_t sizeof_D = low_D ? sizeof(elem_t) : sizeof(acc_t);
|
||||
const size_t sizeof_C = full_C ? sizeof(acc_t) : sizeof(elem_t);
|
||||
gemmini_fence();
|
||||
|
||||
if (a_transpose || b_transpose || (I < 4)) {
|
||||
for (size_t k = 0; k < K; k++) {
|
||||
for (size_t j = 0; j < J; j++) {
|
||||
for (size_t i = 0; i < I; i++) {
|
||||
const uint32_t A_sp_addr = a_transpose ? (A_sp_addr_start + (k*I + i)*DIM) :
|
||||
(A_sp_addr_start + (i*K + k)*DIM);
|
||||
const uint32_t B_sp_addr = b_transpose ? (B_sp_addr_start + (j*K + k)*DIM) :
|
||||
(B_sp_addr_start + (k*J + j)*DIM);
|
||||
const uint32_t C_sp_addr = C_sp_addr_start + (i*J + j)*DIM;
|
||||
// Compute
|
||||
uint32_t pre_sp_addr = i == 0 ? B_sp_addr : GARBAGE_ADDR;
|
||||
uint32_t out_sp_addr = C_sp_addr | ((k == 0 ? 0 : 1) << (ADDR_LEN-2));
|
||||
gemmini_extended_preload(pre_sp_addr, out_sp_addr, DIM, DIM, DIM, DIM);
|
||||
if (i == 0) { // First iteration
|
||||
gemmini_extended_compute_preloaded(A_sp_addr, GARBAGE_ADDR, DIM, DIM, DIM, DIM);
|
||||
} else { // All other iterations
|
||||
gemmini_extended_compute_accumulated(A_sp_addr, GARBAGE_ADDR, DIM, DIM, DIM, DIM);
|
||||
}
|
||||
if (k == K - 1) {
|
||||
// Move-out C (if not normalizing)
|
||||
// if (((act != LAYERNORM) && (act != SOFTMAX)) && (j == J-1 || j % C_blocks == C_blocks-1)) {
|
||||
const size_t rounded_j = j; // (j / C_blocks) * C_blocks;
|
||||
const uint32_t rounded_C_sp_addr = C_sp_addr; // C_sp_addr_start + (i*J + rounded_j)*DIM;
|
||||
|
||||
const uint32_t C_dst_sp_addr = ((uint32_t) C_dst_sp_addr_start) + (i * J + rounded_j) * DIM; // * DIM * sizeof_C;
|
||||
|
||||
// const size_t blocks = rounded_j + C_blocks <= J ? C_blocks : J-rounded_j;
|
||||
constexpr size_t cols = DIM; // blocks * DIM - (rounded_j + blocks >= J ? pad_J : 0);
|
||||
constexpr size_t rows = DIM; // DIM - (i == I - 1 ? pad_I : 0);
|
||||
|
||||
gemmini_extended_mvout_spad(C_dst_sp_addr, 1, rounded_C_sp_addr, cols, rows);
|
||||
// }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t k = 0; k < K; k++) {
|
||||
for (size_t j = 0; j < J; j++) {
|
||||
uint32_t A_sp_addr = A_sp_addr_start + k * DIM; // (i*K + k)*DIM;
|
||||
const uint32_t B_sp_addr = B_sp_addr_start + (k*J + j)*DIM;
|
||||
uint32_t C_sp_addr = C_sp_addr_start + j * DIM; // (i*J + j)*DIM;
|
||||
for (size_t i = 0; i < I; i += 4) {
|
||||
// Compute
|
||||
// constexpr uint32_t pre_sp_addr = i == 0 ? B_sp_addr : GARBAGE_ADDR;
|
||||
const uint32_t out_sp_addr = C_sp_addr | ((k == 0 ? 0 : 1) << (ADDR_LEN-2));
|
||||
if (i == 0) { // First iteration
|
||||
gemmini_extended_preload(B_sp_addr, out_sp_addr, DIM, DIM, DIM, DIM);
|
||||
gemmini_extended_compute_preloaded(A_sp_addr, GARBAGE_ADDR, DIM, DIM, DIM, DIM);
|
||||
gemmini_extended_preload(GARBAGE_ADDR, out_sp_addr + J * DIM, DIM, DIM, DIM, DIM);
|
||||
gemmini_extended_compute_accumulated(A_sp_addr + K * DIM, GARBAGE_ADDR, DIM, DIM, DIM, DIM);
|
||||
gemmini_extended_preload(GARBAGE_ADDR, out_sp_addr + 2 * J * DIM, DIM, DIM, DIM, DIM);
|
||||
gemmini_extended_compute_accumulated(A_sp_addr + 2 * K * DIM, GARBAGE_ADDR, DIM, DIM, DIM, DIM);
|
||||
gemmini_extended_preload(GARBAGE_ADDR, out_sp_addr + 3 * J * DIM, DIM, DIM, DIM, DIM);
|
||||
gemmini_extended_compute_accumulated(A_sp_addr + 3 * K * DIM, GARBAGE_ADDR, DIM, DIM, DIM, DIM);
|
||||
} else { // All other iterations
|
||||
gemmini_extended_preload(GARBAGE_ADDR, out_sp_addr, DIM, DIM, DIM, DIM);
|
||||
gemmini_extended_compute_accumulated(A_sp_addr, GARBAGE_ADDR, DIM, DIM, DIM, DIM);
|
||||
gemmini_extended_preload(GARBAGE_ADDR, out_sp_addr + J * DIM, DIM, DIM, DIM, DIM);
|
||||
gemmini_extended_compute_accumulated(A_sp_addr + K * DIM, GARBAGE_ADDR, DIM, DIM, DIM, DIM);
|
||||
gemmini_extended_preload(GARBAGE_ADDR, out_sp_addr + 2 * J * DIM, DIM, DIM, DIM, DIM);
|
||||
gemmini_extended_compute_accumulated(A_sp_addr + 2 * K * DIM, GARBAGE_ADDR, DIM, DIM, DIM, DIM);
|
||||
gemmini_extended_preload(GARBAGE_ADDR, out_sp_addr + 3 * J * DIM, DIM, DIM, DIM, DIM);
|
||||
gemmini_extended_compute_accumulated(A_sp_addr + 3 * K * DIM, GARBAGE_ADDR, DIM, DIM, DIM, DIM);
|
||||
}
|
||||
if (k == K - 1) {
|
||||
for (int x = 0; x < 3; x++) gemmini_fence();
|
||||
gemmini_extended_mvout_spad((uint32_t) C_dst_sp_addr_start + (i * J + j) * DIM, 1, C_sp_addr, DIM, DIM);
|
||||
gemmini_extended_mvout_spad((uint32_t) C_dst_sp_addr_start + ((i + 1) * J + j) * DIM, 1, C_sp_addr + J * DIM, DIM, DIM);
|
||||
gemmini_extended_mvout_spad((uint32_t) C_dst_sp_addr_start + ((i + 2) * J + j) * DIM, 1, C_sp_addr + 2 * J * DIM, DIM, DIM);
|
||||
gemmini_extended_mvout_spad((uint32_t) C_dst_sp_addr_start + ((i + 3) * J + j) * DIM, 1, C_sp_addr + 3 * J * DIM, DIM, DIM);
|
||||
}
|
||||
A_sp_addr += 4 * K * DIM;
|
||||
C_sp_addr += 4 * J * DIM;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
gemmini_fence();
|
||||
}*/
|
||||
|
||||
|
||||
#endif
|
||||
@@ -17,6 +17,10 @@
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#ifndef CORES_PER_CLUSTER
|
||||
#define CORES_PER_CLUSTER 2
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
@@ -48,6 +52,7 @@ void vx_wspawn_wait();
|
||||
void vx_spawn_kernel(context_t * ctx, vx_spawn_kernel_cb callback, void * arg);
|
||||
|
||||
void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback, void * arg);
|
||||
void vx_spawn_tasks_cluster(int num_tasks, vx_spawn_tasks_cb callback, void * arg);
|
||||
void vx_spawn_tasks_contiguous(int num_tasks, vx_spawn_tasks_cb callback , void * arg);
|
||||
|
||||
void vx_serial(vx_serial_cb callback, void * arg);
|
||||
|
||||
@@ -7,6 +7,13 @@ OUTPUT_FORMAT("elf32-littleriscv", "elf32-littleriscv",
|
||||
"elf32-littleriscv")
|
||||
OUTPUT_ARCH(riscv)
|
||||
ENTRY(_start)
|
||||
|
||||
MEMORY {
|
||||
DRAM0 (rwx): ORIGIN = 0x80000000, LENGTH = 512M
|
||||
DRAM1 (rwx): ORIGIN = 0xa0000000, LENGTH = 32K
|
||||
DRAM2 (rwx): ORIGIN = 0xa1000000, LENGTH = 32K
|
||||
}
|
||||
|
||||
SECTIONS
|
||||
{
|
||||
. = STARTUP_ADDR;
|
||||
@@ -85,6 +92,7 @@ SECTIONS
|
||||
/* Adjust the address for the data segment. We want to adjust up to
|
||||
the same address within the page on the next page up. */
|
||||
. = DATA_SEGMENT_ALIGN (CONSTANT (MAXPAGESIZE), CONSTANT (COMMONPAGESIZE));
|
||||
|
||||
/* Exception handling */
|
||||
.eh_frame : ONLY_IF_RW { KEEP (*(.eh_frame)) *(.eh_frame.*) }
|
||||
.gnu_extab : ONLY_IF_RW { *(.gnu_extab) }
|
||||
@@ -166,6 +174,7 @@ SECTIONS
|
||||
*(.data .data.* .gnu.linkonce.d.*)
|
||||
SORT(CONSTRUCTORS)
|
||||
}
|
||||
|
||||
.data1 : { *(.data1) }
|
||||
.got : { *(.got.plt) *(.igot.plt) *(.got) *(.igot) }
|
||||
/* We want the small data sections together, so single-instruction offsets
|
||||
@@ -200,6 +209,7 @@ SECTIONS
|
||||
}
|
||||
. = ALIGN(32 / 8);
|
||||
. = SEGMENT_START("ldata-segment", .);
|
||||
|
||||
. = ALIGN(32 / 8);
|
||||
__BSS_END__ = .;
|
||||
__global_pointer = MIN(__SDATA_BEGIN__ + 0x800,
|
||||
@@ -249,4 +259,12 @@ SECTIONS
|
||||
.gnu.attributes 0 : { KEEP (*(.gnu.attributes)) }
|
||||
/DISCARD/ : { *(.note.GNU-stack) *(.gnu_debuglink) *(.gnu.lto_*) }
|
||||
|
||||
.operand.a : {
|
||||
*(.operand.a)
|
||||
. += 32K;
|
||||
}> DRAM1
|
||||
.operand.b : {
|
||||
*(.operand.b)
|
||||
. += 32K;
|
||||
}> DRAM2
|
||||
}
|
||||
|
||||
@@ -74,15 +74,6 @@ static void __attribute__ ((noinline)) spawn_tasks_all_stub() {
|
||||
}
|
||||
}
|
||||
|
||||
static void __attribute__ ((noinline)) spawn_tasks_rem_stub() {
|
||||
int cid = vx_core_id();
|
||||
int tid = vx_thread_id();
|
||||
|
||||
wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)g_wspawn_args[cid];
|
||||
int task_id = p_wspawn_args->offset + tid;
|
||||
(p_wspawn_args->callback)(task_id, p_wspawn_args->arg);
|
||||
}
|
||||
|
||||
static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_stub() {
|
||||
int NT = vx_num_threads();
|
||||
int NW = vx_num_warps();
|
||||
@@ -103,6 +94,60 @@ static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_stub() {
|
||||
}
|
||||
}
|
||||
|
||||
static void __attribute__ ((noinline)) spawn_tasks_cluster_all_stub() {
|
||||
int NT = vx_num_threads();
|
||||
int NW = vx_num_warps();
|
||||
int cid = vx_core_id();
|
||||
int wid = vx_warp_id();
|
||||
int tid = vx_thread_id();
|
||||
|
||||
const int core_id_in_cluster = cid % CORES_PER_CLUSTER;
|
||||
// round-robin warp_id allocation across cores in cluster
|
||||
const int wid_in_cluster = CORES_PER_CLUSTER * wid + core_id_in_cluster;
|
||||
|
||||
wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)g_wspawn_args[cid];
|
||||
|
||||
int waves = p_wspawn_args->NWs + (wid < p_wspawn_args->RWs);
|
||||
int offset = p_wspawn_args->offset + (NT * wid_in_cluster + tid);
|
||||
|
||||
vx_spawn_tasks_cb callback = p_wspawn_args->callback;
|
||||
void* arg = p_wspawn_args->arg;
|
||||
|
||||
// sequential iterations
|
||||
for (int wave_id = 0; wave_id < waves; ++wave_id) {
|
||||
int task_id = offset + (wave_id * NT * NW * CORES_PER_CLUSTER);
|
||||
callback(task_id, arg);
|
||||
}
|
||||
}
|
||||
|
||||
static void __attribute__ ((noinline)) spawn_tasks_rem_stub() {
|
||||
int cid = vx_core_id();
|
||||
int tid = vx_thread_id();
|
||||
|
||||
wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)g_wspawn_args[cid];
|
||||
int task_id = p_wspawn_args->offset + tid;
|
||||
(p_wspawn_args->callback)(task_id, p_wspawn_args->arg);
|
||||
}
|
||||
|
||||
static void __attribute__ ((noinline)) spawn_tasks_cluster_rem_stub() {
|
||||
int NT = vx_num_threads();
|
||||
int cid = vx_core_id();
|
||||
int tid = vx_thread_id();
|
||||
int wid = vx_warp_id();
|
||||
|
||||
const int core_id_in_cluster = cid % CORES_PER_CLUSTER;
|
||||
// round-robin warp_id allocation across cores in cluster
|
||||
const int wid_in_cluster = CORES_PER_CLUSTER * wid + core_id_in_cluster;
|
||||
|
||||
wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)g_wspawn_args[cid];
|
||||
// FIXME: This assumes that all cores but the last one are working with full
|
||||
// warps, and only the last core has a partially-filled warp.
|
||||
int offset = p_wspawn_args->offset + (NT * wid_in_cluster + tid);
|
||||
|
||||
int task_id = offset;
|
||||
(p_wspawn_args->callback)(task_id, p_wspawn_args->arg);
|
||||
}
|
||||
|
||||
static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_cb() {
|
||||
// activate all threads
|
||||
vx_tmc(-1);
|
||||
@@ -111,11 +156,21 @@ static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_cb() {
|
||||
spawn_tasks_contiguous_all_stub();
|
||||
|
||||
// disable warp
|
||||
// deadlock here on warps 1, 2, 3
|
||||
vx_tmc_zero();
|
||||
}
|
||||
|
||||
static void __attribute__ ((noinline)) spawn_tasks_all_cb() {
|
||||
static void __attribute__ ((noinline)) spawn_tasks_cluster_all_cb() {
|
||||
// activate all threads
|
||||
vx_tmc(-1);
|
||||
|
||||
// call stub routine
|
||||
spawn_tasks_cluster_all_stub();
|
||||
|
||||
// disable warp
|
||||
vx_tmc_zero();
|
||||
}
|
||||
|
||||
static void __attribute__ ((noinline)) spawn_tasks_all_cb() {
|
||||
// activate all threads
|
||||
vx_tmc(-1);
|
||||
|
||||
@@ -126,6 +181,98 @@ static void __attribute__ ((noinline)) spawn_tasks_all_cb() {
|
||||
vx_tmc_zero();
|
||||
}
|
||||
|
||||
// This function runs in every core, but with only 1 warp and 1 thread enabled.
|
||||
// The logic in this function figures out how many warps/threads this particular
|
||||
// core has to enable to fulfill an entire grid of computation.
|
||||
void vx_spawn_tasks_cluster(int num_tasks, vx_spawn_tasks_cb callback, void *arg) {
|
||||
// device specs
|
||||
const int NC = vx_num_cores();
|
||||
const int NW = vx_num_warps();
|
||||
const int NT = vx_num_threads();
|
||||
// NOTE: assumes divisible
|
||||
const int num_cluster = NC / CORES_PER_CLUSTER;
|
||||
|
||||
// current core id
|
||||
int core_id = vx_core_id();
|
||||
if (core_id >= NUM_CORES_MAX)
|
||||
return;
|
||||
const int cluster_id = core_id / CORES_PER_CLUSTER;
|
||||
const int core_id_in_cluster = core_id % CORES_PER_CLUSTER;
|
||||
|
||||
// Distribute threads equally across as many cores as possible, even if they
|
||||
// don't fill up NW*NT in a single core. This makes sure the warps get evenly
|
||||
// distributed in a single cluster
|
||||
//
|
||||
// TODO: Try to contain in a single cluster if possible?
|
||||
const int num_active_cores = (num_tasks + (NT - 1)) / NT;
|
||||
if (core_id >= num_active_cores)
|
||||
return; // terminate extra cores
|
||||
|
||||
// FIXME: assumes num_tasks is divisible by num_cluster
|
||||
const int num_tasks_this_cluster = num_tasks / num_cluster;
|
||||
const int num_full_warps = num_tasks_this_cluster / NT;
|
||||
const int rem_threads_in_last_warp = num_tasks_this_cluster % NT;
|
||||
// const int num_warps = (num_tasks_this_cluster + (NT - 1)) / NT;
|
||||
|
||||
int num_warps_this_core = num_full_warps / CORES_PER_CLUSTER;
|
||||
const int num_warps_in_last_row = num_full_warps % CORES_PER_CLUSTER;
|
||||
if (core_id_in_cluster < num_warps_in_last_row) {
|
||||
num_warps_this_core++;
|
||||
}
|
||||
// if 0, last warp is full-threads enabled
|
||||
int rem_threads_in_last_warp_this_core = 0;
|
||||
if (rem_threads_in_last_warp != 0) {
|
||||
if (core_id_in_cluster == num_warps_in_last_row - 1) {
|
||||
rem_threads_in_last_warp_this_core = rem_threads_in_last_warp;
|
||||
}
|
||||
}
|
||||
|
||||
// sequential iterations
|
||||
const int num_full_waves = num_warps_this_core / NW;
|
||||
const int rem_full_warps_in_last_wave = num_warps_this_core % NW;
|
||||
|
||||
const const int offset = cluster_id * num_tasks_this_cluster;
|
||||
wspawn_tasks_args_t wspawn_args = {callback, arg, offset, num_full_waves,
|
||||
rem_full_warps_in_last_wave};
|
||||
g_wspawn_args[core_id] = &wspawn_args;
|
||||
|
||||
if (num_warps_this_core > 0) {
|
||||
// execute callback on other warps
|
||||
const int nw = MIN(num_warps_this_core, NW);
|
||||
vx_wspawn(nw, spawn_tasks_cluster_all_cb);
|
||||
|
||||
// activate all threads
|
||||
vx_tmc(-1);
|
||||
|
||||
// call stub routine
|
||||
spawn_tasks_cluster_all_stub();
|
||||
|
||||
// back to single-threaded
|
||||
vx_tmc_one();
|
||||
|
||||
// wait for spawn warps to terminate
|
||||
vx_wspawn_wait();
|
||||
}
|
||||
|
||||
// TODO: Instead of launching an additional wave just to work on remaining
|
||||
// threads, handle this in the last wave amongst other full warps.
|
||||
if (rem_threads_in_last_warp != 0 && core_id_in_cluster == 0) {
|
||||
// adjust offset
|
||||
// FIXME: use rem_threads_in_last_warp_this_core
|
||||
wspawn_args.offset += (num_tasks_this_cluster - rem_threads_in_last_warp);
|
||||
|
||||
// activate remaining threads
|
||||
const int tmask = (1 << rem_threads_in_last_warp) - 1;
|
||||
vx_tmc(tmask);
|
||||
|
||||
// call stub routine
|
||||
spawn_tasks_cluster_rem_stub();
|
||||
|
||||
// back to single-threaded
|
||||
vx_tmc_one();
|
||||
}
|
||||
}
|
||||
|
||||
void vx_spawn_tasks_contiguous(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
|
||||
// device specs
|
||||
int NC = vx_num_cores();
|
||||
@@ -179,7 +326,6 @@ void vx_spawn_tasks_contiguous(int num_tasks, vx_spawn_tasks_cb callback , void
|
||||
vx_tmc_one();
|
||||
|
||||
// wait for spawn warps to terminate
|
||||
// deadlock here on warp 0!
|
||||
vx_wspawn_wait();
|
||||
}
|
||||
|
||||
|
||||
@@ -102,6 +102,8 @@ init_regs:
|
||||
#endif
|
||||
csrr t0, VX_CSR_MHARTID
|
||||
sll t1, t0, STACK_LOG2_SIZE
|
||||
sll t2, t0, 4
|
||||
add t1, t1, t2
|
||||
sub sp, sp, t1
|
||||
|
||||
# set thread pointer register
|
||||
|
||||
Reference in New Issue
Block a user