diff --git a/kernels/flash_attention/Makefile b/kernels/flash_attention/Makefile index b7818e74..4c8e2f28 100644 --- a/kernels/flash_attention/Makefile +++ b/kernels/flash_attention/Makefile @@ -1,8 +1,8 @@ PROJECT = flash_attention -# VX_SRCS = kernel.cpp +VX_SRCS = kernel.cpp # VX_SRCS = kernel.gemmini.warpspec.cpp -VX_SRCS = kernel.gemmini.cpp +# VX_SRCS = kernel.gemmini.cpp VX_INCLUDES = flash_impl.hpp ../sgemm_tcore/sgemm_impl.hpp OPTS ?= -n16 diff --git a/kernels/flash_attention/flash_impl.hpp b/kernels/flash_attention/flash_impl.hpp index 6f210b75..05578b39 100644 --- a/kernels/flash_attention/flash_impl.hpp +++ b/kernels/flash_attention/flash_impl.hpp @@ -17,9 +17,9 @@ constexpr uint32_t ROWMAX_SETS = 3; // constexpr bool WARP_SPECIALIZED = true; // constexpr bool GEMMINI_WARP_SPECIALIZED = false; // constexpr bool TENSOR_CORE = true; -constexpr bool WARP_SPECIALIZED = false; +constexpr bool WARP_SPECIALIZED = true; constexpr bool GEMMINI_WARP_SPECIALIZED = false; -constexpr bool TENSOR_CORE = false; +constexpr bool TENSOR_CORE = true; // temporary safety stop for wrong configs static_assert(NUM_CORES == 4);