flash: Fix overlap in smem alloc for P tile

This commit is contained in:
Hansung Kim
2024-08-31 15:18:14 -07:00
parent bdd6e6a9ce
commit 817cc9a5a5

View File

@@ -523,13 +523,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
float *smem_Q = reinterpret_cast<float *>(smem_per_threadblock);
float *smem_K = smem_Q + smem_Q_size;
// in-place multiplication of QK into Q
float *smem_S = reinterpret_cast<float *>(smem_per_threadblock);
float *smem_P0 = smem_S; // in-place update from S to P
float *smem_O = smem_S + smem_QK_size;
float *smem_P0 = reinterpret_cast<float *>(DEV_FAKE_SMEM_START_ADDR);
float *smem_P1 = smem_P0 + smem_QK_size;
float *smem_O = smem_P1 + smem_QK_size;
float *smem_V0 =
reinterpret_cast<float *>(DEV_FAKE_SMEM_START_ADDR) + smem_QK_size;
float *smem_V0 = smem_P1 + smem_QK_size;
float *smem_V1 = smem_V0 + smem_QK_size;
// allocate rowmax/rowsum storage at the end of the sharedmem address space
@@ -566,6 +564,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
float *gmem_tmp_d3 = reinterpret_cast<float *>(0xd3000000UL);
float *gmem_tmp_d4 = reinterpret_cast<float *>(0xd4000000UL);
float *gmem_tmp_d5 = reinterpret_cast<float *>(0xd5000000UL);
float *gmem_tmp_d6 = reinterpret_cast<float *>(0xd6000000UL);
float *gmem_tmp_d7 = reinterpret_cast<float *>(0xd7000000UL);
float *gmem_tmp_e0 = reinterpret_cast<float *>(0xe0000000UL);
float *gmem_tmp_e1 = reinterpret_cast<float *>(0xe1000000UL);
float *gmem_tmp_e2 = reinterpret_cast<float *>(0xe2000000UL);
@@ -662,19 +662,19 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
if constexpr (DEBUG) {
if (tile_k_ == 0) {
thread_block_copy_tile(smem_P_produce, gmem_tmp_d0,
tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
// thread_block_copy_tile(smem_P_produce, gmem_tmp_d0,
// tid_in_warpgroup, threads_per_warpgroup,
// warpgroup_id_in_cluster);
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_warpgroup,
threads_per_warpgroup,
warpgroup_id_in_cluster);
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_warpgroup,
threads_per_warpgroup,
warpgroup_id_in_cluster);
} else if (tile_k_ == k_tiles - 1) {
thread_block_copy_tile(smem_P_produce, gmem_tmp_d1,
tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
} else if (tile_k_ == 1) {
// thread_block_copy_tile(smem_P_produce, gmem_tmp_d1,
// tid_in_warpgroup, threads_per_warpgroup,
// warpgroup_id_in_cluster);
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_warpgroup,
threads_per_warpgroup,
warpgroup_id_in_cluster);
@@ -698,10 +698,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// GEMM II: O = O + P*V
// clear out accumulators
initialize_accum_regs<0>();
initialize_accum_regs<1>();
// V dimension is [seqlen, headdim], stored N(headdim)-major
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
HEADDIM, threads_per_warpgroup>(
@@ -724,10 +720,22 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
if constexpr (DEBUG) {
// O before PV
if (tile_k_ == 0) {
thread_block_copy_tile(smem_P_consume, gmem_tmp_d0,
tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
thread_block_copy_tile(smem_V_consume, gmem_tmp_d6,
tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
thread_block_copy_tile(smem_O, gmem_tmp_d2, tid_in_warpgroup,
threads_per_warpgroup,
warpgroup_id_in_cluster);
} else if (tile_k_ == k_tiles - 1) {
} else if (tile_k_ == 1) {
thread_block_copy_tile(smem_P_consume, gmem_tmp_d1,
tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
thread_block_copy_tile(smem_V_consume, gmem_tmp_d7,
tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
thread_block_copy_tile(smem_O, gmem_tmp_d3, tid_in_warpgroup,
threads_per_warpgroup,
warpgroup_id_in_cluster);
@@ -738,6 +746,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
}
if constexpr (!DOUBLE_BUF) {
// clear out accumulators
initialize_accum_regs<0>();
initialize_accum_regs<1>();
thread_block_gemm_single_tile<float, MemLayout::K_major,
MemLayout::MN_major, B_ROW, HEADDIM,
B_COL,
@@ -768,6 +780,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
float *smem_O_half0 = smem_O;
float *smem_O_half1 = smem_O + (B_ROW / 2) * HEADDIM;
// clear out accumulators
initialize_accum_regs<0>();
initialize_accum_regs<1>();
// split by rows into 2 chunks
thread_block_gemm_single_tile<float, MemLayout::K_major,
MemLayout::MN_major, B_ROW / 2, HEADDIM,
@@ -797,7 +813,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
thread_block_copy_tile(smem_O, gmem_tmp_d4, tid_in_warpgroup,
threads_per_warpgroup,
warpgroup_id_in_cluster);
} else if (tile_k_ == k_tiles - 1) {
} else if (tile_k_ == 1) {
thread_block_copy_tile(smem_O, gmem_tmp_d5, tid_in_warpgroup,
threads_per_warpgroup,
warpgroup_id_in_cluster);
@@ -813,7 +829,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// threadblock_barrier(threadblock_id_in_cluster,
// warps_per_threadblock_per_core);
threadblock_barrier(3, // FIXME
8);
NUM_WARPS);
}
asm volatile ("tile_loop_finish_%=:" :: );