diff --git a/tests/kernel/tensor/flash_attn.py b/tests/kernel/tensor/flash_attn.py index 9c934adb..dfe92c5f 100644 --- a/tests/kernel/tensor/flash_attn.py +++ b/tests/kernel/tensor/flash_attn.py @@ -98,7 +98,7 @@ if __name__ == "__main__": full_S_T = full_S.transpose([1, 0]) full_S.astype('float32').tofile("full_S.bin") - col_to_save = 128 + col_to_save = 0 for col in range(0, seqlen, Bc): print(f"tile iteration {col}~{col + Bc} ======================================") @@ -148,6 +148,8 @@ if __name__ == "__main__": O.astype('float32').tofile("O_before_PV.bin") V = C_array[col:col+Bc, :] + if col == col_to_save: + V.astype('float32').tofile("V_expected.bin") # O = P.transpose([1, 0]) @ V O = O + P @ V if col == col_to_save: