flash: Fix barrier stall with DEBUG
Verified for up to P_expected on 2nd iter; O_before_PV is partially correct
This commit is contained in:
@@ -89,8 +89,9 @@ inline void thread_block_copy_rowmax(const float *src, float *dest,
|
||||
dest[offset] = src[offset];
|
||||
}
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
// threadblock_barrier(threadblock_id_in_cluster,
|
||||
// warps_per_threadblock_per_core);
|
||||
threadblock_barrier(1, 7);
|
||||
|
||||
asm volatile("threadblock_copy_rowmax_finish_%=:" ::);
|
||||
}
|
||||
@@ -127,8 +128,9 @@ inline void thread_block_copy_tile(const float *src, float *dest,
|
||||
dest[gmem_offset] = src[smem_offset];
|
||||
}
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
// threadblock_barrier(threadblock_id_in_cluster,
|
||||
// warps_per_threadblock_per_core);
|
||||
threadblock_barrier(1, 7);
|
||||
}
|
||||
|
||||
asm volatile("threadblock_copy_tile_finish_%=:" ::);
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
#include "gemmini_mmio.h"
|
||||
#include "flash_impl.hpp"
|
||||
|
||||
constexpr bool DEBUG = false;
|
||||
constexpr bool DEBUG = true;
|
||||
|
||||
static_assert(GEMMINI_DMA && !WARP_SPECIALIZED,
|
||||
"GEMMINI_DMA should be set and WARP_SPECIALIZED unset");
|
||||
@@ -528,8 +528,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
|
||||
asm volatile("move_k_v_finish_%=:" ::);
|
||||
|
||||
// // intra-warpgroup barrier
|
||||
// // FIXME hardcoded
|
||||
// NOTE: cannot put barrier here; thread 1-7 in warp 0 will skip the
|
||||
// branch and call this barrier earlier than when thread 0 finishes.
|
||||
// Since tmask is not considered, that will be a barrier resolve done too
|
||||
// early
|
||||
// threadblock_barrier(0, 1);
|
||||
|
||||
} else /* warp_id != 0 */ {
|
||||
@@ -538,6 +540,24 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
{
|
||||
const uint32_t tile_k_ = tile_k - 1;
|
||||
|
||||
if constexpr (DEBUG) {
|
||||
// verify S = Q*K
|
||||
|
||||
if (warpgroup_id == 0) {
|
||||
if (tile_k == 0) {
|
||||
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
||||
smem_S_produce, gmem_tmp_d0, tid_in_warpgroup_simt,
|
||||
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||
} else if (tile_k == 1) {
|
||||
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
||||
smem_S_produce, gmem_tmp_d1, tid_in_warpgroup_simt,
|
||||
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||
}
|
||||
|
||||
threadblock_barrier(barrier_id_simt, barrier_count_simt);
|
||||
}
|
||||
}
|
||||
|
||||
// Online softmax
|
||||
//
|
||||
thread_block_online_softmax</*block_row_major=*/GEMMINI_DMA>(
|
||||
@@ -550,25 +570,26 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
if constexpr (DEBUG) {
|
||||
if (warpgroup_id == 0) {
|
||||
if (tile_k_ == 0) {
|
||||
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0,
|
||||
tid_in_warpgroup_simt, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2,
|
||||
tid_in_warpgroup_simt, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
thread_block_copy_rowmax(
|
||||
smem_rowmax, gmem_tmp_e0, tid_in_warpgroup_simt,
|
||||
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||
thread_block_copy_rowmax(
|
||||
smem_rowsum, gmem_tmp_e2, tid_in_warpgroup_simt,
|
||||
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||
} else if (tile_k_ == 1) {
|
||||
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1,
|
||||
tid_in_warpgroup_simt, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
tid_in_warpgroup_simt, threads_per_warpgroup_simt,
|
||||
warpgroup_id_simt);
|
||||
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3,
|
||||
tid_in_warpgroup_simt, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
tid_in_warpgroup_simt, threads_per_warpgroup_simt,
|
||||
warpgroup_id_simt);
|
||||
}
|
||||
|
||||
threadblock_barrier(barrier_id_simt, barrier_count_simt);
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME: put synchronization with GEMM II here
|
||||
#if 0
|
||||
// fence GEMM II to make sure dependency on O tile is settled
|
||||
if (tid_in_warpgroup == 0) {
|
||||
@@ -662,22 +683,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
warps_per_warpgroup_per_core);
|
||||
#endif
|
||||
|
||||
if constexpr (DEBUG) {
|
||||
if (warpgroup_id == 0) {
|
||||
if (tile_k == 0) {
|
||||
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
||||
smem_S_produce, gmem_tmp_d0, tid_in_warpgroup_simt,
|
||||
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||
} else if (tile_k == 1) {
|
||||
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
||||
smem_S_produce, gmem_tmp_d1, tid_in_warpgroup_simt,
|
||||
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||
}
|
||||
|
||||
threadblock_barrier(barrier_id_simt, barrier_count_simt);
|
||||
}
|
||||
}
|
||||
|
||||
// intra-warpgroup barrier
|
||||
threadblock_barrier(barrier_id_simt, barrier_count_simt);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user