Fix build target for flash ampere
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
PROJECT = flash_attention
|
||||
|
||||
# VX_SRCS = kernel.cpp
|
||||
VX_SRCS = kernel.cpp
|
||||
# VX_SRCS = kernel.gemmini.warpspec.cpp
|
||||
VX_SRCS = kernel.gemmini.cpp
|
||||
# VX_SRCS = kernel.gemmini.cpp
|
||||
VX_INCLUDES = flash_impl.hpp ../sgemm_tcore/sgemm_impl.hpp
|
||||
|
||||
OPTS ?= -n16
|
||||
|
||||
@@ -17,9 +17,9 @@ constexpr uint32_t ROWMAX_SETS = 3;
|
||||
// constexpr bool WARP_SPECIALIZED = true;
|
||||
// constexpr bool GEMMINI_WARP_SPECIALIZED = false;
|
||||
// constexpr bool TENSOR_CORE = true;
|
||||
constexpr bool WARP_SPECIALIZED = false;
|
||||
constexpr bool WARP_SPECIALIZED = true;
|
||||
constexpr bool GEMMINI_WARP_SPECIALIZED = false;
|
||||
constexpr bool TENSOR_CORE = false;
|
||||
constexpr bool TENSOR_CORE = true;
|
||||
|
||||
// temporary safety stop for wrong configs
|
||||
static_assert(NUM_CORES == 4);
|
||||
|
||||
Reference in New Issue
Block a user