Merge branch 'ae' into ae-volta

This commit is contained in:
Hansung Kim
2025-01-30 01:48:05 -08:00
4 changed files with 13 additions and 9 deletions

View File

@@ -26,7 +26,8 @@ ln -sf input.b.rand.fp32.seqlen1024headdim64.row.bin input.b.bin
ln -sf input.c.rand.fp32.seqlen1024headdim64.row.bin input.c.bin ln -sf input.c.rand.fp32.seqlen1024headdim64.row.bin input.c.bin
for arch in "${archs[@]}"; do for arch in "${archs[@]}"; do
git checkout ae-$arch git checkout ae-flash-$arch
git pull
# re-compile libvortexrt.a # re-compile libvortexrt.a
# FIXME after restructure # FIXME after restructure
@@ -34,13 +35,11 @@ for arch in "${archs[@]}"; do
make make
popd popd
for dim in "${dims[@]}"; do echo "compiling flash_attn kernel for $arch with seqlen 1024, headdim 64"
echo "compiling flash_attn kernel for $arch with seqlen 1024, headdim 64"
# touch source file to force re-building, as the Makefile does not track # touch source file to force re-building, as the Makefile does not track
# binary changes # binary changes
touch kernel.cpp touch kernel.cpp
make CONFIG=flash.$arch.seqlen1024.headdim64 make CONFIG=flash.$arch.seqlen1024.headdim64
done
done done

View File

@@ -1,5 +1,9 @@
#!/bin/sh #!/bin/sh
# hopper and virgo has the same SIMT configurations
git checkout ae-hopper
git pull
if [ ! -f input.a.rand01.fp16.m256n256k256.row.bin ]; then if [ ! -f input.a.rand01.fp16.m256n256k256.row.bin ]; then
echo "input binaries not found, generating operands" echo "input binaries not found, generating operands"
python3 generate_operands.py python3 generate_operands.py

View File

@@ -53,7 +53,7 @@
#define PRINTF(...) sprintf(PRINT_BUF, __VA_ARGS__) #define PRINTF(...) sprintf(PRINT_BUF, __VA_ARGS__)
// #define PRINTF(...) vx_printf(__VA_ARGS__) // #define PRINTF(...) vx_printf(__VA_ARGS__)
#define SWISH(beta, x) ((x) / (1 + exp(-(beta) * (x)))) #define SWISH(beta, x) ((x) / (1 + exp(-(beta) * (x))))
// #define POWER #define POWER
typedef uint16_t smem_elem_t; typedef uint16_t smem_elem_t;
// typedef float smem_elem_t; // typedef float smem_elem_t;

View File

@@ -53,6 +53,7 @@ done
for arch in "${archs[@]}"; do for arch in "${archs[@]}"; do
git checkout ae-$arch git checkout ae-$arch
git pull
# re-compile libvortexrt.a # re-compile libvortexrt.a
# FIXME after restructure # FIXME after restructure