Compare commits
15 Commits
ae-volta
...
ae-flash-v
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c24585570d | ||
|
|
8071faf7c2 | ||
|
|
a4bd41392c | ||
|
|
692f3dddff | ||
|
|
c75ed0d531 | ||
|
|
96500e0abc | ||
|
|
4f12227327 | ||
|
|
efd2d232fe | ||
|
|
b97df2ce6a | ||
|
|
e4f8f3481c | ||
|
|
c7f713c71e | ||
|
|
b06e345706 | ||
|
|
8a635b5fcb | ||
|
|
f23b2a3fcc | ||
|
|
ac34a8f5f5 |
@@ -95,6 +95,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
constexpr uint32_t quartile = (128 << 10) >> 2; // 128KB / 4
|
constexpr uint32_t quartile = (128 << 10) >> 2; // 128KB / 4
|
||||||
static_assert((quartile * 4) == SMEM_SIZE, "wrong quartile constant");
|
static_assert((quartile * 4) == SMEM_SIZE, "wrong quartile constant");
|
||||||
|
|
||||||
|
MARK_BEG();
|
||||||
|
|
||||||
constexpr uint32_t smem_a_offset = 0;
|
constexpr uint32_t smem_a_offset = 0;
|
||||||
constexpr uint32_t smem_a_dbuf_offset = 1 * quartile;
|
constexpr uint32_t smem_a_dbuf_offset = 1 * quartile;
|
||||||
constexpr uint32_t smem_b_offset =
|
constexpr uint32_t smem_b_offset =
|
||||||
@@ -119,6 +121,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
threadblocks_per_cluster, threadblock_id_in_cluster,
|
threadblocks_per_cluster, threadblock_id_in_cluster,
|
||||||
sharedmem_per_threadblock);
|
sharedmem_per_threadblock);
|
||||||
|
|
||||||
|
MARK_END();
|
||||||
|
|
||||||
float *gmem_tmp_d0 = reinterpret_cast<float *>(0xd0000000UL);
|
float *gmem_tmp_d0 = reinterpret_cast<float *>(0xd0000000UL);
|
||||||
float *gmem_tmp_d1 = reinterpret_cast<float *>(0xd1000000UL);
|
float *gmem_tmp_d1 = reinterpret_cast<float *>(0xd1000000UL);
|
||||||
float *gmem_tmp_d2 = reinterpret_cast<float *>(0xd2000000UL);
|
float *gmem_tmp_d2 = reinterpret_cast<float *>(0xd2000000UL);
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
#include "include/gemmini.h"
|
#include "include/gemmini.h"
|
||||||
#include "gemmini_mmio.h"
|
#include "gemmini_mmio.h"
|
||||||
|
|
||||||
#define FP_SIZE 16
|
#define FP_SIZE 32
|
||||||
|
|
||||||
// "fake" fp16 type that only has the correct data width.
|
// "fake" fp16 type that only has the correct data width.
|
||||||
using float16_t = uint16_t;
|
using float16_t = uint16_t;
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
gemmini_params.dim16fp16.h
|
gemmini_params.dim8fp32.h
|
||||||
@@ -84,7 +84,7 @@
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifndef NUM_CORES
|
#ifndef NUM_CORES
|
||||||
#define NUM_CORES 8
|
#define NUM_CORES 4
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifndef NUM_WARPS
|
#ifndef NUM_WARPS
|
||||||
|
|||||||
@@ -12,9 +12,9 @@
|
|||||||
// 64KB
|
// 64KB
|
||||||
// #define SMEM_SIZE 0x10000
|
// #define SMEM_SIZE 0x10000
|
||||||
// 128KB (FP16 GEMM)
|
// 128KB (FP16 GEMM)
|
||||||
#define SMEM_SIZE 0x20000
|
// #define SMEM_SIZE 0x20000
|
||||||
// 256KB (FlashAttention)
|
// 256KB (FlashAttention)
|
||||||
// #define SMEM_SIZE 0x40000
|
#define SMEM_SIZE 0x40000
|
||||||
|
|
||||||
#define SMEM_MASK (SMEM_SIZE - 1)
|
#define SMEM_MASK (SMEM_SIZE - 1)
|
||||||
#define SMEM_ADDR_END (SMEM_BASE + SMEM_SIZE)
|
#define SMEM_ADDR_END (SMEM_BASE + SMEM_SIZE)
|
||||||
|
|||||||
@@ -18,7 +18,7 @@
|
|||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#ifndef CORES_PER_CLUSTER
|
#ifndef CORES_PER_CLUSTER
|
||||||
#define CORES_PER_CLUSTER 8
|
#define CORES_PER_CLUSTER 4
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
|||||||
Reference in New Issue
Block a user