Fix overlapping smem in rowmax

This commit is contained in:
Hansung Kim
2024-08-14 21:09:47 -07:00
parent 692d028afd
commit 9cabe3413b

View File

@@ -43,8 +43,8 @@ inline void thread_block_flashattn(float *S, float *gmem,
const uint32_t first_thread_offset = Bcol * row;
uint32_t thread_offset = first_thread_offset + tid_in_warp;
constexpr uint32_t load_iter = Bcol / NUM_THREADS;
float curr_max = S[first_thread_offset];
constexpr uint32_t load_iter = Bcol / NUM_THREADS;
#pragma GCC unroll
for (int iter = 0; iter < load_iter; iter++) {
asm volatile("fmax.s %0, %1, %2"
@@ -53,7 +53,7 @@ inline void thread_block_flashattn(float *S, float *gmem,
thread_offset += NUM_THREADS;
}
// get max value across the same-warp threads using smem
float *warp_smem = S + (row * NUM_THREADS);
float *warp_smem = S + (2 * Brow * Bcol) + (row * NUM_THREADS);
warp_smem[tid_in_warp] = curr_max;
// sync writes to warp_smem