flash: Add early return for warp-indivisible row iter

This commit is contained in:
Hansung Kim
2024-09-11 00:56:09 -07:00
parent 068d48534e
commit 18cf0e73cd

View File

@@ -8,6 +8,8 @@
#define B_COL 64
#define HEADDIM 64
#define ROW_REMAINDER_LOGIC
constexpr uint32_t ROWMAX_SETS = 3;
constexpr bool WARP_SPECIALIZED = false;
@@ -56,6 +58,14 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock,
for (int row_offset = 0; row_offset < B_COL;
row_offset += warps_in_threadblock) {
const uint32_t row = row_offset + warp_id;
#ifdef ROW_REMAINDER_LOGIC
if (row >= B_ROW) {
// WARNING: the number of barrier calls have to exactly match that in the
// outside of the branch to prevent stalls!! FIXME better proof this.
continue;
}
#endif
uint32_t thread_offset = HEADDIM * row + tid_in_warp;
constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS;
const float one = 0.0f;
@@ -114,6 +124,14 @@ inline void thread_block_copy_tile(const float *src, float *dest,
for (int row_offset = 0; row_offset < dim_row;
row_offset += warps_in_threadblock) {
const uint32_t row = row_offset + warp_id;
#ifdef ROW_REMAINDER_LOGIC
if (row >= B_ROW) {
// WARNING: the number of barrier calls have to exactly match that in the
// outside of the branch to prevent stalls!! FIXME better proof this.
threadblock_barrier(1, 7);
continue;
}
#endif
constexpr uint32_t per_row_iter = dim_col / NUM_THREADS;
#pragma GCC unroll
@@ -176,19 +194,21 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
for (int row_offset = 0; row_offset < B_ROW;
row_offset += warps_in_threadblock) {
const uint32_t row = row_offset + warp_id;
#ifdef ROW_REMAINDER_LOGIC
// if the number of warps doesn't exactly divide the number of rows,
// early-exit to prevent out-of-bounds access
// if (row >= B_ROW) {
// // WARNING: the number of barrier calls have to exactly match that in the
// // outside of the branch to prevent stalls!! FIXME better proof this.
// threadblock_barrier(1, 7);
// threadblock_barrier(1, 7);
// threadblock_barrier(1, 7);
// threadblock_barrier(1, 7);
// threadblock_barrier(1, 7);
// threadblock_barrier(1, 7);
// continue;
// }
if (row >= B_ROW) {
// WARNING: the number of barrier calls have to exactly match that in the
// outside of the branch to prevent stalls!! FIXME better proof this.
threadblock_barrier(1, 7);
threadblock_barrier(1, 7);
threadblock_barrier(1, 7);
threadblock_barrier(1, 7);
threadblock_barrier(1, 7);
threadblock_barrier(1, 7);
continue;
}
#endif
const uint32_t first_thread_offset = B_COL * row;
// rowmax
@@ -456,6 +476,14 @@ __attribute__((always_inline)) inline void thread_block_O_rescale(
for (int row_offset = 0; row_offset < B_ROW;
row_offset += warps_in_threadblock) {
const uint32_t row = row_offset + warp_id;
#ifdef ROW_REMAINDER_LOGIC
if (row >= B_ROW) {
// WARNING: the number of barrier calls have to exactly match that in the
// outside of the branch to prevent stalls!! FIXME better proof this.
continue;
}
#endif
constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS;
// Oi rescale
@@ -474,6 +502,9 @@ __attribute__((always_inline)) inline void thread_block_O_rescale(
}
}
// reconverge after warp divergence
threadblock_barrier(1, 7);
asm volatile("thread_block_O_rescale_finish_%=:" ::);
}