flash: Reduce smem use for rowmax; verify result

This commit is contained in:
Hansung Kim
2024-08-20 14:34:45 -07:00
parent d8d5df64e6
commit 615d36a5c2

View File

@@ -13,6 +13,8 @@
// FIXME
#define HEADDIM B_COL
constexpr bool DEBUG = true;
inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock,
const uint32_t threads_per_threadblock,
float *smem_O,
@@ -53,6 +55,66 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock,
}
}
inline void thread_block_copy_rowmax(const float *src, float *dest,
const uint32_t tid_in_threadblock,
const uint32_t threads_per_threadblock,
const uint32_t threadblocks_per_cluster,
const uint32_t threadblock_id_in_cluster) {
asm volatile("threadblock_copy_rowmax_start_%=:" ::);
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS;
const uint32_t warps_per_threadblock_per_core =
NUM_WARPS / threadblocks_per_cluster;
constexpr uint32_t num_warps = B_ROW / NUM_THREADS;
if (warp_id < num_warps) {
uint32_t offset = NUM_THREADS * warp_id + tid_in_warp;
dest[offset] = src[offset];
}
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
asm volatile("threadblock_copy_rowmax_finish_%=:" ::);
}
inline void thread_block_copy_tile(const float *src, float *dest,
const uint32_t tid_in_threadblock,
const uint32_t threads_per_threadblock,
const uint32_t threadblocks_per_cluster,
const uint32_t threadblock_id_in_cluster) {
asm volatile("threadblock_copy_tile_start_%=:" ::);
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS;
const uint32_t warps_per_threadblock_per_core =
NUM_WARPS / threadblocks_per_cluster;
// FIXME: dedup this pattern
for (int warp_offset = 0; warp_offset < B_ROW;
warp_offset += warps_in_threadblock) {
const uint32_t row = warp_offset + warp_id;
const uint32_t first_thread_offset = B_COL * row;
constexpr uint32_t per_row_iter = B_COL / NUM_THREADS;
uint32_t thread_offset = first_thread_offset + tid_in_warp;
float per_thread_max = FLT_MIN;
#pragma GCC unroll
for (int i = 0; i < per_row_iter; i++) {
dest[thread_offset] = src[thread_offset];
thread_offset += NUM_THREADS;
}
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
asm volatile("threadblock_copy_tile_finish_%=:" ::);
}
template <int order>
inline float exponential_taylor_term(const float x) {
asm volatile("exponential_taylor_term_start_%=:" ::);
@@ -73,38 +135,7 @@ inline float exponential_taylor_term(const float x) {
return res;
}
inline void thread_block_copy_data(const float *src, float *dest,
const uint32_t tid_in_threadblock,
const uint32_t threads_per_threadblock,
const uint32_t threadblocks_per_cluster,
const uint32_t threadblock_id_in_cluster) {
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS;
const uint32_t warps_per_threadblock_per_core =
NUM_WARPS / threadblocks_per_cluster;
for (int warp_offset = 0; warp_offset < B_ROW;
warp_offset += warps_in_threadblock) {
const uint32_t row = warp_offset + warp_id;
const uint32_t first_thread_offset = B_COL * row;
constexpr uint32_t per_row_iter = B_COL / NUM_THREADS;
uint32_t thread_offset = first_thread_offset + tid_in_warp;
float per_thread_max = FLT_MIN;
#pragma GCC unroll
for (int i = 0; i < per_row_iter; i++) {
const float f = src[thread_offset];
dest[thread_offset] = f;
thread_offset += NUM_THREADS;
}
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
}
inline void thread_block_online_softmax(
__attribute__((always_inline)) inline void thread_block_online_softmax(
const float *smem_S, float *smem_O, float *smem_P,
const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock,
const uint32_t threadblocks_per_cluster,
@@ -128,9 +159,7 @@ inline void thread_block_online_softmax(
// asm volatile("fmv.s %0, f22" : "=f"(ft[6]));
// asm volatile("fmv.s %0, f23" : "=f"(ft[7]));
float *smem_rowmax_prev = smem_rowmax;
float *smem_rowmax_new = smem_rowmax + B_ROW;
float *smem_rowmax_this = smem_rowmax + 2 * B_ROW;
float *smem_rowmax_this = smem_rowmax + B_ROW;
for (int warp_offset = 0; warp_offset < B_ROW;
warp_offset += warps_in_threadblock) {
@@ -192,26 +221,34 @@ inline void thread_block_online_softmax(
// update previous rowmax
// i.e. mi_new = max(mi, mij)
float prev_rowmax = smem_rowmax_prev[row];
float prev_rowmax = smem_rowmax[row];
// stage prev rowmax in scratchpad for warp-wide broadcast
warp_smem[0] = prev_rowmax;
asm volatile("fmax.s %0, %1, %2"
: "=f"(rowmax)
: "f"(rowmax), "f"(prev_rowmax));
smem_rowmax_new[row] = rowmax;
smem_rowmax[row] = rowmax;
}
#endif
// FIXME: unnecessary?
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
// broadcast prev rowmax to all threads in the warp
// NOTE: memory consistency is a little sketchy here
const float rowmax_prev = warp_smem[0];
const float rowmax_this = smem_rowmax_this[row];
// exponential
//
// B_ROW / (B_ROW * B_COL / (exp_elem * threads_per_threadblock))
// const uint32_t row_stride =
// (exp_elem_per_thread * threads_per_threadblock) / B_COL;
// broadcast rowmax to all threads in the warp
const float rowmax_new = smem_rowmax_new[row];
// broadcast updated rowmax to all threads in the warp
const float rowmax_new = smem_rowmax[row];
// each thread computes two fp32 elements, downconverts it to fp16, then
// packs them into one fp32
@@ -279,8 +316,9 @@ inline void thread_block_online_softmax(
rowsum += other;
}
const float mi_prev = smem_rowmax_prev[row];
const float mi_this = smem_rowmax_this[row];
const float mi_prev = rowmax_prev;
// TODO: replace this with a register?
const float mi_this = rowmax_this;
const float x = mi_prev - mi_this;
// 2nd-order Taylor approximation
@@ -309,8 +347,8 @@ inline void thread_block_online_softmax(
for (int i = 0; i < per_row_iter; i++) {
float o = smem_O[thread_offset];
const float mi_prev = smem_rowmax_prev[row];
const float mi_new = smem_rowmax_new[row];
const float mi_prev = rowmax_prev;
const float mi_new = rowmax_new;
const float x = mi_prev - mi_new;
// 2nd-order Taylor approximation
@@ -398,9 +436,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// sharedmem "scratchpad" area to put temporary data, e.g. for tree reduction
// in rowsum
// NOTE: out-of bounds is not checked
// TODO: reduce this from B_ROW to NUM_WARPS
constexpr uint32_t smem_scratchpad_size =
B_ROW * NUM_THREADS * 2 /*arbitrary slack*/;
float *smem_scratchpad = smem_rowmax - smem_scratchpad_size;
float *smem_scratchpad = smem_rowsum - smem_scratchpad_size;
const uint32_t warps_per_threadblock_per_core =
NUM_WARPS / threadblocks_per_cluster;
@@ -414,6 +453,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
const float *gmem_V = reinterpret_cast<float *>(arg->addr_v);
float *gmem_O = reinterpret_cast<float *>(arg->addr_o);
float *gmem_tmp_d0 = reinterpret_cast<float *>(0xd0000000UL);
float *gmem_tmp_d1 = reinterpret_cast<float *>(0xd1000000UL);
float *gmem_tmp_d2 = reinterpret_cast<float *>(0xd2000000UL);
float *gmem_tmp_d3 = reinterpret_cast<float *>(0xd3000000UL);
float *gmem_tmp_d4 = reinterpret_cast<float *>(0xd4000000UL);
float *gmem_tmp_d5 = reinterpret_cast<float *>(0xd5000000UL);
float *gmem_tmp_e0 = reinterpret_cast<float *>(0xe0000000UL);
float *gmem_tmp_e1 = reinterpret_cast<float *>(0xe1000000UL);
float *gmem_tmp_e2 = reinterpret_cast<float *>(0xe2000000UL);
float *gmem_tmp_e3 = reinterpret_cast<float *>(0xe3000000UL);
// "inner loop" along the columns of K^T
for (uint32_t tile_k = 0; tile_k < (dim_seqlen / B_COL); tile_k++) {
@@ -469,6 +519,43 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
if constexpr (DEBUG) {
if (tile_k == 0) {
thread_block_copy_tile(
smem_P, gmem_tmp_d0, tid_in_threadblock, threads_per_threadblock,
threadblocks_per_cluster, threadblock_id_in_cluster);
thread_block_copy_tile(
smem_O, gmem_tmp_d2, tid_in_threadblock, threads_per_threadblock,
threadblocks_per_cluster, threadblock_id_in_cluster);
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_threadblock,
threads_per_threadblock,
threadblocks_per_cluster,
threadblock_id_in_cluster);
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_threadblock,
threads_per_threadblock,
threadblocks_per_cluster,
threadblock_id_in_cluster);
} else if (tile_k == 1) {
thread_block_copy_tile(
smem_P, gmem_tmp_d1, tid_in_threadblock, threads_per_threadblock,
threadblocks_per_cluster, threadblock_id_in_cluster);
thread_block_copy_tile(
smem_O, gmem_tmp_d3, tid_in_threadblock, threads_per_threadblock,
threadblocks_per_cluster, threadblock_id_in_cluster);
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_threadblock,
threads_per_threadblock,
threadblocks_per_cluster,
threadblock_id_in_cluster);
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, tid_in_threadblock,
threads_per_threadblock,
threadblocks_per_cluster,
threadblock_id_in_cluster);
}
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
// GEMM II: O = O + P*V
// clear out accumulators
@@ -495,18 +582,22 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
if constexpr (DEBUG) {
if (tile_k == 0) {
thread_block_copy_tile(
smem_O, gmem_tmp_d4, tid_in_threadblock, threads_per_threadblock,
threadblocks_per_cluster, threadblock_id_in_cluster);
} else if (tile_k == 1) {
thread_block_copy_tile(
smem_O, gmem_tmp_d5, tid_in_threadblock, threads_per_threadblock,
threadblocks_per_cluster, threadblock_id_in_cluster);
}
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
}
float *gmem_tmp0 = reinterpret_cast<float *>(0xd0000000UL);
float *gmem_tmp1 = reinterpret_cast<float *>(0xe0000000UL);
// copy out tile data to GMEM for debugging
thread_block_copy_data(smem_P, gmem_tmp0, tid_in_threadblock,
threads_per_threadblock, threadblocks_per_cluster,
threadblock_id_in_cluster);
thread_block_copy_data(smem_O, gmem_tmp1, tid_in_threadblock,
threads_per_threadblock, threadblocks_per_cluster,
threadblock_id_in_cluster);
}
int main() {