sgemm_tcore: Fix address overlap for DMA
Enforce square shapes of tiles in smem. TODO need to configure loop bounds correctly.
This commit is contained in:
@@ -2,7 +2,7 @@ PROJECT = sgemm_tcore
|
|||||||
|
|
||||||
SRCS = main.cpp common.h
|
SRCS = main.cpp common.h
|
||||||
|
|
||||||
VX_SRCS = kernel.cpp
|
VX_SRCS = kernel.activation.cpp
|
||||||
|
|
||||||
OPTS ?= -n16
|
OPTS ?= -n16
|
||||||
|
|
||||||
|
|||||||
@@ -32,7 +32,7 @@
|
|||||||
#error Unsupported smem size
|
#error Unsupported smem size
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define WARP_SPECIALIZED 0
|
#define WARP_SPECIALIZED 1
|
||||||
static_assert(
|
static_assert(
|
||||||
!WARP_SPECIALIZED || GEMMINI_DMA,
|
!WARP_SPECIALIZED || GEMMINI_DMA,
|
||||||
"warp specialization is currently only supported with GEMMINI_DMA");
|
"warp specialization is currently only supported with GEMMINI_DMA");
|
||||||
@@ -284,12 +284,15 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN);
|
const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN);
|
||||||
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
||||||
|
|
||||||
|
// layout: local_a -- local_a_buf -- local_b -- local_b_buf
|
||||||
volatile float *local_a = sharedmem_per_threadblock;
|
volatile float *local_a = sharedmem_per_threadblock;
|
||||||
constexpr size_t local_a_elems = (BM * BK);
|
constexpr size_t local_a_elems = (BM * BK);
|
||||||
volatile float *local_a_buf = local_a + local_a_elems;
|
volatile float *local_a_buf = local_a + local_a_elems;
|
||||||
|
|
||||||
volatile float *local_b = local_a_buf + local_a_elems;
|
volatile float *local_b = local_a_buf + local_a_elems;
|
||||||
constexpr size_t local_b_elems = (BK * BN);
|
// set local B tile size to be the same as local A size (BM * BK), since DMA
|
||||||
|
// is currently only configured for square-shape tiles. FIXME.
|
||||||
|
constexpr size_t local_b_elems = (BK * BM);
|
||||||
volatile float *local_b_buf = local_a_buf + local_b_elems;
|
volatile float *local_b_buf = local_a_buf + local_b_elems;
|
||||||
|
|
||||||
constexpr uint32_t skips =
|
constexpr uint32_t skips =
|
||||||
@@ -618,7 +621,9 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||||
}
|
}
|
||||||
|
|
||||||
threadblock_barrier(0/*global threadblock in cluster*/, threadblock_dim_y);
|
// global barrier that synchronizes both warpgroups at every M-N
|
||||||
|
// iteration
|
||||||
|
threadblock_barrier(0/*all warpgroups*/, threadblock_dim_y);
|
||||||
|
|
||||||
#pragma GCC unroll 2
|
#pragma GCC unroll 2
|
||||||
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
||||||
@@ -669,7 +674,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
// "static" shared memory allocation. This would determine threadblock
|
// "static" shared memory allocation. This would determine threadblock
|
||||||
// occupancy of a single cluster
|
// occupancy of a single cluster
|
||||||
float *sharedmem_per_threadblock =
|
float *sharedmem_per_threadblock =
|
||||||
(float *)DEV_SMEM_START_ADDR + 2/*overkill for non-dma*/ * (2 * BM * BK) *
|
(float *)DEV_SMEM_START_ADDR + 2/*overkill for non-dma*/ * ((BM + BN) * BK) *
|
||||||
threadblock_id_in_cluster;
|
threadblock_id_in_cluster;
|
||||||
|
|
||||||
thread_block_gemm(arg, tid_in_threadblock, threads_per_threadblock,
|
thread_block_gemm(arg, tid_in_threadblock, threads_per_threadblock,
|
||||||
|
|||||||
@@ -19,7 +19,7 @@
|
|||||||
// * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields
|
// * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields
|
||||||
// BM <= BK*TM*TN
|
// BM <= BK*TM*TN
|
||||||
#define BM 64
|
#define BM 64
|
||||||
#define BN 64
|
#define BN 32
|
||||||
#define BK 64
|
#define BK 64
|
||||||
#define WM 16
|
#define WM 16
|
||||||
#define WN 8
|
#define WN 8
|
||||||
|
|||||||
Reference in New Issue
Block a user