flash: Fix barrier stall with DEBUG

Verified for up to P_expected on 2nd iter; O_before_PV is partially
correct
This commit is contained in:
Hansung Kim
2024-09-09 17:02:05 -07:00
parent b652e25945
commit a17edac875
2 changed files with 40 additions and 33 deletions

View File

@@ -89,8 +89,9 @@ inline void thread_block_copy_rowmax(const float *src, float *dest,
dest[offset] = src[offset]; dest[offset] = src[offset];
} }
threadblock_barrier(threadblock_id_in_cluster, // threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core); // warps_per_threadblock_per_core);
threadblock_barrier(1, 7);
asm volatile("threadblock_copy_rowmax_finish_%=:" ::); asm volatile("threadblock_copy_rowmax_finish_%=:" ::);
} }
@@ -127,8 +128,9 @@ inline void thread_block_copy_tile(const float *src, float *dest,
dest[gmem_offset] = src[smem_offset]; dest[gmem_offset] = src[smem_offset];
} }
threadblock_barrier(threadblock_id_in_cluster, // threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core); // warps_per_threadblock_per_core);
threadblock_barrier(1, 7);
} }
asm volatile("threadblock_copy_tile_finish_%=:" ::); asm volatile("threadblock_copy_tile_finish_%=:" ::);

View File

@@ -8,7 +8,7 @@
#include "gemmini_mmio.h" #include "gemmini_mmio.h"
#include "flash_impl.hpp" #include "flash_impl.hpp"
constexpr bool DEBUG = false; constexpr bool DEBUG = true;
static_assert(GEMMINI_DMA && !WARP_SPECIALIZED, static_assert(GEMMINI_DMA && !WARP_SPECIALIZED,
"GEMMINI_DMA should be set and WARP_SPECIALIZED unset"); "GEMMINI_DMA should be set and WARP_SPECIALIZED unset");
@@ -528,8 +528,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
asm volatile("move_k_v_finish_%=:" ::); asm volatile("move_k_v_finish_%=:" ::);
// // intra-warpgroup barrier // NOTE: cannot put barrier here; thread 1-7 in warp 0 will skip the
// // FIXME hardcoded // branch and call this barrier earlier than when thread 0 finishes.
// Since tmask is not considered, that will be a barrier resolve done too
// early
// threadblock_barrier(0, 1); // threadblock_barrier(0, 1);
} else /* warp_id != 0 */ { } else /* warp_id != 0 */ {
@@ -538,6 +540,24 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
{ {
const uint32_t tile_k_ = tile_k - 1; const uint32_t tile_k_ = tile_k - 1;
if constexpr (DEBUG) {
// verify S = Q*K
if (warpgroup_id == 0) {
if (tile_k == 0) {
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
smem_S_produce, gmem_tmp_d0, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
} else if (tile_k == 1) {
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
smem_S_produce, gmem_tmp_d1, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
}
threadblock_barrier(barrier_id_simt, barrier_count_simt);
}
}
// Online softmax // Online softmax
// //
thread_block_online_softmax</*block_row_major=*/GEMMINI_DMA>( thread_block_online_softmax</*block_row_major=*/GEMMINI_DMA>(
@@ -550,25 +570,26 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
if constexpr (DEBUG) { if constexpr (DEBUG) {
if (warpgroup_id == 0) { if (warpgroup_id == 0) {
if (tile_k_ == 0) { if (tile_k_ == 0) {
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, thread_block_copy_rowmax(
tid_in_warpgroup_simt, threads_per_warpgroup, smem_rowmax, gmem_tmp_e0, tid_in_warpgroup_simt,
warpgroup_id_in_cluster); threads_per_warpgroup_simt, warpgroup_id_simt);
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, thread_block_copy_rowmax(
tid_in_warpgroup_simt, threads_per_warpgroup, smem_rowsum, gmem_tmp_e2, tid_in_warpgroup_simt,
warpgroup_id_in_cluster); threads_per_warpgroup_simt, warpgroup_id_simt);
} else if (tile_k_ == 1) { } else if (tile_k_ == 1) {
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1,
tid_in_warpgroup_simt, threads_per_warpgroup, tid_in_warpgroup_simt, threads_per_warpgroup_simt,
warpgroup_id_in_cluster); warpgroup_id_simt);
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3,
tid_in_warpgroup_simt, threads_per_warpgroup, tid_in_warpgroup_simt, threads_per_warpgroup_simt,
warpgroup_id_in_cluster); warpgroup_id_simt);
} }
threadblock_barrier(barrier_id_simt, barrier_count_simt); threadblock_barrier(barrier_id_simt, barrier_count_simt);
} }
} }
// FIXME: put synchronization with GEMM II here
#if 0 #if 0
// fence GEMM II to make sure dependency on O tile is settled // fence GEMM II to make sure dependency on O tile is settled
if (tid_in_warpgroup == 0) { if (tid_in_warpgroup == 0) {
@@ -662,22 +683,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
warps_per_warpgroup_per_core); warps_per_warpgroup_per_core);
#endif #endif
if constexpr (DEBUG) {
if (warpgroup_id == 0) {
if (tile_k == 0) {
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
smem_S_produce, gmem_tmp_d0, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
} else if (tile_k == 1) {
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
smem_S_produce, gmem_tmp_d1, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
}
threadblock_barrier(barrier_id_simt, barrier_count_simt);
}
}
// intra-warpgroup barrier // intra-warpgroup barrier
threadblock_barrier(barrier_id_simt, barrier_count_simt); threadblock_barrier(barrier_id_simt, barrier_count_simt);
} }