flash: Fix grid size to hw cluster size
Verified fast config, minus the barrier stall at the end.
This commit is contained in:
@@ -676,15 +676,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
int main() {
|
||||
kernel_arg_t *arg = (kernel_arg_t *)KERNEL_ARG_DEV_MEM_ADDR;
|
||||
|
||||
// FIXME:: use actuall seqlen/headdim
|
||||
const uint32_t problem_size = (B_ROW * B_COL) / (ELEM_PER_THREAD);
|
||||
const uint32_t hw_threads_per_cluster =
|
||||
CORES_PER_CLUSTER * vx_num_threads() * vx_num_warps();
|
||||
// prevent launching more threads than the necessary problem size
|
||||
// TODO: this does not take into account multiple clusters
|
||||
const uint32_t grid_size = (problem_size > hw_threads_per_cluster)
|
||||
? hw_threads_per_cluster
|
||||
: problem_size;
|
||||
// fix to 1 threadblock per cluster
|
||||
const uint32_t grid_size = hw_threads_per_cluster;
|
||||
|
||||
#ifdef RADIANCE
|
||||
vx_spawn_tasks_cluster(grid_size, (vx_spawn_tasks_cb)kernel_body, arg);
|
||||
|
||||
@@ -538,6 +538,7 @@ wmma_store(const int thread_in_warp, const int warp_col, const int warp_row,
|
||||
|
||||
__attribute__((convergent)) inline void
|
||||
threadblock_barrier(const uint32_t barrier_id, const uint32_t count) {
|
||||
asm volatile("" ::: "memory");
|
||||
vx_fence();
|
||||
vx_barrier(barrier_id, count);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user