diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 2b1fea33..5e6f5b9b 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -545,7 +545,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // NOTE: out-of bounds is not checked // TODO: reduce this from B_ROW to NUM_WARPS constexpr uint32_t smem_scratchpad_size = - B_ROW * NUM_THREADS * 2 /*arbitrary slack*/; + threads_per_warpgroup * 2 /*arbitrary slack*/; float *smem_scratchpad = smem_O_row_scale_1 - smem_scratchpad_size; // initialize rowmax/rowsum values in sharedmem