sgemm_impl: Fix wrong barrier count; add barrier for write_to_smem

This commit is contained in:
Hansung Kim
2024-08-19 15:33:23 -07:00
parent e93e54cdec
commit 4aba018733

View File

@@ -568,7 +568,9 @@ template <typename T,
__attribute__((always_inline)) inline void
thread_block_gemm_single_tile(const T *local_a, const T *local_b, T *local_c,
const uint32_t tid_in_threadblock,
const uint32_t threads_per_threadblock) {
const uint32_t threads_per_threadblock,
const uint32_t threadblocks_per_cluster,
const uint32_t threadblock_id_in_cluster) {
// no double-buffering
// FIXME: duplicated from thread_block_gemm
const uint32_t threads_per_warpgroup = threads_per_threadblock;
@@ -576,6 +578,8 @@ thread_block_gemm_single_tile(const T *local_a, const T *local_b, T *local_c,
const uint32_t warp_row = warp_id_in_warpgroup / (BN / WN);
const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN);
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
const uint32_t warps_per_threadblock_per_core =
NUM_WARPS / threadblocks_per_cluster;
#pragma GCC unroll 1
for (int i = 0; i < BK_LOOP; i++) {
@@ -608,6 +612,11 @@ thread_block_gemm_single_tile(const T *local_a, const T *local_b, T *local_c,
}
if constexpr (write_to_smem) {
// need to protect smem reads in the earlier step from writes in below,
// especially when the destination smem address overlaps with the input
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
#pragma GCC unroll
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
#pragma GCC unroll
@@ -655,7 +664,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN);
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
const uint32_t warps_per_threadblock_per_core =
NUM_WARPS / threads_per_threadblock;
NUM_WARPS / threadblocks_per_cluster;
T *local_a = reinterpret_cast<T *>(sharedmem_per_threadblock + smem_a_offset);
T *local_a_buf =
@@ -858,7 +867,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
/*write_to_smem=*/false>(
local_a_consume, local_b_consume,
static_cast<T *>(nullptr) /*ignore*/, tid_in_threadblock,
threads_per_threadblock);
threads_per_threadblock, threadblocks_per_cluster,
threadblock_id_in_cluster);
if constexpr (GEMMINI_DMA) {
// Call gemmini fence at the end of the loop to overlap dma & wmma.