From 2f7fb372f1f5210336402fd28293b04f8d150a7c Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Mon, 19 Aug 2024 21:19:16 -0700 Subject: [PATCH] Fix range for i'm a python noob --- tests/kernel/tensor/flash_attn.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/kernel/tensor/flash_attn.py b/tests/kernel/tensor/flash_attn.py index 79f9e021..9934599d 100644 --- a/tests/kernel/tensor/flash_attn.py +++ b/tests/kernel/tensor/flash_attn.py @@ -64,23 +64,23 @@ if __name__ == "__main__": AT_packed = A_packed.transpose([1, 0, 2]) AT_array = AT_packed.reshape([-1, seqlen * 2]) AT_array.astype('float16').tofile("input.a.col.bin") - print('AT:') - print(AT_array) + # print('AT:') + # print(AT_array) B_packed = pack_fp16_by_column(B_array) B_array = B_packed.reshape([-1, headdim * 2]) B_array.astype('float16').tofile("input.b.row.bin") - print('B:') - print(B_array) + # print('B:') + # print(B_array) else: A_array.astype('float32').tofile("input.a.row.bin") AT_array = A_array.transpose([1, 0]) AT_array.astype('float32').tofile("input.a.col.bin") B_array.astype('float32').tofile("input.b.bin") C_array.astype('float32').tofile("input.c.bin") - print('AT:') - print(AT_array) - print('B:') - print(B_array) + # print('AT:') + # print(AT_array) + # print('B:') + # print(B_array) assert((seqlen % 64) == 0) @@ -94,7 +94,7 @@ if __name__ == "__main__": def exp2(x): return (x**2) / 2.0 + x + 1.0 - for col in range(0, Bc, seqlen): + for col in range(0, seqlen, Bc): print(f"tile iteration {col}~{col + Bc} ======================================") # FIXME: only work with the first 64 rows of Q for now