Merge branch 'ae' into ae-volta
This commit is contained in:
@@ -26,7 +26,7 @@ ln -sf input.b.rand.fp32.seqlen1024headdim64.row.bin input.b.bin
|
||||
ln -sf input.c.rand.fp32.seqlen1024headdim64.row.bin input.c.bin
|
||||
|
||||
for arch in "${archs[@]}"; do
|
||||
git checkout ae-$arch
|
||||
git checkout ae-flash-$arch
|
||||
|
||||
# re-compile libvortexrt.a
|
||||
# FIXME after restructure
|
||||
@@ -34,7 +34,6 @@ for arch in "${archs[@]}"; do
|
||||
make
|
||||
popd
|
||||
|
||||
for dim in "${dims[@]}"; do
|
||||
echo "compiling flash_attn kernel for $arch with seqlen 1024, headdim 64"
|
||||
|
||||
# touch source file to force re-building, as the Makefile does not track
|
||||
@@ -43,4 +42,3 @@ for arch in "${archs[@]}"; do
|
||||
|
||||
make CONFIG=flash.$arch.seqlen1024.headdim64
|
||||
done
|
||||
done
|
||||
|
||||
@@ -53,7 +53,7 @@
|
||||
#define PRINTF(...) sprintf(PRINT_BUF, __VA_ARGS__)
|
||||
// #define PRINTF(...) vx_printf(__VA_ARGS__)
|
||||
#define SWISH(beta, x) ((x) / (1 + exp(-(beta) * (x))))
|
||||
// #define POWER
|
||||
#define POWER
|
||||
|
||||
typedef uint16_t smem_elem_t;
|
||||
// typedef float smem_elem_t;
|
||||
|
||||
Reference in New Issue
Block a user