Compare commits
10 Commits
ae-ampere
...
ae-flash-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
37da0e2aa8 | ||
|
|
179cbbcc69 | ||
|
|
9a9f8549d8 | ||
|
|
66c09d3db2 | ||
|
|
dcb8549722 | ||
|
|
a368eb2dae | ||
|
|
7f4adfaaa2 | ||
|
|
30001b7677 | ||
|
|
5ab2cc4334 | ||
|
|
bc64474114 |
@@ -1,8 +1,8 @@
|
||||
PROJECT = flash_attention
|
||||
|
||||
# VX_SRCS = kernel.cpp
|
||||
VX_SRCS = kernel.cpp
|
||||
# VX_SRCS = kernel.gemmini.warpspec.cpp
|
||||
VX_SRCS = kernel.gemmini.cpp
|
||||
# VX_SRCS = kernel.gemmini.cpp
|
||||
VX_INCLUDES = flash_impl.hpp ../sgemm_tcore/sgemm_impl.hpp
|
||||
|
||||
OPTS ?= -n16
|
||||
|
||||
@@ -17,9 +17,9 @@ constexpr uint32_t ROWMAX_SETS = 3;
|
||||
// constexpr bool WARP_SPECIALIZED = true;
|
||||
// constexpr bool GEMMINI_WARP_SPECIALIZED = false;
|
||||
// constexpr bool TENSOR_CORE = true;
|
||||
constexpr bool WARP_SPECIALIZED = false;
|
||||
constexpr bool WARP_SPECIALIZED = true;
|
||||
constexpr bool GEMMINI_WARP_SPECIALIZED = false;
|
||||
constexpr bool TENSOR_CORE = false;
|
||||
constexpr bool TENSOR_CORE = true;
|
||||
|
||||
// temporary safety stop for wrong configs
|
||||
static_assert(NUM_CORES == 4);
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#include "include/gemmini.h"
|
||||
#include "gemmini_mmio.h"
|
||||
|
||||
#define FP_SIZE 16
|
||||
#define FP_SIZE 32
|
||||
|
||||
// "fake" fp16 type that only has the correct data width.
|
||||
using float16_t = uint16_t;
|
||||
|
||||
@@ -1 +1 @@
|
||||
gemmini_params.dim16fp16.h
|
||||
gemmini_params.dim8fp32.h
|
||||
@@ -84,7 +84,7 @@
|
||||
#endif
|
||||
|
||||
#ifndef NUM_CORES
|
||||
#define NUM_CORES 8
|
||||
#define NUM_CORES 4
|
||||
#endif
|
||||
|
||||
#ifndef NUM_WARPS
|
||||
|
||||
@@ -12,9 +12,9 @@
|
||||
// 64KB
|
||||
// #define SMEM_SIZE 0x10000
|
||||
// 128KB (FP16 GEMM)
|
||||
#define SMEM_SIZE 0x20000
|
||||
// #define SMEM_SIZE 0x20000
|
||||
// 256KB (FlashAttention)
|
||||
// #define SMEM_SIZE 0x40000
|
||||
#define SMEM_SIZE 0x40000
|
||||
|
||||
#define SMEM_MASK (SMEM_SIZE - 1)
|
||||
#define SMEM_ADDR_END (SMEM_BASE + SMEM_SIZE)
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
#include <stdio.h>
|
||||
|
||||
#ifndef CORES_PER_CLUSTER
|
||||
#define CORES_PER_CLUSTER 8
|
||||
#define CORES_PER_CLUSTER 4
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
Reference in New Issue
Block a user