sgemm_tcore: Fix address overlap for DMA
Enforce square shapes of tiles in smem. TODO need to configure loop bounds correctly.
This commit is contained in:
@@ -2,7 +2,7 @@ PROJECT = sgemm_tcore
|
||||
|
||||
SRCS = main.cpp common.h
|
||||
|
||||
VX_SRCS = kernel.cpp
|
||||
VX_SRCS = kernel.activation.cpp
|
||||
|
||||
OPTS ?= -n16
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@
|
||||
#error Unsupported smem size
|
||||
#endif
|
||||
|
||||
#define WARP_SPECIALIZED 0
|
||||
#define WARP_SPECIALIZED 1
|
||||
static_assert(
|
||||
!WARP_SPECIALIZED || GEMMINI_DMA,
|
||||
"warp specialization is currently only supported with GEMMINI_DMA");
|
||||
@@ -284,12 +284,15 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN);
|
||||
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
||||
|
||||
// layout: local_a -- local_a_buf -- local_b -- local_b_buf
|
||||
volatile float *local_a = sharedmem_per_threadblock;
|
||||
constexpr size_t local_a_elems = (BM * BK);
|
||||
volatile float *local_a_buf = local_a + local_a_elems;
|
||||
|
||||
volatile float *local_b = local_a_buf + local_a_elems;
|
||||
constexpr size_t local_b_elems = (BK * BN);
|
||||
// set local B tile size to be the same as local A size (BM * BK), since DMA
|
||||
// is currently only configured for square-shape tiles. FIXME.
|
||||
constexpr size_t local_b_elems = (BK * BM);
|
||||
volatile float *local_b_buf = local_a_buf + local_b_elems;
|
||||
|
||||
constexpr uint32_t skips =
|
||||
@@ -618,7 +621,9 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
}
|
||||
|
||||
threadblock_barrier(0/*global threadblock in cluster*/, threadblock_dim_y);
|
||||
// global barrier that synchronizes both warpgroups at every M-N
|
||||
// iteration
|
||||
threadblock_barrier(0/*all warpgroups*/, threadblock_dim_y);
|
||||
|
||||
#pragma GCC unroll 2
|
||||
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
||||
@@ -669,7 +674,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// "static" shared memory allocation. This would determine threadblock
|
||||
// occupancy of a single cluster
|
||||
float *sharedmem_per_threadblock =
|
||||
(float *)DEV_SMEM_START_ADDR + 2/*overkill for non-dma*/ * (2 * BM * BK) *
|
||||
(float *)DEV_SMEM_START_ADDR + 2/*overkill for non-dma*/ * ((BM + BN) * BK) *
|
||||
threadblock_id_in_cluster;
|
||||
|
||||
thread_block_gemm(arg, tid_in_threadblock, threads_per_threadblock,
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
// * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields
|
||||
// BM <= BK*TM*TN
|
||||
#define BM 64
|
||||
#define BN 64
|
||||
#define BN 32
|
||||
#define BK 64
|
||||
#define WM 16
|
||||
#define WN 8
|
||||
|
||||
Reference in New Issue
Block a user