flash: Add DOUBLE_BUF compile-time param (wip)
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
#define HEADDIM B_COL
|
||||
|
||||
constexpr bool DEBUG = true;
|
||||
constexpr bool DOUBLE_BUF = false;
|
||||
|
||||
inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock,
|
||||
const uint32_t threads_per_threadblock,
|
||||
@@ -26,12 +27,14 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock,
|
||||
|
||||
static_assert((B_ROW % NUM_THREADS) == 0,
|
||||
"B_ROW must be a multiple of NUM_THREADS");
|
||||
// FIXME: this shouldn't be necessary
|
||||
static_assert(B_ROW < (NUM_THREADS * CORES_PER_CLUSTER * NUM_WARPS),
|
||||
static_assert(B_ROW < (NUM_THREADS * CORES_PER_CLUSTER *
|
||||
(NUM_WARPS / (DOUBLE_BUF ? 2 : 1))),
|
||||
"not enough warps to initialize rowmax/rowsum");
|
||||
|
||||
constexpr uint32_t num_warps = B_ROW / NUM_THREADS;
|
||||
if (warp_id < num_warps) {
|
||||
// each thread initializes one element in rowmax/rowsum
|
||||
// multiple warps participate for the whole vector
|
||||
constexpr uint32_t needed_warps = B_ROW / NUM_THREADS;
|
||||
if (warp_id < needed_warps /* more warps in HW than needed? */) {
|
||||
uint32_t offset = NUM_THREADS * warp_id + tid_in_warp;
|
||||
// mi, mi~, minew
|
||||
smem_rowmax[offset] = FLT_MIN;
|
||||
@@ -40,10 +43,10 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock,
|
||||
smem_rowsum[offset] = 0.0f;
|
||||
}
|
||||
|
||||
// each warp clears out a row of smem_O
|
||||
// FIXME: dedup this pattern
|
||||
for (int warp_offset = 0; warp_offset < B_COL;
|
||||
warp_offset += warps_in_threadblock) {
|
||||
// each warp clears out a row of smem_O
|
||||
const uint32_t row = warp_offset + warp_id;
|
||||
uint32_t thread_offset = HEADDIM * row + tid_in_warp;
|
||||
constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS;
|
||||
@@ -58,7 +61,6 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock,
|
||||
inline void thread_block_copy_rowmax(const float *src, float *dest,
|
||||
const uint32_t tid_in_threadblock,
|
||||
const uint32_t threads_per_threadblock,
|
||||
const uint32_t threadblocks_per_cluster,
|
||||
const uint32_t threadblock_id_in_cluster) {
|
||||
asm volatile("threadblock_copy_rowmax_start_%=:" ::);
|
||||
|
||||
@@ -66,8 +68,10 @@ inline void thread_block_copy_rowmax(const float *src, float *dest,
|
||||
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
|
||||
const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS;
|
||||
const uint32_t warps_per_threadblock_per_core =
|
||||
NUM_WARPS / threadblocks_per_cluster;
|
||||
warps_in_threadblock / CORES_PER_CLUSTER;
|
||||
|
||||
// each thread copies one element in rowmax
|
||||
// multiple warps participate for the whole vector
|
||||
constexpr uint32_t num_warps = B_ROW / NUM_THREADS;
|
||||
if (warp_id < num_warps) {
|
||||
uint32_t offset = NUM_THREADS * warp_id + tid_in_warp;
|
||||
@@ -83,7 +87,6 @@ inline void thread_block_copy_rowmax(const float *src, float *dest,
|
||||
inline void thread_block_copy_tile(const float *src, float *dest,
|
||||
const uint32_t tid_in_threadblock,
|
||||
const uint32_t threads_per_threadblock,
|
||||
const uint32_t threadblocks_per_cluster,
|
||||
const uint32_t threadblock_id_in_cluster) {
|
||||
asm volatile("threadblock_copy_tile_start_%=:" ::);
|
||||
|
||||
@@ -91,7 +94,7 @@ inline void thread_block_copy_tile(const float *src, float *dest,
|
||||
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
|
||||
const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS;
|
||||
const uint32_t warps_per_threadblock_per_core =
|
||||
NUM_WARPS / threadblocks_per_cluster;
|
||||
warps_in_threadblock / CORES_PER_CLUSTER;
|
||||
|
||||
// FIXME: dedup this pattern
|
||||
for (int warp_offset = 0; warp_offset < B_ROW;
|
||||
@@ -138,7 +141,6 @@ inline float exponential_taylor_term(const float x) {
|
||||
__attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
const float *smem_S, float *smem_O, float *smem_P,
|
||||
const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock,
|
||||
const uint32_t threadblocks_per_cluster,
|
||||
const uint32_t threadblock_id_in_cluster, float *smem_scratchpad,
|
||||
float *smem_rowmax, float *smem_rowsum) {
|
||||
asm volatile("thread_block_online_softmax_start_%=:" ::);
|
||||
@@ -147,7 +149,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
|
||||
const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS;
|
||||
const uint32_t warps_per_threadblock_per_core =
|
||||
NUM_WARPS / threadblocks_per_cluster;
|
||||
warps_in_threadblock / CORES_PER_CLUSTER;
|
||||
|
||||
// float ft[8];
|
||||
// asm volatile("fmv.s %0, f16" : "=f"(ft[0]));
|
||||
@@ -402,7 +404,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
#endif
|
||||
|
||||
// FIXME: headdim not considered
|
||||
uint32_t threads_per_threadblock = (B_ROW * B_COL) / (ELEM_PER_THREAD);
|
||||
uint32_t threads_per_threadblock =
|
||||
(B_ROW * B_COL) / (ELEM_PER_THREAD) / (DOUBLE_BUF ? 2 : 1);
|
||||
const uint32_t hw_threads_per_cluster =
|
||||
cores_per_cluster * vx_num_threads() * vx_num_warps();
|
||||
// cap maximum threadblock size to # of HW threads in cluster, to prevent
|
||||
@@ -418,6 +421,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
threadblock_id % threadblocks_per_cluster;
|
||||
const int tid_in_threadblock = task_id % threads_per_threadblock;
|
||||
|
||||
// FIXME do proper software pipelining
|
||||
if (DOUBLE_BUF && threadblock_id != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t dim_seqlen = arg->dim_seqlen;
|
||||
const uint32_t dim_headdim = arg->dim_headdim;
|
||||
|
||||
@@ -528,7 +536,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
HEADDIM>(dim_seqlen, 0, tile_k, gmem_Q /*=gmem_S*/,
|
||||
smem_S, tid_in_threadblock);
|
||||
// the above should be equivalent to:
|
||||
// load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
|
||||
// load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major,
|
||||
// B_COL,
|
||||
// HEADDIM>(dim_seqlen, tile_k, 0, gmem_Q /*=gmem_S*/,
|
||||
// smem_S, tid_in_threadblock);
|
||||
|
||||
@@ -541,7 +550,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
|
||||
thread_block_online_softmax(
|
||||
smem_S, smem_O, smem_P, tid_in_threadblock, threads_per_threadblock,
|
||||
threadblocks_per_cluster, threadblock_id_in_cluster, smem_scratchpad,
|
||||
threadblock_id_in_cluster, smem_scratchpad,
|
||||
smem_rowmax, smem_rowsum);
|
||||
|
||||
// FIXME unnecessary?
|
||||
@@ -550,34 +559,30 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
|
||||
if constexpr (DEBUG) {
|
||||
if (tile_k == 0) {
|
||||
thread_block_copy_tile(
|
||||
smem_P, gmem_tmp_d0, tid_in_threadblock, threads_per_threadblock,
|
||||
threadblocks_per_cluster, threadblock_id_in_cluster);
|
||||
thread_block_copy_tile(
|
||||
smem_O, gmem_tmp_d2, tid_in_threadblock, threads_per_threadblock,
|
||||
threadblocks_per_cluster, threadblock_id_in_cluster);
|
||||
thread_block_copy_tile(smem_P, gmem_tmp_d0, tid_in_threadblock,
|
||||
threads_per_threadblock,
|
||||
threadblock_id_in_cluster);
|
||||
thread_block_copy_tile(smem_O, gmem_tmp_d2, tid_in_threadblock,
|
||||
threads_per_threadblock,
|
||||
threadblock_id_in_cluster);
|
||||
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_threadblock,
|
||||
threads_per_threadblock,
|
||||
threadblocks_per_cluster,
|
||||
threadblock_id_in_cluster);
|
||||
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_threadblock,
|
||||
threads_per_threadblock,
|
||||
threadblocks_per_cluster,
|
||||
threadblock_id_in_cluster);
|
||||
} else if (tile_k == k_tiles - 1) {
|
||||
thread_block_copy_tile(
|
||||
smem_P, gmem_tmp_d1, tid_in_threadblock, threads_per_threadblock,
|
||||
threadblocks_per_cluster, threadblock_id_in_cluster);
|
||||
thread_block_copy_tile(
|
||||
smem_O, gmem_tmp_d3, tid_in_threadblock, threads_per_threadblock,
|
||||
threadblocks_per_cluster, threadblock_id_in_cluster);
|
||||
thread_block_copy_tile(smem_P, gmem_tmp_d1, tid_in_threadblock,
|
||||
threads_per_threadblock,
|
||||
threadblock_id_in_cluster);
|
||||
thread_block_copy_tile(smem_O, gmem_tmp_d3, tid_in_threadblock,
|
||||
threads_per_threadblock,
|
||||
threadblock_id_in_cluster);
|
||||
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_threadblock,
|
||||
threads_per_threadblock,
|
||||
threadblocks_per_cluster,
|
||||
threadblock_id_in_cluster);
|
||||
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, tid_in_threadblock,
|
||||
threads_per_threadblock,
|
||||
threadblocks_per_cluster,
|
||||
threadblock_id_in_cluster);
|
||||
}
|
||||
|
||||
@@ -601,12 +606,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
warps_per_threadblock_per_core);
|
||||
|
||||
thread_block_gemm_single_tile<float, MemLayout::K_major,
|
||||
MemLayout::MN_major,
|
||||
B_ROW, HEADDIM, B_COL,
|
||||
MemLayout::MN_major, B_ROW, HEADDIM, B_COL,
|
||||
/*load_accum=*/true,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_P, smem_V, smem_O /*load accum*/, smem_O,
|
||||
tid_in_threadblock, threads_per_threadblock, threadblocks_per_cluster,
|
||||
smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_threadblock,
|
||||
threads_per_threadblock, threadblocks_per_cluster,
|
||||
threadblock_id_in_cluster);
|
||||
// FIXME: wrong but fast
|
||||
// thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
||||
@@ -623,13 +627,13 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
|
||||
if constexpr (DEBUG) {
|
||||
if (tile_k == 0) {
|
||||
thread_block_copy_tile(
|
||||
smem_O, gmem_tmp_d4, tid_in_threadblock, threads_per_threadblock,
|
||||
threadblocks_per_cluster, threadblock_id_in_cluster);
|
||||
thread_block_copy_tile(smem_O, gmem_tmp_d4, tid_in_threadblock,
|
||||
threads_per_threadblock,
|
||||
threadblock_id_in_cluster);
|
||||
} else if (tile_k == k_tiles - 1) {
|
||||
thread_block_copy_tile(
|
||||
smem_O, gmem_tmp_d5, tid_in_threadblock, threads_per_threadblock,
|
||||
threadblocks_per_cluster, threadblock_id_in_cluster);
|
||||
thread_block_copy_tile(smem_O, gmem_tmp_d5, tid_in_threadblock,
|
||||
threads_per_threadblock,
|
||||
threadblock_id_in_cluster);
|
||||
}
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
|
||||
Reference in New Issue
Block a user