Merge branch 'ae' into ae-volta
This commit is contained in:
@@ -26,7 +26,8 @@ ln -sf input.b.rand.fp32.seqlen1024headdim64.row.bin input.b.bin
|
|||||||
ln -sf input.c.rand.fp32.seqlen1024headdim64.row.bin input.c.bin
|
ln -sf input.c.rand.fp32.seqlen1024headdim64.row.bin input.c.bin
|
||||||
|
|
||||||
for arch in "${archs[@]}"; do
|
for arch in "${archs[@]}"; do
|
||||||
git checkout ae-$arch
|
git checkout ae-flash-$arch
|
||||||
|
git pull
|
||||||
|
|
||||||
# re-compile libvortexrt.a
|
# re-compile libvortexrt.a
|
||||||
# FIXME after restructure
|
# FIXME after restructure
|
||||||
@@ -34,13 +35,11 @@ for arch in "${archs[@]}"; do
|
|||||||
make
|
make
|
||||||
popd
|
popd
|
||||||
|
|
||||||
for dim in "${dims[@]}"; do
|
echo "compiling flash_attn kernel for $arch with seqlen 1024, headdim 64"
|
||||||
echo "compiling flash_attn kernel for $arch with seqlen 1024, headdim 64"
|
|
||||||
|
|
||||||
# touch source file to force re-building, as the Makefile does not track
|
# touch source file to force re-building, as the Makefile does not track
|
||||||
# binary changes
|
# binary changes
|
||||||
touch kernel.cpp
|
touch kernel.cpp
|
||||||
|
|
||||||
make CONFIG=flash.$arch.seqlen1024.headdim64
|
make CONFIG=flash.$arch.seqlen1024.headdim64
|
||||||
done
|
|
||||||
done
|
done
|
||||||
|
|||||||
@@ -1,5 +1,9 @@
|
|||||||
#!/bin/sh
|
#!/bin/sh
|
||||||
|
|
||||||
|
# hopper and virgo has the same SIMT configurations
|
||||||
|
git checkout ae-hopper
|
||||||
|
git pull
|
||||||
|
|
||||||
if [ ! -f input.a.rand01.fp16.m256n256k256.row.bin ]; then
|
if [ ! -f input.a.rand01.fp16.m256n256k256.row.bin ]; then
|
||||||
echo "input binaries not found, generating operands"
|
echo "input binaries not found, generating operands"
|
||||||
python3 generate_operands.py
|
python3 generate_operands.py
|
||||||
|
|||||||
@@ -53,7 +53,7 @@
|
|||||||
#define PRINTF(...) sprintf(PRINT_BUF, __VA_ARGS__)
|
#define PRINTF(...) sprintf(PRINT_BUF, __VA_ARGS__)
|
||||||
// #define PRINTF(...) vx_printf(__VA_ARGS__)
|
// #define PRINTF(...) vx_printf(__VA_ARGS__)
|
||||||
#define SWISH(beta, x) ((x) / (1 + exp(-(beta) * (x))))
|
#define SWISH(beta, x) ((x) / (1 + exp(-(beta) * (x))))
|
||||||
// #define POWER
|
#define POWER
|
||||||
|
|
||||||
typedef uint16_t smem_elem_t;
|
typedef uint16_t smem_elem_t;
|
||||||
// typedef float smem_elem_t;
|
// typedef float smem_elem_t;
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ done
|
|||||||
|
|
||||||
for arch in "${archs[@]}"; do
|
for arch in "${archs[@]}"; do
|
||||||
git checkout ae-$arch
|
git checkout ae-$arch
|
||||||
|
git pull
|
||||||
|
|
||||||
# re-compile libvortexrt.a
|
# re-compile libvortexrt.a
|
||||||
# FIXME after restructure
|
# FIXME after restructure
|
||||||
|
|||||||
Reference in New Issue
Block a user