From a1858e0c803377d4e70e9e3964bfa4d9285f5d4e Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Thu, 15 Aug 2024 20:33:33 -0700 Subject: [PATCH] sgemm_impl: Parameterize BK/TCK by FP_SIZE --- tests/regression/sgemm_tcore/sgemm_impl.hpp | 40 ++++++++++++++++----- 1 file changed, 31 insertions(+), 9 deletions(-) diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index 8e22b16a..14ba1760 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -6,6 +6,18 @@ #include "include/gemmini.h" #include "gemmini_mmio.h" +#define FP_SIZE 32 + +// "fake" fp16 type that only has the correct data width. +using float16_t = uint16_t; + +#if (FP_SIZE == 32) +using float_type = float; +#elif (FP_SIZE == 16) + +using float_type = float16_t; +#endif + // Constraints on parameters: // * Memory: // (BM + BN) * BK * sizeof(T) <= sharedmem size. @@ -20,12 +32,24 @@ // BM <= BK*TM*TN #define BM 64 #define BN 64 +#if (FP_SIZE == 32) +#define BK 64 +#elif (FP_SIZE == 16) #define BK 128 +#else +#error "unsupported FP_SIZE" +#endif #define WM 16 #define WN 8 #define TCM 8 #define TCN 8 +#if (FP_SIZE == 32) +#define TCK 8 +#elif (FP_SIZE == 16) #define TCK 16 +#else +#error "unsupported FP_SIZE" +#endif #define WMITER (WM / TCM) #define WNITER (WN / TCN) #define ELEM_PER_THREAD (WM * WN / NUM_THREADS) @@ -82,9 +106,6 @@ #error Unsupported smem size #endif -// "fake" fp16 type that only has the correct data width. -using float16_t = uint16_t; - inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) { const int tg = tid / 4; @@ -416,7 +437,7 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u const uint32_t local_b_col = tid_in_threadblock % BN_adjusted; // FIXME: need fix for fp16? - constexpr uint32_t threads_in_threadblock = (BM * BN) / ELEM_PER_THREAD; + constexpr uint32_t threads_per_threadblock = (BM * BN) / ELEM_PER_THREAD; // Data move from GMEM to SMEM // @@ -433,7 +454,7 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u const uint32_t global_a_row = k_adjusted + local_as_row; const uint32_t global_a_col = BM * block_m + local_as_col; // number of rows a full TB can read at a time - constexpr uint32_t row_stride_as = threads_in_threadblock / BM; + constexpr uint32_t row_stride_as = threads_per_threadblock / BM; const float *global_a = reinterpret_cast(A) + dim_m * global_a_row + global_a_col; volatile float *local_a_tmp = reinterpret_cast(local_a) + @@ -452,7 +473,7 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u if constexpr (!GMEM_COALESCED_A) { // !GMEM_COALESCED_A: threads do uncoalesced read from neighboring row in // GMEM, writes to neighboring cols in SMEM - constexpr uint32_t row_stride_as = threads_in_threadblock / BM; + constexpr uint32_t row_stride_as = threads_per_threadblock / BM; const uint32_t global_a_row = BM * threadblock_id_y + local_as_col; const float *global_a = reinterpret_cast(A) + dim_k_adjusted * global_a_row + (k_adjusted + local_as_row); @@ -508,7 +529,7 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u local_a_tmp += BM * row_stride_as * 8; } } else { - constexpr uint32_t row_stride_a = threads_in_threadblock / BK_adjusted; + constexpr uint32_t row_stride_a = threads_per_threadblock / BK_adjusted; const uint32_t global_a_row = BM * threadblock_id_y + local_a_row; const float *global_a = reinterpret_cast(A) + dim_k_adjusted * global_a_row + @@ -566,7 +587,7 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u } // end move A // move B - constexpr uint32_t row_stride_b = threads_in_threadblock / BN_adjusted; + constexpr uint32_t row_stride_b = threads_per_threadblock / BN_adjusted; const uint32_t global_b_col = BN_adjusted * threadblock_id_x + local_b_col; // NOTE: not k_adjusted here; k is along the row dimension which is not // compressed for fp16 @@ -648,7 +669,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, const uint32_t warp_row = warp_id_in_warpgroup / (BN / WN); const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN); const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; - const uint32_t warps_per_threadblock_per_core = NUM_WARPS / threads_per_threadblock; + const uint32_t warps_per_threadblock_per_core = + NUM_WARPS / threads_per_threadblock; volatile T *local_a = reinterpret_cast(sharedmem_per_threadblock); constexpr size_t local_a_elems = (BM * BK);