diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index d627e413..92955e2f 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -105,7 +105,6 @@ inline void thread_block_online_softmax( : "f"(max), "f"(S[first_thread_offset + i])); } smem_rowmax[row] = max; - gmem_tmp0[row] = max; } #else @@ -188,7 +187,7 @@ inline void thread_block_online_softmax( // Store S transposed to the shared memory - smem_S[thread_offset] = f0; + smem_P[thread_offset] = f0; // S[thread_offset + 1] = f1; gmem_tmp1[thread_offset] = f0; @@ -206,7 +205,7 @@ inline void thread_block_online_softmax( float per_thread_sum = 0.0f; #pragma GCC unroll for (int i = 0; i < per_row_iter; i++) { - per_thread_sum += smem_S[thread_offset]; + per_thread_sum += smem_P[thread_offset]; thread_offset += NUM_THREADS; } // stage per-thread sum value in smem @@ -355,6 +354,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { float *tile_S = (float *)arg->addr_q; #endif + // FIXME: V is stored in d0000000 for debugging purpose + const float *gmem_V = reinterpret_cast(arg->addr_k); + thread_block_online_softmax( tile_S, smem_O, smem_P, tid_in_threadblock, threads_per_threadblock, threadblock_id_in_cluster, smem_scratchpad, smem_rowmax, smem_rowsum); @@ -365,9 +367,20 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { float *gmem_tmp2 = reinterpret_cast(0xf0000000UL); - thread_block_gemm_single_tile( + load_tile_to_smem( + B_COL, 0 /*FIXME*/, 0 /*FIXME*/, gmem_V, smem_V, tid_in_threadblock); + + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + + // FIXME: support MN_major for A for ideal performance + thread_block_gemm_single_tile( smem_P, smem_V, gmem_tmp2 /*smem_O*/, tid_in_threadblock, threads_per_threadblock); + + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); } int main() {