10 Commits

Author SHA1 Message Date
Virgo-AE Eval
37da0e2aa8 Merge branch 'ae' into ae-flash-ampere 2025-02-07 14:53:00 -08:00
Richard Yan
179cbbcc69 Merge branch 'ae' into ae-flash-ampere 2025-01-31 03:53:51 -08:00
Richard Yan
9a9f8549d8 Merge branch 'ae' into ae-flash-ampere 2025-01-30 23:42:26 -08:00
Hansung Kim
66c09d3db2 Merge branch 'ae' into ae-flash-ampere 2025-01-30 13:24:58 -08:00
Hansung Kim
dcb8549722 Merge branch 'ae' into ae-flash-ampere 2025-01-30 01:49:02 -08:00
Hansung Kim
a368eb2dae Fix build target for flash ampere 2025-01-30 01:21:37 -08:00
Hansung Kim
7f4adfaaa2 Increase SMEM size for flash 2025-01-30 01:20:17 -08:00
Hansung Kim
30001b7677 Merge branch 'ae' into ae-flash-ampere 2025-01-30 01:16:21 -08:00
Hansung Kim
5ab2cc4334 Switch to fp32 for flash 2025-01-30 01:12:55 -08:00
Hansung Kim
bc64474114 Update config for flash ampere 2025-01-30 01:06:54 -08:00
3 changed files with 4 additions and 8 deletions

View File

@@ -1,8 +1,8 @@
PROJECT = flash_attention
# VX_SRCS = kernel.cpp
VX_SRCS = kernel.cpp
# VX_SRCS = kernel.gemmini.warpspec.cpp
VX_SRCS = kernel.gemmini.cpp
# VX_SRCS = kernel.gemmini.cpp
VX_INCLUDES = flash_impl.hpp ../sgemm_tcore/sgemm_impl.hpp
OPTS ?= -n16

View File

@@ -17,9 +17,9 @@ constexpr uint32_t ROWMAX_SETS = 3;
// constexpr bool WARP_SPECIALIZED = true;
// constexpr bool GEMMINI_WARP_SPECIALIZED = false;
// constexpr bool TENSOR_CORE = true;
constexpr bool WARP_SPECIALIZED = false;
constexpr bool WARP_SPECIALIZED = true;
constexpr bool GEMMINI_WARP_SPECIALIZED = false;
constexpr bool TENSOR_CORE = false;
constexpr bool TENSOR_CORE = true;
// temporary safety stop for wrong configs
static_assert(NUM_CORES == 4);

View File

@@ -95,8 +95,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
constexpr uint32_t quartile = (128 << 10) >> 2; // 128KB / 4
static_assert((quartile * 4) == SMEM_SIZE, "wrong quartile constant");
MARK_BEG();
constexpr uint32_t smem_a_offset = 0;
constexpr uint32_t smem_a_dbuf_offset = 1 * quartile;
constexpr uint32_t smem_b_offset =
@@ -121,8 +119,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
threadblocks_per_cluster, threadblock_id_in_cluster,
sharedmem_per_threadblock);
MARK_END();
float *gmem_tmp_d0 = reinterpret_cast<float *>(0xd0000000UL);
float *gmem_tmp_d1 = reinterpret_cast<float *>(0xd1000000UL);
float *gmem_tmp_d2 = reinterpret_cast<float *>(0xd2000000UL);