diff --git a/kernel/include/gemmini_mmio.h b/kernel/include/gemmini_mmio.h index ed55236c..0dec66b0 100644 --- a/kernel/include/gemmini_mmio.h +++ b/kernel/include/gemmini_mmio.h @@ -9,9 +9,9 @@ // #define SMEM_SIZE 0x4000 // 64KB // #define SMEM_SIZE 0x10000 -// 128KB +// 128KB (FP16 GEMM) // #define SMEM_SIZE 0x20000 -// 256KB +// 256KB (FlashAttention) #define SMEM_SIZE 0x40000 #define SMEM_MASK (SMEM_SIZE - 1) diff --git a/tests/regression/flash_attention/flash_impl.hpp b/tests/regression/flash_attention/flash_impl.hpp index 47e21c70..2e2bd693 100644 --- a/tests/regression/flash_attention/flash_impl.hpp +++ b/tests/regression/flash_attention/flash_impl.hpp @@ -11,8 +11,10 @@ #define ROW_REMAINDER_LOGIC constexpr uint32_t ROWMAX_SETS = 3; -constexpr bool WARP_SPECIALIZED = true; -constexpr bool TENSOR_CORE = true; +// constexpr bool WARP_SPECIALIZED = true; +// constexpr bool TENSOR_CORE = true; +constexpr bool WARP_SPECIALIZED = false; +constexpr bool TENSOR_CORE = false; // temporary safety stop for wrong configs static_assert(NUM_CORES == 4); diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index aaf66492..989c5df9 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -6,7 +6,7 @@ #include "include/gemmini.h" #include "gemmini_mmio.h" -#define FP_SIZE 16 +#define FP_SIZE 32 // "fake" fp16 type that only has the correct data width. using float16_t = uint16_t; @@ -19,7 +19,7 @@ using float_type = float16_t; // Generate kernel for the Hopper-style SMEM-decoupled tensor core. This uses // asynchronous HGMMA and HGMMA_WAIT instructions. -#define TENSOR_HOPPER 1 +#define TENSOR_HOPPER 0 // Constraints on parameters: // * Memory: @@ -110,7 +110,7 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER == // result matrix will be stored in a swizzled form in the global memory. #define WMMA_STORE_FAST 1 -#define GEMMINI_DMA 0 +#define GEMMINI_DMA 1 #define GEMMINI_DMA_FAST 1 #if SMEM_SIZE == 0x4000 #define SMEM_ADDR_Q0 ((float * const) 0xff000000)