flash: Incomplete parallel stage-2 rowmax

This commit is contained in:
Hansung Kim
2024-08-29 13:29:00 -07:00
parent 4260bf7d6e
commit 5ba06dfd9d

View File

@@ -161,9 +161,9 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
float *smem_rowmax_this = smem_rowmax + B_ROW; float *smem_rowmax_this = smem_rowmax + B_ROW;
for (int warp_offset = 0; warp_offset < B_ROW; for (int row_offset = 0; row_offset < B_ROW;
warp_offset += warps_in_threadblock) { row_offset += warps_in_threadblock) {
const uint32_t row = warp_offset + warp_id; const uint32_t row = row_offset + warp_id;
const uint32_t first_thread_offset = B_COL * row; const uint32_t first_thread_offset = B_COL * row;
// rowmax // rowmax
@@ -208,11 +208,13 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
threadblock_barrier(threadblock_id_in_cluster, threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core); warps_per_threadblock_per_core);
// #define PARALLEL_ROWMAX
#ifndef PARALLEL_ROWMAX
// elect 0-th thread to reduce all other thread's values in the warp // elect 0-th thread to reduce all other thread's values in the warp
if (tid_in_warp == 0) { if (tid_in_warp == 0) {
float rowmax = per_thread_max; float rowmax = per_thread_max;
for (int iter = 1; iter < NUM_THREADS; iter++) { for (int i = 1; i < NUM_THREADS; i++) {
float other = warp_smem[iter]; float other = warp_smem[i];
asm volatile("fmax.s %0, %1, %2" asm volatile("fmax.s %0, %1, %2"
: "=f"(rowmax) : "=f"(rowmax)
: "f"(rowmax), "f"(other)); : "f"(rowmax), "f"(other));
@@ -230,9 +232,33 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
smem_rowmax[row] = rowmax; smem_rowmax[row] = rowmax;
} }
#else
if (warp_id < warps_in_threadblock / NUM_THREADS) {
const uint32_t row = row_offset + NUM_THREADS * warp_id + tid_in_warp;
float *const thread_smem = smem_scratchpad + (tid_in_warp * NUM_THREADS);
float rowmax = FLT_MIN;
#pragma GCC unroll
for (int i = 0; i < NUM_THREADS; i++) {
const float f = thread_smem[i];
asm volatile("fmax.s %0, %1, %2" : "=f"(rowmax) : "f"(rowmax), "f"(f));
}
smem_rowmax_this[row] = rowmax;
// update previous rowmax
// i.e. mi_new = max(mi, mij)
float prev_rowmax = smem_rowmax[row];
// stage prev rowmax in scratchpad for warp-wide broadcast
thread_smem[0] = prev_rowmax;
asm volatile("fmax.s %0, %1, %2"
: "=f"(rowmax)
: "f"(rowmax), "f"(prev_rowmax));
smem_rowmax[row] = rowmax;
}
#endif #endif
// FIXME: unnecessary? #endif // DUMB_ROWMAX
threadblock_barrier(threadblock_id_in_cluster, threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core); warps_per_threadblock_per_core);