From 6bdc6af60780f6daabe8e0345f0d59f162f943b5 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Thu, 30 Jan 2025 01:15:57 -0800 Subject: [PATCH 1/2] Fix branch name and dims for flash script --- kernels/flash_attention/compile_flash.sh | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/kernels/flash_attention/compile_flash.sh b/kernels/flash_attention/compile_flash.sh index 95dcb233..5808508c 100755 --- a/kernels/flash_attention/compile_flash.sh +++ b/kernels/flash_attention/compile_flash.sh @@ -26,7 +26,7 @@ ln -sf input.b.rand.fp32.seqlen1024headdim64.row.bin input.b.bin ln -sf input.c.rand.fp32.seqlen1024headdim64.row.bin input.c.bin for arch in "${archs[@]}"; do - git checkout ae-$arch + git checkout ae-flash-$arch # re-compile libvortexrt.a # FIXME after restructure @@ -34,13 +34,11 @@ for arch in "${archs[@]}"; do make popd - for dim in "${dims[@]}"; do - echo "compiling flash_attn kernel for $arch with seqlen 1024, headdim 64" + echo "compiling flash_attn kernel for $arch with seqlen 1024, headdim 64" - # touch source file to force re-building, as the Makefile does not track - # binary changes - touch kernel.cpp + # touch source file to force re-building, as the Makefile does not track + # binary changes + touch kernel.cpp - make CONFIG=flash.$arch.seqlen1024.headdim64 - done + make CONFIG=flash.$arch.seqlen1024.headdim64 done From dde3602046cbeb9b0ad081ab206bc9af438e767c Mon Sep 17 00:00:00 2001 From: Richard Yan Date: Thu, 30 Jan 2025 01:34:22 -0800 Subject: [PATCH 2/2] disable prints for virgo gemm --- kernels/sgemm_gemmini_dma/kernel.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kernels/sgemm_gemmini_dma/kernel.cpp b/kernels/sgemm_gemmini_dma/kernel.cpp index 8ae36bd8..5b6c3f71 100644 --- a/kernels/sgemm_gemmini_dma/kernel.cpp +++ b/kernels/sgemm_gemmini_dma/kernel.cpp @@ -53,7 +53,7 @@ #define PRINTF(...) sprintf(PRINT_BUF, __VA_ARGS__) // #define PRINTF(...) vx_printf(__VA_ARGS__) #define SWISH(beta, x) ((x) / (1 + exp(-(beta) * (x)))) -// #define POWER +#define POWER typedef uint16_t smem_elem_t; // typedef float smem_elem_t;