flash: Fix online softmax for warp-specialized
Note: now that threads_per_threadblock is passed as compile-time constant, the compiler likes to completely loop unroll which can cause a lot of stack spills. todo fix GEMM part.
This commit is contained in:
@@ -13,14 +13,22 @@
|
||||
// FIXME
|
||||
#define HEADDIM B_COL
|
||||
|
||||
constexpr uint32_t ROWMAX_SETS = 3;
|
||||
constexpr bool DEBUG = true;
|
||||
constexpr bool DOUBLE_BUF = false;
|
||||
constexpr bool DOUBLE_BUF = true;
|
||||
|
||||
// temporary safety stop for wrong configs
|
||||
static_assert(NUM_CORES == 4);
|
||||
static_assert(NUM_THREADS == 8);
|
||||
static_assert(NUM_WARPS == 8);
|
||||
|
||||
inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock,
|
||||
const uint32_t threads_per_threadblock,
|
||||
float *smem_O,
|
||||
float *smem_rowmax,
|
||||
float *smem_rowsum) {
|
||||
asm volatile("threadblock_init_sharedmem_start_%=:" ::);
|
||||
|
||||
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
||||
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
|
||||
const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS;
|
||||
@@ -36,26 +44,30 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock,
|
||||
constexpr uint32_t needed_warps = B_ROW / NUM_THREADS;
|
||||
if (warp_id < needed_warps /* more warps in HW than needed? */) {
|
||||
uint32_t offset = NUM_THREADS * warp_id + tid_in_warp;
|
||||
// mi, mi~, minew
|
||||
smem_rowmax[offset] = FLT_MIN;
|
||||
smem_rowmax[offset + B_ROW] = FLT_MIN;
|
||||
smem_rowmax[offset + 2 * B_ROW] = FLT_MIN;
|
||||
#pragma GCC unroll
|
||||
for (int i = 0; i < ROWMAX_SETS; i++) {
|
||||
smem_rowmax[offset + i * ROWMAX_SETS] = FLT_MIN;
|
||||
}
|
||||
smem_rowsum[offset] = 0.0f;
|
||||
}
|
||||
|
||||
// each warp clears out a row of smem_O
|
||||
// FIXME: dedup this pattern
|
||||
for (int warp_offset = 0; warp_offset < B_COL;
|
||||
warp_offset += warps_in_threadblock) {
|
||||
const uint32_t row = warp_offset + warp_id;
|
||||
#pragma GCC unroll 1
|
||||
for (int row_offset = 0; row_offset < B_COL;
|
||||
row_offset += warps_in_threadblock) {
|
||||
const uint32_t row = row_offset + warp_id;
|
||||
uint32_t thread_offset = HEADDIM * row + tid_in_warp;
|
||||
constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS;
|
||||
const float one = 0.0f;
|
||||
#pragma GCC unroll
|
||||
for (int i = 0; i < per_row_iter; i++) {
|
||||
smem_O[thread_offset] = 0.0f;
|
||||
thread_offset += NUM_THREADS;
|
||||
}
|
||||
}
|
||||
|
||||
asm volatile("threadblock_init_sharedmem_finish_%=:" ::);
|
||||
}
|
||||
|
||||
inline void thread_block_copy_rowmax(const float *src, float *dest,
|
||||
@@ -97,9 +109,10 @@ inline void thread_block_copy_tile(const float *src, float *dest,
|
||||
warps_in_threadblock / CORES_PER_CLUSTER;
|
||||
|
||||
// FIXME: dedup this pattern
|
||||
for (int warp_offset = 0; warp_offset < B_ROW;
|
||||
warp_offset += warps_in_threadblock) {
|
||||
const uint32_t row = warp_offset + warp_id;
|
||||
#pragma GCC unroll 1
|
||||
for (int row_offset = 0; row_offset < B_ROW;
|
||||
row_offset += warps_in_threadblock) {
|
||||
const uint32_t row = row_offset + warp_id;
|
||||
const uint32_t first_thread_offset = B_COL * row;
|
||||
|
||||
constexpr uint32_t per_row_iter = B_COL / NUM_THREADS;
|
||||
@@ -163,6 +176,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
|
||||
float *smem_rowmax_this = smem_rowmax + B_ROW;
|
||||
|
||||
#pragma GCC unroll 1
|
||||
for (int row_offset = 0; row_offset < B_ROW;
|
||||
row_offset += warps_in_threadblock) {
|
||||
const uint32_t row = row_offset + warp_id;
|
||||
@@ -171,27 +185,46 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
// rowmax
|
||||
//
|
||||
// two-level tree reduction: reduce each row into NUM_THREADS intermediate
|
||||
// maxes, then reduce it to one global max
|
||||
// maxes, then reduce it down to one row max
|
||||
// one warp handles one row in tile
|
||||
|
||||
constexpr uint32_t per_row_iter = B_COL / NUM_THREADS;
|
||||
uint32_t thread_offset = first_thread_offset + tid_in_warp;
|
||||
// FIXME: threadblock_id needs to be in here too
|
||||
float *warp_smem = smem_scratchpad + (warp_id * NUM_THREADS);
|
||||
|
||||
// #define DUMB_ROWMAX
|
||||
#ifdef DUMB_ROWMAX
|
||||
// FIXME remove
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
|
||||
// no tree reduction; a single thread in a warp does serialized max across
|
||||
// the entire row
|
||||
if (tid_in_warp == 0) {
|
||||
float max = S[first_thread_offset];
|
||||
#pragma GCC unroll
|
||||
float rowmax = smem_S[first_thread_offset];
|
||||
#pragma GCC unroll 16
|
||||
for (int i = 0; i < B_COL; i++) {
|
||||
asm volatile("fmax.s %0, %1, %2"
|
||||
: "=f"(max)
|
||||
: "f"(max), "f"(S[first_thread_offset + i]));
|
||||
: "=f"(rowmax)
|
||||
: "f"(rowmax), "f"(smem_S[first_thread_offset + i]));
|
||||
}
|
||||
smem_rowmax[row] = max;
|
||||
smem_rowmax_this[row] = rowmax;
|
||||
|
||||
// update previous rowmax
|
||||
// i.e. mi_new = max(mi, mij)
|
||||
float prev_rowmax = smem_rowmax[row];
|
||||
// stage prev rowmax in scratchpad for warp-wide broadcast
|
||||
warp_smem[0] = prev_rowmax;
|
||||
asm volatile("fmax.s %0, %1, %2"
|
||||
: "=f"(rowmax)
|
||||
: "f"(rowmax), "f"(prev_rowmax));
|
||||
smem_rowmax[row] = rowmax;
|
||||
}
|
||||
|
||||
#else
|
||||
static_assert((B_COL % NUM_THREADS) == 0,
|
||||
"B_COL must be a multiple of NUM_THREADS");
|
||||
constexpr uint32_t per_row_iter = B_COL / NUM_THREADS;
|
||||
uint32_t thread_offset = first_thread_offset + tid_in_warp;
|
||||
float per_thread_max = FLT_MIN;
|
||||
#pragma GCC unroll
|
||||
for (int i = 0; i < per_row_iter; i++) {
|
||||
@@ -202,8 +235,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
thread_offset += NUM_THREADS;
|
||||
}
|
||||
// stage per-thread max value in smem
|
||||
// FIXME: threadblock_id needs to be in here too
|
||||
float *warp_smem = smem_scratchpad + (warp_id * NUM_THREADS);
|
||||
warp_smem[tid_in_warp] = per_thread_max;
|
||||
|
||||
// sync writes to warp_smem
|
||||
@@ -233,9 +264,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
: "f"(rowmax), "f"(prev_rowmax));
|
||||
smem_rowmax[row] = rowmax;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
if (warp_id < warps_in_threadblock / NUM_THREADS) {
|
||||
const uint32_t row = row_offset + NUM_THREADS * warp_id + tid_in_warp;
|
||||
float *const thread_smem = smem_scratchpad + (tid_in_warp * NUM_THREADS);
|
||||
@@ -257,8 +286,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
: "f"(rowmax), "f"(prev_rowmax));
|
||||
smem_rowmax[row] = rowmax;
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // PARALLEL_ROWMAX
|
||||
#endif // DUMB_ROWMAX
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
@@ -404,16 +432,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
#endif
|
||||
|
||||
// FIXME: headdim not considered
|
||||
uint32_t threads_per_threadblock =
|
||||
constexpr uint32_t threads_per_threadblock_theoretical =
|
||||
(B_ROW * B_COL) / (ELEM_PER_THREAD) / (DOUBLE_BUF ? 2 : 1);
|
||||
const uint32_t hw_threads_per_cluster =
|
||||
cores_per_cluster * vx_num_threads() * vx_num_warps();
|
||||
constexpr uint32_t hw_threads_per_cluster =
|
||||
CORES_PER_CLUSTER * NUM_THREADS * NUM_WARPS;
|
||||
// cap maximum threadblock size to # of HW threads in cluster, to prevent
|
||||
// multiple "wave" invocations which slows down the kernel
|
||||
if (threads_per_threadblock > hw_threads_per_cluster) {
|
||||
threads_per_threadblock = hw_threads_per_cluster;
|
||||
}
|
||||
const uint32_t threadblocks_per_cluster =
|
||||
constexpr uint32_t threads_per_threadblock =
|
||||
(threads_per_threadblock_theoretical > hw_threads_per_cluster)
|
||||
? hw_threads_per_cluster
|
||||
: threads_per_threadblock_theoretical;
|
||||
constexpr uint32_t threadblocks_per_cluster =
|
||||
hw_threads_per_cluster / threads_per_threadblock;
|
||||
|
||||
const int threadblock_id = task_id / threads_per_threadblock;
|
||||
@@ -452,7 +481,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
smem_QK_size + smem_V_size;
|
||||
|
||||
// allocate rowmax/rowsum storage at the end of the sharedmem address space
|
||||
constexpr uint32_t smem_rowmax_size = B_ROW * 3 /* mi, mi~, minew */;
|
||||
constexpr uint32_t smem_rowmax_size = B_ROW * ROWMAX_SETS;
|
||||
constexpr uint32_t smem_rowsum_size = B_ROW;
|
||||
float *smem_rowmax =
|
||||
reinterpret_cast<float *>(SMEM_ADDR_END) - smem_rowmax_size;
|
||||
@@ -505,16 +534,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
|
||||
// load Q
|
||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_ROW,
|
||||
HEADDIM>(
|
||||
HEADDIM, threads_per_threadblock>(
|
||||
dim_seqlen, 0 /*FIXME: only work on first B_ROW rows of Q for now*/,
|
||||
0 /* always 0 because dim_k == headdim */, gmem_Q, smem_Q,
|
||||
tid_in_threadblock);
|
||||
|
||||
// load K
|
||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
|
||||
HEADDIM>(dim_seqlen, tile_k,
|
||||
0 /* always 0 because dim_k == headdim */,
|
||||
gmem_K, smem_K, tid_in_threadblock);
|
||||
HEADDIM, threads_per_threadblock>(
|
||||
dim_seqlen, tile_k, 0 /* always 0 because dim_k == headdim */, gmem_K,
|
||||
smem_K, tid_in_threadblock);
|
||||
|
||||
// GMEM->SMEM and compute barrier
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
@@ -533,8 +562,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
} else {
|
||||
// load Q*K
|
||||
load_tile_to_smem<float, MemLayout::K_major, MemLayout::K_major, B_COL,
|
||||
HEADDIM>(dim_seqlen, 0, tile_k, gmem_Q /*=gmem_S*/,
|
||||
smem_S, tid_in_threadblock);
|
||||
HEADDIM, threads_per_threadblock>(
|
||||
dim_seqlen, 0, tile_k, gmem_Q /*=gmem_S*/, smem_S,
|
||||
tid_in_threadblock);
|
||||
// the above should be equivalent to:
|
||||
// load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major,
|
||||
// B_COL,
|
||||
@@ -598,7 +628,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
|
||||
// V dimension is [seqlen, headdim], stored N(headdim)-major
|
||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
|
||||
HEADDIM>(
|
||||
HEADDIM, threads_per_threadblock>(
|
||||
HEADDIM, 0 /* 0 because always reads the full N-dimension */, tile_k,
|
||||
gmem_V, smem_V, tid_in_threadblock);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user