Fix range for

i'm a python noob
This commit is contained in:
Hansung Kim
2024-08-19 21:19:16 -07:00
parent 09afd43904
commit 2f7fb372f1

View File

@@ -64,23 +64,23 @@ if __name__ == "__main__":
AT_packed = A_packed.transpose([1, 0, 2])
AT_array = AT_packed.reshape([-1, seqlen * 2])
AT_array.astype('float16').tofile("input.a.col.bin")
print('AT:')
print(AT_array)
# print('AT:')
# print(AT_array)
B_packed = pack_fp16_by_column(B_array)
B_array = B_packed.reshape([-1, headdim * 2])
B_array.astype('float16').tofile("input.b.row.bin")
print('B:')
print(B_array)
# print('B:')
# print(B_array)
else:
A_array.astype('float32').tofile("input.a.row.bin")
AT_array = A_array.transpose([1, 0])
AT_array.astype('float32').tofile("input.a.col.bin")
B_array.astype('float32').tofile("input.b.bin")
C_array.astype('float32').tofile("input.c.bin")
print('AT:')
print(AT_array)
print('B:')
print(B_array)
# print('AT:')
# print(AT_array)
# print('B:')
# print(B_array)
assert((seqlen % 64) == 0)
@@ -94,7 +94,7 @@ if __name__ == "__main__":
def exp2(x):
return (x**2) / 2.0 + x + 1.0
for col in range(0, Bc, seqlen):
for col in range(0, seqlen, Bc):
print(f"tile iteration {col}~{col + Bc} ======================================")
# FIXME: only work with the first 64 rows of Q for now