sgemm_impl: Parameterize BK/TCK by FP_SIZE
This commit is contained in:
@@ -6,6 +6,18 @@
|
|||||||
#include "include/gemmini.h"
|
#include "include/gemmini.h"
|
||||||
#include "gemmini_mmio.h"
|
#include "gemmini_mmio.h"
|
||||||
|
|
||||||
|
#define FP_SIZE 32
|
||||||
|
|
||||||
|
// "fake" fp16 type that only has the correct data width.
|
||||||
|
using float16_t = uint16_t;
|
||||||
|
|
||||||
|
#if (FP_SIZE == 32)
|
||||||
|
using float_type = float;
|
||||||
|
#elif (FP_SIZE == 16)
|
||||||
|
|
||||||
|
using float_type = float16_t;
|
||||||
|
#endif
|
||||||
|
|
||||||
// Constraints on parameters:
|
// Constraints on parameters:
|
||||||
// * Memory:
|
// * Memory:
|
||||||
// (BM + BN) * BK * sizeof(T) <= sharedmem size.
|
// (BM + BN) * BK * sizeof(T) <= sharedmem size.
|
||||||
@@ -20,12 +32,24 @@
|
|||||||
// BM <= BK*TM*TN
|
// BM <= BK*TM*TN
|
||||||
#define BM 64
|
#define BM 64
|
||||||
#define BN 64
|
#define BN 64
|
||||||
|
#if (FP_SIZE == 32)
|
||||||
|
#define BK 64
|
||||||
|
#elif (FP_SIZE == 16)
|
||||||
#define BK 128
|
#define BK 128
|
||||||
|
#else
|
||||||
|
#error "unsupported FP_SIZE"
|
||||||
|
#endif
|
||||||
#define WM 16
|
#define WM 16
|
||||||
#define WN 8
|
#define WN 8
|
||||||
#define TCM 8
|
#define TCM 8
|
||||||
#define TCN 8
|
#define TCN 8
|
||||||
|
#if (FP_SIZE == 32)
|
||||||
|
#define TCK 8
|
||||||
|
#elif (FP_SIZE == 16)
|
||||||
#define TCK 16
|
#define TCK 16
|
||||||
|
#else
|
||||||
|
#error "unsupported FP_SIZE"
|
||||||
|
#endif
|
||||||
#define WMITER (WM / TCM)
|
#define WMITER (WM / TCM)
|
||||||
#define WNITER (WN / TCN)
|
#define WNITER (WN / TCN)
|
||||||
#define ELEM_PER_THREAD (WM * WN / NUM_THREADS)
|
#define ELEM_PER_THREAD (WM * WN / NUM_THREADS)
|
||||||
@@ -82,9 +106,6 @@
|
|||||||
#error Unsupported smem size
|
#error Unsupported smem size
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// "fake" fp16 type that only has the correct data width.
|
|
||||||
using float16_t = uint16_t;
|
|
||||||
|
|
||||||
inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) {
|
inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) {
|
||||||
const int tg = tid / 4;
|
const int tg = tid / 4;
|
||||||
|
|
||||||
@@ -416,7 +437,7 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
|
|||||||
const uint32_t local_b_col = tid_in_threadblock % BN_adjusted;
|
const uint32_t local_b_col = tid_in_threadblock % BN_adjusted;
|
||||||
|
|
||||||
// FIXME: need fix for fp16?
|
// FIXME: need fix for fp16?
|
||||||
constexpr uint32_t threads_in_threadblock = (BM * BN) / ELEM_PER_THREAD;
|
constexpr uint32_t threads_per_threadblock = (BM * BN) / ELEM_PER_THREAD;
|
||||||
|
|
||||||
// Data move from GMEM to SMEM
|
// Data move from GMEM to SMEM
|
||||||
//
|
//
|
||||||
@@ -433,7 +454,7 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
|
|||||||
const uint32_t global_a_row = k_adjusted + local_as_row;
|
const uint32_t global_a_row = k_adjusted + local_as_row;
|
||||||
const uint32_t global_a_col = BM * block_m + local_as_col;
|
const uint32_t global_a_col = BM * block_m + local_as_col;
|
||||||
// number of rows a full TB can read at a time
|
// number of rows a full TB can read at a time
|
||||||
constexpr uint32_t row_stride_as = threads_in_threadblock / BM;
|
constexpr uint32_t row_stride_as = threads_per_threadblock / BM;
|
||||||
const float *global_a = reinterpret_cast<const float *>(A) +
|
const float *global_a = reinterpret_cast<const float *>(A) +
|
||||||
dim_m * global_a_row + global_a_col;
|
dim_m * global_a_row + global_a_col;
|
||||||
volatile float *local_a_tmp = reinterpret_cast<volatile float *>(local_a) +
|
volatile float *local_a_tmp = reinterpret_cast<volatile float *>(local_a) +
|
||||||
@@ -452,7 +473,7 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
|
|||||||
if constexpr (!GMEM_COALESCED_A) {
|
if constexpr (!GMEM_COALESCED_A) {
|
||||||
// !GMEM_COALESCED_A: threads do uncoalesced read from neighboring row in
|
// !GMEM_COALESCED_A: threads do uncoalesced read from neighboring row in
|
||||||
// GMEM, writes to neighboring cols in SMEM
|
// GMEM, writes to neighboring cols in SMEM
|
||||||
constexpr uint32_t row_stride_as = threads_in_threadblock / BM;
|
constexpr uint32_t row_stride_as = threads_per_threadblock / BM;
|
||||||
const uint32_t global_a_row = BM * threadblock_id_y + local_as_col;
|
const uint32_t global_a_row = BM * threadblock_id_y + local_as_col;
|
||||||
const float *global_a =
|
const float *global_a =
|
||||||
reinterpret_cast<float *>(A) + dim_k_adjusted * global_a_row + (k_adjusted + local_as_row);
|
reinterpret_cast<float *>(A) + dim_k_adjusted * global_a_row + (k_adjusted + local_as_row);
|
||||||
@@ -508,7 +529,7 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
|
|||||||
local_a_tmp += BM * row_stride_as * 8;
|
local_a_tmp += BM * row_stride_as * 8;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
constexpr uint32_t row_stride_a = threads_in_threadblock / BK_adjusted;
|
constexpr uint32_t row_stride_a = threads_per_threadblock / BK_adjusted;
|
||||||
const uint32_t global_a_row = BM * threadblock_id_y + local_a_row;
|
const uint32_t global_a_row = BM * threadblock_id_y + local_a_row;
|
||||||
const float *global_a = reinterpret_cast<const float *>(A) +
|
const float *global_a = reinterpret_cast<const float *>(A) +
|
||||||
dim_k_adjusted * global_a_row +
|
dim_k_adjusted * global_a_row +
|
||||||
@@ -566,7 +587,7 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
|
|||||||
} // end move A
|
} // end move A
|
||||||
|
|
||||||
// move B
|
// move B
|
||||||
constexpr uint32_t row_stride_b = threads_in_threadblock / BN_adjusted;
|
constexpr uint32_t row_stride_b = threads_per_threadblock / BN_adjusted;
|
||||||
const uint32_t global_b_col = BN_adjusted * threadblock_id_x + local_b_col;
|
const uint32_t global_b_col = BN_adjusted * threadblock_id_x + local_b_col;
|
||||||
// NOTE: not k_adjusted here; k is along the row dimension which is not
|
// NOTE: not k_adjusted here; k is along the row dimension which is not
|
||||||
// compressed for fp16
|
// compressed for fp16
|
||||||
@@ -648,7 +669,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
const uint32_t warp_row = warp_id_in_warpgroup / (BN / WN);
|
const uint32_t warp_row = warp_id_in_warpgroup / (BN / WN);
|
||||||
const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN);
|
const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN);
|
||||||
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
||||||
const uint32_t warps_per_threadblock_per_core = NUM_WARPS / threads_per_threadblock;
|
const uint32_t warps_per_threadblock_per_core =
|
||||||
|
NUM_WARPS / threads_per_threadblock;
|
||||||
|
|
||||||
volatile T *local_a = reinterpret_cast<T *>(sharedmem_per_threadblock);
|
volatile T *local_a = reinterpret_cast<T *>(sharedmem_per_threadblock);
|
||||||
constexpr size_t local_a_elems = (BM * BK);
|
constexpr size_t local_a_elems = (BM * BK);
|
||||||
|
|||||||
Reference in New Issue
Block a user