diff --git a/kernels/flash_attention/args.bin b/kernels/flash_attention/args.bin new file mode 120000 index 00000000..11ccab60 --- /dev/null +++ b/kernels/flash_attention/args.bin @@ -0,0 +1 @@ +args.seq1024.headdim64.bin \ No newline at end of file diff --git a/kernels/flash_attention/args.seq1024.headdim64.bin b/kernels/flash_attention/args.seq1024.headdim64.bin new file mode 100644 index 00000000..c7088eca Binary files /dev/null and b/kernels/flash_attention/args.seq1024.headdim64.bin differ diff --git a/kernels/flash_attention/args.seq128.headdim64.bin b/kernels/flash_attention/args.seq128.headdim64.bin new file mode 100644 index 00000000..a6117d72 Binary files /dev/null and b/kernels/flash_attention/args.seq128.headdim64.bin differ diff --git a/kernels/flash_attention/args.seq192.headdim64.bin b/kernels/flash_attention/args.seq192.headdim64.bin new file mode 100644 index 00000000..531a6b3e Binary files /dev/null and b/kernels/flash_attention/args.seq192.headdim64.bin differ diff --git a/kernels/flash_attention/args.seq64.headdim64.bin b/kernels/flash_attention/args.seq64.headdim64.bin new file mode 100644 index 00000000..1171fc6e Binary files /dev/null and b/kernels/flash_attention/args.seq64.headdim64.bin differ diff --git a/kernels/flash_attention/compile_flash.sh b/kernels/flash_attention/compile_flash.sh new file mode 100755 index 00000000..5808508c --- /dev/null +++ b/kernels/flash_attention/compile_flash.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +archs=("ampere" "virgo") + +if [ -z "$TOOLDIR" ]; then + echo "error: \$TOOLDIR not set. Did you run source ci/toolchain_env.sh?" + exit 1 +fi + +check_exists() { + if ! [ -f "$1" ]; then + echo "error: looked for file $1 that does not exist." + exit 1 + fi +} + +# generate operands +echo "generating flash_attn operands for seqlen 1024, headdim 64" +python3 flash_attn.py 1024 64 64 +mv -v input.a.col.bin input.a.rand.fp32.seqlen1024headdim64.col.bin +mv -v input.a.row.bin input.a.rand.fp32.seqlen1024headdim64.row.bin +mv -v input.b.bin input.b.rand.fp32.seqlen1024headdim64.row.bin +mv -v input.c.bin input.c.rand.fp32.seqlen1024headdim64.row.bin +ln -sf input.a.rand.fp32.seqlen1024headdim64.row.bin input.a.bin +ln -sf input.b.rand.fp32.seqlen1024headdim64.row.bin input.b.bin +ln -sf input.c.rand.fp32.seqlen1024headdim64.row.bin input.c.bin + +for arch in "${archs[@]}"; do + git checkout ae-flash-$arch + + # re-compile libvortexrt.a + # FIXME after restructure + pushd ../../lib + make + popd + + echo "compiling flash_attn kernel for $arch with seqlen 1024, headdim 64" + + # touch source file to force re-building, as the Makefile does not track + # binary changes + touch kernel.cpp + + make CONFIG=flash.$arch.seqlen1024.headdim64 +done diff --git a/kernels/flash_attention/flash_attn.py b/kernels/flash_attention/flash_attn.py new file mode 100644 index 00000000..5d8dca15 --- /dev/null +++ b/kernels/flash_attention/flash_attn.py @@ -0,0 +1,159 @@ +import sys +import numpy as np + +def parse_mnk(): + if len(sys.argv) != 4: + print(f"usage: {sys.argv[0]} dimM dimN dimK", file=sys.stderr) + sys.exit(1) + m = int(sys.argv[1]) + n = int(sys.argv[2]) + k = int(sys.argv[3]) + return (m, n, k) + + +# Reorder array in a way that groups two adjacent elements along the column to +# be now adjacent along the row. This way, when the resulting fp16 array is +# read in column-major order with 32-bit granularity, the fp16 elements will be +# read in the same order as regular fp32 elements in column-major. +# +# For example: +# [[1 2] +# [3 4] +# [5 6] +# [7 8]] +# becomes +# [[1 3 2 4] +# [5 7 6 8]] +def pack_fp16_by_column(array): + rows = array.shape[0] + cols = array.shape[1] + + T = array.transpose([1, 0]) + T_packed = T.reshape([cols, -1, 2]) + result = T_packed.transpose([1, 0, 2]) + return result + + +# Do the same as pack_fp16_by_column, but for every two elements along the row. +def pack_fp16_by_row(array): + rows = array.shape[0] + cols = array.shape[1] + + result = array.reshape([rows, -1, 2]) + return result + + +if __name__ == "__main__": + seqlen, _, headdim = parse_mnk() + + rand = True + if not rand: + A_array = np.arange(seqlen * headdim).reshape([seqlen, headdim]) + B_array = np.arange(headdim * seqlen).reshape([headdim, seqlen]) + C_array = np.arange(seqlen * seqlen).reshape([seqlen, headdim]) + else: + np.random.seed(0) + A_array = np.random.rand(seqlen, headdim) - 0.5 + B_array = np.random.rand(headdim, seqlen) - 0.5 + C_array = np.random.rand(seqlen, headdim) - 0.5 + # C_array = np.zeros([M, N]) + + fp16 = False + if fp16: + A_packed = pack_fp16_by_row(A_array) + AT_packed = A_packed.transpose([1, 0, 2]) + AT_array = AT_packed.reshape([-1, seqlen * 2]) + AT_array.astype('float16').tofile("input.a.col.bin") + # print('AT:') + # print(AT_array) + B_packed = pack_fp16_by_column(B_array) + B_array = B_packed.reshape([-1, headdim * 2]) + B_array.astype('float16').tofile("input.b.row.bin") + # print('B:') + # print(B_array) + else: + A_array.astype('float32').tofile("input.a.row.bin") + AT_array = A_array.transpose([1, 0]) + AT_array.astype('float32').tofile("input.a.col.bin") + B_array.astype('float32').tofile("input.b.bin") + C_array.astype('float32').tofile("input.c.bin") + # print('AT:') + # print(AT_array) + # print('B:') + # print(B_array) + + assert((seqlen % 64) == 0) + + Br = 64 + Bc = Br + + rowmax = np.zeros([Br]) + rowsum = np.zeros([Br]) + O = np.zeros([Br, headdim]) + + def exp2(x): + return (x**2) / 2.0 + x + 1.0 + + full_S = A_array @ B_array + full_S_T = full_S.transpose([1, 0]) + full_S.astype('float32').tofile("full_S.bin") + + col_to_save = 0 + + for col in range(0, seqlen, Bc): + print(f"tile iteration {col}~{col + Bc} ======================================") + + # FIXME: only work with the first 64 rows of Q for now + Q_tile = A_array[0:64, :] + K_tile = B_array[:, col:col+Bc] + + S = Q_tile @ K_tile + if col == col_to_save: + print('S_expected:') + print(S) + S.astype('float32').tofile("S_expected.bin") + + # generate rowmax result in online softmax + rowmax_this = np.max(S, axis=1) + rowmax_prev = rowmax.copy() + rowmax = np.maximum(rowmax, rowmax_this) + if col == col_to_save: + rowmax.astype('float32').tofile("rowmax.bin") + + # subtrace rowmax from each row by broadcasting + # (placeholder for exp) + x = S - rowmax[:, np.newaxis] + P = exp2(x) + # for i in range(3, 4): + # P += (x**i) / np.math.factorial(i) + # P = np.exp(exp) + # print('P error:') + # print(P / np.exp(x)) + if col == col_to_save: + print('P_expected:') + print(P) + P.astype('float32').tofile("P_expected.bin") + P.transpose([1, 0]).astype('float32').tofile("P_expected.col.bin") + + rowsum_this = np.sum(P, axis=1) + x = rowmax_prev - rowmax_this + rowsum = exp2(x) * rowsum + rowsum_this + if col == col_to_save: + rowsum.astype('float32').tofile("rowsum.bin") + + x = rowmax_prev - rowmax + O = O / (exp2(x)[:, np.newaxis]) + if col == col_to_save: + print('O_before_PV:') + print(O) + O.astype('float32').tofile("O_before_PV.bin") + + V = C_array[col:col+Bc, :] + if col == col_to_save: + V.astype('float32').tofile("V_expected.bin") + # O = P.transpose([1, 0]) @ V + O = O + P @ V + if col == col_to_save: + print('O_after_PV:') + print(O) + O.astype('float32').tofile("O_after_PV.bin") diff --git a/kernels/sgemm_tcore/compile_tcore.sh b/kernels/sgemm_tcore/compile_tcore.sh index 11bfbaf8..128eb175 100755 --- a/kernels/sgemm_tcore/compile_tcore.sh +++ b/kernels/sgemm_tcore/compile_tcore.sh @@ -41,12 +41,22 @@ check_exists() { fi } +# generate operands +for dim in "${dims[@]}"; do + echo "generating operands for dim $dim" + python3 generate_operands.py $dim $dim $dim + mv -v input.a.col.bin input.a.rand01.fp16.m${dim}n${dim}k${dim}.col.swizzle_fp16.bin + mv -v input.a.row.bin input.a.rand01.fp16.m${dim}n${dim}k${dim}.row.swizzle_fp16.bin + mv -v input.b.row.bin input.b.rand01.fp16.m${dim}n${dim}k${dim}.row.bin + mv -v input.b.row.swizzled.bin input.b.rand01.fp16.m${dim}n${dim}k${dim}.row.swizzle_fp16.bin +done + for arch in "${archs[@]}"; do git checkout ae-$arch # re-compile libvortexrt.a # FIXME after restructure - pushd ../../libs + pushd ../../lib make popd diff --git a/kernels/sgemm_tcore/generate_operands.py b/kernels/sgemm_tcore/generate_operands.py new file mode 100644 index 00000000..bb429232 --- /dev/null +++ b/kernels/sgemm_tcore/generate_operands.py @@ -0,0 +1,116 @@ +import sys +import numpy as np + +def parse_mnk(): + if len(sys.argv) != 4: + print(f"usage: {sys.argv[0]} dimM dimN dimK", file=sys.stderr) + sys.exit(1) + m = int(sys.argv[1]) + n = int(sys.argv[2]) + k = int(sys.argv[3]) + return (m, n, k) + + +# Reorder array in a way that groups two adjacent elements along the column to +# be now adjacent along the row. This way, when the resulting fp16 array is +# read in column-major order with 32-bit granularity, the fp16 elements will be +# read in the same order as regular fp32 elements in column-major. +# +# For example: +# [[1 2] +# [3 4] +# [5 6] +# [7 8]] +# becomes +# [[1 3 2 4] +# [5 7 6 8]] +def pack_fp16_by_column(array): + rows = array.shape[0] + cols = array.shape[1] + + T = array.transpose([1, 0]) + T_packed = T.reshape([cols, -1, 2]) + result = T_packed.transpose([1, 0, 2]) + return result + + +# Do the same as pack_fp16_by_column, but for every two elements along the row. +def pack_fp16_by_row(array): + rows = array.shape[0] + cols = array.shape[1] + + result = array.reshape([rows, -1, 2]) + return result + + +if __name__ == "__main__": + M, N, K = parse_mnk() + + rand = True + if not rand: + A_array = np.arange(M * K).reshape([M, K]) + B_array = np.arange(K * N).reshape([K, N]) + # C_array = np.arange(M * N).reshape([M, N]) + C_array = np.zeros([M, N]) + else: + np.random.seed(0) + A_array = np.random.rand(M, K) + B_array = np.random.rand(K, N) + C_array = np.random.rand(N, K) + # C_array = np.zeros([M, N]) + + with open('a_matrix.h', 'w') as f: + for i in range(A_array.shape[0]): + for j in range(A_array.shape[1]): + f.write(f'{A_array[i,j]:f}f, ') + f.write('\n') + with open('b_matrix.h', 'w') as f: + for i in range(B_array.shape[0]): + for j in range(B_array.shape[1]): + f.write(f'{B_array[i,j]:f}f, ') + f.write('\n') + with open('c_matrix.h', 'w') as f: + for i in range(C_array.shape[0]): + for j in range(C_array.shape[1]): + f.write(f'{C_array[i,j]:f}f, ') + f.write('\n') + + np.savez("abc", A_array=A_array, B_array=B_array, C_array=C_array) + + fp16 = True + if fp16: + A_packed = pack_fp16_by_row(A_array) + A_swizzled = A_packed.reshape([-1, M * 2]) + A_swizzled.astype('float16').tofile("input.a.row.bin") + AT_packed = A_packed.transpose([1, 0, 2]) + AT_swizzled = AT_packed.reshape([-1, M * 2]) + AT_swizzled.astype('float16').tofile("input.a.col.bin") + print('A:') + print(A_swizzled) + print('AT:') + print(AT_swizzled) + B_array.astype('float16').tofile("input.b.row.bin") + # B_packed_row = pack_fp16_by_row(B_array) + # B_packed_row = B_packed_row.reshape([-1, N * 2]) + # B_packed_row.astype('float16').tofile("input.b.row.bin") + B_packed = pack_fp16_by_column(B_array) + B_swizzled = B_packed.reshape([-1, N * 2]) + B_swizzled.astype('float16').tofile("input.b.row.swizzled.bin") + print('B:') + print(B_swizzled) + else: + A_array.astype('float32').tofile("input.a.row.bin") + AT_array = A_array.transpose([1, 0]) + AT_array.astype('float32').tofile("input.a.col.bin") + B_array.astype('float32').tofile("input.b.bin") + C_array.astype('float32').tofile("input.c.bin") + print('AT:') + print(AT_array) + print('B:') + print(B_array) + + D_expected = A_array @ B_array + D_expected.astype('float32').tofile("d_expected.bin") + print('D_expected:') + print(D_expected) +