flash: Incomplete parallel stage-2 rowmax
This commit is contained in:
@@ -161,9 +161,9 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
|
||||
float *smem_rowmax_this = smem_rowmax + B_ROW;
|
||||
|
||||
for (int warp_offset = 0; warp_offset < B_ROW;
|
||||
warp_offset += warps_in_threadblock) {
|
||||
const uint32_t row = warp_offset + warp_id;
|
||||
for (int row_offset = 0; row_offset < B_ROW;
|
||||
row_offset += warps_in_threadblock) {
|
||||
const uint32_t row = row_offset + warp_id;
|
||||
const uint32_t first_thread_offset = B_COL * row;
|
||||
|
||||
// rowmax
|
||||
@@ -208,11 +208,13 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
|
||||
// #define PARALLEL_ROWMAX
|
||||
#ifndef PARALLEL_ROWMAX
|
||||
// elect 0-th thread to reduce all other thread's values in the warp
|
||||
if (tid_in_warp == 0) {
|
||||
float rowmax = per_thread_max;
|
||||
for (int iter = 1; iter < NUM_THREADS; iter++) {
|
||||
float other = warp_smem[iter];
|
||||
for (int i = 1; i < NUM_THREADS; i++) {
|
||||
float other = warp_smem[i];
|
||||
asm volatile("fmax.s %0, %1, %2"
|
||||
: "=f"(rowmax)
|
||||
: "f"(rowmax), "f"(other));
|
||||
@@ -230,9 +232,33 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
smem_rowmax[row] = rowmax;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
if (warp_id < warps_in_threadblock / NUM_THREADS) {
|
||||
const uint32_t row = row_offset + NUM_THREADS * warp_id + tid_in_warp;
|
||||
float *const thread_smem = smem_scratchpad + (tid_in_warp * NUM_THREADS);
|
||||
float rowmax = FLT_MIN;
|
||||
#pragma GCC unroll
|
||||
for (int i = 0; i < NUM_THREADS; i++) {
|
||||
const float f = thread_smem[i];
|
||||
asm volatile("fmax.s %0, %1, %2" : "=f"(rowmax) : "f"(rowmax), "f"(f));
|
||||
}
|
||||
smem_rowmax_this[row] = rowmax;
|
||||
|
||||
// update previous rowmax
|
||||
// i.e. mi_new = max(mi, mij)
|
||||
float prev_rowmax = smem_rowmax[row];
|
||||
// stage prev rowmax in scratchpad for warp-wide broadcast
|
||||
thread_smem[0] = prev_rowmax;
|
||||
asm volatile("fmax.s %0, %1, %2"
|
||||
: "=f"(rowmax)
|
||||
: "f"(rowmax), "f"(prev_rowmax));
|
||||
smem_rowmax[row] = rowmax;
|
||||
}
|
||||
#endif
|
||||
|
||||
// FIXME: unnecessary?
|
||||
#endif // DUMB_ROWMAX
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user