sgemm_impl: Add tiling params for hopper tensor core

This commit is contained in:
Hansung Kim
2024-10-23 19:50:18 -07:00
parent 68cd6455fe
commit 6417a625b1

View File

@@ -17,6 +17,10 @@ using float_type = float;
using float_type = float16_t;
#endif
// Generate kernel for the Hopper-style SMEM-decoupled tensor core. This uses
// asynchronous HGMMA and HGMMA_WAIT instructions.
#define TENSOR_HOPPER 1
// Constraints on parameters:
// * Memory:
// (BM + BN) * BK * sizeof(T) <= sharedmem size.
@@ -29,6 +33,31 @@ using float_type = float16_t;
// (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER
// * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields
// BM <= BK*TM*TN
#if (TENSOR_HOPPER == 1)
#define BM 128
#define BN 64
#if (FP_SIZE == 32)
#define BK 64
#elif (FP_SIZE == 16)
#define BK 128
#else
#error "unsupported FP_SIZE"
#endif
#define WM 16
#define WN 16
#define TCM 16
#define TCN 16
#if (FP_SIZE == 32)
#define TCK 16
#elif (FP_SIZE == 16)
#define TCK 32
#else
#error "unsupported FP_SIZE"
#endif
#else // !HOPPER
#define BM ((NUM_CORES == 8) ? 128 : 64)
#define BN 64
#if (FP_SIZE == 32)
@@ -38,6 +67,7 @@ using float_type = float16_t;
#else
#error "unsupported FP_SIZE"
#endif
#define WM 16
#define WN 8
#define TCM 8
@@ -49,6 +79,8 @@ using float_type = float16_t;
#else
#error "unsupported FP_SIZE"
#endif
#endif
#define WMITER (WM / TCM)
#define WNITER (WN / TCN)
#define ELEM_PER_THREAD (WM * WN / NUM_THREADS)