flash.py: Write V to file

This commit is contained in:
Hansung Kim
2024-09-01 18:17:05 -07:00
parent 6cc1b5ca37
commit f7603b18d3

View File

@@ -98,7 +98,7 @@ if __name__ == "__main__":
full_S_T = full_S.transpose([1, 0])
full_S.astype('float32').tofile("full_S.bin")
col_to_save = 128
col_to_save = 0
for col in range(0, seqlen, Bc):
print(f"tile iteration {col}~{col + Bc} ======================================")
@@ -148,6 +148,8 @@ if __name__ == "__main__":
O.astype('float32').tofile("O_before_PV.bin")
V = C_array[col:col+Bc, :]
if col == col_to_save:
V.astype('float32').tofile("V_expected.bin")
# O = P.transpose([1, 0]) @ V
O = O + P @ V
if col == col_to_save: