13 Commits

Author SHA1 Message Date
Virgo-AE Eval
326141b11f Merge branch 'ae' into ae-volta 2025-02-07 14:52:25 -08:00
Richard Yan
d893780594 Merge branch 'ae' into ae-volta 2025-01-31 03:53:26 -08:00
Richard Yan
9b7c22a7e9 Merge branch 'ae' into ae-volta 2025-01-30 23:41:16 -08:00
Richard Yan
b1ebabef26 Merge branch 'ae' into ae-volta 2025-01-30 15:35:22 -08:00
Hansung Kim
9f524538a4 Merge branch 'ae' into ae-volta 2025-01-30 13:24:46 -08:00
Hansung Kim
51ebe18ebb Merge remote-tracking branch 'origin/ae-volta' into ae-volta 2025-01-30 01:48:25 -08:00
Hansung Kim
7b0a95034b Merge branch 'ae' into ae-volta 2025-01-30 01:48:05 -08:00
Richard Yan
c240069147 Merge branch 'ae' into ae-volta 2025-01-30 01:35:35 -08:00
Hansung Kim
d86c33acf3 Merge branch 'ae' into ae-volta 2025-01-30 01:05:27 -08:00
Hansung Kim
b49e8a293c Merge branch 'ae' into ae-volta 2025-01-30 00:49:19 -08:00
Hansung Kim
19731b8e2f Merge branch 'ae' into ae-volta 2025-01-30 00:35:00 -08:00
Richard Yan
afc69507a3 Merge branch 'ae' into ae-volta 2025-01-29 23:31:34 -08:00
Richard Yan
6e279c905f volta change 2025-01-29 22:16:39 -08:00
6 changed files with 7 additions and 11 deletions

View File

@@ -95,8 +95,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
constexpr uint32_t quartile = (128 << 10) >> 2; // 128KB / 4
static_assert((quartile * 4) == SMEM_SIZE, "wrong quartile constant");
MARK_BEG();
constexpr uint32_t smem_a_offset = 0;
constexpr uint32_t smem_a_dbuf_offset = 1 * quartile;
constexpr uint32_t smem_b_offset =
@@ -121,8 +119,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
threadblocks_per_cluster, threadblock_id_in_cluster,
sharedmem_per_threadblock);
MARK_END();
float *gmem_tmp_d0 = reinterpret_cast<float *>(0xd0000000UL);
float *gmem_tmp_d1 = reinterpret_cast<float *>(0xd1000000UL);
float *gmem_tmp_d2 = reinterpret_cast<float *>(0xd2000000UL);

View File

@@ -6,7 +6,7 @@
#include "include/gemmini.h"
#include "gemmini_mmio.h"
#define FP_SIZE 32
#define FP_SIZE 16
// "fake" fp16 type that only has the correct data width.
using float16_t = uint16_t;
@@ -110,7 +110,7 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER ==
// result matrix will be stored in a swizzled form in the global memory.
#define WMMA_STORE_FAST 0
#define GEMMINI_DMA 1
#define GEMMINI_DMA 0
#define GEMMINI_DMA_FAST 1
#if SMEM_SIZE == 0x4000
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)

View File

@@ -1 +1 @@
gemmini_params.dim8fp32.h
gemmini_params.dim16fp16.h

View File

@@ -84,7 +84,7 @@
#endif
#ifndef NUM_CORES
#define NUM_CORES 4
#define NUM_CORES 8
#endif
#ifndef NUM_WARPS

View File

@@ -12,9 +12,9 @@
// 64KB
// #define SMEM_SIZE 0x10000
// 128KB (FP16 GEMM)
// #define SMEM_SIZE 0x20000
#define SMEM_SIZE 0x20000
// 256KB (FlashAttention)
#define SMEM_SIZE 0x40000
// #define SMEM_SIZE 0x40000
#define SMEM_MASK (SMEM_SIZE - 1)
#define SMEM_ADDR_END (SMEM_BASE + SMEM_SIZE)

View File

@@ -18,7 +18,7 @@
#include <stdio.h>
#ifndef CORES_PER_CLUSTER
#define CORES_PER_CLUSTER 4
#define CORES_PER_CLUSTER 8
#endif
#ifdef __cplusplus