From 471f89e37171003618ada99af6ae922d6e42fdb5 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Thu, 30 Jan 2025 01:02:12 -0800 Subject: [PATCH 1/2] Add arg binary for flash --- kernels/flash_attention/args.bin | 1 + kernels/flash_attention/args.seq1024.headdim64.bin | Bin 0 -> 40 bytes kernels/flash_attention/args.seq128.headdim64.bin | Bin 0 -> 40 bytes kernels/flash_attention/args.seq192.headdim64.bin | Bin 0 -> 40 bytes kernels/flash_attention/args.seq64.headdim64.bin | Bin 0 -> 40 bytes 5 files changed, 1 insertion(+) create mode 120000 kernels/flash_attention/args.bin create mode 100644 kernels/flash_attention/args.seq1024.headdim64.bin create mode 100644 kernels/flash_attention/args.seq128.headdim64.bin create mode 100644 kernels/flash_attention/args.seq192.headdim64.bin create mode 100644 kernels/flash_attention/args.seq64.headdim64.bin diff --git a/kernels/flash_attention/args.bin b/kernels/flash_attention/args.bin new file mode 120000 index 00000000..11ccab60 --- /dev/null +++ b/kernels/flash_attention/args.bin @@ -0,0 +1 @@ +args.seq1024.headdim64.bin \ No newline at end of file diff --git a/kernels/flash_attention/args.seq1024.headdim64.bin b/kernels/flash_attention/args.seq1024.headdim64.bin new file mode 100644 index 0000000000000000000000000000000000000000..c7088ecae30f209e2d911461e57d3d7a20e06e07 GIT binary patch literal 40 acmZQzVPJ4z0D}b(ieVv?UIe8NfM@_U69VV} literal 0 HcmV?d00001 diff --git a/kernels/flash_attention/args.seq128.headdim64.bin b/kernels/flash_attention/args.seq128.headdim64.bin new file mode 100644 index 0000000000000000000000000000000000000000..a6117d7293ad6b843f257cbc161b802d608d11a1 GIT binary patch literal 40 ccmZo*U|?_nVjx%mCK(tOLg__N`T&Rq07!QOWB>pF literal 0 HcmV?d00001 diff --git a/kernels/flash_attention/args.seq192.headdim64.bin b/kernels/flash_attention/args.seq192.headdim64.bin new file mode 100644 index 0000000000000000000000000000000000000000..531a6b3eb6d782c57a596b6bd12f767a58870a44 GIT binary patch literal 40 ccmX@Wz`)=D#6YkBOfoPmgwl(k^Z^hJ08)1YqyPW_ literal 0 HcmV?d00001 diff --git a/kernels/flash_attention/args.seq64.headdim64.bin b/kernels/flash_attention/args.seq64.headdim64.bin new file mode 100644 index 0000000000000000000000000000000000000000..1171fc6ebdf199490f3a90d1ae35788c766633aa GIT binary patch literal 40 ccmZ=@U|?_nVjx%mCK(tOLg__N`T&Rq06upEBme*a literal 0 HcmV?d00001 From b73147cd064751d492430dc221014c8b7234b549 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Thu, 30 Jan 2025 01:04:20 -0800 Subject: [PATCH 2/2] Add compile and operand generate script for flash --- kernels/flash_attention/compile_flash.sh | 46 +++++++ kernels/flash_attention/flash_attn.py | 159 +++++++++++++++++++++++ 2 files changed, 205 insertions(+) create mode 100755 kernels/flash_attention/compile_flash.sh create mode 100644 kernels/flash_attention/flash_attn.py diff --git a/kernels/flash_attention/compile_flash.sh b/kernels/flash_attention/compile_flash.sh new file mode 100755 index 00000000..95dcb233 --- /dev/null +++ b/kernels/flash_attention/compile_flash.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +archs=("ampere" "virgo") + +if [ -z "$TOOLDIR" ]; then + echo "error: \$TOOLDIR not set. Did you run source ci/toolchain_env.sh?" + exit 1 +fi + +check_exists() { + if ! [ -f "$1" ]; then + echo "error: looked for file $1 that does not exist." + exit 1 + fi +} + +# generate operands +echo "generating flash_attn operands for seqlen 1024, headdim 64" +python3 flash_attn.py 1024 64 64 +mv -v input.a.col.bin input.a.rand.fp32.seqlen1024headdim64.col.bin +mv -v input.a.row.bin input.a.rand.fp32.seqlen1024headdim64.row.bin +mv -v input.b.bin input.b.rand.fp32.seqlen1024headdim64.row.bin +mv -v input.c.bin input.c.rand.fp32.seqlen1024headdim64.row.bin +ln -sf input.a.rand.fp32.seqlen1024headdim64.row.bin input.a.bin +ln -sf input.b.rand.fp32.seqlen1024headdim64.row.bin input.b.bin +ln -sf input.c.rand.fp32.seqlen1024headdim64.row.bin input.c.bin + +for arch in "${archs[@]}"; do + git checkout ae-$arch + + # re-compile libvortexrt.a + # FIXME after restructure + pushd ../../lib + make + popd + + for dim in "${dims[@]}"; do + echo "compiling flash_attn kernel for $arch with seqlen 1024, headdim 64" + + # touch source file to force re-building, as the Makefile does not track + # binary changes + touch kernel.cpp + + make CONFIG=flash.$arch.seqlen1024.headdim64 + done +done diff --git a/kernels/flash_attention/flash_attn.py b/kernels/flash_attention/flash_attn.py new file mode 100644 index 00000000..5d8dca15 --- /dev/null +++ b/kernels/flash_attention/flash_attn.py @@ -0,0 +1,159 @@ +import sys +import numpy as np + +def parse_mnk(): + if len(sys.argv) != 4: + print(f"usage: {sys.argv[0]} dimM dimN dimK", file=sys.stderr) + sys.exit(1) + m = int(sys.argv[1]) + n = int(sys.argv[2]) + k = int(sys.argv[3]) + return (m, n, k) + + +# Reorder array in a way that groups two adjacent elements along the column to +# be now adjacent along the row. This way, when the resulting fp16 array is +# read in column-major order with 32-bit granularity, the fp16 elements will be +# read in the same order as regular fp32 elements in column-major. +# +# For example: +# [[1 2] +# [3 4] +# [5 6] +# [7 8]] +# becomes +# [[1 3 2 4] +# [5 7 6 8]] +def pack_fp16_by_column(array): + rows = array.shape[0] + cols = array.shape[1] + + T = array.transpose([1, 0]) + T_packed = T.reshape([cols, -1, 2]) + result = T_packed.transpose([1, 0, 2]) + return result + + +# Do the same as pack_fp16_by_column, but for every two elements along the row. +def pack_fp16_by_row(array): + rows = array.shape[0] + cols = array.shape[1] + + result = array.reshape([rows, -1, 2]) + return result + + +if __name__ == "__main__": + seqlen, _, headdim = parse_mnk() + + rand = True + if not rand: + A_array = np.arange(seqlen * headdim).reshape([seqlen, headdim]) + B_array = np.arange(headdim * seqlen).reshape([headdim, seqlen]) + C_array = np.arange(seqlen * seqlen).reshape([seqlen, headdim]) + else: + np.random.seed(0) + A_array = np.random.rand(seqlen, headdim) - 0.5 + B_array = np.random.rand(headdim, seqlen) - 0.5 + C_array = np.random.rand(seqlen, headdim) - 0.5 + # C_array = np.zeros([M, N]) + + fp16 = False + if fp16: + A_packed = pack_fp16_by_row(A_array) + AT_packed = A_packed.transpose([1, 0, 2]) + AT_array = AT_packed.reshape([-1, seqlen * 2]) + AT_array.astype('float16').tofile("input.a.col.bin") + # print('AT:') + # print(AT_array) + B_packed = pack_fp16_by_column(B_array) + B_array = B_packed.reshape([-1, headdim * 2]) + B_array.astype('float16').tofile("input.b.row.bin") + # print('B:') + # print(B_array) + else: + A_array.astype('float32').tofile("input.a.row.bin") + AT_array = A_array.transpose([1, 0]) + AT_array.astype('float32').tofile("input.a.col.bin") + B_array.astype('float32').tofile("input.b.bin") + C_array.astype('float32').tofile("input.c.bin") + # print('AT:') + # print(AT_array) + # print('B:') + # print(B_array) + + assert((seqlen % 64) == 0) + + Br = 64 + Bc = Br + + rowmax = np.zeros([Br]) + rowsum = np.zeros([Br]) + O = np.zeros([Br, headdim]) + + def exp2(x): + return (x**2) / 2.0 + x + 1.0 + + full_S = A_array @ B_array + full_S_T = full_S.transpose([1, 0]) + full_S.astype('float32').tofile("full_S.bin") + + col_to_save = 0 + + for col in range(0, seqlen, Bc): + print(f"tile iteration {col}~{col + Bc} ======================================") + + # FIXME: only work with the first 64 rows of Q for now + Q_tile = A_array[0:64, :] + K_tile = B_array[:, col:col+Bc] + + S = Q_tile @ K_tile + if col == col_to_save: + print('S_expected:') + print(S) + S.astype('float32').tofile("S_expected.bin") + + # generate rowmax result in online softmax + rowmax_this = np.max(S, axis=1) + rowmax_prev = rowmax.copy() + rowmax = np.maximum(rowmax, rowmax_this) + if col == col_to_save: + rowmax.astype('float32').tofile("rowmax.bin") + + # subtrace rowmax from each row by broadcasting + # (placeholder for exp) + x = S - rowmax[:, np.newaxis] + P = exp2(x) + # for i in range(3, 4): + # P += (x**i) / np.math.factorial(i) + # P = np.exp(exp) + # print('P error:') + # print(P / np.exp(x)) + if col == col_to_save: + print('P_expected:') + print(P) + P.astype('float32').tofile("P_expected.bin") + P.transpose([1, 0]).astype('float32').tofile("P_expected.col.bin") + + rowsum_this = np.sum(P, axis=1) + x = rowmax_prev - rowmax_this + rowsum = exp2(x) * rowsum + rowsum_this + if col == col_to_save: + rowsum.astype('float32').tofile("rowsum.bin") + + x = rowmax_prev - rowmax + O = O / (exp2(x)[:, np.newaxis]) + if col == col_to_save: + print('O_before_PV:') + print(O) + O.astype('float32').tofile("O_before_PV.bin") + + V = C_array[col:col+Bc, :] + if col == col_to_save: + V.astype('float32').tofile("V_expected.bin") + # O = P.transpose([1, 0]) @ V + O = O + P @ V + if col == col_to_save: + print('O_after_PV:') + print(O) + O.astype('float32').tofile("O_after_PV.bin")