sgemm_impl: Add tiling params for hopper tensor core
This commit is contained in:
@@ -17,6 +17,10 @@ using float_type = float;
|
||||
using float_type = float16_t;
|
||||
#endif
|
||||
|
||||
// Generate kernel for the Hopper-style SMEM-decoupled tensor core. This uses
|
||||
// asynchronous HGMMA and HGMMA_WAIT instructions.
|
||||
#define TENSOR_HOPPER 1
|
||||
|
||||
// Constraints on parameters:
|
||||
// * Memory:
|
||||
// (BM + BN) * BK * sizeof(T) <= sharedmem size.
|
||||
@@ -29,6 +33,31 @@ using float_type = float16_t;
|
||||
// (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER
|
||||
// * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields
|
||||
// BM <= BK*TM*TN
|
||||
#if (TENSOR_HOPPER == 1)
|
||||
#define BM 128
|
||||
#define BN 64
|
||||
#if (FP_SIZE == 32)
|
||||
#define BK 64
|
||||
#elif (FP_SIZE == 16)
|
||||
#define BK 128
|
||||
#else
|
||||
#error "unsupported FP_SIZE"
|
||||
#endif
|
||||
|
||||
#define WM 16
|
||||
#define WN 16
|
||||
#define TCM 16
|
||||
#define TCN 16
|
||||
#if (FP_SIZE == 32)
|
||||
#define TCK 16
|
||||
#elif (FP_SIZE == 16)
|
||||
#define TCK 32
|
||||
#else
|
||||
#error "unsupported FP_SIZE"
|
||||
#endif
|
||||
|
||||
#else // !HOPPER
|
||||
|
||||
#define BM ((NUM_CORES == 8) ? 128 : 64)
|
||||
#define BN 64
|
||||
#if (FP_SIZE == 32)
|
||||
@@ -38,6 +67,7 @@ using float_type = float16_t;
|
||||
#else
|
||||
#error "unsupported FP_SIZE"
|
||||
#endif
|
||||
|
||||
#define WM 16
|
||||
#define WN 8
|
||||
#define TCM 8
|
||||
@@ -49,6 +79,8 @@ using float_type = float16_t;
|
||||
#else
|
||||
#error "unsupported FP_SIZE"
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#define WMITER (WM / TCM)
|
||||
#define WNITER (WN / TCN)
|
||||
#define ELEM_PER_THREAD (WM * WN / NUM_THREADS)
|
||||
|
||||
Reference in New Issue
Block a user