sgemm_impl: Parameterize BK/TCK by FP_SIZE

This commit is contained in:
Hansung Kim
2024-08-15 20:33:33 -07:00
parent fd2ff6208d
commit a1858e0c80

View File

@@ -6,6 +6,18 @@
#include "include/gemmini.h" #include "include/gemmini.h"
#include "gemmini_mmio.h" #include "gemmini_mmio.h"
#define FP_SIZE 32
// "fake" fp16 type that only has the correct data width.
using float16_t = uint16_t;
#if (FP_SIZE == 32)
using float_type = float;
#elif (FP_SIZE == 16)
using float_type = float16_t;
#endif
// Constraints on parameters: // Constraints on parameters:
// * Memory: // * Memory:
// (BM + BN) * BK * sizeof(T) <= sharedmem size. // (BM + BN) * BK * sizeof(T) <= sharedmem size.
@@ -20,12 +32,24 @@
// BM <= BK*TM*TN // BM <= BK*TM*TN
#define BM 64 #define BM 64
#define BN 64 #define BN 64
#if (FP_SIZE == 32)
#define BK 64
#elif (FP_SIZE == 16)
#define BK 128 #define BK 128
#else
#error "unsupported FP_SIZE"
#endif
#define WM 16 #define WM 16
#define WN 8 #define WN 8
#define TCM 8 #define TCM 8
#define TCN 8 #define TCN 8
#if (FP_SIZE == 32)
#define TCK 8
#elif (FP_SIZE == 16)
#define TCK 16 #define TCK 16
#else
#error "unsupported FP_SIZE"
#endif
#define WMITER (WM / TCM) #define WMITER (WM / TCM)
#define WNITER (WN / TCN) #define WNITER (WN / TCN)
#define ELEM_PER_THREAD (WM * WN / NUM_THREADS) #define ELEM_PER_THREAD (WM * WN / NUM_THREADS)
@@ -82,9 +106,6 @@
#error Unsupported smem size #error Unsupported smem size
#endif #endif
// "fake" fp16 type that only has the correct data width.
using float16_t = uint16_t;
inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) { inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) {
const int tg = tid / 4; const int tg = tid / 4;
@@ -416,7 +437,7 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
const uint32_t local_b_col = tid_in_threadblock % BN_adjusted; const uint32_t local_b_col = tid_in_threadblock % BN_adjusted;
// FIXME: need fix for fp16? // FIXME: need fix for fp16?
constexpr uint32_t threads_in_threadblock = (BM * BN) / ELEM_PER_THREAD; constexpr uint32_t threads_per_threadblock = (BM * BN) / ELEM_PER_THREAD;
// Data move from GMEM to SMEM // Data move from GMEM to SMEM
// //
@@ -433,7 +454,7 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
const uint32_t global_a_row = k_adjusted + local_as_row; const uint32_t global_a_row = k_adjusted + local_as_row;
const uint32_t global_a_col = BM * block_m + local_as_col; const uint32_t global_a_col = BM * block_m + local_as_col;
// number of rows a full TB can read at a time // number of rows a full TB can read at a time
constexpr uint32_t row_stride_as = threads_in_threadblock / BM; constexpr uint32_t row_stride_as = threads_per_threadblock / BM;
const float *global_a = reinterpret_cast<const float *>(A) + const float *global_a = reinterpret_cast<const float *>(A) +
dim_m * global_a_row + global_a_col; dim_m * global_a_row + global_a_col;
volatile float *local_a_tmp = reinterpret_cast<volatile float *>(local_a) + volatile float *local_a_tmp = reinterpret_cast<volatile float *>(local_a) +
@@ -452,7 +473,7 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
if constexpr (!GMEM_COALESCED_A) { if constexpr (!GMEM_COALESCED_A) {
// !GMEM_COALESCED_A: threads do uncoalesced read from neighboring row in // !GMEM_COALESCED_A: threads do uncoalesced read from neighboring row in
// GMEM, writes to neighboring cols in SMEM // GMEM, writes to neighboring cols in SMEM
constexpr uint32_t row_stride_as = threads_in_threadblock / BM; constexpr uint32_t row_stride_as = threads_per_threadblock / BM;
const uint32_t global_a_row = BM * threadblock_id_y + local_as_col; const uint32_t global_a_row = BM * threadblock_id_y + local_as_col;
const float *global_a = const float *global_a =
reinterpret_cast<float *>(A) + dim_k_adjusted * global_a_row + (k_adjusted + local_as_row); reinterpret_cast<float *>(A) + dim_k_adjusted * global_a_row + (k_adjusted + local_as_row);
@@ -508,7 +529,7 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
local_a_tmp += BM * row_stride_as * 8; local_a_tmp += BM * row_stride_as * 8;
} }
} else { } else {
constexpr uint32_t row_stride_a = threads_in_threadblock / BK_adjusted; constexpr uint32_t row_stride_a = threads_per_threadblock / BK_adjusted;
const uint32_t global_a_row = BM * threadblock_id_y + local_a_row; const uint32_t global_a_row = BM * threadblock_id_y + local_a_row;
const float *global_a = reinterpret_cast<const float *>(A) + const float *global_a = reinterpret_cast<const float *>(A) +
dim_k_adjusted * global_a_row + dim_k_adjusted * global_a_row +
@@ -566,7 +587,7 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
} // end move A } // end move A
// move B // move B
constexpr uint32_t row_stride_b = threads_in_threadblock / BN_adjusted; constexpr uint32_t row_stride_b = threads_per_threadblock / BN_adjusted;
const uint32_t global_b_col = BN_adjusted * threadblock_id_x + local_b_col; const uint32_t global_b_col = BN_adjusted * threadblock_id_x + local_b_col;
// NOTE: not k_adjusted here; k is along the row dimension which is not // NOTE: not k_adjusted here; k is along the row dimension which is not
// compressed for fp16 // compressed for fp16
@@ -648,7 +669,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
const uint32_t warp_row = warp_id_in_warpgroup / (BN / WN); const uint32_t warp_row = warp_id_in_warpgroup / (BN / WN);
const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN); const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN);
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
const uint32_t warps_per_threadblock_per_core = NUM_WARPS / threads_per_threadblock; const uint32_t warps_per_threadblock_per_core =
NUM_WARPS / threads_per_threadblock;
volatile T *local_a = reinterpret_cast<T *>(sharedmem_per_threadblock); volatile T *local_a = reinterpret_cast<T *>(sharedmem_per_threadblock);
constexpr size_t local_a_elems = (BM * BK); constexpr size_t local_a_elems = (BM * BK);