From e02892ab7de11f22bccb26c84127814122b515be Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sat, 7 Sep 2024 17:49:37 -0700 Subject: [PATCH] flash: Fix DMA for up to GEMM II yeah --- tests/regression/flash_attention/kernel.cpp | 194 ++++++++++++-------- 1 file changed, 122 insertions(+), 72 deletions(-) diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 64e28939..8da963a6 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -168,16 +168,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( const uint32_t warps_per_threadblock_per_core = warps_in_threadblock / CORES_PER_CLUSTER; - // float ft[8]; - // asm volatile("fmv.s %0, f16" : "=f"(ft[0])); - // asm volatile("fmv.s %0, f17" : "=f"(ft[1])); - // asm volatile("fmv.s %0, f18" : "=f"(ft[2])); - // asm volatile("fmv.s %0, f19" : "=f"(ft[3])); - // asm volatile("fmv.s %0, f20" : "=f"(ft[4])); - // asm volatile("fmv.s %0, f21" : "=f"(ft[5])); - // asm volatile("fmv.s %0, f22" : "=f"(ft[6])); - // asm volatile("fmv.s %0, f23" : "=f"(ft[7])); - float *smem_rowmax_this = smem_rowmax + B_ROW; #pragma GCC unroll 1 @@ -541,6 +531,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { uint8_t *smem_per_threadblock = reinterpret_cast( DEV_SMEM_START_ADDR); float *smem_cursor = reinterpret_cast(smem_per_threadblock); + // float *smem_cursor = reinterpret_cast(DEV_FAKE_SMEM_START_ADDR); float *smem_Q0 = smem_cursor; smem_cursor += smem_Q_size; float *smem_Q1 = smem_cursor; @@ -587,31 +578,33 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { constexpr uint32_t smem_rowmax_size = B_ROW * ROWMAX_SETS; constexpr uint32_t smem_rowsum_size = B_ROW; constexpr uint32_t smem_O_row_scale_size = B_ROW; - smem_cursor = reinterpret_cast(SMEM_ADDR_END); + // smem_cursor = reinterpret_cast(DEV_FAKE_SMEM_START_ADDR + SMEM_SIZE); + smem_cursor = reinterpret_cast(0xff038000); - smem_cursor -= smem_rowmax_size; float *smem_rowmax_0 = smem_cursor; - smem_cursor -= smem_rowmax_size; + smem_cursor += smem_rowmax_size; float *smem_rowmax_1 = smem_cursor; - smem_cursor -= smem_rowsum_size; + smem_cursor += smem_rowmax_size; float *smem_rowsum_0 = smem_cursor; - smem_cursor -= smem_rowsum_size; + smem_cursor += smem_rowsum_size; float *smem_rowsum_1 = smem_cursor; - smem_cursor -= smem_O_row_scale_size; + smem_cursor += smem_rowsum_size; float *smem_O_row_scale_0 = smem_cursor; - smem_cursor -= smem_O_row_scale_size; + smem_cursor += smem_O_row_scale_size; float *smem_O_row_scale_1 = smem_cursor; + smem_cursor += smem_O_row_scale_size; // sharedmem "scratchpad" area to put temporary data, e.g. for tree reduction // in rowsum // NOTE: out-of bounds is not checked // TODO: reduce this from B_ROW to NUM_WARPS constexpr uint32_t smem_scratchpad_size = - threads_per_warpgroup * 2 /*arbitrary slack*/; - smem_cursor -= smem_scratchpad_size; + B_ROW * NUM_THREADS * 2 /*arbitrary slack*/; + // threads_per_warpgroup * 2 /*arbitrary slack*/; float *smem_scratchpad_0 = smem_cursor; - smem_cursor -= smem_scratchpad_size; + smem_cursor += smem_scratchpad_size; float *smem_scratchpad_1 = smem_cursor; + smem_cursor += smem_scratchpad_size; // select the correct buffer by warpgroup float *smem_Q = (warpgroup_id % 2) ? smem_Q1 : smem_Q0; @@ -628,19 +621,24 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { (warpgroup_id % 2) ? smem_scratchpad_1 : smem_scratchpad_0; // initialize rowmax/rowsum values in sharedmem - // thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O, - // smem_rowmax, smem_rowsum, smem_O_row_scale); + thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O, + smem_rowmax, smem_rowsum, smem_O_row_scale); constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary // delay warpgroup 0 by 1 iteration to do ping-pong scheduling - // if (warpgroup_id == 1) { - // threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); - // } + if (warpgroup_id == 1) { + threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); + } static_assert(!GEMMINI_DMA || Q_IS_K_MAJOR, "DMA code assumes Q matrix is stored K-major"); + // skip everything except DMA in the loop FSM + constexpr uint32_t skips = + loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/0, /*skip_ldd=*/1, + /*skip_ex=*/1, /*skip_stc=*/1); + if constexpr (GEMMINI_DMA) { if (tid_in_warpgroup == 0) { gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0); @@ -680,8 +678,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q), (uint64_t)(gmem_K), k_LOOP_WS_CONFIG_ADDRS_AB) // configure address strides for the DMA - // GEMMINI_CISC_CMD_R((B_COL << 16) | (HEADDIM << 8) | - // 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); GEMMINI_CISC_CMD_R((dim_seqlen << 16) | (HEADDIM << 8) | 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); gemmini_fence(); @@ -691,11 +687,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { GEMMINI_CISC_CMD_I(9); gemmini_fence(); #else - // skip everything except DMA in the loop FSM - constexpr uint32_t skips = - loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/0, /*skip_ldd=*/1, - /*skip_ex=*/1, /*skip_stc=*/1); - + // do DMA + // // among other things, this also configures CONFIG_BOUNDS so that the // DMA knows the full matrix dimensions sp_tiled_matmul_full_spad_ws( @@ -707,6 +700,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); gemmini_fence(); #endif + + // re-configure DMA for K and V load that will later happen in the loop + // GMEM addr stride for K + gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t), MVIN_SCALE_IDENTITY, + false, 0); + // GMEM addr stride for V + gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY, + false, 1); + gemmini_fence(); } asm volatile("dma_move_end_%=:" ::); @@ -767,7 +769,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<0>(); initialize_accum_regs<1>(); - if constexpr (Q_IS_K_MAJOR) { + if constexpr (GEMMINI_DMA) { + thread_block_gemm_single_tile< + float, MemLayout::block_row_major, MemLayout::block_row_major, + B_ROW, B_COL, HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0, + /*load_accum=*/false, + /*write_to_smem=*/true>( + smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else if constexpr (Q_IS_K_MAJOR) { thread_block_gemm_single_tile< float, MemLayout::K_major, MemLayout::MN_major, B_ROW, B_COL, HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0, @@ -803,6 +814,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<1>(); // split by rows into 2 chunks + // TODO: GEMMINI_DMA if constexpr (Q_IS_K_MAJOR) { thread_block_gemm_single_tile< float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL, @@ -826,6 +838,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<0>(); initialize_accum_regs<1>(); + // TODO: GEMMINI_DMA if constexpr (Q_IS_K_MAJOR) { thread_block_gemm_single_tile< float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL, @@ -877,7 +890,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // inter-warpgroup barrier before online softmax threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); -#if 0 // Online softmax // thread_block_online_softmax(smem_S, smem_P, tid_in_warpgroup, @@ -885,22 +897,52 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { smem_scratchpad, smem_rowmax, smem_rowsum, smem_O_row_scale); + // FIXME: unnecessary? + threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + // data movement for K and V // // Q stays in SMEM for the entire loop - // - // load K for the next iteration - load_tile_to_smem( - dim_seqlen, tile_k + 1, 0 /* dim_k == headdim */, gmem_K, smem_K, - tid_in_warpgroup); + if constexpr (GEMMINI_DMA) { + if (tid_in_threadblock == 0) { + // configure GMEM addresses for K and V tiles + // load K for the next iteration + const float *gmem_K_tile = gmem_K + (B_COL * (tile_k + 1)); + // load V for the current iteration + const float *gmem_V_tile = gmem_V + (HEADDIM * B_COL * tile_k); + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile), + (uint64_t)(gmem_V_tile), + k_LOOP_WS_CONFIG_ADDRS_AB) + // configure address strides for the DMA + // FIXME: unnecessary? + GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 16) | (dim_seqlen /*KT*/ << 8) | + 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); + gemmini_fence(); - // load V for the current iteration - // V dimension is [seqlen, headdim], stored N(headdim)-major - load_tile_to_smem( - HEADDIM, 0 /* full N-dimension */, tile_k, gmem_V, smem_V, - tid_in_warpgroup); + // do DMA + sp_tiled_matmul_full_spad_ws( + spad_addr_K0, spad_addr_V0, + /*spad_D=*/0, /*spad_C=*/spad_addr_S0, + /*I=*/(HEADDIM / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); + gemmini_fence(); + } + } else { + // load K for the next iteration + load_tile_to_smem( + dim_seqlen, tile_k + 1, 0 /* dim_k == headdim */, gmem_K, smem_K, + tid_in_warpgroup); + + // load V for the current iteration + // V dimension is [seqlen, headdim], stored N(headdim)-major + load_tile_to_smem( + HEADDIM, 0 /* full N-dimension */, tile_k, gmem_V, smem_V, + tid_in_warpgroup); + } // protect write to SMEM threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); @@ -970,25 +1012,38 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<0>(); initialize_accum_regs<1>(); - thread_block_gemm_single_tile( - smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_warpgroup, - threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); - - // FIXME: wrong but fast - // thread_block_gemm_single_tile( - // smem_P, smem_V, smem_O /*load accum*/, smem_O, - // tid_in_warpgroup, threads_per_warpgroup, - // warpgroups_per_cluster, warpgroup_id_in_cluster); + if constexpr (GEMMINI_DMA) { + thread_block_gemm_single_tile( + smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_warpgroup, + threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile( + smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_warpgroup, + threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + // FIXME: wrong but fast + // thread_block_gemm_single_tile( + // smem_P, smem_V, smem_O /*load accum*/, smem_O, + // tid_in_warpgroup, threads_per_warpgroup, + // warpgroups_per_cluster, warpgroup_id_in_cluster); + } } else { // when warp-specialized, there's only enough warps to do 64x32 tile // size so we need to do 2 GEMM calls @@ -1006,6 +1061,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<1>(); // split by rows into 2 chunks + // TODO: GEMMINI_DMA thread_block_gemm_single_tile< float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0, @@ -1047,13 +1103,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { warps_per_warpgroup_per_core); } } - - tile_iter_end: - // synchronize progress of two warpgroups - // threadblock_barrier(threadblock_id_in_cluster, - // warps_per_threadblock_per_core); - // threadblock_barrier(3, // FIXME - // NUM_WARPS); +#if 0 #endif }