Fix range for
i'm a python noob
This commit is contained in:
@@ -64,23 +64,23 @@ if __name__ == "__main__":
|
||||
AT_packed = A_packed.transpose([1, 0, 2])
|
||||
AT_array = AT_packed.reshape([-1, seqlen * 2])
|
||||
AT_array.astype('float16').tofile("input.a.col.bin")
|
||||
print('AT:')
|
||||
print(AT_array)
|
||||
# print('AT:')
|
||||
# print(AT_array)
|
||||
B_packed = pack_fp16_by_column(B_array)
|
||||
B_array = B_packed.reshape([-1, headdim * 2])
|
||||
B_array.astype('float16').tofile("input.b.row.bin")
|
||||
print('B:')
|
||||
print(B_array)
|
||||
# print('B:')
|
||||
# print(B_array)
|
||||
else:
|
||||
A_array.astype('float32').tofile("input.a.row.bin")
|
||||
AT_array = A_array.transpose([1, 0])
|
||||
AT_array.astype('float32').tofile("input.a.col.bin")
|
||||
B_array.astype('float32').tofile("input.b.bin")
|
||||
C_array.astype('float32').tofile("input.c.bin")
|
||||
print('AT:')
|
||||
print(AT_array)
|
||||
print('B:')
|
||||
print(B_array)
|
||||
# print('AT:')
|
||||
# print(AT_array)
|
||||
# print('B:')
|
||||
# print(B_array)
|
||||
|
||||
assert((seqlen % 64) == 0)
|
||||
|
||||
@@ -94,7 +94,7 @@ if __name__ == "__main__":
|
||||
def exp2(x):
|
||||
return (x**2) / 2.0 + x + 1.0
|
||||
|
||||
for col in range(0, Bc, seqlen):
|
||||
for col in range(0, seqlen, Bc):
|
||||
print(f"tile iteration {col}~{col + Bc} ======================================")
|
||||
|
||||
# FIXME: only work with the first 64 rows of Q for now
|
||||
|
||||
Reference in New Issue
Block a user