flash: Fix barrier stall with DEBUG

Verified for up to P_expected on 2nd iter; O_before_PV is partially
correct
This commit is contained in:
Hansung Kim
2024-09-09 17:02:05 -07:00
parent b652e25945
commit a17edac875
2 changed files with 40 additions and 33 deletions

View File

@@ -89,8 +89,9 @@ inline void thread_block_copy_rowmax(const float *src, float *dest,
dest[offset] = src[offset];
}
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
// threadblock_barrier(threadblock_id_in_cluster,
// warps_per_threadblock_per_core);
threadblock_barrier(1, 7);
asm volatile("threadblock_copy_rowmax_finish_%=:" ::);
}
@@ -127,8 +128,9 @@ inline void thread_block_copy_tile(const float *src, float *dest,
dest[gmem_offset] = src[smem_offset];
}
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
// threadblock_barrier(threadblock_id_in_cluster,
// warps_per_threadblock_per_core);
threadblock_barrier(1, 7);
}
asm volatile("threadblock_copy_tile_finish_%=:" ::);

View File

@@ -8,7 +8,7 @@
#include "gemmini_mmio.h"
#include "flash_impl.hpp"
constexpr bool DEBUG = false;
constexpr bool DEBUG = true;
static_assert(GEMMINI_DMA && !WARP_SPECIALIZED,
"GEMMINI_DMA should be set and WARP_SPECIALIZED unset");
@@ -528,8 +528,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
asm volatile("move_k_v_finish_%=:" ::);
// // intra-warpgroup barrier
// // FIXME hardcoded
// NOTE: cannot put barrier here; thread 1-7 in warp 0 will skip the
// branch and call this barrier earlier than when thread 0 finishes.
// Since tmask is not considered, that will be a barrier resolve done too
// early
// threadblock_barrier(0, 1);
} else /* warp_id != 0 */ {
@@ -538,6 +540,24 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
{
const uint32_t tile_k_ = tile_k - 1;
if constexpr (DEBUG) {
// verify S = Q*K
if (warpgroup_id == 0) {
if (tile_k == 0) {
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
smem_S_produce, gmem_tmp_d0, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
} else if (tile_k == 1) {
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
smem_S_produce, gmem_tmp_d1, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
}
threadblock_barrier(barrier_id_simt, barrier_count_simt);
}
}
// Online softmax
//
thread_block_online_softmax</*block_row_major=*/GEMMINI_DMA>(
@@ -550,25 +570,26 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
if constexpr (DEBUG) {
if (warpgroup_id == 0) {
if (tile_k_ == 0) {
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0,
tid_in_warpgroup_simt, threads_per_warpgroup,
warpgroup_id_in_cluster);
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2,
tid_in_warpgroup_simt, threads_per_warpgroup,
warpgroup_id_in_cluster);
thread_block_copy_rowmax(
smem_rowmax, gmem_tmp_e0, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
thread_block_copy_rowmax(
smem_rowsum, gmem_tmp_e2, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
} else if (tile_k_ == 1) {
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1,
tid_in_warpgroup_simt, threads_per_warpgroup,
warpgroup_id_in_cluster);
tid_in_warpgroup_simt, threads_per_warpgroup_simt,
warpgroup_id_simt);
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3,
tid_in_warpgroup_simt, threads_per_warpgroup,
warpgroup_id_in_cluster);
tid_in_warpgroup_simt, threads_per_warpgroup_simt,
warpgroup_id_simt);
}
threadblock_barrier(barrier_id_simt, barrier_count_simt);
}
}
// FIXME: put synchronization with GEMM II here
#if 0
// fence GEMM II to make sure dependency on O tile is settled
if (tid_in_warpgroup == 0) {
@@ -662,22 +683,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
warps_per_warpgroup_per_core);
#endif
if constexpr (DEBUG) {
if (warpgroup_id == 0) {
if (tile_k == 0) {
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
smem_S_produce, gmem_tmp_d0, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
} else if (tile_k == 1) {
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
smem_S_produce, gmem_tmp_d1, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
}
threadblock_barrier(barrier_id_simt, barrier_count_simt);
}
}
// intra-warpgroup barrier
threadblock_barrier(barrier_id_simt, barrier_count_simt);
}