flash: Add DOUBLE_BUF compile-time param (wip)

This commit is contained in:
Hansung Kim
2024-08-29 14:18:32 -07:00
parent 5ba06dfd9d
commit fd1ab358fa

View File

@@ -14,6 +14,7 @@
#define HEADDIM B_COL
constexpr bool DEBUG = true;
constexpr bool DOUBLE_BUF = false;
inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock,
const uint32_t threads_per_threadblock,
@@ -26,12 +27,14 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock,
static_assert((B_ROW % NUM_THREADS) == 0,
"B_ROW must be a multiple of NUM_THREADS");
// FIXME: this shouldn't be necessary
static_assert(B_ROW < (NUM_THREADS * CORES_PER_CLUSTER * NUM_WARPS),
static_assert(B_ROW < (NUM_THREADS * CORES_PER_CLUSTER *
(NUM_WARPS / (DOUBLE_BUF ? 2 : 1))),
"not enough warps to initialize rowmax/rowsum");
constexpr uint32_t num_warps = B_ROW / NUM_THREADS;
if (warp_id < num_warps) {
// each thread initializes one element in rowmax/rowsum
// multiple warps participate for the whole vector
constexpr uint32_t needed_warps = B_ROW / NUM_THREADS;
if (warp_id < needed_warps /* more warps in HW than needed? */) {
uint32_t offset = NUM_THREADS * warp_id + tid_in_warp;
// mi, mi~, minew
smem_rowmax[offset] = FLT_MIN;
@@ -40,10 +43,10 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock,
smem_rowsum[offset] = 0.0f;
}
// each warp clears out a row of smem_O
// FIXME: dedup this pattern
for (int warp_offset = 0; warp_offset < B_COL;
warp_offset += warps_in_threadblock) {
// each warp clears out a row of smem_O
const uint32_t row = warp_offset + warp_id;
uint32_t thread_offset = HEADDIM * row + tid_in_warp;
constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS;
@@ -58,7 +61,6 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock,
inline void thread_block_copy_rowmax(const float *src, float *dest,
const uint32_t tid_in_threadblock,
const uint32_t threads_per_threadblock,
const uint32_t threadblocks_per_cluster,
const uint32_t threadblock_id_in_cluster) {
asm volatile("threadblock_copy_rowmax_start_%=:" ::);
@@ -66,8 +68,10 @@ inline void thread_block_copy_rowmax(const float *src, float *dest,
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS;
const uint32_t warps_per_threadblock_per_core =
NUM_WARPS / threadblocks_per_cluster;
warps_in_threadblock / CORES_PER_CLUSTER;
// each thread copies one element in rowmax
// multiple warps participate for the whole vector
constexpr uint32_t num_warps = B_ROW / NUM_THREADS;
if (warp_id < num_warps) {
uint32_t offset = NUM_THREADS * warp_id + tid_in_warp;
@@ -83,7 +87,6 @@ inline void thread_block_copy_rowmax(const float *src, float *dest,
inline void thread_block_copy_tile(const float *src, float *dest,
const uint32_t tid_in_threadblock,
const uint32_t threads_per_threadblock,
const uint32_t threadblocks_per_cluster,
const uint32_t threadblock_id_in_cluster) {
asm volatile("threadblock_copy_tile_start_%=:" ::);
@@ -91,7 +94,7 @@ inline void thread_block_copy_tile(const float *src, float *dest,
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS;
const uint32_t warps_per_threadblock_per_core =
NUM_WARPS / threadblocks_per_cluster;
warps_in_threadblock / CORES_PER_CLUSTER;
// FIXME: dedup this pattern
for (int warp_offset = 0; warp_offset < B_ROW;
@@ -138,7 +141,6 @@ inline float exponential_taylor_term(const float x) {
__attribute__((always_inline)) inline void thread_block_online_softmax(
const float *smem_S, float *smem_O, float *smem_P,
const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock,
const uint32_t threadblocks_per_cluster,
const uint32_t threadblock_id_in_cluster, float *smem_scratchpad,
float *smem_rowmax, float *smem_rowsum) {
asm volatile("thread_block_online_softmax_start_%=:" ::);
@@ -147,7 +149,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS;
const uint32_t warps_per_threadblock_per_core =
NUM_WARPS / threadblocks_per_cluster;
warps_in_threadblock / CORES_PER_CLUSTER;
// float ft[8];
// asm volatile("fmv.s %0, f16" : "=f"(ft[0]));
@@ -402,7 +404,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
#endif
// FIXME: headdim not considered
uint32_t threads_per_threadblock = (B_ROW * B_COL) / (ELEM_PER_THREAD);
uint32_t threads_per_threadblock =
(B_ROW * B_COL) / (ELEM_PER_THREAD) / (DOUBLE_BUF ? 2 : 1);
const uint32_t hw_threads_per_cluster =
cores_per_cluster * vx_num_threads() * vx_num_warps();
// cap maximum threadblock size to # of HW threads in cluster, to prevent
@@ -418,6 +421,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
threadblock_id % threadblocks_per_cluster;
const int tid_in_threadblock = task_id % threads_per_threadblock;
// FIXME do proper software pipelining
if (DOUBLE_BUF && threadblock_id != 0) {
return;
}
const uint32_t dim_seqlen = arg->dim_seqlen;
const uint32_t dim_headdim = arg->dim_headdim;
@@ -528,7 +536,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
HEADDIM>(dim_seqlen, 0, tile_k, gmem_Q /*=gmem_S*/,
smem_S, tid_in_threadblock);
// the above should be equivalent to:
// load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
// load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major,
// B_COL,
// HEADDIM>(dim_seqlen, tile_k, 0, gmem_Q /*=gmem_S*/,
// smem_S, tid_in_threadblock);
@@ -541,7 +550,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
thread_block_online_softmax(
smem_S, smem_O, smem_P, tid_in_threadblock, threads_per_threadblock,
threadblocks_per_cluster, threadblock_id_in_cluster, smem_scratchpad,
threadblock_id_in_cluster, smem_scratchpad,
smem_rowmax, smem_rowsum);
// FIXME unnecessary?
@@ -550,34 +559,30 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
if constexpr (DEBUG) {
if (tile_k == 0) {
thread_block_copy_tile(
smem_P, gmem_tmp_d0, tid_in_threadblock, threads_per_threadblock,
threadblocks_per_cluster, threadblock_id_in_cluster);
thread_block_copy_tile(
smem_O, gmem_tmp_d2, tid_in_threadblock, threads_per_threadblock,
threadblocks_per_cluster, threadblock_id_in_cluster);
thread_block_copy_tile(smem_P, gmem_tmp_d0, tid_in_threadblock,
threads_per_threadblock,
threadblock_id_in_cluster);
thread_block_copy_tile(smem_O, gmem_tmp_d2, tid_in_threadblock,
threads_per_threadblock,
threadblock_id_in_cluster);
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_threadblock,
threads_per_threadblock,
threadblocks_per_cluster,
threadblock_id_in_cluster);
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_threadblock,
threads_per_threadblock,
threadblocks_per_cluster,
threadblock_id_in_cluster);
} else if (tile_k == k_tiles - 1) {
thread_block_copy_tile(
smem_P, gmem_tmp_d1, tid_in_threadblock, threads_per_threadblock,
threadblocks_per_cluster, threadblock_id_in_cluster);
thread_block_copy_tile(
smem_O, gmem_tmp_d3, tid_in_threadblock, threads_per_threadblock,
threadblocks_per_cluster, threadblock_id_in_cluster);
thread_block_copy_tile(smem_P, gmem_tmp_d1, tid_in_threadblock,
threads_per_threadblock,
threadblock_id_in_cluster);
thread_block_copy_tile(smem_O, gmem_tmp_d3, tid_in_threadblock,
threads_per_threadblock,
threadblock_id_in_cluster);
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_threadblock,
threads_per_threadblock,
threadblocks_per_cluster,
threadblock_id_in_cluster);
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, tid_in_threadblock,
threads_per_threadblock,
threadblocks_per_cluster,
threadblock_id_in_cluster);
}
@@ -601,12 +606,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
warps_per_threadblock_per_core);
thread_block_gemm_single_tile<float, MemLayout::K_major,
MemLayout::MN_major,
B_ROW, HEADDIM, B_COL,
MemLayout::MN_major, B_ROW, HEADDIM, B_COL,
/*load_accum=*/true,
/*write_to_smem=*/true>(
smem_P, smem_V, smem_O /*load accum*/, smem_O,
tid_in_threadblock, threads_per_threadblock, threadblocks_per_cluster,
smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_threadblock,
threads_per_threadblock, threadblocks_per_cluster,
threadblock_id_in_cluster);
// FIXME: wrong but fast
// thread_block_gemm_single_tile<float, MemLayout::MN_major,
@@ -623,13 +627,13 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
if constexpr (DEBUG) {
if (tile_k == 0) {
thread_block_copy_tile(
smem_O, gmem_tmp_d4, tid_in_threadblock, threads_per_threadblock,
threadblocks_per_cluster, threadblock_id_in_cluster);
thread_block_copy_tile(smem_O, gmem_tmp_d4, tid_in_threadblock,
threads_per_threadblock,
threadblock_id_in_cluster);
} else if (tile_k == k_tiles - 1) {
thread_block_copy_tile(
smem_O, gmem_tmp_d5, tid_in_threadblock, threads_per_threadblock,
threadblocks_per_cluster, threadblock_id_in_cluster);
thread_block_copy_tile(smem_O, gmem_tmp_d5, tid_in_threadblock,
threads_per_threadblock,
threadblock_id_in_cluster);
}
threadblock_barrier(threadblock_id_in_cluster,