flash: Fix barrier stall with DEBUG
Verified for up to P_expected on 2nd iter; O_before_PV is partially correct
This commit is contained in:
@@ -89,8 +89,9 @@ inline void thread_block_copy_rowmax(const float *src, float *dest,
|
|||||||
dest[offset] = src[offset];
|
dest[offset] = src[offset];
|
||||||
}
|
}
|
||||||
|
|
||||||
threadblock_barrier(threadblock_id_in_cluster,
|
// threadblock_barrier(threadblock_id_in_cluster,
|
||||||
warps_per_threadblock_per_core);
|
// warps_per_threadblock_per_core);
|
||||||
|
threadblock_barrier(1, 7);
|
||||||
|
|
||||||
asm volatile("threadblock_copy_rowmax_finish_%=:" ::);
|
asm volatile("threadblock_copy_rowmax_finish_%=:" ::);
|
||||||
}
|
}
|
||||||
@@ -127,8 +128,9 @@ inline void thread_block_copy_tile(const float *src, float *dest,
|
|||||||
dest[gmem_offset] = src[smem_offset];
|
dest[gmem_offset] = src[smem_offset];
|
||||||
}
|
}
|
||||||
|
|
||||||
threadblock_barrier(threadblock_id_in_cluster,
|
// threadblock_barrier(threadblock_id_in_cluster,
|
||||||
warps_per_threadblock_per_core);
|
// warps_per_threadblock_per_core);
|
||||||
|
threadblock_barrier(1, 7);
|
||||||
}
|
}
|
||||||
|
|
||||||
asm volatile("threadblock_copy_tile_finish_%=:" ::);
|
asm volatile("threadblock_copy_tile_finish_%=:" ::);
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
#include "gemmini_mmio.h"
|
#include "gemmini_mmio.h"
|
||||||
#include "flash_impl.hpp"
|
#include "flash_impl.hpp"
|
||||||
|
|
||||||
constexpr bool DEBUG = false;
|
constexpr bool DEBUG = true;
|
||||||
|
|
||||||
static_assert(GEMMINI_DMA && !WARP_SPECIALIZED,
|
static_assert(GEMMINI_DMA && !WARP_SPECIALIZED,
|
||||||
"GEMMINI_DMA should be set and WARP_SPECIALIZED unset");
|
"GEMMINI_DMA should be set and WARP_SPECIALIZED unset");
|
||||||
@@ -528,8 +528,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
|
|
||||||
asm volatile("move_k_v_finish_%=:" ::);
|
asm volatile("move_k_v_finish_%=:" ::);
|
||||||
|
|
||||||
// // intra-warpgroup barrier
|
// NOTE: cannot put barrier here; thread 1-7 in warp 0 will skip the
|
||||||
// // FIXME hardcoded
|
// branch and call this barrier earlier than when thread 0 finishes.
|
||||||
|
// Since tmask is not considered, that will be a barrier resolve done too
|
||||||
|
// early
|
||||||
// threadblock_barrier(0, 1);
|
// threadblock_barrier(0, 1);
|
||||||
|
|
||||||
} else /* warp_id != 0 */ {
|
} else /* warp_id != 0 */ {
|
||||||
@@ -538,6 +540,24 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
{
|
{
|
||||||
const uint32_t tile_k_ = tile_k - 1;
|
const uint32_t tile_k_ = tile_k - 1;
|
||||||
|
|
||||||
|
if constexpr (DEBUG) {
|
||||||
|
// verify S = Q*K
|
||||||
|
|
||||||
|
if (warpgroup_id == 0) {
|
||||||
|
if (tile_k == 0) {
|
||||||
|
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
||||||
|
smem_S_produce, gmem_tmp_d0, tid_in_warpgroup_simt,
|
||||||
|
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||||
|
} else if (tile_k == 1) {
|
||||||
|
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
||||||
|
smem_S_produce, gmem_tmp_d1, tid_in_warpgroup_simt,
|
||||||
|
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||||
|
}
|
||||||
|
|
||||||
|
threadblock_barrier(barrier_id_simt, barrier_count_simt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Online softmax
|
// Online softmax
|
||||||
//
|
//
|
||||||
thread_block_online_softmax</*block_row_major=*/GEMMINI_DMA>(
|
thread_block_online_softmax</*block_row_major=*/GEMMINI_DMA>(
|
||||||
@@ -550,25 +570,26 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
if constexpr (DEBUG) {
|
if constexpr (DEBUG) {
|
||||||
if (warpgroup_id == 0) {
|
if (warpgroup_id == 0) {
|
||||||
if (tile_k_ == 0) {
|
if (tile_k_ == 0) {
|
||||||
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0,
|
thread_block_copy_rowmax(
|
||||||
tid_in_warpgroup_simt, threads_per_warpgroup,
|
smem_rowmax, gmem_tmp_e0, tid_in_warpgroup_simt,
|
||||||
warpgroup_id_in_cluster);
|
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||||
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2,
|
thread_block_copy_rowmax(
|
||||||
tid_in_warpgroup_simt, threads_per_warpgroup,
|
smem_rowsum, gmem_tmp_e2, tid_in_warpgroup_simt,
|
||||||
warpgroup_id_in_cluster);
|
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||||
} else if (tile_k_ == 1) {
|
} else if (tile_k_ == 1) {
|
||||||
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1,
|
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1,
|
||||||
tid_in_warpgroup_simt, threads_per_warpgroup,
|
tid_in_warpgroup_simt, threads_per_warpgroup_simt,
|
||||||
warpgroup_id_in_cluster);
|
warpgroup_id_simt);
|
||||||
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3,
|
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3,
|
||||||
tid_in_warpgroup_simt, threads_per_warpgroup,
|
tid_in_warpgroup_simt, threads_per_warpgroup_simt,
|
||||||
warpgroup_id_in_cluster);
|
warpgroup_id_simt);
|
||||||
}
|
}
|
||||||
|
|
||||||
threadblock_barrier(barrier_id_simt, barrier_count_simt);
|
threadblock_barrier(barrier_id_simt, barrier_count_simt);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FIXME: put synchronization with GEMM II here
|
||||||
#if 0
|
#if 0
|
||||||
// fence GEMM II to make sure dependency on O tile is settled
|
// fence GEMM II to make sure dependency on O tile is settled
|
||||||
if (tid_in_warpgroup == 0) {
|
if (tid_in_warpgroup == 0) {
|
||||||
@@ -662,22 +683,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
warps_per_warpgroup_per_core);
|
warps_per_warpgroup_per_core);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
if constexpr (DEBUG) {
|
|
||||||
if (warpgroup_id == 0) {
|
|
||||||
if (tile_k == 0) {
|
|
||||||
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
|
||||||
smem_S_produce, gmem_tmp_d0, tid_in_warpgroup_simt,
|
|
||||||
threads_per_warpgroup_simt, warpgroup_id_simt);
|
|
||||||
} else if (tile_k == 1) {
|
|
||||||
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
|
||||||
smem_S_produce, gmem_tmp_d1, tid_in_warpgroup_simt,
|
|
||||||
threads_per_warpgroup_simt, warpgroup_id_simt);
|
|
||||||
}
|
|
||||||
|
|
||||||
threadblock_barrier(barrier_id_simt, barrier_count_simt);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// intra-warpgroup barrier
|
// intra-warpgroup barrier
|
||||||
threadblock_barrier(barrier_id_simt, barrier_count_simt);
|
threadblock_barrier(barrier_id_simt, barrier_count_simt);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user