sgemm_impl: Fix wrong barrier count; add barrier for write_to_smem
This commit is contained in:
@@ -568,7 +568,9 @@ template <typename T,
|
||||
__attribute__((always_inline)) inline void
|
||||
thread_block_gemm_single_tile(const T *local_a, const T *local_b, T *local_c,
|
||||
const uint32_t tid_in_threadblock,
|
||||
const uint32_t threads_per_threadblock) {
|
||||
const uint32_t threads_per_threadblock,
|
||||
const uint32_t threadblocks_per_cluster,
|
||||
const uint32_t threadblock_id_in_cluster) {
|
||||
// no double-buffering
|
||||
// FIXME: duplicated from thread_block_gemm
|
||||
const uint32_t threads_per_warpgroup = threads_per_threadblock;
|
||||
@@ -576,6 +578,8 @@ thread_block_gemm_single_tile(const T *local_a, const T *local_b, T *local_c,
|
||||
const uint32_t warp_row = warp_id_in_warpgroup / (BN / WN);
|
||||
const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN);
|
||||
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
||||
const uint32_t warps_per_threadblock_per_core =
|
||||
NUM_WARPS / threadblocks_per_cluster;
|
||||
|
||||
#pragma GCC unroll 1
|
||||
for (int i = 0; i < BK_LOOP; i++) {
|
||||
@@ -608,6 +612,11 @@ thread_block_gemm_single_tile(const T *local_a, const T *local_b, T *local_c,
|
||||
}
|
||||
|
||||
if constexpr (write_to_smem) {
|
||||
// need to protect smem reads in the earlier step from writes in below,
|
||||
// especially when the destination smem address overlaps with the input
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
|
||||
#pragma GCC unroll
|
||||
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
||||
#pragma GCC unroll
|
||||
@@ -655,7 +664,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
||||
const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN);
|
||||
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
||||
const uint32_t warps_per_threadblock_per_core =
|
||||
NUM_WARPS / threads_per_threadblock;
|
||||
NUM_WARPS / threadblocks_per_cluster;
|
||||
|
||||
T *local_a = reinterpret_cast<T *>(sharedmem_per_threadblock + smem_a_offset);
|
||||
T *local_a_buf =
|
||||
@@ -858,7 +867,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
||||
/*write_to_smem=*/false>(
|
||||
local_a_consume, local_b_consume,
|
||||
static_cast<T *>(nullptr) /*ignore*/, tid_in_threadblock,
|
||||
threads_per_threadblock);
|
||||
threads_per_threadblock, threadblocks_per_cluster,
|
||||
threadblock_id_in_cluster);
|
||||
|
||||
if constexpr (GEMMINI_DMA) {
|
||||
// Call gemmini fence at the end of the loop to overlap dma & wmma.
|
||||
|
||||
Reference in New Issue
Block a user