flash: Reduce smem use for rowmax; verify result
This commit is contained in:
@@ -13,6 +13,8 @@
|
||||
// FIXME
|
||||
#define HEADDIM B_COL
|
||||
|
||||
constexpr bool DEBUG = true;
|
||||
|
||||
inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock,
|
||||
const uint32_t threads_per_threadblock,
|
||||
float *smem_O,
|
||||
@@ -53,6 +55,66 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock,
|
||||
}
|
||||
}
|
||||
|
||||
inline void thread_block_copy_rowmax(const float *src, float *dest,
|
||||
const uint32_t tid_in_threadblock,
|
||||
const uint32_t threads_per_threadblock,
|
||||
const uint32_t threadblocks_per_cluster,
|
||||
const uint32_t threadblock_id_in_cluster) {
|
||||
asm volatile("threadblock_copy_rowmax_start_%=:" ::);
|
||||
|
||||
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
||||
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
|
||||
const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS;
|
||||
const uint32_t warps_per_threadblock_per_core =
|
||||
NUM_WARPS / threadblocks_per_cluster;
|
||||
|
||||
constexpr uint32_t num_warps = B_ROW / NUM_THREADS;
|
||||
if (warp_id < num_warps) {
|
||||
uint32_t offset = NUM_THREADS * warp_id + tid_in_warp;
|
||||
dest[offset] = src[offset];
|
||||
}
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
|
||||
asm volatile("threadblock_copy_rowmax_finish_%=:" ::);
|
||||
}
|
||||
|
||||
inline void thread_block_copy_tile(const float *src, float *dest,
|
||||
const uint32_t tid_in_threadblock,
|
||||
const uint32_t threads_per_threadblock,
|
||||
const uint32_t threadblocks_per_cluster,
|
||||
const uint32_t threadblock_id_in_cluster) {
|
||||
asm volatile("threadblock_copy_tile_start_%=:" ::);
|
||||
|
||||
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
||||
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
|
||||
const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS;
|
||||
const uint32_t warps_per_threadblock_per_core =
|
||||
NUM_WARPS / threadblocks_per_cluster;
|
||||
|
||||
// FIXME: dedup this pattern
|
||||
for (int warp_offset = 0; warp_offset < B_ROW;
|
||||
warp_offset += warps_in_threadblock) {
|
||||
const uint32_t row = warp_offset + warp_id;
|
||||
const uint32_t first_thread_offset = B_COL * row;
|
||||
|
||||
constexpr uint32_t per_row_iter = B_COL / NUM_THREADS;
|
||||
uint32_t thread_offset = first_thread_offset + tid_in_warp;
|
||||
float per_thread_max = FLT_MIN;
|
||||
#pragma GCC unroll
|
||||
for (int i = 0; i < per_row_iter; i++) {
|
||||
dest[thread_offset] = src[thread_offset];
|
||||
thread_offset += NUM_THREADS;
|
||||
}
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
}
|
||||
|
||||
asm volatile("threadblock_copy_tile_finish_%=:" ::);
|
||||
}
|
||||
|
||||
template <int order>
|
||||
inline float exponential_taylor_term(const float x) {
|
||||
asm volatile("exponential_taylor_term_start_%=:" ::);
|
||||
@@ -73,38 +135,7 @@ inline float exponential_taylor_term(const float x) {
|
||||
return res;
|
||||
}
|
||||
|
||||
inline void thread_block_copy_data(const float *src, float *dest,
|
||||
const uint32_t tid_in_threadblock,
|
||||
const uint32_t threads_per_threadblock,
|
||||
const uint32_t threadblocks_per_cluster,
|
||||
const uint32_t threadblock_id_in_cluster) {
|
||||
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
||||
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
|
||||
const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS;
|
||||
const uint32_t warps_per_threadblock_per_core =
|
||||
NUM_WARPS / threadblocks_per_cluster;
|
||||
|
||||
for (int warp_offset = 0; warp_offset < B_ROW;
|
||||
warp_offset += warps_in_threadblock) {
|
||||
const uint32_t row = warp_offset + warp_id;
|
||||
const uint32_t first_thread_offset = B_COL * row;
|
||||
|
||||
constexpr uint32_t per_row_iter = B_COL / NUM_THREADS;
|
||||
uint32_t thread_offset = first_thread_offset + tid_in_warp;
|
||||
float per_thread_max = FLT_MIN;
|
||||
#pragma GCC unroll
|
||||
for (int i = 0; i < per_row_iter; i++) {
|
||||
const float f = src[thread_offset];
|
||||
dest[thread_offset] = f;
|
||||
thread_offset += NUM_THREADS;
|
||||
}
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
}
|
||||
}
|
||||
|
||||
inline void thread_block_online_softmax(
|
||||
__attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
const float *smem_S, float *smem_O, float *smem_P,
|
||||
const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock,
|
||||
const uint32_t threadblocks_per_cluster,
|
||||
@@ -128,9 +159,7 @@ inline void thread_block_online_softmax(
|
||||
// asm volatile("fmv.s %0, f22" : "=f"(ft[6]));
|
||||
// asm volatile("fmv.s %0, f23" : "=f"(ft[7]));
|
||||
|
||||
float *smem_rowmax_prev = smem_rowmax;
|
||||
float *smem_rowmax_new = smem_rowmax + B_ROW;
|
||||
float *smem_rowmax_this = smem_rowmax + 2 * B_ROW;
|
||||
float *smem_rowmax_this = smem_rowmax + B_ROW;
|
||||
|
||||
for (int warp_offset = 0; warp_offset < B_ROW;
|
||||
warp_offset += warps_in_threadblock) {
|
||||
@@ -192,26 +221,34 @@ inline void thread_block_online_softmax(
|
||||
|
||||
// update previous rowmax
|
||||
// i.e. mi_new = max(mi, mij)
|
||||
float prev_rowmax = smem_rowmax_prev[row];
|
||||
float prev_rowmax = smem_rowmax[row];
|
||||
// stage prev rowmax in scratchpad for warp-wide broadcast
|
||||
warp_smem[0] = prev_rowmax;
|
||||
asm volatile("fmax.s %0, %1, %2"
|
||||
: "=f"(rowmax)
|
||||
: "f"(rowmax), "f"(prev_rowmax));
|
||||
smem_rowmax_new[row] = rowmax;
|
||||
smem_rowmax[row] = rowmax;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// FIXME: unnecessary?
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
|
||||
// broadcast prev rowmax to all threads in the warp
|
||||
// NOTE: memory consistency is a little sketchy here
|
||||
const float rowmax_prev = warp_smem[0];
|
||||
const float rowmax_this = smem_rowmax_this[row];
|
||||
|
||||
// exponential
|
||||
//
|
||||
// B_ROW / (B_ROW * B_COL / (exp_elem * threads_per_threadblock))
|
||||
// const uint32_t row_stride =
|
||||
// (exp_elem_per_thread * threads_per_threadblock) / B_COL;
|
||||
|
||||
// broadcast rowmax to all threads in the warp
|
||||
const float rowmax_new = smem_rowmax_new[row];
|
||||
// broadcast updated rowmax to all threads in the warp
|
||||
const float rowmax_new = smem_rowmax[row];
|
||||
|
||||
// each thread computes two fp32 elements, downconverts it to fp16, then
|
||||
// packs them into one fp32
|
||||
@@ -279,8 +316,9 @@ inline void thread_block_online_softmax(
|
||||
rowsum += other;
|
||||
}
|
||||
|
||||
const float mi_prev = smem_rowmax_prev[row];
|
||||
const float mi_this = smem_rowmax_this[row];
|
||||
const float mi_prev = rowmax_prev;
|
||||
// TODO: replace this with a register?
|
||||
const float mi_this = rowmax_this;
|
||||
|
||||
const float x = mi_prev - mi_this;
|
||||
// 2nd-order Taylor approximation
|
||||
@@ -309,8 +347,8 @@ inline void thread_block_online_softmax(
|
||||
for (int i = 0; i < per_row_iter; i++) {
|
||||
float o = smem_O[thread_offset];
|
||||
|
||||
const float mi_prev = smem_rowmax_prev[row];
|
||||
const float mi_new = smem_rowmax_new[row];
|
||||
const float mi_prev = rowmax_prev;
|
||||
const float mi_new = rowmax_new;
|
||||
|
||||
const float x = mi_prev - mi_new;
|
||||
// 2nd-order Taylor approximation
|
||||
@@ -398,9 +436,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// sharedmem "scratchpad" area to put temporary data, e.g. for tree reduction
|
||||
// in rowsum
|
||||
// NOTE: out-of bounds is not checked
|
||||
// TODO: reduce this from B_ROW to NUM_WARPS
|
||||
constexpr uint32_t smem_scratchpad_size =
|
||||
B_ROW * NUM_THREADS * 2 /*arbitrary slack*/;
|
||||
float *smem_scratchpad = smem_rowmax - smem_scratchpad_size;
|
||||
float *smem_scratchpad = smem_rowsum - smem_scratchpad_size;
|
||||
|
||||
const uint32_t warps_per_threadblock_per_core =
|
||||
NUM_WARPS / threadblocks_per_cluster;
|
||||
@@ -414,6 +453,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
const float *gmem_V = reinterpret_cast<float *>(arg->addr_v);
|
||||
float *gmem_O = reinterpret_cast<float *>(arg->addr_o);
|
||||
|
||||
float *gmem_tmp_d0 = reinterpret_cast<float *>(0xd0000000UL);
|
||||
float *gmem_tmp_d1 = reinterpret_cast<float *>(0xd1000000UL);
|
||||
float *gmem_tmp_d2 = reinterpret_cast<float *>(0xd2000000UL);
|
||||
float *gmem_tmp_d3 = reinterpret_cast<float *>(0xd3000000UL);
|
||||
float *gmem_tmp_d4 = reinterpret_cast<float *>(0xd4000000UL);
|
||||
float *gmem_tmp_d5 = reinterpret_cast<float *>(0xd5000000UL);
|
||||
float *gmem_tmp_e0 = reinterpret_cast<float *>(0xe0000000UL);
|
||||
float *gmem_tmp_e1 = reinterpret_cast<float *>(0xe1000000UL);
|
||||
float *gmem_tmp_e2 = reinterpret_cast<float *>(0xe2000000UL);
|
||||
float *gmem_tmp_e3 = reinterpret_cast<float *>(0xe3000000UL);
|
||||
|
||||
// "inner loop" along the columns of K^T
|
||||
for (uint32_t tile_k = 0; tile_k < (dim_seqlen / B_COL); tile_k++) {
|
||||
|
||||
@@ -469,6 +519,43 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
|
||||
if constexpr (DEBUG) {
|
||||
if (tile_k == 0) {
|
||||
thread_block_copy_tile(
|
||||
smem_P, gmem_tmp_d0, tid_in_threadblock, threads_per_threadblock,
|
||||
threadblocks_per_cluster, threadblock_id_in_cluster);
|
||||
thread_block_copy_tile(
|
||||
smem_O, gmem_tmp_d2, tid_in_threadblock, threads_per_threadblock,
|
||||
threadblocks_per_cluster, threadblock_id_in_cluster);
|
||||
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_threadblock,
|
||||
threads_per_threadblock,
|
||||
threadblocks_per_cluster,
|
||||
threadblock_id_in_cluster);
|
||||
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_threadblock,
|
||||
threads_per_threadblock,
|
||||
threadblocks_per_cluster,
|
||||
threadblock_id_in_cluster);
|
||||
} else if (tile_k == 1) {
|
||||
thread_block_copy_tile(
|
||||
smem_P, gmem_tmp_d1, tid_in_threadblock, threads_per_threadblock,
|
||||
threadblocks_per_cluster, threadblock_id_in_cluster);
|
||||
thread_block_copy_tile(
|
||||
smem_O, gmem_tmp_d3, tid_in_threadblock, threads_per_threadblock,
|
||||
threadblocks_per_cluster, threadblock_id_in_cluster);
|
||||
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_threadblock,
|
||||
threads_per_threadblock,
|
||||
threadblocks_per_cluster,
|
||||
threadblock_id_in_cluster);
|
||||
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, tid_in_threadblock,
|
||||
threads_per_threadblock,
|
||||
threadblocks_per_cluster,
|
||||
threadblock_id_in_cluster);
|
||||
}
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
}
|
||||
|
||||
// GEMM II: O = O + P*V
|
||||
|
||||
// clear out accumulators
|
||||
@@ -495,18 +582,22 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
|
||||
if constexpr (DEBUG) {
|
||||
if (tile_k == 0) {
|
||||
thread_block_copy_tile(
|
||||
smem_O, gmem_tmp_d4, tid_in_threadblock, threads_per_threadblock,
|
||||
threadblocks_per_cluster, threadblock_id_in_cluster);
|
||||
} else if (tile_k == 1) {
|
||||
thread_block_copy_tile(
|
||||
smem_O, gmem_tmp_d5, tid_in_threadblock, threads_per_threadblock,
|
||||
threadblocks_per_cluster, threadblock_id_in_cluster);
|
||||
}
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
}
|
||||
}
|
||||
|
||||
float *gmem_tmp0 = reinterpret_cast<float *>(0xd0000000UL);
|
||||
float *gmem_tmp1 = reinterpret_cast<float *>(0xe0000000UL);
|
||||
|
||||
// copy out tile data to GMEM for debugging
|
||||
thread_block_copy_data(smem_P, gmem_tmp0, tid_in_threadblock,
|
||||
threads_per_threadblock, threadblocks_per_cluster,
|
||||
threadblock_id_in_cluster);
|
||||
thread_block_copy_data(smem_O, gmem_tmp1, tid_in_threadblock,
|
||||
threads_per_threadblock, threadblocks_per_cluster,
|
||||
threadblock_id_in_cluster);
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
Reference in New Issue
Block a user