flash: Fix overlap in smem alloc for P tile
This commit is contained in:
@@ -523,13 +523,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
|
||||
float *smem_Q = reinterpret_cast<float *>(smem_per_threadblock);
|
||||
float *smem_K = smem_Q + smem_Q_size;
|
||||
// in-place multiplication of QK into Q
|
||||
float *smem_S = reinterpret_cast<float *>(smem_per_threadblock);
|
||||
float *smem_P0 = smem_S; // in-place update from S to P
|
||||
float *smem_O = smem_S + smem_QK_size;
|
||||
float *smem_P0 = reinterpret_cast<float *>(DEV_FAKE_SMEM_START_ADDR);
|
||||
float *smem_P1 = smem_P0 + smem_QK_size;
|
||||
float *smem_O = smem_P1 + smem_QK_size;
|
||||
float *smem_V0 =
|
||||
reinterpret_cast<float *>(DEV_FAKE_SMEM_START_ADDR) + smem_QK_size;
|
||||
float *smem_V0 = smem_P1 + smem_QK_size;
|
||||
float *smem_V1 = smem_V0 + smem_QK_size;
|
||||
|
||||
// allocate rowmax/rowsum storage at the end of the sharedmem address space
|
||||
@@ -566,6 +564,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
float *gmem_tmp_d3 = reinterpret_cast<float *>(0xd3000000UL);
|
||||
float *gmem_tmp_d4 = reinterpret_cast<float *>(0xd4000000UL);
|
||||
float *gmem_tmp_d5 = reinterpret_cast<float *>(0xd5000000UL);
|
||||
float *gmem_tmp_d6 = reinterpret_cast<float *>(0xd6000000UL);
|
||||
float *gmem_tmp_d7 = reinterpret_cast<float *>(0xd7000000UL);
|
||||
float *gmem_tmp_e0 = reinterpret_cast<float *>(0xe0000000UL);
|
||||
float *gmem_tmp_e1 = reinterpret_cast<float *>(0xe1000000UL);
|
||||
float *gmem_tmp_e2 = reinterpret_cast<float *>(0xe2000000UL);
|
||||
@@ -662,19 +662,19 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
|
||||
if constexpr (DEBUG) {
|
||||
if (tile_k_ == 0) {
|
||||
thread_block_copy_tile(smem_P_produce, gmem_tmp_d0,
|
||||
tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
// thread_block_copy_tile(smem_P_produce, gmem_tmp_d0,
|
||||
// tid_in_warpgroup, threads_per_warpgroup,
|
||||
// warpgroup_id_in_cluster);
|
||||
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
} else if (tile_k_ == k_tiles - 1) {
|
||||
thread_block_copy_tile(smem_P_produce, gmem_tmp_d1,
|
||||
tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
} else if (tile_k_ == 1) {
|
||||
// thread_block_copy_tile(smem_P_produce, gmem_tmp_d1,
|
||||
// tid_in_warpgroup, threads_per_warpgroup,
|
||||
// warpgroup_id_in_cluster);
|
||||
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
@@ -698,10 +698,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
|
||||
// GEMM II: O = O + P*V
|
||||
|
||||
// clear out accumulators
|
||||
initialize_accum_regs<0>();
|
||||
initialize_accum_regs<1>();
|
||||
|
||||
// V dimension is [seqlen, headdim], stored N(headdim)-major
|
||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
|
||||
HEADDIM, threads_per_warpgroup>(
|
||||
@@ -724,10 +720,22 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
if constexpr (DEBUG) {
|
||||
// O before PV
|
||||
if (tile_k_ == 0) {
|
||||
thread_block_copy_tile(smem_P_consume, gmem_tmp_d0,
|
||||
tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
thread_block_copy_tile(smem_V_consume, gmem_tmp_d6,
|
||||
tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
thread_block_copy_tile(smem_O, gmem_tmp_d2, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
} else if (tile_k_ == k_tiles - 1) {
|
||||
} else if (tile_k_ == 1) {
|
||||
thread_block_copy_tile(smem_P_consume, gmem_tmp_d1,
|
||||
tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
thread_block_copy_tile(smem_V_consume, gmem_tmp_d7,
|
||||
tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
thread_block_copy_tile(smem_O, gmem_tmp_d3, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
@@ -738,6 +746,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
}
|
||||
|
||||
if constexpr (!DOUBLE_BUF) {
|
||||
// clear out accumulators
|
||||
initialize_accum_regs<0>();
|
||||
initialize_accum_regs<1>();
|
||||
|
||||
thread_block_gemm_single_tile<float, MemLayout::K_major,
|
||||
MemLayout::MN_major, B_ROW, HEADDIM,
|
||||
B_COL,
|
||||
@@ -768,6 +780,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
float *smem_O_half0 = smem_O;
|
||||
float *smem_O_half1 = smem_O + (B_ROW / 2) * HEADDIM;
|
||||
|
||||
// clear out accumulators
|
||||
initialize_accum_regs<0>();
|
||||
initialize_accum_regs<1>();
|
||||
|
||||
// split by rows into 2 chunks
|
||||
thread_block_gemm_single_tile<float, MemLayout::K_major,
|
||||
MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
||||
@@ -797,7 +813,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
thread_block_copy_tile(smem_O, gmem_tmp_d4, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
} else if (tile_k_ == k_tiles - 1) {
|
||||
} else if (tile_k_ == 1) {
|
||||
thread_block_copy_tile(smem_O, gmem_tmp_d5, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
@@ -813,7 +829,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// threadblock_barrier(threadblock_id_in_cluster,
|
||||
// warps_per_threadblock_per_core);
|
||||
threadblock_barrier(3, // FIXME
|
||||
8);
|
||||
NUM_WARPS);
|
||||
}
|
||||
|
||||
asm volatile ("tile_loop_finish_%=:" :: );
|
||||
|
||||
Reference in New Issue
Block a user