flash: Incomplete parallel stage-2 rowmax
This commit is contained in:
@@ -161,9 +161,9 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
|||||||
|
|
||||||
float *smem_rowmax_this = smem_rowmax + B_ROW;
|
float *smem_rowmax_this = smem_rowmax + B_ROW;
|
||||||
|
|
||||||
for (int warp_offset = 0; warp_offset < B_ROW;
|
for (int row_offset = 0; row_offset < B_ROW;
|
||||||
warp_offset += warps_in_threadblock) {
|
row_offset += warps_in_threadblock) {
|
||||||
const uint32_t row = warp_offset + warp_id;
|
const uint32_t row = row_offset + warp_id;
|
||||||
const uint32_t first_thread_offset = B_COL * row;
|
const uint32_t first_thread_offset = B_COL * row;
|
||||||
|
|
||||||
// rowmax
|
// rowmax
|
||||||
@@ -208,11 +208,13 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
|||||||
threadblock_barrier(threadblock_id_in_cluster,
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
warps_per_threadblock_per_core);
|
warps_per_threadblock_per_core);
|
||||||
|
|
||||||
|
// #define PARALLEL_ROWMAX
|
||||||
|
#ifndef PARALLEL_ROWMAX
|
||||||
// elect 0-th thread to reduce all other thread's values in the warp
|
// elect 0-th thread to reduce all other thread's values in the warp
|
||||||
if (tid_in_warp == 0) {
|
if (tid_in_warp == 0) {
|
||||||
float rowmax = per_thread_max;
|
float rowmax = per_thread_max;
|
||||||
for (int iter = 1; iter < NUM_THREADS; iter++) {
|
for (int i = 1; i < NUM_THREADS; i++) {
|
||||||
float other = warp_smem[iter];
|
float other = warp_smem[i];
|
||||||
asm volatile("fmax.s %0, %1, %2"
|
asm volatile("fmax.s %0, %1, %2"
|
||||||
: "=f"(rowmax)
|
: "=f"(rowmax)
|
||||||
: "f"(rowmax), "f"(other));
|
: "f"(rowmax), "f"(other));
|
||||||
@@ -230,9 +232,33 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
|||||||
smem_rowmax[row] = rowmax;
|
smem_rowmax[row] = rowmax;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
if (warp_id < warps_in_threadblock / NUM_THREADS) {
|
||||||
|
const uint32_t row = row_offset + NUM_THREADS * warp_id + tid_in_warp;
|
||||||
|
float *const thread_smem = smem_scratchpad + (tid_in_warp * NUM_THREADS);
|
||||||
|
float rowmax = FLT_MIN;
|
||||||
|
#pragma GCC unroll
|
||||||
|
for (int i = 0; i < NUM_THREADS; i++) {
|
||||||
|
const float f = thread_smem[i];
|
||||||
|
asm volatile("fmax.s %0, %1, %2" : "=f"(rowmax) : "f"(rowmax), "f"(f));
|
||||||
|
}
|
||||||
|
smem_rowmax_this[row] = rowmax;
|
||||||
|
|
||||||
|
// update previous rowmax
|
||||||
|
// i.e. mi_new = max(mi, mij)
|
||||||
|
float prev_rowmax = smem_rowmax[row];
|
||||||
|
// stage prev rowmax in scratchpad for warp-wide broadcast
|
||||||
|
thread_smem[0] = prev_rowmax;
|
||||||
|
asm volatile("fmax.s %0, %1, %2"
|
||||||
|
: "=f"(rowmax)
|
||||||
|
: "f"(rowmax), "f"(prev_rowmax));
|
||||||
|
smem_rowmax[row] = rowmax;
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// FIXME: unnecessary?
|
#endif // DUMB_ROWMAX
|
||||||
|
|
||||||
threadblock_barrier(threadblock_id_in_cluster,
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
warps_per_threadblock_per_core);
|
warps_per_threadblock_per_core);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user