diff --git a/tests/regression/flash_attention/flash_impl.hpp b/tests/regression/flash_attention/flash_impl.hpp index 48a0068f..fd027553 100644 --- a/tests/regression/flash_attention/flash_impl.hpp +++ b/tests/regression/flash_attention/flash_impl.hpp @@ -9,12 +9,10 @@ #define HEADDIM 64 constexpr uint32_t ROWMAX_SETS = 3; -constexpr bool DEBUG = true; constexpr bool WARP_SPECIALIZED = false; constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000; -constexpr bool GEMMINI_DMA_FAST = false; constexpr bool Q_IS_K_MAJOR = true; // temporary safety stop for wrong configs diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index 4572c921..4a2c3133 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -8,6 +8,8 @@ #include "gemmini_mmio.h" #include "flash_impl.hpp" +constexpr bool DEBUG = true; + static_assert(GEMMINI_DMA && !WARP_SPECIALIZED, "GEMMINI_DMA should be set and WARP_SPECIALIZED unset"); @@ -227,7 +229,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q_tile), (uint64_t)(gmem_K_tile), k_LOOP_WS_CONFIG_ADDRS_AB) // configure address strides for the DMA - GEMMINI_CISC_CMD_R((dim_seqlen << 16) | (HEADDIM << 8) | + GEMMINI_CISC_CMD_R((dim_seqlen << 20) | (HEADDIM << 8) | 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); gemmini_fence(); @@ -313,7 +315,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { GEMMINI_CISC_CMD_I(1); } - // mvout to SMEM gemmini_fence(); gemmini_fence(); gemmini_fence(); @@ -332,6 +333,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } #endif + // mvout to SMEM // GEMMINI_CISC_CMD_I(9); sp_tiled_matmul_full_spad_ws( /*spad_A=*/spad_addr_Q, /*spad_B=*/spad_addr_K, @@ -427,7 +429,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { k_LOOP_WS_CONFIG_ADDRS_AB) // configure address strides for the DMA // FIXME: unnecessary? - GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 16) | (dim_seqlen /*KT*/ << 8) | + GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) | 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); gemmini_fence();