flash: Change timing for QKV move
Verified with warp_specialized false; true remains to be fixed.
This commit is contained in:
@@ -15,7 +15,7 @@
|
||||
|
||||
constexpr uint32_t ROWMAX_SETS = 3;
|
||||
constexpr bool DEBUG = true;
|
||||
constexpr bool WARP_SPECIALIZED = true;
|
||||
constexpr bool WARP_SPECIALIZED = false;
|
||||
|
||||
constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000;
|
||||
|
||||
@@ -490,8 +490,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
threads_per_threadblock / NUM_THREADS;
|
||||
|
||||
// warpgroup context
|
||||
constexpr uint32_t threads_per_warpgroup = threads_per_threadblock / 2;
|
||||
constexpr uint32_t warpgroups_per_cluster = threadblocks_per_cluster * 2;
|
||||
constexpr uint32_t threads_per_warpgroup =
|
||||
threads_per_threadblock / (WARP_SPECIALIZED ? 2 : 1);
|
||||
constexpr uint32_t warpgroups_per_cluster =
|
||||
threadblocks_per_cluster * (WARP_SPECIALIZED ? 2 : 1);
|
||||
const uint32_t warps_per_warpgroup_per_core =
|
||||
NUM_WARPS / warpgroups_per_cluster;
|
||||
const uint32_t warpgroup_id = task_id / threads_per_warpgroup;
|
||||
@@ -507,6 +509,25 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
const uint32_t dim_seqlen = arg->dim_seqlen;
|
||||
const uint32_t dim_headdim = arg->dim_headdim;
|
||||
|
||||
// get global memory addresses from kernel arguments
|
||||
const float *gmem_Q = reinterpret_cast<float *>(arg->addr_q);
|
||||
const float *gmem_K = reinterpret_cast<float *>(arg->addr_k);
|
||||
const float *gmem_V = reinterpret_cast<float *>(arg->addr_v);
|
||||
float *gmem_O = reinterpret_cast<float *>(arg->addr_o);
|
||||
|
||||
float *gmem_tmp_d0 = reinterpret_cast<float *>(0xd0000000UL);
|
||||
float *gmem_tmp_d1 = reinterpret_cast<float *>(0xd1000000UL);
|
||||
float *gmem_tmp_d2 = reinterpret_cast<float *>(0xd2000000UL);
|
||||
float *gmem_tmp_d3 = reinterpret_cast<float *>(0xd3000000UL);
|
||||
float *gmem_tmp_d4 = reinterpret_cast<float *>(0xd4000000UL);
|
||||
float *gmem_tmp_d5 = reinterpret_cast<float *>(0xd5000000UL);
|
||||
float *gmem_tmp_d6 = reinterpret_cast<float *>(0xd6000000UL);
|
||||
float *gmem_tmp_d7 = reinterpret_cast<float *>(0xd7000000UL);
|
||||
float *gmem_tmp_e0 = reinterpret_cast<float *>(0xe0000000UL);
|
||||
float *gmem_tmp_e1 = reinterpret_cast<float *>(0xe1000000UL);
|
||||
float *gmem_tmp_e2 = reinterpret_cast<float *>(0xe2000000UL);
|
||||
float *gmem_tmp_e3 = reinterpret_cast<float *>(0xe3000000UL);
|
||||
|
||||
// static shared memory allocation
|
||||
constexpr uint32_t smem_Q_size = B_ROW * HEADDIM;
|
||||
constexpr uint32_t smem_K_size = B_COL * HEADDIM;
|
||||
@@ -572,32 +593,23 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
smem_cursor -= smem_scratchpad_size;
|
||||
float *smem_scratchpad_1 = smem_cursor;
|
||||
|
||||
// select the correct buffer by warpgroup
|
||||
float *smem_Q = (warpgroup_id % 2) ? smem_Q1 : smem_Q0;
|
||||
float *smem_K = (warpgroup_id % 2) ? smem_K1 : smem_K0;
|
||||
float *smem_V = (warpgroup_id % 2) ? smem_V1 : smem_V0;
|
||||
float *smem_S = (warpgroup_id % 2) ? smem_S1 : smem_S0;
|
||||
float *smem_O = (warpgroup_id % 2) ? smem_O1 : smem_O0;
|
||||
float *smem_P = smem_S;
|
||||
float *smem_O_row_scale =
|
||||
(warpgroup_id % 2) ? smem_O_row_scale_1 : smem_O_row_scale_0;
|
||||
float *smem_rowmax = (warpgroup_id % 2) ? smem_rowmax_1 : smem_rowmax_0;
|
||||
float *smem_rowsum = (warpgroup_id % 2) ? smem_rowsum_1 : smem_rowsum_0;
|
||||
float *smem_scratchpad =
|
||||
(warpgroup_id % 2) ? smem_scratchpad_1 : smem_scratchpad_0;
|
||||
|
||||
// initialize rowmax/rowsum values in sharedmem
|
||||
if (warpgroup_id == 0) {
|
||||
thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O0,
|
||||
smem_rowmax_0, smem_rowsum_0, smem_O_row_scale_0);
|
||||
} else {
|
||||
thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O1,
|
||||
smem_rowmax_1, smem_rowsum_1, smem_O_row_scale_1);
|
||||
}
|
||||
|
||||
const float *gmem_Q = reinterpret_cast<float *>(arg->addr_q);
|
||||
const float *gmem_K = reinterpret_cast<float *>(arg->addr_k);
|
||||
const float *gmem_V = reinterpret_cast<float *>(arg->addr_v);
|
||||
float *gmem_O = reinterpret_cast<float *>(arg->addr_o);
|
||||
|
||||
float *gmem_tmp_d0 = reinterpret_cast<float *>(0xd0000000UL);
|
||||
float *gmem_tmp_d1 = reinterpret_cast<float *>(0xd1000000UL);
|
||||
float *gmem_tmp_d2 = reinterpret_cast<float *>(0xd2000000UL);
|
||||
float *gmem_tmp_d3 = reinterpret_cast<float *>(0xd3000000UL);
|
||||
float *gmem_tmp_d4 = reinterpret_cast<float *>(0xd4000000UL);
|
||||
float *gmem_tmp_d5 = reinterpret_cast<float *>(0xd5000000UL);
|
||||
float *gmem_tmp_d6 = reinterpret_cast<float *>(0xd6000000UL);
|
||||
float *gmem_tmp_d7 = reinterpret_cast<float *>(0xd7000000UL);
|
||||
float *gmem_tmp_e0 = reinterpret_cast<float *>(0xe0000000UL);
|
||||
float *gmem_tmp_e1 = reinterpret_cast<float *>(0xe1000000UL);
|
||||
float *gmem_tmp_e2 = reinterpret_cast<float *>(0xe2000000UL);
|
||||
float *gmem_tmp_e3 = reinterpret_cast<float *>(0xe3000000UL);
|
||||
thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O,
|
||||
smem_rowmax, smem_rowsum, smem_O_row_scale);
|
||||
|
||||
constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary
|
||||
|
||||
@@ -606,13 +618,38 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
||||
}
|
||||
|
||||
// read Q and K into SMEM before the loop starts
|
||||
//
|
||||
static_assert(B_ROW == B_COL, "currently only supports square tiles");
|
||||
|
||||
// load Q; this stays in SMEM for the entire loop
|
||||
if constexpr (!WARP_SPECIALIZED) {
|
||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_ROW,
|
||||
HEADDIM, threads_per_warpgroup>(
|
||||
dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q,
|
||||
tid_in_warpgroup);
|
||||
} else {
|
||||
// FIXME: transpose to K-major in SMEM for correctness
|
||||
load_tile_to_smem<float, MemLayout::K_major, MemLayout::K_major, B_ROW,
|
||||
HEADDIM, threads_per_warpgroup>(
|
||||
dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q,
|
||||
tid_in_warpgroup);
|
||||
}
|
||||
|
||||
// load K
|
||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
|
||||
HEADDIM, threads_per_warpgroup>(
|
||||
dim_seqlen, /*tile_k=*/0, 0 /* dim_k == headdim */, gmem_K, smem_K,
|
||||
tid_in_warpgroup);
|
||||
|
||||
// protect write to SMEM
|
||||
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||
|
||||
asm volatile ("tile_loop_start_%=:" :: );
|
||||
|
||||
// "inner loop" along the columns of K^T
|
||||
const uint32_t k_tiles = (dim_seqlen / B_COL);
|
||||
for (uint32_t tile_k = 0; tile_k < k_tiles; tile_k++) {
|
||||
asm volatile ("buf_select_start_%=:" :: );
|
||||
|
||||
// float *smem_P_produce = (tile_k % 2) ? smem_P0 : smem_P1;
|
||||
// float *smem_P_consume = (tile_k % 2) ? smem_P1 : smem_P0;
|
||||
// float *smem_V_produce = (tile_k % 2) ? smem_V0 : smem_V1;
|
||||
@@ -622,67 +659,87 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// float *smem_O_row_scale_consume =
|
||||
// (tile_k % 2) ? smem_O_row_scale_1 : smem_O_row_scale_0;
|
||||
|
||||
float *smem_Q = (warpgroup_id % 2) ? smem_Q1 : smem_Q0;
|
||||
float *smem_K = (warpgroup_id % 2) ? smem_K1 : smem_K0;
|
||||
float *smem_V = (warpgroup_id % 2) ? smem_V1 : smem_V0;
|
||||
float *smem_S = (warpgroup_id % 2) ? smem_S1 : smem_S0;
|
||||
float *smem_O = (warpgroup_id % 2) ? smem_O1 : smem_O0;
|
||||
float *smem_P = smem_S;
|
||||
float *smem_O_row_scale =
|
||||
(warpgroup_id % 2) ? smem_O_row_scale_1 : smem_O_row_scale_0;
|
||||
float *smem_rowmax = (warpgroup_id % 2) ? smem_rowmax_1 : smem_rowmax_0;
|
||||
float *smem_rowsum = (warpgroup_id % 2) ? smem_rowsum_1 : smem_rowsum_0;
|
||||
float *smem_scratchpad =
|
||||
(warpgroup_id % 2) ? smem_scratchpad_1 : smem_scratchpad_0;
|
||||
|
||||
asm volatile ("buf_select_finish_%=:" :: );
|
||||
|
||||
const uint32_t tile_k_ = tile_k;
|
||||
|
||||
constexpr bool skip_gemm_qk = true;
|
||||
constexpr bool skip_gemm_qk = false;
|
||||
if constexpr (!skip_gemm_qk) {
|
||||
static_assert(B_ROW == B_COL, "currently only supports square tiles");
|
||||
|
||||
// load Q
|
||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_ROW,
|
||||
HEADDIM, threads_per_warpgroup>(
|
||||
dim_seqlen, 0 /*FIXME: only work on first B_ROW rows of Q for now*/,
|
||||
0 /* always 0 because dim_k == headdim */, gmem_Q, smem_Q,
|
||||
tid_in_warpgroup);
|
||||
|
||||
// load K
|
||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
|
||||
HEADDIM, threads_per_warpgroup>(
|
||||
dim_seqlen, tile_k_, 0 /* always 0 because dim_k == headdim */,
|
||||
gmem_K, smem_K, tid_in_warpgroup);
|
||||
|
||||
// GMEM->SMEM and compute barrier
|
||||
threadblock_barrier(warpgroup_id_in_cluster,
|
||||
warps_per_warpgroup_per_core);
|
||||
|
||||
// clear out accumulators before GEMM
|
||||
initialize_accum_regs<0>();
|
||||
initialize_accum_regs<1>();
|
||||
|
||||
// GEMM I: S = Q*K
|
||||
thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
||||
MemLayout::MN_major, B_ROW, B_COL, HEADDIM,
|
||||
/*load_accum=*/false,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_warpgroup,
|
||||
threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
//
|
||||
// FIXME: deduplicate this between GEMM II
|
||||
if constexpr (!WARP_SPECIALIZED) {
|
||||
// clear out accumulators before GEMM
|
||||
initialize_accum_regs<0>();
|
||||
initialize_accum_regs<1>();
|
||||
|
||||
thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
||||
MemLayout::MN_major, B_ROW, B_COL,
|
||||
HEADDIM,
|
||||
/*load_accum=*/false,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_warpgroup,
|
||||
threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
} else {
|
||||
// when warp-specialized, there's only enough warps to do 64x32 tile
|
||||
// size so we need to do 2 GEMM calls
|
||||
static_assert(B_ROW / 2 == 32,
|
||||
"tile size assumption for warp-specialization not met");
|
||||
|
||||
// assumes smem_Q is K-major
|
||||
// FIXME: fix this to MN-major
|
||||
float *smem_Q_half0 = smem_Q;
|
||||
float *smem_Q_half1 = smem_Q + (B_ROW / 2) * HEADDIM;
|
||||
float *smem_S_half0 = smem_S;
|
||||
float *smem_S_half1 = smem_S + (B_ROW / 2) * B_COL;
|
||||
|
||||
// clear out accumulators before GEMM
|
||||
initialize_accum_regs<0>();
|
||||
initialize_accum_regs<1>();
|
||||
|
||||
// split by rows into 2 chunks
|
||||
thread_block_gemm_single_tile<float, MemLayout::K_major /*FIXME*/,
|
||||
MemLayout::MN_major, B_ROW / 2, B_COL,
|
||||
HEADDIM,
|
||||
/*load_accum=*/false,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
thread_block_gemm_single_tile<float, MemLayout::K_major /*FIXME*/,
|
||||
MemLayout::MN_major, B_ROW / 2, B_COL,
|
||||
HEADDIM,
|
||||
/*load_accum=*/false,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
}
|
||||
} else {
|
||||
// load Q*K
|
||||
load_tile_to_smem<float, MemLayout::K_major, MemLayout::K_major, B_COL,
|
||||
HEADDIM, threads_per_warpgroup>(
|
||||
dim_seqlen, warpgroup_id /* parallelize across rows */, tile_k_,
|
||||
gmem_Q /*=gmem_S*/, smem_S, tid_in_warpgroup);
|
||||
dim_seqlen, warpgroup_id /* parallelize across rows */, tile_k,
|
||||
gmem_Q /*contains S*/, smem_S, tid_in_warpgroup);
|
||||
}
|
||||
|
||||
// protect GEMM result writes (smem_S) before softmax
|
||||
// protect write to SMEM (smem_S) before softmax
|
||||
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||
|
||||
if constexpr (DEBUG) {
|
||||
if (warpgroup_id == 0) {
|
||||
if (tile_k == 0) {
|
||||
thread_block_copy_tile(smem_S, gmem_tmp_d0,
|
||||
tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
} else if (tile_k == 1) {
|
||||
thread_block_copy_tile(smem_S, gmem_tmp_d1,
|
||||
tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
}
|
||||
|
||||
threadblock_barrier(warpgroup_id_in_cluster,
|
||||
warps_per_warpgroup_per_core);
|
||||
}
|
||||
}
|
||||
|
||||
// inter-warpgroup barrier before online softmax
|
||||
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
||||
|
||||
@@ -693,32 +750,36 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
smem_scratchpad, smem_rowmax, smem_rowsum,
|
||||
smem_O_row_scale);
|
||||
|
||||
// TODO: put the data movement for QKV here for inter-warpgroup
|
||||
// data movement for K and V
|
||||
//
|
||||
// Q stays in SMEM for the entire loop
|
||||
//
|
||||
// load K for the next iteration
|
||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
|
||||
HEADDIM, threads_per_warpgroup>(
|
||||
dim_seqlen, tile_k + 1, 0 /* dim_k == headdim */, gmem_K, smem_K,
|
||||
tid_in_warpgroup);
|
||||
|
||||
// load V for the current iteration
|
||||
// V dimension is [seqlen, headdim], stored N(headdim)-major
|
||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
|
||||
HEADDIM, threads_per_warpgroup>(
|
||||
HEADDIM, 0 /* full N-dimension */, tile_k_, gmem_V, smem_V,
|
||||
HEADDIM, 0 /* full N-dimension */, tile_k, gmem_V, smem_V,
|
||||
tid_in_warpgroup);
|
||||
|
||||
// protect write to SMEM
|
||||
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||
|
||||
if constexpr (DEBUG) {
|
||||
if (warpgroup_id == 0) {
|
||||
if (tile_k_ == 0) {
|
||||
// thread_block_copy_tile(smem_P, gmem_tmp_d0,
|
||||
// tid_in_warpgroup, threads_per_warpgroup,
|
||||
// warpgroup_id_in_cluster);
|
||||
if (tile_k == 0) {
|
||||
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
} else if (tile_k_ == 1) {
|
||||
// thread_block_copy_tile(smem_P, gmem_tmp_d1,
|
||||
// tid_in_warpgroup, threads_per_warpgroup,
|
||||
// warpgroup_id_in_cluster);
|
||||
} else if (tile_k == 1) {
|
||||
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
@@ -748,24 +809,18 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
if constexpr (DEBUG) {
|
||||
if (warpgroup_id == 0) {
|
||||
// O before PV
|
||||
if (tile_k_ == 0) {
|
||||
thread_block_copy_tile(smem_P, gmem_tmp_d0, tid_in_warpgroup,
|
||||
if (tile_k == 0) {
|
||||
thread_block_copy_tile(smem_P, gmem_tmp_d2, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
thread_block_copy_tile(smem_V, gmem_tmp_d6, tid_in_warpgroup,
|
||||
thread_block_copy_tile(smem_O, gmem_tmp_d4, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
thread_block_copy_tile(smem_O, gmem_tmp_d2, tid_in_warpgroup,
|
||||
} else if (tile_k == 1) {
|
||||
thread_block_copy_tile(smem_P, gmem_tmp_d3, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
} else if (tile_k_ == 1) {
|
||||
thread_block_copy_tile(smem_P, gmem_tmp_d1, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
thread_block_copy_tile(smem_V, gmem_tmp_d7, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
thread_block_copy_tile(smem_O, gmem_tmp_d3, tid_in_warpgroup,
|
||||
thread_block_copy_tile(smem_O, gmem_tmp_d5, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
}
|
||||
@@ -838,12 +893,12 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
if constexpr (DEBUG) {
|
||||
if (warpgroup_id == 0) {
|
||||
// O after PV
|
||||
if (tile_k_ == 0) {
|
||||
thread_block_copy_tile(smem_O, gmem_tmp_d4, tid_in_warpgroup,
|
||||
if (tile_k == 0) {
|
||||
thread_block_copy_tile(smem_O, gmem_tmp_d6, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
} else if (tile_k_ == 1) {
|
||||
thread_block_copy_tile(smem_O, gmem_tmp_d5, tid_in_warpgroup,
|
||||
} else if (tile_k == 1) {
|
||||
thread_block_copy_tile(smem_O, gmem_tmp_d7, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user