From cdb8377b62af9b242ad9a55e371430c9cd283f04 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sun, 8 Sep 2024 16:09:06 -0700 Subject: [PATCH] flash: Do GEMM II in Gemmini; verify 1st iteration --- .../regression/flash_attention/flash_impl.hpp | 31 ++- .../flash_attention/kernel.gemmini.cpp | 252 +++++++----------- 2 files changed, 113 insertions(+), 170 deletions(-) diff --git a/tests/regression/flash_attention/flash_impl.hpp b/tests/regression/flash_attention/flash_impl.hpp index fd027553..46a62546 100644 --- a/tests/regression/flash_attention/flash_impl.hpp +++ b/tests/regression/flash_attention/flash_impl.hpp @@ -95,7 +95,7 @@ inline void thread_block_copy_rowmax(const float *src, float *dest, asm volatile("threadblock_copy_rowmax_finish_%=:" ::); } -template +template inline void thread_block_copy_tile(const float *src, float *dest, const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock, @@ -113,14 +113,18 @@ inline void thread_block_copy_tile(const float *src, float *dest, for (int row_offset = 0; row_offset < dim_row; row_offset += warps_in_threadblock) { const uint32_t row = row_offset + warp_id; - const uint32_t first_thread_offset = dim_col * row; constexpr uint32_t per_row_iter = dim_col / NUM_THREADS; - uint32_t thread_offset = first_thread_offset + tid_in_warp; #pragma GCC unroll for (int i = 0; i < per_row_iter; i++) { - dest[thread_offset] = src[thread_offset]; - thread_offset += NUM_THREADS; + const uint32_t col_offset = NUM_THREADS * i; + const uint32_t col = col_offset + tid_in_warp; + const auto [smem_row, smem_col] = + remap_to_gemmini_dma_layout(row, col); + const uint32_t smem_offset = B_COL * smem_row + smem_col; + const uint32_t gmem_offset = B_COL * row + col; + + dest[gmem_offset] = src[smem_offset]; } threadblock_barrier(threadblock_id_in_cluster, @@ -415,6 +419,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( asm volatile("thread_block_online_softmax_finish_%=:" ::); } +template __attribute__((always_inline)) inline void thread_block_O_rescale( const float *smem_O_in, float *smem_O_out, const float *smem_O_row_scale, const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock, @@ -431,19 +436,21 @@ __attribute__((always_inline)) inline void thread_block_O_rescale( for (int row_offset = 0; row_offset < B_ROW; row_offset += warps_in_threadblock) { const uint32_t row = row_offset + warp_id; - const uint32_t first_thread_offset = B_COL * row; - constexpr uint32_t per_row_iter = B_COL / NUM_THREADS; - uint32_t thread_offset = first_thread_offset + tid_in_warp; + constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS; // Oi rescale // #pragma GCC unroll for (int i = 0; i < per_row_iter; i++) { - const float o = smem_O_in[thread_offset]; - const float scale = smem_O_row_scale[row]; - smem_O_out[thread_offset] = (o * scale); + const uint32_t col_offset = NUM_THREADS * i; + const uint32_t col = col_offset + tid_in_warp; + const auto [smem_row, smem_col] = + remap_to_gemmini_dma_layout(row, col); - thread_offset += NUM_THREADS; + const uint32_t offset = HEADDIM * smem_row + smem_col; + const float o = smem_O_in[offset]; + const float scale = smem_O_row_scale[row]; + smem_O_out[offset] = (o * scale); } } diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index 4a2c3133..d5c553d8 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -110,8 +110,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { smem_cursor += smem_QK_size; float *smem_S1 = smem_cursor; smem_cursor += smem_QK_size; - float *smem_P0 = smem_S0; // in-place update - float *smem_P1 = smem_S1; // in-place update + float *smem_P0 = smem_cursor; + smem_cursor += smem_QK_size; + float *smem_P1 = smem_cursor; + smem_cursor += smem_QK_size; float *smem_O0 = smem_cursor; smem_cursor += smem_O_size; float *smem_O1 = smem_cursor; @@ -135,6 +137,14 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { spad_addr_V1 + (smem_V_size * sizeof(float) / spad_addr_factor); constexpr uint32_t spad_addr_S1 = spad_addr_S0 + (smem_QK_size * sizeof(float) / spad_addr_factor); + constexpr uint32_t spad_addr_P0 = + spad_addr_S1 + (smem_QK_size * sizeof(float) / spad_addr_factor); + constexpr uint32_t spad_addr_P1 = + spad_addr_P0 + (smem_QK_size * sizeof(float) / spad_addr_factor); + constexpr uint32_t spad_addr_O0 = + spad_addr_P1 + (smem_QK_size * sizeof(float) / spad_addr_factor); + constexpr uint32_t spad_addr_O1 = + spad_addr_O0 + (smem_O_size * sizeof(float) / spad_addr_factor); // allocate rowmax/rowsum storage at the end of the sharedmem address space constexpr uint32_t smem_rowmax_size = B_ROW * ROWMAX_SETS; @@ -189,6 +199,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { constexpr uint32_t skips_mvout_spad = loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1, /*skip_ex=*/1, /*skip_stc=*/0); + constexpr uint32_t skips_matmul = + loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1, + /*skip_ex=*/0, /*skip_stc=*/1); if constexpr (GEMMINI_DMA) { if (tid_in_warpgroup == 0) { @@ -281,12 +294,13 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const uint32_t k_tiles = (dim_seqlen / B_COL); for (uint32_t tile_k = 0; tile_k < k_tiles; tile_k++) { // select the correct double buffer by tile iteration - float *smem_Q = (tile_k & 1) ? smem_Q1 : smem_Q0; // FIXME - float *smem_K = (tile_k & 1) ? smem_K1 : smem_K0; // FIXME - float *smem_V = (tile_k & 1) ? smem_V1 : smem_V0; // FIXME - float *smem_S = (tile_k & 1) ? smem_S1 : smem_S0; // FIXME - float *smem_O = (tile_k & 1) ? smem_O1 : smem_O0; // FIXME - float *smem_P = smem_S; // FIXME + // FIXME do correct double buffering + float *smem_Q = (tile_k & 1) ? smem_Q1 : smem_Q0; + float *smem_K = (tile_k & 1) ? smem_K1 : smem_K0; + float *smem_V = (tile_k & 1) ? smem_V1 : smem_V0; + float *smem_S = (tile_k & 1) ? smem_S1 : smem_S0; + float *smem_P = (tile_k & 1) ? smem_P1 : smem_P0; + float *smem_O = (tile_k & 1) ? smem_O1 : smem_O0; float *smem_O_row_scale = (tile_k & 1) ? smem_O_row_scale_1 : smem_O_row_scale_0; float *smem_rowmax = (tile_k & 1) ? smem_rowmax_1 : smem_rowmax_0; @@ -298,6 +312,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const auto spad_addr_K = (tile_k & 1) ? spad_addr_K1 : spad_addr_K0; const auto spad_addr_V = (tile_k & 1) ? spad_addr_V1 : spad_addr_V0; const auto spad_addr_S = (tile_k & 1) ? spad_addr_S1 : spad_addr_S0; + const auto spad_addr_P = (tile_k & 1) ? spad_addr_P1 : spad_addr_P0; + const auto spad_addr_O = spad_addr_O0; // NOTE: there's only single O tile // GEMM I: S = Q*K // @@ -320,7 +336,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { gemmini_fence(); gemmini_fence(); -#if 0 // TODO +#if 0 // TODO: speed up mvout to SMEM // loop_ws variant that skips configuring strides #define gemmini_loop_ws(I, J, K, pad_I, pad_J, pad_K, A, B, D, C, A_stride, B_stride, D_stride, C_stride, A_transpose, B_transpose, full_C, low_D, ex_accumulate, act, a_spad_id, b_spad_id, is_resadd) \ { \ @@ -350,7 +366,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } } - // thread reconvergence + // reconverge from mmio divergence threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); asm volatile("gemm_qk_finish_%=:" ::); @@ -358,13 +374,13 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if constexpr (DEBUG) { if (warpgroup_id == 0) { if (tile_k == 0) { - thread_block_copy_tile(smem_S, gmem_tmp_d0, - tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); + thread_block_copy_tile( + smem_S, gmem_tmp_d0, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); } else if (tile_k == 1) { - thread_block_copy_tile(smem_S, gmem_tmp_d1, - tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); + thread_block_copy_tile( + smem_S, gmem_tmp_d1, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); } threadblock_barrier(warpgroup_id_in_cluster, @@ -408,7 +424,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } } -#if 0 // data movement for K and V // // Q stays in SMEM for the entire loop @@ -463,9 +478,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); // Oi rescale - thread_block_O_rescale(smem_O, smem_O /*in-place*/, - smem_O_row_scale, tid_in_warpgroup, - threads_per_warpgroup, warpgroup_id_in_cluster); + thread_block_O_rescale( + smem_O, smem_O /*in-place*/, smem_O_row_scale, tid_in_warpgroup, + threads_per_warpgroup, warpgroup_id_in_cluster); // rescale-to-PV-GEMM barrier threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); @@ -474,19 +489,19 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if (warpgroup_id == 0) { // O before PV if (tile_k == 0) { - thread_block_copy_tile(smem_P, gmem_tmp_d2, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); - thread_block_copy_tile(smem_O, gmem_tmp_d4, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); + thread_block_copy_tile( + smem_P, gmem_tmp_d2, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); + thread_block_copy_tile( + smem_O, gmem_tmp_d4, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); } else if (tile_k == 1) { - thread_block_copy_tile(smem_P, gmem_tmp_d3, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); - thread_block_copy_tile(smem_O, gmem_tmp_d5, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); + thread_block_copy_tile( + smem_P, gmem_tmp_d3, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); + thread_block_copy_tile( + smem_O, gmem_tmp_d5, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); } threadblock_barrier(warpgroup_id_in_cluster, @@ -498,134 +513,54 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { asm volatile("gemm_pv_start_%=:" ::); - if constexpr (!WARP_SPECIALIZED) { - // clear out accumulators before GEMM - initialize_accum_regs<0>(); - initialize_accum_regs<1>(); - - if constexpr (GEMMINI_DMA) { - thread_block_gemm_single_tile< - float, MemLayout::K_major /* P matrix is row-major */, - MemLayout::block_row_major, B_ROW, HEADDIM, B_COL, - /*leading_dim_a=*/0, /*leading_dim_b=*/0, - /*load_accum=*/true, - /*write_to_smem=*/true>( - smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_warpgroup, - threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); + if (tid_in_warpgroup == 0) { +#if 0 + if (tile_k == 0) { + gemmini_fence(); + GEMMINI_CISC_CMD_I(0); + } else if (tile_k & 1) { + gemmini_fence(); + GEMMINI_CISC_CMD_I(2); } else { - thread_block_gemm_single_tile( - smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_warpgroup, - threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); - // FIXME: wrong but fast - // thread_block_gemm_single_tile( - // smem_P, smem_V, smem_O /*load accum*/, smem_O, - // tid_in_warpgroup, threads_per_warpgroup, - // warpgroups_per_cluster, warpgroup_id_in_cluster); + gemmini_fence(); + GEMMINI_CISC_CMD_I(1); } - } else { - // when warp-specialized, there's only enough warps to do 64x32 tile - // size so we need to do 2 GEMM calls - static_assert(B_ROW / 2 == 32, - "tile size assumption for warp-specialization not met"); +#else + // do matmul + // among other things, this also configures CONFIG_BOUNDS so that the + // DMA knows the full matrix dimensions + sp_tiled_matmul_full_spad_ws( + spad_addr_P, spad_addr_V, + /*spad_D=*/0, /*spad_C=*/spad_addr_O, + /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul); +#endif - // assumes smem_P is K-major - float *smem_P_half0 = smem_P; - float *smem_P_half1 = smem_P + (B_ROW / 2) * B_COL; - float *smem_O_half0 = smem_O; - float *smem_O_half1 = smem_O + (B_ROW / 2) * HEADDIM; + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); - // clear out accumulators before GEMM - initialize_accum_regs<0>(); - initialize_accum_regs<1>(); + // mvout to SMEM + // GEMMINI_CISC_CMD_I(9); + sp_tiled_matmul_full_spad_ws( + /*spad_A=*/spad_addr_P, /*spad_B=*/spad_addr_V, + /*spad_D=*/0, /*spad_C=*/spad_addr_O, + /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL/ DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad); + gemmini_fence(); - // split by rows into 2 chunks - if constexpr (GEMMINI_DMA) { - if constexpr (GEMMINI_DMA_FAST) { - thread_block_gemm_single_tile( - smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); - } else { - thread_block_gemm_single_tile< - float, MemLayout::K_major /* P matrix is row-major */, - MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL, - /*leading_dim_a=*/0, - /*leading_dim_b=*/0, - /*load_accum=*/true, - /*write_to_smem=*/true>( - smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); - } - } else { - thread_block_gemm_single_tile< - float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, - B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0, - /*load_accum=*/true, - /*write_to_smem=*/true>( - smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); - } - - initialize_accum_regs<0>(); - initialize_accum_regs<1>(); - - if constexpr (GEMMINI_DMA) { - if constexpr (GEMMINI_DMA_FAST) { - thread_block_gemm_single_tile( - smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); - } else { - thread_block_gemm_single_tile< - float, MemLayout::K_major /* P matrix is row-major */, - MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL, - /*leading_dim_a=*/0, - /*leading_dim_b=*/0, - /*load_accum=*/true, - /*write_to_smem=*/true>( - smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); - } - } else { - thread_block_gemm_single_tile< - float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, - B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0, - /*load_accum=*/true, - /*write_to_smem=*/true>( - smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); + if constexpr (DEBUG) { + // for copy-out to GMEM + gemmini_fence(); } } + // reconverge from mmio divergence threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); asm volatile("gemm_pv_finish_%=:" ::); @@ -634,19 +569,20 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if (warpgroup_id == 0) { // O after PV if (tile_k == 0) { - thread_block_copy_tile(smem_O, gmem_tmp_d6, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); + thread_block_copy_tile( + smem_O, gmem_tmp_d6, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); } else if (tile_k == 1) { - thread_block_copy_tile(smem_O, gmem_tmp_d7, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); + thread_block_copy_tile( + smem_O, gmem_tmp_d7, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); } threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); } } +#if 0 #endif }