sgemm_tcore: Unpack arg params, remove threadblock_dim_y

thread_block_gemm is meant to be reusable, so it shouldn't assume what
the kernel arg struct looks like.

threadblock_dim_y was ambiguous and didn't match the literal name either
(it was used as # of warps that participate in a barrier).
This commit is contained in:
Hansung Kim
2024-08-14 20:34:49 -07:00
parent 70919c39c9
commit 014f7cd06f
2 changed files with 18 additions and 26 deletions

View File

@@ -31,7 +31,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
const uint32_t threadblocks_per_cluster =
hw_threads_per_cluster / threads_per_threadblock;
const uint32_t threadblock_dim_y = vx_num_warps() / threadblocks_per_cluster;
const int threadblock_id = task_id / threads_per_threadblock;
const int threadblock_id_in_cluster =
threadblock_id % threadblocks_per_cluster;
@@ -51,11 +50,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
DEV_SMEM_START_ADDR + sizeof(float_type) * 2 /*overkill for non-dma*/ *
(2 * BM * BK) * threadblock_id_in_cluster);
thread_block_gemm<float_type>(
arg, tid_in_threadblock, threads_per_threadblock, threadblock_dim_y,
/*threadblock_id_x, threadblock_id_y,*/
threadblocks_per_cluster,
// threadblock_id,
thread_block_gemm<float_type, /*write_to_gmem=*/true>(
(const float_type *)arg->addr_a, (const float_type *)arg->addr_b,
(float *)arg->addr_c, arg->dim_m, arg->dim_n, arg->dim_k,
tid_in_threadblock, threads_per_threadblock, threadblocks_per_cluster,
threadblock_id_in_cluster, sharedmem_per_threadblock);
}

View File

@@ -28,7 +28,7 @@
#define TCK 16
#define WMITER (WM / TCM)
#define WNITER (WN / TCN)
#define ELEM_PER_THREAD (WMITER * WNITER * (TCM * TCN) / NUM_THREADS)
#define ELEM_PER_THREAD (WM * WN / NUM_THREADS)
// FIXME: NUM_THREADS and NUM_WARPS hardcoded
#if ((BM * BN / ELEM_PER_THREAD) > (CORES_PER_CLUSTER * 8 * 8))
#error "threadblock size too big for cluster"
@@ -433,8 +433,6 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
const uint32_t global_a_row = k_adjusted + local_as_row;
const uint32_t global_a_col = BM * block_m + local_as_col;
// number of rows a full TB can read at a time
// this is equivalent to threadblock_dim_y (assuming threadblock_dim_x ==
// BK)
constexpr uint32_t row_stride_as = threads_in_threadblock / BM;
const float *global_a = reinterpret_cast<const float *>(A) +
dim_m * global_a_row + global_a_col;
@@ -628,23 +626,15 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
}
template <typename T, bool write_to_gmem = true>
inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
inline void thread_block_gemm(const T *A, const T *B, float *C,
const uint32_t dim_m,
const uint32_t dim_n,
const uint32_t dim_k,
const uint32_t tid_in_threadblock,
const uint32_t threads_per_threadblock,
const uint32_t threadblock_dim_y,
/*const uint32_t threadblock_id_x,
const uint32_t threadblock_id_y,*/
const uint32_t threadblocks_per_cluster,
const uint32_t threadblock_id_in_cluster,
uint8_t *sharedmem_per_threadblock) {
const T *A = (const T *)arg->addr_a;
const T *B = (const T *)arg->addr_b;
float *C = (float *)arg->addr_c;
const uint32_t dim_m = arg->dim_m;
const uint32_t dim_n = arg->dim_n;
const uint32_t dim_k = arg->dim_k;
const uint32_t local_a_row = tid_in_threadblock / BK;
const uint32_t local_a_col = tid_in_threadblock % BK;
const uint32_t local_as_row = tid_in_threadblock / BM;
@@ -658,6 +648,7 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
const uint32_t warp_row = warp_id_in_warpgroup / (BN / WN);
const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN);
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
const uint32_t warps_per_threadblock_per_core = NUM_WARPS / threads_per_threadblock;
volatile T *local_a = reinterpret_cast<T *>(sharedmem_per_threadblock);
constexpr size_t local_a_elems = (BM * BK);
@@ -737,7 +728,8 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
#endif
}
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
#pragma GCC unroll 1
@@ -814,7 +806,8 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
global_dmem_load<T>(dim_m, dim_n, dim_k, block_k * BK, A, B, local_a,
local_b, tid_in_threadblock, block_n, block_m);
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
#endif
// consumer code: SMEM->RF and compute
@@ -865,14 +858,15 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
if constexpr (GEMMINI_DMA) {
// Call gemmini fence at the end of the loop to overlap dma & wmma.
// Hopefully by this time, dma would have finished so that this is a
// no-op
// Usually, by this time, dma has finished the copy so that this
// becomes a no-op.
if (tid_in_threadblock == 0) {
gemmini_fence();
}
}
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
if constexpr (write_to_gmem) {