sgemm_tcore: Fix address overlap for DMA

Enforce square shapes of tiles in smem.  TODO need to configure loop
bounds correctly.
This commit is contained in:
Hansung Kim
2024-06-18 15:06:07 -07:00
parent 36b02ad595
commit 50b843d8c4
3 changed files with 11 additions and 6 deletions

View File

@@ -2,7 +2,7 @@ PROJECT = sgemm_tcore
SRCS = main.cpp common.h
VX_SRCS = kernel.cpp
VX_SRCS = kernel.activation.cpp
OPTS ?= -n16

View File

@@ -32,7 +32,7 @@
#error Unsupported smem size
#endif
#define WARP_SPECIALIZED 0
#define WARP_SPECIALIZED 1
static_assert(
!WARP_SPECIALIZED || GEMMINI_DMA,
"warp specialization is currently only supported with GEMMINI_DMA");
@@ -284,12 +284,15 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN);
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
// layout: local_a -- local_a_buf -- local_b -- local_b_buf
volatile float *local_a = sharedmem_per_threadblock;
constexpr size_t local_a_elems = (BM * BK);
volatile float *local_a_buf = local_a + local_a_elems;
volatile float *local_b = local_a_buf + local_a_elems;
constexpr size_t local_b_elems = (BK * BN);
// set local B tile size to be the same as local A size (BM * BK), since DMA
// is currently only configured for square-shape tiles. FIXME.
constexpr size_t local_b_elems = (BK * BM);
volatile float *local_b_buf = local_a_buf + local_b_elems;
constexpr uint32_t skips =
@@ -618,7 +621,9 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
}
threadblock_barrier(0/*global threadblock in cluster*/, threadblock_dim_y);
// global barrier that synchronizes both warpgroups at every M-N
// iteration
threadblock_barrier(0/*all warpgroups*/, threadblock_dim_y);
#pragma GCC unroll 2
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
@@ -669,7 +674,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// "static" shared memory allocation. This would determine threadblock
// occupancy of a single cluster
float *sharedmem_per_threadblock =
(float *)DEV_SMEM_START_ADDR + 2/*overkill for non-dma*/ * (2 * BM * BK) *
(float *)DEV_SMEM_START_ADDR + 2/*overkill for non-dma*/ * ((BM + BN) * BK) *
threadblock_id_in_cluster;
thread_block_gemm(arg, tid_in_threadblock, threads_per_threadblock,

View File

@@ -19,7 +19,7 @@
// * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields
// BM <= BK*TM*TN
#define BM 64
#define BN 64
#define BN 32
#define BK 64
#define WM 16
#define WN 8