generate_matrix.py: switch to fp16 rand, generate row-major A
This commit is contained in:
@@ -46,7 +46,7 @@ def pack_fp16_by_row(array):
|
||||
if __name__ == "__main__":
|
||||
M, N, K = parse_mnk()
|
||||
|
||||
rand = False
|
||||
rand = True
|
||||
if not rand:
|
||||
A_array = np.arange(M * K).reshape([M, K])
|
||||
B_array = np.arange(K * N).reshape([K, N])
|
||||
@@ -77,12 +77,16 @@ if __name__ == "__main__":
|
||||
|
||||
np.savez("abc", A_array=A_array, B_array=B_array, C_array=C_array)
|
||||
|
||||
fp16 = False
|
||||
fp16 = True
|
||||
if fp16:
|
||||
A_packed = pack_fp16_by_row(A_array)
|
||||
A_swizzled = A_packed.reshape([-1, M * 2])
|
||||
A_swizzled.astype('float16').tofile("input.a.row.bin")
|
||||
AT_packed = A_packed.transpose([1, 0, 2])
|
||||
AT_swizzled = AT_packed.reshape([-1, M * 2])
|
||||
AT_swizzled.astype('float16').tofile("input.a.col.bin")
|
||||
print('A:')
|
||||
print(A_swizzled)
|
||||
print('AT:')
|
||||
print(AT_swizzled)
|
||||
B_packed = pack_fp16_by_column(B_array)
|
||||
|
||||
Reference in New Issue
Block a user