Fix overlapping smem in rowmax
This commit is contained in:
@@ -43,8 +43,8 @@ inline void thread_block_flashattn(float *S, float *gmem,
|
||||
const uint32_t first_thread_offset = Bcol * row;
|
||||
uint32_t thread_offset = first_thread_offset + tid_in_warp;
|
||||
|
||||
constexpr uint32_t load_iter = Bcol / NUM_THREADS;
|
||||
float curr_max = S[first_thread_offset];
|
||||
constexpr uint32_t load_iter = Bcol / NUM_THREADS;
|
||||
#pragma GCC unroll
|
||||
for (int iter = 0; iter < load_iter; iter++) {
|
||||
asm volatile("fmax.s %0, %1, %2"
|
||||
@@ -53,7 +53,7 @@ inline void thread_block_flashattn(float *S, float *gmem,
|
||||
thread_offset += NUM_THREADS;
|
||||
}
|
||||
// get max value across the same-warp threads using smem
|
||||
float *warp_smem = S + (row * NUM_THREADS);
|
||||
float *warp_smem = S + (2 * Brow * Bcol) + (row * NUM_THREADS);
|
||||
warp_smem[tid_in_warp] = curr_max;
|
||||
|
||||
// sync writes to warp_smem
|
||||
|
||||
Reference in New Issue
Block a user