10 Commits

Author SHA1 Message Date
Virgo-AE Eval
37da0e2aa8 Merge branch 'ae' into ae-flash-ampere 2025-02-07 14:53:00 -08:00
Richard Yan
179cbbcc69 Merge branch 'ae' into ae-flash-ampere 2025-01-31 03:53:51 -08:00
Richard Yan
9a9f8549d8 Merge branch 'ae' into ae-flash-ampere 2025-01-30 23:42:26 -08:00
Hansung Kim
66c09d3db2 Merge branch 'ae' into ae-flash-ampere 2025-01-30 13:24:58 -08:00
Hansung Kim
dcb8549722 Merge branch 'ae' into ae-flash-ampere 2025-01-30 01:49:02 -08:00
Hansung Kim
a368eb2dae Fix build target for flash ampere 2025-01-30 01:21:37 -08:00
Hansung Kim
7f4adfaaa2 Increase SMEM size for flash 2025-01-30 01:20:17 -08:00
Hansung Kim
30001b7677 Merge branch 'ae' into ae-flash-ampere 2025-01-30 01:16:21 -08:00
Hansung Kim
5ab2cc4334 Switch to fp32 for flash 2025-01-30 01:12:55 -08:00
Hansung Kim
bc64474114 Update config for flash ampere 2025-01-30 01:06:54 -08:00
7 changed files with 11 additions and 11 deletions

View File

@@ -1,8 +1,8 @@
PROJECT = flash_attention PROJECT = flash_attention
# VX_SRCS = kernel.cpp VX_SRCS = kernel.cpp
# VX_SRCS = kernel.gemmini.warpspec.cpp # VX_SRCS = kernel.gemmini.warpspec.cpp
VX_SRCS = kernel.gemmini.cpp # VX_SRCS = kernel.gemmini.cpp
VX_INCLUDES = flash_impl.hpp ../sgemm_tcore/sgemm_impl.hpp VX_INCLUDES = flash_impl.hpp ../sgemm_tcore/sgemm_impl.hpp
OPTS ?= -n16 OPTS ?= -n16

View File

@@ -17,9 +17,9 @@ constexpr uint32_t ROWMAX_SETS = 3;
// constexpr bool WARP_SPECIALIZED = true; // constexpr bool WARP_SPECIALIZED = true;
// constexpr bool GEMMINI_WARP_SPECIALIZED = false; // constexpr bool GEMMINI_WARP_SPECIALIZED = false;
// constexpr bool TENSOR_CORE = true; // constexpr bool TENSOR_CORE = true;
constexpr bool WARP_SPECIALIZED = false; constexpr bool WARP_SPECIALIZED = true;
constexpr bool GEMMINI_WARP_SPECIALIZED = false; constexpr bool GEMMINI_WARP_SPECIALIZED = false;
constexpr bool TENSOR_CORE = false; constexpr bool TENSOR_CORE = true;
// temporary safety stop for wrong configs // temporary safety stop for wrong configs
static_assert(NUM_CORES == 4); static_assert(NUM_CORES == 4);

View File

@@ -6,7 +6,7 @@
#include "include/gemmini.h" #include "include/gemmini.h"
#include "gemmini_mmio.h" #include "gemmini_mmio.h"
#define FP_SIZE 16 #define FP_SIZE 32
// "fake" fp16 type that only has the correct data width. // "fake" fp16 type that only has the correct data width.
using float16_t = uint16_t; using float16_t = uint16_t;
@@ -110,7 +110,7 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER ==
// result matrix will be stored in a swizzled form in the global memory. // result matrix will be stored in a swizzled form in the global memory.
#define WMMA_STORE_FAST 0 #define WMMA_STORE_FAST 0
#define GEMMINI_DMA 0 #define GEMMINI_DMA 1
#define GEMMINI_DMA_FAST 1 #define GEMMINI_DMA_FAST 1
#if SMEM_SIZE == 0x4000 #if SMEM_SIZE == 0x4000
#define SMEM_ADDR_Q0 ((float * const) 0xff000000) #define SMEM_ADDR_Q0 ((float * const) 0xff000000)

View File

@@ -1 +1 @@
gemmini_params.dim16fp16.h gemmini_params.dim8fp32.h

View File

@@ -84,7 +84,7 @@
#endif #endif
#ifndef NUM_CORES #ifndef NUM_CORES
#define NUM_CORES 8 #define NUM_CORES 4
#endif #endif
#ifndef NUM_WARPS #ifndef NUM_WARPS

View File

@@ -12,9 +12,9 @@
// 64KB // 64KB
// #define SMEM_SIZE 0x10000 // #define SMEM_SIZE 0x10000
// 128KB (FP16 GEMM) // 128KB (FP16 GEMM)
#define SMEM_SIZE 0x20000 // #define SMEM_SIZE 0x20000
// 256KB (FlashAttention) // 256KB (FlashAttention)
// #define SMEM_SIZE 0x40000 #define SMEM_SIZE 0x40000
#define SMEM_MASK (SMEM_SIZE - 1) #define SMEM_MASK (SMEM_SIZE - 1)
#define SMEM_ADDR_END (SMEM_BASE + SMEM_SIZE) #define SMEM_ADDR_END (SMEM_BASE + SMEM_SIZE)

View File

@@ -18,7 +18,7 @@
#include <stdio.h> #include <stdio.h>
#ifndef CORES_PER_CLUSTER #ifndef CORES_PER_CLUSTER
#define CORES_PER_CLUSTER 8 #define CORES_PER_CLUSTER 4
#endif #endif
#ifdef __cplusplus #ifdef __cplusplus