Compare commits
15 Commits
ae-flash-a
...
ae-flash-v
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c24585570d | ||
|
|
8071faf7c2 | ||
|
|
a4bd41392c | ||
|
|
692f3dddff | ||
|
|
c75ed0d531 | ||
|
|
96500e0abc | ||
|
|
4f12227327 | ||
|
|
efd2d232fe | ||
|
|
b97df2ce6a | ||
|
|
e4f8f3481c | ||
|
|
c7f713c71e | ||
|
|
b06e345706 | ||
|
|
8a635b5fcb | ||
|
|
f23b2a3fcc | ||
|
|
ac34a8f5f5 |
@@ -1,8 +1,8 @@
|
||||
PROJECT = flash_attention
|
||||
|
||||
VX_SRCS = kernel.cpp
|
||||
# VX_SRCS = kernel.cpp
|
||||
# VX_SRCS = kernel.gemmini.warpspec.cpp
|
||||
# VX_SRCS = kernel.gemmini.cpp
|
||||
VX_SRCS = kernel.gemmini.cpp
|
||||
VX_INCLUDES = flash_impl.hpp ../sgemm_tcore/sgemm_impl.hpp
|
||||
|
||||
OPTS ?= -n16
|
||||
|
||||
@@ -17,9 +17,9 @@ constexpr uint32_t ROWMAX_SETS = 3;
|
||||
// constexpr bool WARP_SPECIALIZED = true;
|
||||
// constexpr bool GEMMINI_WARP_SPECIALIZED = false;
|
||||
// constexpr bool TENSOR_CORE = true;
|
||||
constexpr bool WARP_SPECIALIZED = true;
|
||||
constexpr bool WARP_SPECIALIZED = false;
|
||||
constexpr bool GEMMINI_WARP_SPECIALIZED = false;
|
||||
constexpr bool TENSOR_CORE = true;
|
||||
constexpr bool TENSOR_CORE = false;
|
||||
|
||||
// temporary safety stop for wrong configs
|
||||
static_assert(NUM_CORES == 4);
|
||||
|
||||
@@ -95,6 +95,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
constexpr uint32_t quartile = (128 << 10) >> 2; // 128KB / 4
|
||||
static_assert((quartile * 4) == SMEM_SIZE, "wrong quartile constant");
|
||||
|
||||
MARK_BEG();
|
||||
|
||||
constexpr uint32_t smem_a_offset = 0;
|
||||
constexpr uint32_t smem_a_dbuf_offset = 1 * quartile;
|
||||
constexpr uint32_t smem_b_offset =
|
||||
@@ -119,6 +121,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
threadblocks_per_cluster, threadblock_id_in_cluster,
|
||||
sharedmem_per_threadblock);
|
||||
|
||||
MARK_END();
|
||||
|
||||
float *gmem_tmp_d0 = reinterpret_cast<float *>(0xd0000000UL);
|
||||
float *gmem_tmp_d1 = reinterpret_cast<float *>(0xd1000000UL);
|
||||
float *gmem_tmp_d2 = reinterpret_cast<float *>(0xd2000000UL);
|
||||
|
||||
Reference in New Issue
Block a user