diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index dd6cb4f3..b1086e18 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -475,8 +475,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<0>(); initialize_accum_regs<1>(); - load_tile_to_smem( - B_COL, 0, 0, gmem_V, smem_V, tid_in_threadblock); + // V dimension is [seqlen, headdim], stored N(headdim)-major + load_tile_to_smem( + HEADDIM, 0 /* 0 because always reads the full N-dimension */, + tile_k * B_COL, gmem_V, smem_V, tid_in_threadblock); threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core); diff --git a/tests/regression/flash_attention/main.cpp b/tests/regression/flash_attention/main.cpp index 3747bbb6..b1b8d522 100644 --- a/tests/regression/flash_attention/main.cpp +++ b/tests/regression/flash_attention/main.cpp @@ -99,7 +99,7 @@ int main(int argc, char *argv[]) { std::cout << "open device connection" << std::endl; RT_CHECK(vx_dev_open(&device)); - uint32_t dim_seqlen = 64; + uint32_t dim_seqlen = 128; uint32_t dim_headdim = 64; using float_type = half;