diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 52947c73..4295e69d 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -161,9 +161,9 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( float *smem_rowmax_this = smem_rowmax + B_ROW; - for (int warp_offset = 0; warp_offset < B_ROW; - warp_offset += warps_in_threadblock) { - const uint32_t row = warp_offset + warp_id; + for (int row_offset = 0; row_offset < B_ROW; + row_offset += warps_in_threadblock) { + const uint32_t row = row_offset + warp_id; const uint32_t first_thread_offset = B_COL * row; // rowmax @@ -208,11 +208,13 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core); +// #define PARALLEL_ROWMAX +#ifndef PARALLEL_ROWMAX // elect 0-th thread to reduce all other thread's values in the warp if (tid_in_warp == 0) { float rowmax = per_thread_max; - for (int iter = 1; iter < NUM_THREADS; iter++) { - float other = warp_smem[iter]; + for (int i = 1; i < NUM_THREADS; i++) { + float other = warp_smem[i]; asm volatile("fmax.s %0, %1, %2" : "=f"(rowmax) : "f"(rowmax), "f"(other)); @@ -230,9 +232,33 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( smem_rowmax[row] = rowmax; } +#else + + if (warp_id < warps_in_threadblock / NUM_THREADS) { + const uint32_t row = row_offset + NUM_THREADS * warp_id + tid_in_warp; + float *const thread_smem = smem_scratchpad + (tid_in_warp * NUM_THREADS); + float rowmax = FLT_MIN; +#pragma GCC unroll + for (int i = 0; i < NUM_THREADS; i++) { + const float f = thread_smem[i]; + asm volatile("fmax.s %0, %1, %2" : "=f"(rowmax) : "f"(rowmax), "f"(f)); + } + smem_rowmax_this[row] = rowmax; + + // update previous rowmax + // i.e. mi_new = max(mi, mij) + float prev_rowmax = smem_rowmax[row]; + // stage prev rowmax in scratchpad for warp-wide broadcast + thread_smem[0] = prev_rowmax; + asm volatile("fmax.s %0, %1, %2" + : "=f"(rowmax) + : "f"(rowmax), "f"(prev_rowmax)); + smem_rowmax[row] = rowmax; + } #endif - // FIXME: unnecessary? +#endif // DUMB_ROWMAX + threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core);