flash: Fix grid size to hw cluster size

Verified fast config, minus the barrier stall at the end.
This commit is contained in:
Hansung Kim
2024-09-09 15:43:31 -07:00
parent 829af5d429
commit d31c8ffd7d
2 changed files with 3 additions and 7 deletions

View File

@@ -676,15 +676,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
int main() { int main() {
kernel_arg_t *arg = (kernel_arg_t *)KERNEL_ARG_DEV_MEM_ADDR; kernel_arg_t *arg = (kernel_arg_t *)KERNEL_ARG_DEV_MEM_ADDR;
// FIXME:: use actuall seqlen/headdim
const uint32_t problem_size = (B_ROW * B_COL) / (ELEM_PER_THREAD);
const uint32_t hw_threads_per_cluster = const uint32_t hw_threads_per_cluster =
CORES_PER_CLUSTER * vx_num_threads() * vx_num_warps(); CORES_PER_CLUSTER * vx_num_threads() * vx_num_warps();
// prevent launching more threads than the necessary problem size // fix to 1 threadblock per cluster
// TODO: this does not take into account multiple clusters const uint32_t grid_size = hw_threads_per_cluster;
const uint32_t grid_size = (problem_size > hw_threads_per_cluster)
? hw_threads_per_cluster
: problem_size;
#ifdef RADIANCE #ifdef RADIANCE
vx_spawn_tasks_cluster(grid_size, (vx_spawn_tasks_cb)kernel_body, arg); vx_spawn_tasks_cluster(grid_size, (vx_spawn_tasks_cb)kernel_body, arg);

View File

@@ -538,6 +538,7 @@ wmma_store(const int thread_in_warp, const int warp_col, const int warp_row,
__attribute__((convergent)) inline void __attribute__((convergent)) inline void
threadblock_barrier(const uint32_t barrier_id, const uint32_t count) { threadblock_barrier(const uint32_t barrier_id, const uint32_t count) {
asm volatile("" ::: "memory");
vx_fence(); vx_fence();
vx_barrier(barrier_id, count); vx_barrier(barrier_id, count);
} }