Compare commits
13 Commits
ae-flash-a
...
ae-volta
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
326141b11f | ||
|
|
d893780594 | ||
|
|
9b7c22a7e9 | ||
|
|
b1ebabef26 | ||
|
|
9f524538a4 | ||
|
|
51ebe18ebb | ||
|
|
7b0a95034b | ||
|
|
c240069147 | ||
|
|
d86c33acf3 | ||
|
|
b49e8a293c | ||
|
|
19731b8e2f | ||
|
|
afc69507a3 | ||
|
|
6e279c905f |
@@ -1,8 +1,8 @@
|
|||||||
PROJECT = flash_attention
|
PROJECT = flash_attention
|
||||||
|
|
||||||
VX_SRCS = kernel.cpp
|
# VX_SRCS = kernel.cpp
|
||||||
# VX_SRCS = kernel.gemmini.warpspec.cpp
|
# VX_SRCS = kernel.gemmini.warpspec.cpp
|
||||||
# VX_SRCS = kernel.gemmini.cpp
|
VX_SRCS = kernel.gemmini.cpp
|
||||||
VX_INCLUDES = flash_impl.hpp ../sgemm_tcore/sgemm_impl.hpp
|
VX_INCLUDES = flash_impl.hpp ../sgemm_tcore/sgemm_impl.hpp
|
||||||
|
|
||||||
OPTS ?= -n16
|
OPTS ?= -n16
|
||||||
|
|||||||
@@ -17,9 +17,9 @@ constexpr uint32_t ROWMAX_SETS = 3;
|
|||||||
// constexpr bool WARP_SPECIALIZED = true;
|
// constexpr bool WARP_SPECIALIZED = true;
|
||||||
// constexpr bool GEMMINI_WARP_SPECIALIZED = false;
|
// constexpr bool GEMMINI_WARP_SPECIALIZED = false;
|
||||||
// constexpr bool TENSOR_CORE = true;
|
// constexpr bool TENSOR_CORE = true;
|
||||||
constexpr bool WARP_SPECIALIZED = true;
|
constexpr bool WARP_SPECIALIZED = false;
|
||||||
constexpr bool GEMMINI_WARP_SPECIALIZED = false;
|
constexpr bool GEMMINI_WARP_SPECIALIZED = false;
|
||||||
constexpr bool TENSOR_CORE = true;
|
constexpr bool TENSOR_CORE = false;
|
||||||
|
|
||||||
// temporary safety stop for wrong configs
|
// temporary safety stop for wrong configs
|
||||||
static_assert(NUM_CORES == 4);
|
static_assert(NUM_CORES == 4);
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
#include "include/gemmini.h"
|
#include "include/gemmini.h"
|
||||||
#include "gemmini_mmio.h"
|
#include "gemmini_mmio.h"
|
||||||
|
|
||||||
#define FP_SIZE 32
|
#define FP_SIZE 16
|
||||||
|
|
||||||
// "fake" fp16 type that only has the correct data width.
|
// "fake" fp16 type that only has the correct data width.
|
||||||
using float16_t = uint16_t;
|
using float16_t = uint16_t;
|
||||||
@@ -110,7 +110,7 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER ==
|
|||||||
// result matrix will be stored in a swizzled form in the global memory.
|
// result matrix will be stored in a swizzled form in the global memory.
|
||||||
#define WMMA_STORE_FAST 0
|
#define WMMA_STORE_FAST 0
|
||||||
|
|
||||||
#define GEMMINI_DMA 1
|
#define GEMMINI_DMA 0
|
||||||
#define GEMMINI_DMA_FAST 1
|
#define GEMMINI_DMA_FAST 1
|
||||||
#if SMEM_SIZE == 0x4000
|
#if SMEM_SIZE == 0x4000
|
||||||
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
|
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
gemmini_params.dim8fp32.h
|
gemmini_params.dim16fp16.h
|
||||||
@@ -84,7 +84,7 @@
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifndef NUM_CORES
|
#ifndef NUM_CORES
|
||||||
#define NUM_CORES 4
|
#define NUM_CORES 8
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifndef NUM_WARPS
|
#ifndef NUM_WARPS
|
||||||
|
|||||||
@@ -12,9 +12,9 @@
|
|||||||
// 64KB
|
// 64KB
|
||||||
// #define SMEM_SIZE 0x10000
|
// #define SMEM_SIZE 0x10000
|
||||||
// 128KB (FP16 GEMM)
|
// 128KB (FP16 GEMM)
|
||||||
// #define SMEM_SIZE 0x20000
|
#define SMEM_SIZE 0x20000
|
||||||
// 256KB (FlashAttention)
|
// 256KB (FlashAttention)
|
||||||
#define SMEM_SIZE 0x40000
|
// #define SMEM_SIZE 0x40000
|
||||||
|
|
||||||
#define SMEM_MASK (SMEM_SIZE - 1)
|
#define SMEM_MASK (SMEM_SIZE - 1)
|
||||||
#define SMEM_ADDR_END (SMEM_BASE + SMEM_SIZE)
|
#define SMEM_ADDR_END (SMEM_BASE + SMEM_SIZE)
|
||||||
|
|||||||
@@ -18,7 +18,7 @@
|
|||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#ifndef CORES_PER_CLUSTER
|
#ifndef CORES_PER_CLUSTER
|
||||||
#define CORES_PER_CLUSTER 4
|
#define CORES_PER_CLUSTER 8
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
|||||||
Reference in New Issue
Block a user