sgemm_impl: Parameterize BK/TCK by FP_SIZE
This commit is contained in:
@@ -6,6 +6,18 @@
|
||||
#include "include/gemmini.h"
|
||||
#include "gemmini_mmio.h"
|
||||
|
||||
#define FP_SIZE 32
|
||||
|
||||
// "fake" fp16 type that only has the correct data width.
|
||||
using float16_t = uint16_t;
|
||||
|
||||
#if (FP_SIZE == 32)
|
||||
using float_type = float;
|
||||
#elif (FP_SIZE == 16)
|
||||
|
||||
using float_type = float16_t;
|
||||
#endif
|
||||
|
||||
// Constraints on parameters:
|
||||
// * Memory:
|
||||
// (BM + BN) * BK * sizeof(T) <= sharedmem size.
|
||||
@@ -20,12 +32,24 @@
|
||||
// BM <= BK*TM*TN
|
||||
#define BM 64
|
||||
#define BN 64
|
||||
#if (FP_SIZE == 32)
|
||||
#define BK 64
|
||||
#elif (FP_SIZE == 16)
|
||||
#define BK 128
|
||||
#else
|
||||
#error "unsupported FP_SIZE"
|
||||
#endif
|
||||
#define WM 16
|
||||
#define WN 8
|
||||
#define TCM 8
|
||||
#define TCN 8
|
||||
#if (FP_SIZE == 32)
|
||||
#define TCK 8
|
||||
#elif (FP_SIZE == 16)
|
||||
#define TCK 16
|
||||
#else
|
||||
#error "unsupported FP_SIZE"
|
||||
#endif
|
||||
#define WMITER (WM / TCM)
|
||||
#define WNITER (WN / TCN)
|
||||
#define ELEM_PER_THREAD (WM * WN / NUM_THREADS)
|
||||
@@ -82,9 +106,6 @@
|
||||
#error Unsupported smem size
|
||||
#endif
|
||||
|
||||
// "fake" fp16 type that only has the correct data width.
|
||||
using float16_t = uint16_t;
|
||||
|
||||
inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) {
|
||||
const int tg = tid / 4;
|
||||
|
||||
@@ -416,7 +437,7 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
|
||||
const uint32_t local_b_col = tid_in_threadblock % BN_adjusted;
|
||||
|
||||
// FIXME: need fix for fp16?
|
||||
constexpr uint32_t threads_in_threadblock = (BM * BN) / ELEM_PER_THREAD;
|
||||
constexpr uint32_t threads_per_threadblock = (BM * BN) / ELEM_PER_THREAD;
|
||||
|
||||
// Data move from GMEM to SMEM
|
||||
//
|
||||
@@ -433,7 +454,7 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
|
||||
const uint32_t global_a_row = k_adjusted + local_as_row;
|
||||
const uint32_t global_a_col = BM * block_m + local_as_col;
|
||||
// number of rows a full TB can read at a time
|
||||
constexpr uint32_t row_stride_as = threads_in_threadblock / BM;
|
||||
constexpr uint32_t row_stride_as = threads_per_threadblock / BM;
|
||||
const float *global_a = reinterpret_cast<const float *>(A) +
|
||||
dim_m * global_a_row + global_a_col;
|
||||
volatile float *local_a_tmp = reinterpret_cast<volatile float *>(local_a) +
|
||||
@@ -452,7 +473,7 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
|
||||
if constexpr (!GMEM_COALESCED_A) {
|
||||
// !GMEM_COALESCED_A: threads do uncoalesced read from neighboring row in
|
||||
// GMEM, writes to neighboring cols in SMEM
|
||||
constexpr uint32_t row_stride_as = threads_in_threadblock / BM;
|
||||
constexpr uint32_t row_stride_as = threads_per_threadblock / BM;
|
||||
const uint32_t global_a_row = BM * threadblock_id_y + local_as_col;
|
||||
const float *global_a =
|
||||
reinterpret_cast<float *>(A) + dim_k_adjusted * global_a_row + (k_adjusted + local_as_row);
|
||||
@@ -508,7 +529,7 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
|
||||
local_a_tmp += BM * row_stride_as * 8;
|
||||
}
|
||||
} else {
|
||||
constexpr uint32_t row_stride_a = threads_in_threadblock / BK_adjusted;
|
||||
constexpr uint32_t row_stride_a = threads_per_threadblock / BK_adjusted;
|
||||
const uint32_t global_a_row = BM * threadblock_id_y + local_a_row;
|
||||
const float *global_a = reinterpret_cast<const float *>(A) +
|
||||
dim_k_adjusted * global_a_row +
|
||||
@@ -566,7 +587,7 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
|
||||
} // end move A
|
||||
|
||||
// move B
|
||||
constexpr uint32_t row_stride_b = threads_in_threadblock / BN_adjusted;
|
||||
constexpr uint32_t row_stride_b = threads_per_threadblock / BN_adjusted;
|
||||
const uint32_t global_b_col = BN_adjusted * threadblock_id_x + local_b_col;
|
||||
// NOTE: not k_adjusted here; k is along the row dimension which is not
|
||||
// compressed for fp16
|
||||
@@ -648,7 +669,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
||||
const uint32_t warp_row = warp_id_in_warpgroup / (BN / WN);
|
||||
const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN);
|
||||
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
||||
const uint32_t warps_per_threadblock_per_core = NUM_WARPS / threads_per_threadblock;
|
||||
const uint32_t warps_per_threadblock_per_core =
|
||||
NUM_WARPS / threads_per_threadblock;
|
||||
|
||||
volatile T *local_a = reinterpret_cast<T *>(sharedmem_per_threadblock);
|
||||
constexpr size_t local_a_elems = (BM * BK);
|
||||
|
||||
Reference in New Issue
Block a user