sgemm_wg: Document tiling parameter constraints

This commit is contained in:
Hansung Kim
2024-03-28 18:17:00 -07:00
parent 9673db4e8c
commit a9b0814211

View File

@@ -4,11 +4,20 @@
#include <vx_spawn.h>
#include "common.h"
// Constraints on parameters:
// * Memory:
// (BM + BN) * BK * sizeof(float) <= sharedmem size.
// BM * BK == BN * BK >= threadblock size >= NT * CORES_PER_CLUSTER
// When larger, the kernel runs a sequential loop to read into sharedmem;
// but smaller case is not handled.
// * Compute:
// ( M* N) / (TM*TN) == grid size >= NC*NW*NT
// (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER
// * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields
// BM <= BK*TM*TN.
#define BM 8
#define BN BM
#define BK 2
// #define TM (BM/BK)
// #define TN (BN/BK)
#define TM 2
#define TN 2
@@ -82,7 +91,9 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
threadblock_barrier(tid_in_threadblock, threadblock_id_in_core,
threadblock_dim_y);
// Compute single tile*tile matmul
for (uint32_t local_k = 0; local_k < BK; local_k++) {
// First, pump data from SMEM->RF
#pragma GCC unroll TM
for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) {
reg_a[res_idx_m] =
@@ -94,7 +105,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
local_b[BN * local_k + (TN * local_c_col + res_idx_n)];
}
// Compute multiple result elements (TM) per thread
// Next, compute multiple result elements (TM*TN) by reusing data in RF
#pragma GCC unroll TM
for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) {
#pragma GCC unroll TN
@@ -113,6 +124,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
threadblock_dim_y);
}
// Store result data from RF to GMEM
#pragma GCC unroll TM
for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) {
#pragma GCC unroll TN