flash: data copy func for easy debugging

This commit is contained in:
Hansung Kim
2024-08-19 21:41:37 -07:00
parent 2f7fb372f1
commit df3c41aa0d

View File

@@ -73,6 +73,37 @@ inline float exponential_taylor_term(const float x) {
return res;
}
inline void thread_block_copy_data(const float *src, float *dest,
const uint32_t tid_in_threadblock,
const uint32_t threads_per_threadblock,
const uint32_t threadblocks_per_cluster,
const uint32_t threadblock_id_in_cluster) {
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS;
const uint32_t warps_per_threadblock_per_core =
NUM_WARPS / threadblocks_per_cluster;
for (int warp_offset = 0; warp_offset < B_ROW;
warp_offset += warps_in_threadblock) {
const uint32_t row = warp_offset + warp_id;
const uint32_t first_thread_offset = B_COL * row;
constexpr uint32_t per_row_iter = B_COL / NUM_THREADS;
uint32_t thread_offset = first_thread_offset + tid_in_warp;
float per_thread_max = FLT_MIN;
#pragma GCC unroll
for (int i = 0; i < per_row_iter; i++) {
const float f = src[thread_offset];
dest[thread_offset] = f;
thread_offset += NUM_THREADS;
}
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
}
inline void thread_block_online_softmax(
const float *smem_S, float *smem_O, float *smem_P,
const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock,
@@ -97,9 +128,6 @@ inline void thread_block_online_softmax(
// asm volatile("fmv.s %0, f22" : "=f"(ft[6]));
// asm volatile("fmv.s %0, f23" : "=f"(ft[7]));
volatile float *gmem_tmp0 = reinterpret_cast<volatile float *>(0xd0000000UL);
volatile float *gmem_tmp1 = reinterpret_cast<volatile float *>(0xe0000000UL);
float *smem_rowmax_prev = smem_rowmax;
float *smem_rowmax_new = smem_rowmax + B_ROW;
float *smem_rowmax_this = smem_rowmax + 2 * B_ROW;
@@ -201,9 +229,6 @@ inline void thread_block_online_softmax(
for (int i = 0; i < exp_per_row_iter; i++) {
float f0 = smem_S[thread_offset];
// check Q*K result
gmem_tmp0[thread_offset] = f0;
f0 -= rowmax_new;
// 2nd-order Taylor approximation
@@ -214,7 +239,6 @@ inline void thread_block_online_softmax(
// Store S transposed to the shared memory
smem_P[thread_offset] = exp;
gmem_tmp1[thread_offset] = exp;
thread_offset += NUM_THREADS;
}
@@ -389,85 +413,97 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
const float *gmem_K = reinterpret_cast<float *>(arg->addr_k);
const float *gmem_V = reinterpret_cast<float *>(arg->addr_v);
float *gmem_O = reinterpret_cast<float *>(arg->addr_o);
float *gmem_tmp0 = reinterpret_cast<float *>(0xd0000000UL);
// "inner loop" along the columns of K^T
for (uint32_t tile_k = 0; tile_k < (dim_seqlen / B_COL); tile_k++) {
// #define SKIP_GEMM
#ifndef SKIP_GEMM
#if 0
thread_block_gemm<float_type, /*write_to_gmem=*/true>(
(const float_type *)arg->addr_q, (const float_type *)arg->addr_k,
(float *)smem_S /*write result to SMEM */, B_ROW, B_COL,
HEADDIM, tid_in_threadblock, threads_per_threadblock,
threadblocks_per_cluster, threadblock_id_in_cluster,
smem_per_threadblock);
// clear out accumulators
initialize_accum_regs<0>();
initialize_accum_regs<1>();
static_assert(B_ROW == B_COL, "currently only supports square tiles");
// load Q
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_ROW,
HEADDIM>(
dim_seqlen, 0 /*FIXME: only work on first B_ROW rows of Q for now*/,
0 /* always 0 because dim_k == headdim */, gmem_Q, smem_Q,
tid_in_threadblock);
// load K
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
HEADDIM>(dim_seqlen, tile_k,
0 /* always 0 because dim_k == headdim */,
gmem_K, smem_K, tid_in_threadblock);
// GMEM->SMEM and compute barrier
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
// GEMM I: S = Q*K
thread_block_gemm_single_tile<float, MemLayout::MN_major,
MemLayout::MN_major,
/*load_accum=*/false,
/*write_to_smem=*/true>(
smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_threadblock,
threads_per_threadblock, threadblocks_per_cluster,
threadblock_id_in_cluster);
// protect GEMM result writes (smem_S) before softmax
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
const float *tile_S = (float *)smem_S;
#else
// clear out accumulators
initialize_accum_regs<0>();
initialize_accum_regs<1>();
// load Q
static_assert(B_ROW == B_COL, "currently only supports square tiles");
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_ROW,
HEADDIM>(B_ROW, 0, 0, gmem_Q, smem_Q, tid_in_threadblock);
// load K
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
HEADDIM>(B_COL, 0, 0, gmem_K, smem_K, tid_in_threadblock);
// GMEM->SMEM and compute barrier
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
// GEMM I: S = Q*K
thread_block_gemm_single_tile<float, MemLayout::MN_major, MemLayout::MN_major,
/*load_accum=*/false,
/*write_to_smem=*/true>(
smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_threadblock,
threads_per_threadblock, threadblocks_per_cluster,
threadblock_id_in_cluster);
float *tile_S = (float *)arg->addr_q;
#endif
// protect GEMM result writes (smem_S) before softmax
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
thread_block_online_softmax(
tile_S, smem_O, smem_P, tid_in_threadblock, threads_per_threadblock,
threadblocks_per_cluster, threadblock_id_in_cluster, smem_scratchpad,
smem_rowmax, smem_rowsum);
const float *tile_S = (float *)smem_S;
#else
float *tile_S = (float *)arg->addr_q;
#endif
// FIXME unnecessary?
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
thread_block_online_softmax(tile_S, smem_O, smem_P, tid_in_threadblock,
threads_per_threadblock, threadblocks_per_cluster,
threadblock_id_in_cluster, smem_scratchpad,
smem_rowmax, smem_rowsum);
// GEMM II: O = O + P*V
// FIXME unnecessary?
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
// clear out accumulators
initialize_accum_regs<0>();
initialize_accum_regs<1>();
// GEMM II: O = O + P*V
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, BN, BK>(
B_COL, 0, 0, gmem_V, smem_V, tid_in_threadblock);
// clear out accumulators
initialize_accum_regs<0>();
initialize_accum_regs<1>();
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, BN, BK>(
B_COL, 0, 0, gmem_V, smem_V, tid_in_threadblock);
// FIXME: support MN_major for A for ideal performance
thread_block_gemm_single_tile<float, MemLayout::K_major,
MemLayout::MN_major,
/*load_accum=*/true,
/*write_to_smem=*/true>(
smem_P, smem_V, smem_O /*load accum*/, smem_O,
tid_in_threadblock, threads_per_threadblock, threadblocks_per_cluster,
threadblock_id_in_cluster);
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
// FIXME: support MN_major for A for ideal performance
thread_block_gemm_single_tile<float, MemLayout::K_major, MemLayout::MN_major,
/*load_accum=*/false,
/*write_to_smem=*/true>(
smem_P, smem_V, smem_O, gmem_O /*smem_O*/,
tid_in_threadblock, threads_per_threadblock, threadblocks_per_cluster,
threadblock_id_in_cluster);
float *gmem_tmp0 = reinterpret_cast<float *>(0xd0000000UL);
float *gmem_tmp1 = reinterpret_cast<float *>(0xe0000000UL);
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
// copy out tile data to GMEM for debugging
thread_block_copy_data(smem_P, gmem_tmp0, tid_in_threadblock,
threads_per_threadblock, threadblocks_per_cluster,
threadblock_id_in_cluster);
thread_block_copy_data(smem_O, gmem_tmp1, tid_in_threadblock,
threads_per_threadblock, threadblocks_per_cluster,
threadblock_id_in_cluster);
}
int main() {