flash: Add missing accum reg init and fix barrier count
This commit is contained in:
@@ -54,8 +54,9 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock,
|
|||||||
}
|
}
|
||||||
|
|
||||||
inline void thread_block_online_softmax(
|
inline void thread_block_online_softmax(
|
||||||
float *smem_S, float *smem_O, float *smem_P,
|
const float *smem_S, float *smem_O, float *smem_P,
|
||||||
const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock,
|
const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock,
|
||||||
|
const uint32_t threadblocks_per_cluster,
|
||||||
const uint32_t threadblock_id_in_cluster, float *smem_scratchpad,
|
const uint32_t threadblock_id_in_cluster, float *smem_scratchpad,
|
||||||
float *smem_rowmax, float *smem_rowsum) {
|
float *smem_rowmax, float *smem_rowsum) {
|
||||||
asm volatile("thread_block_flashattn_start_%=:" ::);
|
asm volatile("thread_block_flashattn_start_%=:" ::);
|
||||||
@@ -64,7 +65,7 @@ inline void thread_block_online_softmax(
|
|||||||
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
|
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
|
||||||
const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS;
|
const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS;
|
||||||
const uint32_t warps_per_threadblock_per_core =
|
const uint32_t warps_per_threadblock_per_core =
|
||||||
NUM_WARPS / threads_per_threadblock;
|
NUM_WARPS / threadblocks_per_cluster;
|
||||||
|
|
||||||
// float ft[8];
|
// float ft[8];
|
||||||
// asm volatile("fmv.s %0, f16" : "=f"(ft[0]));
|
// asm volatile("fmv.s %0, f16" : "=f"(ft[0]));
|
||||||
@@ -148,7 +149,6 @@ inline void thread_block_online_softmax(
|
|||||||
: "=f"(rowmax)
|
: "=f"(rowmax)
|
||||||
: "f"(rowmax), "f"(prev_rowmax));
|
: "f"(rowmax), "f"(prev_rowmax));
|
||||||
smem_rowmax_new[row] = rowmax;
|
smem_rowmax_new[row] = rowmax;
|
||||||
gmem_tmp0[row] = rowmax;
|
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@@ -177,18 +177,16 @@ inline void thread_block_online_softmax(
|
|||||||
#pragma GCC unroll
|
#pragma GCC unroll
|
||||||
for (int i = 0; i < exp_per_row_iter; i++) {
|
for (int i = 0; i < exp_per_row_iter; i++) {
|
||||||
float f0 = smem_S[thread_offset];
|
float f0 = smem_S[thread_offset];
|
||||||
// float f1 = S[thread_offset + 1];
|
|
||||||
|
// check Q*K result
|
||||||
|
gmem_tmp0[thread_offset] = f0;;
|
||||||
|
|
||||||
// FIXME: placeholder for proper exp
|
// FIXME: placeholder for proper exp
|
||||||
f0 -= rowmax_new;
|
f0 -= rowmax_new;
|
||||||
// f1 -= rowmax_new;
|
|
||||||
// float16_t h0 = NN_float_to_half(f0);
|
|
||||||
// float16_t h1 = NN_float_to_half(f1);
|
|
||||||
|
|
||||||
// Store S transposed to the shared memory
|
// Store S transposed to the shared memory
|
||||||
|
|
||||||
smem_P[thread_offset] = f0;
|
smem_P[thread_offset] = f0;
|
||||||
// S[thread_offset + 1] = f1;
|
|
||||||
gmem_tmp1[thread_offset] = f0;
|
gmem_tmp1[thread_offset] = f0;
|
||||||
|
|
||||||
thread_offset += NUM_THREADS;
|
thread_offset += NUM_THREADS;
|
||||||
@@ -261,7 +259,6 @@ inline void thread_block_online_softmax(
|
|||||||
|
|
||||||
threadblock_barrier(threadblock_id_in_cluster,
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
warps_per_threadblock_per_core);
|
warps_per_threadblock_per_core);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
asm volatile("thread_block_flashattn_finish_%=:" ::);
|
asm volatile("thread_block_flashattn_finish_%=:" ::);
|
||||||
@@ -299,15 +296,19 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
|
|
||||||
// "static" shared memory allocation. This would determine maximum
|
// "static" shared memory allocation. This would determine maximum
|
||||||
// threadblock occupancy in a cluster
|
// threadblock occupancy in a cluster
|
||||||
const uint32_t smem_QK_size = B_ROW * B_COL;
|
constexpr uint32_t smem_Q_size = B_ROW * HEADDIM;
|
||||||
const uint32_t smem_V_size = B_COL * HEADDIM;
|
constexpr uint32_t smem_QK_size = B_ROW * B_COL;
|
||||||
const uint32_t smem_O_size = B_COL * HEADDIM;
|
constexpr uint32_t smem_V_size = B_COL * HEADDIM;
|
||||||
|
constexpr uint32_t smem_O_size = B_COL * HEADDIM;
|
||||||
uint8_t *smem_per_threadblock = reinterpret_cast<uint8_t *>(
|
uint8_t *smem_per_threadblock = reinterpret_cast<uint8_t *>(
|
||||||
DEV_SMEM_START_ADDR +
|
DEV_SMEM_START_ADDR +
|
||||||
sizeof(float_type) *
|
sizeof(float_type) *
|
||||||
(smem_QK_size + smem_V_size + smem_O_size) *
|
(smem_QK_size + smem_V_size + smem_O_size) *
|
||||||
threadblock_id_in_cluster);
|
threadblock_id_in_cluster);
|
||||||
|
|
||||||
|
float *smem_Q = reinterpret_cast<float *>(smem_per_threadblock);
|
||||||
|
float *smem_K = smem_Q + smem_Q_size;
|
||||||
|
// in-place multiplication of QK into Q
|
||||||
float *smem_S = reinterpret_cast<float *>(smem_per_threadblock);
|
float *smem_S = reinterpret_cast<float *>(smem_per_threadblock);
|
||||||
float *smem_P = smem_S; // in-place update from S to P
|
float *smem_P = smem_S; // in-place update from S to P
|
||||||
float *smem_V =
|
float *smem_V =
|
||||||
@@ -330,42 +331,73 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
float *smem_scratchpad = smem_rowmax - smem_scratchpad_size;
|
float *smem_scratchpad = smem_rowmax - smem_scratchpad_size;
|
||||||
|
|
||||||
const uint32_t warps_per_threadblock_per_core =
|
const uint32_t warps_per_threadblock_per_core =
|
||||||
NUM_WARPS / threads_per_threadblock;
|
NUM_WARPS / threadblocks_per_cluster;
|
||||||
|
|
||||||
// initialize rowmax/rowsum values in sharedmem
|
// initialize rowmax/rowsum values in sharedmem
|
||||||
thread_block_init_sharedmem(tid_in_threadblock, threads_per_threadblock,
|
thread_block_init_sharedmem(tid_in_threadblock, threads_per_threadblock,
|
||||||
smem_O, smem_rowmax, smem_rowsum);
|
smem_O, smem_rowmax, smem_rowsum);
|
||||||
|
|
||||||
#define SKIP_GEMM
|
const float *gmem_Q = reinterpret_cast<float *>(arg->addr_q);
|
||||||
|
const float *gmem_K = reinterpret_cast<float *>(arg->addr_k);
|
||||||
|
const float *gmem_V = reinterpret_cast<float *>(arg->addr_v);
|
||||||
|
float *gmem_O = reinterpret_cast<float *>(arg->addr_o);
|
||||||
|
float *gmem_tmp0 = reinterpret_cast<float *>(0xd0000000UL);
|
||||||
|
|
||||||
|
// #define SKIP_GEMM
|
||||||
#ifndef SKIP_GEMM
|
#ifndef SKIP_GEMM
|
||||||
|
#if 0
|
||||||
thread_block_gemm<float_type, /*write_to_gmem=*/true>(
|
thread_block_gemm<float_type, /*write_to_gmem=*/true>(
|
||||||
(const float_type *)arg->addr_q, (const float_type *)arg->addr_k,
|
(const float_type *)arg->addr_q, (const float_type *)arg->addr_k,
|
||||||
(float *)smem_S /*write result to SMEM */, arg->dim_m, arg->dim_n,
|
(float *)smem_S /*write result to SMEM */, B_ROW, B_COL,
|
||||||
arg->dim_k, tid_in_threadblock, threads_per_threadblock,
|
HEADDIM, tid_in_threadblock, threads_per_threadblock,
|
||||||
threadblocks_per_cluster, threadblock_id_in_cluster,
|
threadblocks_per_cluster, threadblock_id_in_cluster,
|
||||||
smem_per_threadblock);
|
smem_per_threadblock);
|
||||||
|
|
||||||
// protect writes of GEMM results before softmax
|
#else
|
||||||
|
|
||||||
|
// clear out accumulators
|
||||||
|
initialize_accum_regs<0>();
|
||||||
|
initialize_accum_regs<1>();
|
||||||
|
|
||||||
|
// load Q
|
||||||
|
static_assert(B_ROW == B_COL, "currently only supports square tiles");
|
||||||
|
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_ROW,
|
||||||
|
HEADDIM>(B_ROW, 0, 0, gmem_Q, smem_Q, tid_in_threadblock);
|
||||||
|
// load K
|
||||||
|
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
|
||||||
|
HEADDIM>(B_COL, 0, 0, gmem_K, smem_K, tid_in_threadblock);
|
||||||
|
|
||||||
|
// GMEM->SMEM and compute barrier
|
||||||
threadblock_barrier(threadblock_id_in_cluster,
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
warps_per_threadblock_per_core);
|
warps_per_threadblock_per_core);
|
||||||
|
|
||||||
float *tile_S = (float *)smem_S;
|
thread_block_gemm_single_tile<float, MemLayout::MN_major, MemLayout::MN_major,
|
||||||
|
/*write_to_smem=*/true>(
|
||||||
|
smem_Q, smem_K, smem_S, tid_in_threadblock, threads_per_threadblock,
|
||||||
|
threadblocks_per_cluster, threadblock_id_in_cluster);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// protect GEMM result writes (smem_S) before softmax
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
|
||||||
|
const float *tile_S = (float *)smem_S;
|
||||||
#else
|
#else
|
||||||
float *tile_S = (float *)arg->addr_q;
|
float *tile_S = (float *)arg->addr_q;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// FIXME: V is stored in d0000000 for debugging purpose
|
thread_block_online_softmax(tile_S, smem_O, smem_P, tid_in_threadblock,
|
||||||
const float *gmem_V = reinterpret_cast<float *>(arg->addr_k);
|
threads_per_threadblock, threadblocks_per_cluster,
|
||||||
|
threadblock_id_in_cluster, smem_scratchpad,
|
||||||
thread_block_online_softmax(
|
smem_rowmax, smem_rowsum);
|
||||||
tile_S, smem_O, smem_P, tid_in_threadblock, threads_per_threadblock,
|
|
||||||
threadblock_id_in_cluster, smem_scratchpad, smem_rowmax, smem_rowsum);
|
|
||||||
|
|
||||||
// FIXME unnecessary?
|
// FIXME unnecessary?
|
||||||
threadblock_barrier(threadblock_id_in_cluster,
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
warps_per_threadblock_per_core);
|
warps_per_threadblock_per_core);
|
||||||
|
|
||||||
float *gmem_tmp2 = reinterpret_cast<float *>(0xf0000000UL);
|
// clear out accumulators
|
||||||
|
initialize_accum_regs<0>();
|
||||||
|
initialize_accum_regs<1>();
|
||||||
|
|
||||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, BN, BK>(
|
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, BN, BK>(
|
||||||
B_COL, 0 /*FIXME*/, 0 /*FIXME*/, gmem_V, smem_V, tid_in_threadblock);
|
B_COL, 0 /*FIXME*/, 0 /*FIXME*/, gmem_V, smem_V, tid_in_threadblock);
|
||||||
@@ -376,8 +408,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
// FIXME: support MN_major for A for ideal performance
|
// FIXME: support MN_major for A for ideal performance
|
||||||
thread_block_gemm_single_tile<float, MemLayout::K_major, MemLayout::MN_major,
|
thread_block_gemm_single_tile<float, MemLayout::K_major, MemLayout::MN_major,
|
||||||
/*write_to_smem=*/true>(
|
/*write_to_smem=*/true>(
|
||||||
smem_P, smem_V, gmem_tmp2 /*smem_O*/, tid_in_threadblock,
|
smem_P, smem_V, gmem_O /*smem_O*/, tid_in_threadblock,
|
||||||
threads_per_threadblock);
|
threads_per_threadblock, threadblocks_per_cluster,
|
||||||
|
threadblock_id_in_cluster);
|
||||||
|
|
||||||
threadblock_barrier(threadblock_id_in_cluster,
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
warps_per_threadblock_per_core);
|
warps_per_threadblock_per_core);
|
||||||
|
|||||||
Reference in New Issue
Block a user