sgemm_tcore: Fix address overlap for DMA

Enforce square shapes of tiles in smem.  TODO need to configure loop
bounds correctly.
This commit is contained in:
Hansung Kim
2024-06-18 15:06:07 -07:00
parent 36b02ad595
commit 50b843d8c4
3 changed files with 11 additions and 6 deletions

View File

@@ -2,7 +2,7 @@ PROJECT = sgemm_tcore
SRCS = main.cpp common.h SRCS = main.cpp common.h
VX_SRCS = kernel.cpp VX_SRCS = kernel.activation.cpp
OPTS ?= -n16 OPTS ?= -n16

View File

@@ -32,7 +32,7 @@
#error Unsupported smem size #error Unsupported smem size
#endif #endif
#define WARP_SPECIALIZED 0 #define WARP_SPECIALIZED 1
static_assert( static_assert(
!WARP_SPECIALIZED || GEMMINI_DMA, !WARP_SPECIALIZED || GEMMINI_DMA,
"warp specialization is currently only supported with GEMMINI_DMA"); "warp specialization is currently only supported with GEMMINI_DMA");
@@ -284,12 +284,15 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN); const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN);
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
// layout: local_a -- local_a_buf -- local_b -- local_b_buf
volatile float *local_a = sharedmem_per_threadblock; volatile float *local_a = sharedmem_per_threadblock;
constexpr size_t local_a_elems = (BM * BK); constexpr size_t local_a_elems = (BM * BK);
volatile float *local_a_buf = local_a + local_a_elems; volatile float *local_a_buf = local_a + local_a_elems;
volatile float *local_b = local_a_buf + local_a_elems; volatile float *local_b = local_a_buf + local_a_elems;
constexpr size_t local_b_elems = (BK * BN); // set local B tile size to be the same as local A size (BM * BK), since DMA
// is currently only configured for square-shape tiles. FIXME.
constexpr size_t local_b_elems = (BK * BM);
volatile float *local_b_buf = local_a_buf + local_b_elems; volatile float *local_b_buf = local_a_buf + local_b_elems;
constexpr uint32_t skips = constexpr uint32_t skips =
@@ -618,7 +621,9 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float))); asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
} }
threadblock_barrier(0/*global threadblock in cluster*/, threadblock_dim_y); // global barrier that synchronizes both warpgroups at every M-N
// iteration
threadblock_barrier(0/*all warpgroups*/, threadblock_dim_y);
#pragma GCC unroll 2 #pragma GCC unroll 2
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
@@ -669,7 +674,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// "static" shared memory allocation. This would determine threadblock // "static" shared memory allocation. This would determine threadblock
// occupancy of a single cluster // occupancy of a single cluster
float *sharedmem_per_threadblock = float *sharedmem_per_threadblock =
(float *)DEV_SMEM_START_ADDR + 2/*overkill for non-dma*/ * (2 * BM * BK) * (float *)DEV_SMEM_START_ADDR + 2/*overkill for non-dma*/ * ((BM + BN) * BK) *
threadblock_id_in_cluster; threadblock_id_in_cluster;
thread_block_gemm(arg, tid_in_threadblock, threads_per_threadblock, thread_block_gemm(arg, tid_in_threadblock, threads_per_threadblock,

View File

@@ -19,7 +19,7 @@
// * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields // * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields
// BM <= BK*TM*TN // BM <= BK*TM*TN
#define BM 64 #define BM 64
#define BN 64 #define BN 32
#define BK 64 #define BK 64
#define WM 16 #define WM 16
#define WN 8 #define WN 8