From 9f067acdb9bc3fd088aa361dad0201a85a4c0f43 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Thu, 5 Sep 2024 19:55:36 -0700 Subject: [PATCH 01/50] sgemm_impl: Remove #if 0, FP_SIZE 16 --- tests/regression/sgemm_tcore/sgemm_impl.hpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index 2014b507..ac07a666 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -6,7 +6,7 @@ #include "include/gemmini.h" #include "gemmini_mmio.h" -#define FP_SIZE 32 +#define FP_SIZE 16 // "fake" fp16 type that only has the correct data width. using float16_t = uint16_t; @@ -1038,7 +1038,6 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, warps_per_threadblock_per_core); #endif -#if 0 // consumer code: SMEM->RF and compute // ---------------------------------------------------------------------- // @perf: this loop spills to stack a lot because of all the flws in @@ -1087,7 +1086,6 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core); -#endif } if constexpr (write_to_gmem) { From d2f086344db2f1e72e821e0126690ef2f1f4d522 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sat, 7 Sep 2024 15:48:37 -0700 Subject: [PATCH 02/50] flash: Fix DMA addr stride, stop at S=Q*K --- tests/regression/flash_attention/kernel.cpp | 192 ++++++++++++-------- tests/regression/sgemm_tcore/sgemm_impl.hpp | 14 +- 2 files changed, 127 insertions(+), 79 deletions(-) diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 78fc8969..64e28939 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -8,10 +8,9 @@ #include "include/gemmini.h" #include "gemmini_mmio.h" -#define B_ROW BM -#define B_COL BN -// FIXME -#define HEADDIM B_COL +#define B_ROW 64 +#define B_COL 64 +#define HEADDIM 64 constexpr uint32_t ROWMAX_SETS = 3; constexpr bool DEBUG = true; @@ -19,6 +18,8 @@ constexpr bool WARP_SPECIALIZED = false; constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000; +constexpr bool Q_IS_K_MAJOR = true; + // temporary safety stop for wrong configs static_assert(NUM_CORES == 4); static_assert(NUM_THREADS == 8); @@ -99,6 +100,7 @@ inline void thread_block_copy_rowmax(const float *src, float *dest, asm volatile("threadblock_copy_rowmax_finish_%=:" ::); } +template inline void thread_block_copy_tile(const float *src, float *dest, const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock, @@ -113,12 +115,12 @@ inline void thread_block_copy_tile(const float *src, float *dest, // FIXME: dedup this pattern #pragma GCC unroll 1 - for (int row_offset = 0; row_offset < B_ROW; + for (int row_offset = 0; row_offset < dim_row; row_offset += warps_in_threadblock) { const uint32_t row = row_offset + warp_id; - const uint32_t first_thread_offset = B_COL * row; + const uint32_t first_thread_offset = dim_col * row; - constexpr uint32_t per_row_iter = B_COL / NUM_THREADS; + constexpr uint32_t per_row_iter = dim_col / NUM_THREADS; uint32_t thread_offset = first_thread_offset + tid_in_warp; #pragma GCC unroll for (int i = 0; i < per_row_iter; i++) { @@ -533,12 +535,12 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { constexpr uint32_t smem_QK_size = B_ROW * B_COL; constexpr uint32_t smem_V_size = B_COL * HEADDIM; constexpr uint32_t smem_O_size = B_COL * HEADDIM; + static_assert( + threads_per_threadblock == NUM_WARPS * NUM_THREADS * CORES_PER_CLUSTER, + "flashattention kernel assumes 1 threadblock occupancy per cluster"); uint8_t *smem_per_threadblock = reinterpret_cast( - DEV_SMEM_START_ADDR + - sizeof(float_type) * - (smem_QK_size + smem_V_size + smem_O_size) * - threadblock_id_in_cluster); - float *smem_cursor = reinterpret_cast(DEV_FAKE_SMEM_START_ADDR); + DEV_SMEM_START_ADDR); + float *smem_cursor = reinterpret_cast(smem_per_threadblock); float *smem_Q0 = smem_cursor; smem_cursor += smem_Q_size; float *smem_Q1 = smem_cursor; @@ -563,6 +565,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { smem_cursor += smem_O_size; // NOTE: this has to match with smem_* + static_assert(sizeof(elem_t) == sizeof(float)); constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t); constexpr uint32_t spad_addr_Q0 = 0; constexpr uint32_t spad_addr_Q1 = @@ -635,15 +638,18 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); // } + static_assert(!GEMMINI_DMA || Q_IS_K_MAJOR, + "DMA code assumes Q matrix is stored K-major"); + if constexpr (GEMMINI_DMA) { if (tid_in_warpgroup == 0) { gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0); - // configure DMA for Q tile + // configure DMA for the full Q matrix gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY, false, 0); - // configure DMA for K tile - gemmini_extended3_config_ld(B_COL * sizeof(elem_t), MVIN_SCALE_IDENTITY, + // configure DMA for the full K matrix + gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t), MVIN_SCALE_IDENTITY, false, 1); // configure DMA for Q*K store gemmini_extended_config_st(B_COL * sizeof(elem_t), 0, @@ -652,12 +658,12 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } } - // NOTE about barriers: placing barriers around thread-divergent branches may - // cause bugs, since the core doesn't check tmask for barriers. The compiler - // might decide to replicate vx_bar into both paths of a conditional branch, - // which will get evaluated twice along the split/join process and result in - // a different number of calls w.r.t other non-divergent warps and therefore - // stalls. + // NOTE about barriers: Placing barriers around thread-divergent branches may + // cause bugs, because the Vortex core doesn't check for tmask for barriers. + // The compiler might decide to duplicate vx_bar into both paths of a + // conditional branch, which will get evaluated twice because of the way + // branches are handled in SIMT; this might result in stalls especially when + // other warps behave differently on the branch condition. // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); // move Q and K into SMEM before the loop starts @@ -674,13 +680,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q), (uint64_t)(gmem_K), k_LOOP_WS_CONFIG_ADDRS_AB) // configure address strides for the DMA - GEMMINI_CISC_CMD_R((B_COL << 16) | (HEADDIM << 8) | + // GEMMINI_CISC_CMD_R((B_COL << 16) | (HEADDIM << 8) | + // 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); + GEMMINI_CISC_CMD_R((dim_seqlen << 16) | (HEADDIM << 8) | 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); gemmini_fence(); -#define GEMMINI_DMA_CISC +// #define GEMMINI_DMA_CISC #ifdef GEMMINI_DMA_CISC - GEMMINI_CISC_CMD_I(10); + GEMMINI_CISC_CMD_I(9); gemmini_fence(); #else // skip everything except DMA in the loop FSM @@ -693,7 +701,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { sp_tiled_matmul_full_spad_ws( spad_addr_Q0, spad_addr_K0, /*spad_D=*/0, /*spad_C=*/spad_addr_S0, - /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); @@ -704,10 +712,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { asm volatile("dma_move_end_%=:" ::); } else { // load Q; this stays in SMEM for the entire loop - load_tile_to_smem( - dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q, - tid_in_warpgroup); + if constexpr (Q_IS_K_MAJOR) { + load_tile_to_smem( + HEADDIM, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q, + tid_in_warpgroup); + } else { + load_tile_to_smem( + dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q, + tid_in_warpgroup); + } // load K load_tile_to_smem(smem_Q0, gmem_tmp_d0, tid_in_warpgroup, + // threads_per_warpgroup, warpgroup_id_in_cluster); + // thread_block_copy_tile(smem_K0, gmem_tmp_d1, tid_in_warpgroup, + // threads_per_warpgroup, warpgroup_id_in_cluster); - // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); - } + // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + // } -#if 0 asm volatile ("tile_loop_start_%=:" :: ); // "inner loop" along the columns of K^T @@ -751,25 +767,34 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<0>(); initialize_accum_regs<1>(); - thread_block_gemm_single_tile< - float, MemLayout::MN_major, MemLayout::MN_major, B_ROW, B_COL, - HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0, - /*load_accum=*/false, - /*write_to_smem=*/true>( - smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_warpgroup, - threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); + if constexpr (Q_IS_K_MAJOR) { + thread_block_gemm_single_tile< + float, MemLayout::K_major, MemLayout::MN_major, B_ROW, B_COL, + HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0, + /*load_accum=*/false, + /*write_to_smem=*/true>( + smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile< + float, MemLayout::MN_major, MemLayout::MN_major, B_ROW, B_COL, + HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0, + /*load_accum=*/false, + /*write_to_smem=*/true>( + smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } } else { // when warp-specialized, there's only enough warps to do 64x32 tile // size so we need to do 2 GEMM calls static_assert(B_ROW / 2 == 32, "tile size assumption for warp-specialization not met"); - // assumes smem_Q is K-major - // FIXME: fix this to MN-major float *smem_Q_half0 = smem_Q; - float *smem_Q_half1 = smem_Q + (B_ROW / 2); // MN-major - // float *smem_Q_half1 = smem_Q + (B_ROW / 2) * HEADDIM; // K-major + float *smem_Q_half1 = Q_IS_K_MAJOR ? smem_Q + (B_ROW / 2) * HEADDIM + : smem_Q + (B_ROW / 2); float *smem_S_half0 = smem_S; float *smem_S_half1 = smem_S + (B_ROW / 2) * B_COL; @@ -778,26 +803,48 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<1>(); // split by rows into 2 chunks - thread_block_gemm_single_tile< - float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2, - B_COL, HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0, - /*load_accum=*/false, - /*write_to_smem=*/true>( - smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); + if constexpr (Q_IS_K_MAJOR) { + thread_block_gemm_single_tile< + float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL, + HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0, + /*load_accum=*/false, + /*write_to_smem=*/true>( + smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile< + float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2, B_COL, + HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0, + /*load_accum=*/false, + /*write_to_smem=*/true>( + smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } initialize_accum_regs<0>(); initialize_accum_regs<1>(); - thread_block_gemm_single_tile< - float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2, - B_COL, HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0, - /*load_accum=*/false, - /*write_to_smem=*/true>( - smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); + if constexpr (Q_IS_K_MAJOR) { + thread_block_gemm_single_tile< + float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL, + HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0, + /*load_accum=*/false, + /*write_to_smem=*/true>( + smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile< + float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2, B_COL, + HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0, + /*load_accum=*/false, + /*write_to_smem=*/true>( + smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } } } else { // load Q*K @@ -813,11 +860,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if constexpr (DEBUG) { if (warpgroup_id == 0) { if (tile_k == 0) { - thread_block_copy_tile(smem_S, gmem_tmp_d0, + thread_block_copy_tile(smem_S, gmem_tmp_d0, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); } else if (tile_k == 1) { - thread_block_copy_tile(smem_S, gmem_tmp_d1, + thread_block_copy_tile(smem_S, gmem_tmp_d1, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); } @@ -830,6 +877,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // inter-warpgroup barrier before online softmax threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); +#if 0 // Online softmax // thread_block_online_softmax(smem_S, smem_P, tid_in_warpgroup, @@ -897,17 +945,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if (warpgroup_id == 0) { // O before PV if (tile_k == 0) { - thread_block_copy_tile(smem_P, gmem_tmp_d2, tid_in_warpgroup, + thread_block_copy_tile(smem_P, gmem_tmp_d2, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); - thread_block_copy_tile(smem_O, gmem_tmp_d4, tid_in_warpgroup, + thread_block_copy_tile(smem_O, gmem_tmp_d4, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); } else if (tile_k == 1) { - thread_block_copy_tile(smem_P, gmem_tmp_d3, tid_in_warpgroup, + thread_block_copy_tile(smem_P, gmem_tmp_d3, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); - thread_block_copy_tile(smem_O, gmem_tmp_d5, tid_in_warpgroup, + thread_block_copy_tile(smem_O, gmem_tmp_d5, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); } @@ -986,11 +1034,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if (warpgroup_id == 0) { // O after PV if (tile_k == 0) { - thread_block_copy_tile(smem_O, gmem_tmp_d6, tid_in_warpgroup, + thread_block_copy_tile(smem_O, gmem_tmp_d6, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); } else if (tile_k == 1) { - thread_block_copy_tile(smem_O, gmem_tmp_d7, tid_in_warpgroup, + thread_block_copy_tile(smem_O, gmem_tmp_d7, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); } @@ -1006,6 +1054,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // warps_per_threadblock_per_core); // threadblock_barrier(3, // FIXME // NUM_WARPS); +#endif } asm volatile ("tile_loop_finish_%=:" :: ); @@ -1015,7 +1064,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if (warpgroup_id == 0) { threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); } -#endif } int main() { diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index ac07a666..db8df789 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -6,7 +6,7 @@ #include "include/gemmini.h" #include "gemmini_mmio.h" -#define FP_SIZE 16 +#define FP_SIZE 32 // "fake" fp16 type that only has the correct data width. using float16_t = uint16_t; @@ -29,7 +29,7 @@ using float_type = float16_t; // (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER // * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields // BM <= BK*TM*TN -#define BM 128 +#define BM 64 #define BN 64 #if (FP_SIZE == 32) #define BK 64 @@ -62,18 +62,18 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER == #define BK_LOOP 1 // Whether to transpose smem A tile at GMEM->SMEM (produce), or SMEM->RF // (consume). This is because the tensor core expects the A tile to be stored -// in column-major order in SMEM, whereas it will be ultimately stored in -// row-major in the RF. +// in column-major order in SMEM, so a transpose is necessary if A was stored +// row-major in GMEM. // // For correctness, only one of either should be 1. E.g., PRODUCE 1 CONSUME 0 // generates the NN kernel where both A and B are stored row-major in GMEM. // To model the case where the A matrix is already stored column-major in GMEM, // set both to 0. #define TRANSPOSE_AT_PRODUCE 0 -#define TRANSPOSE_AT_CONSUME 0 +#define TRANSPOSE_AT_CONSUME 1 -#define GEMMINI_DMA 0 -#define GEMMINI_DMA_MN_MAJOR 1 +#define GEMMINI_DMA 1 +#define GEMMINI_DMA_MN_MAJOR 0 #if SMEM_SIZE == 0x4000 #define SMEM_ADDR_Q0 ((float * const) 0xff000000) #define SMEM_ADDR_Q1 ((float * const) 0xff001000) From ed9bf6f73e6d7339d5495503a7969e672b70b617 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sat, 7 Sep 2024 15:49:19 -0700 Subject: [PATCH 03/50] common.mk: Switch to -Os to prevent branch code duplication Prevents erroneous stalls at vx_bar. See comment in kernel.cpp --- tests/regression/common.mk | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/regression/common.mk b/tests/regression/common.mk index 50efc499..f000dcf6 100644 --- a/tests/regression/common.mk +++ b/tests/regression/common.mk @@ -48,7 +48,7 @@ VX_CP = $(LLVM_VORTEX)/bin/llvm-objcopy #VX_DP = $(RISCV_TOOLCHAIN_PATH)/bin/$(RISCV_PREFIX)-objdump #VX_CP = $(RISCV_TOOLCHAIN_PATH)/bin/$(RISCV_PREFIX)-objcopy -VX_CFLAGS += -v -O3 -std=c++17 +VX_CFLAGS += -v -Os -std=c++17 VX_CFLAGS += -mcmodel=medany -fno-rtti -fno-exceptions -nostartfiles -fdata-sections -ffunction-sections # comment out below for regression/basic, which uses GCC that doesn't # understand these flags From a967c262b144e98aed84f4de866b670f41e43f7b Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sat, 7 Sep 2024 16:38:22 -0700 Subject: [PATCH 04/50] sgemm_impl: Add new block-row-major layout for DMA --- tests/regression/sgemm_tcore/sgemm_impl.hpp | 45 ++++++++++++--------- 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index db8df789..e563f23c 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -70,10 +70,10 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER == // To model the case where the A matrix is already stored column-major in GMEM, // set both to 0. #define TRANSPOSE_AT_PRODUCE 0 -#define TRANSPOSE_AT_CONSUME 1 +#define TRANSPOSE_AT_CONSUME 0 #define GEMMINI_DMA 1 -#define GEMMINI_DMA_MN_MAJOR 0 +#define GEMMINI_DMA_FLEXIBLE_LAYOUT 0 #if SMEM_SIZE == 0x4000 #define SMEM_ADDR_Q0 ((float * const) 0xff000000) #define SMEM_ADDR_Q1 ((float * const) 0xff001000) @@ -101,6 +101,7 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER == enum class MemLayout { MN_major, K_major, + block_row_major, // Gemmini DMA }; inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) { @@ -253,13 +254,14 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k, constexpr int packed_factor = (std::is_same_v ? 2 : 1); const int local_k_adjusted = local_k / packed_factor; - static_assert(!GEMMINI_DMA || (layout == MemLayout::K_major) || - GEMMINI_DMA_MN_MAJOR, - "GEMMINI_DMA only supported for K-major A tile"); + static_assert(!GEMMINI_DMA || (layout == MemLayout::block_row_major) || + GEMMINI_DMA_FLEXIBLE_LAYOUT, + "wrong memory layout selected for DMA"); static_assert((layout != MemLayout::K_major) || (FP_SIZE == 32), "fp16 is not really tested for K-major A layout"); - if constexpr (layout == MemLayout::K_major) { + if constexpr (layout == MemLayout::K_major || + layout == MemLayout::block_row_major) { constexpr int smem_A_cols = leading_dim; // f8-f15 stores a single row of A @@ -269,8 +271,9 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k, // if using Gemmini DMA, remap logical row/col to Gemmini's 2-level // block-row-major layout const auto [smem_row, smem_col] = - remap_to_gemmini_dma_layout(smem_logical_row, - smem_logical_col); + remap_to_gemmini_dma_layout(smem_logical_row, + smem_logical_col); const volatile uint8_t *smem_addr; smem_addr = reinterpret_cast( @@ -356,8 +359,9 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k, const int thread_in_warp) { asm volatile ("wmma_load_b_start_%=:" :: ); - static_assert(layout == MemLayout::MN_major, - "only N-major layout for the B tile is supported"); + static_assert( + layout == MemLayout::MN_major || layout == MemLayout::block_row_major, + "only N-major or block-row-major layout are supported for the B tile"); const int tid = thread_in_warp; const int tg = tid / 4; @@ -379,8 +383,9 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k, // if using Gemmini DMA, remap logical row/col to Gemmini's 2-level // block-row-major layout const auto [smem_row, smem_col] = - remap_to_gemmini_dma_layout(smem_logical_row, - smem_logical_col); + remap_to_gemmini_dma_layout(smem_logical_row, + smem_logical_col); const volatile uint8_t *smem_addr; smem_addr = reinterpret_cast( @@ -388,10 +393,10 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k, smem_B)[smem_B_cols * smem_row + smem_col]); // f8-f15 stores a single column of B // threads read from different columns; no bank conflicts - if constexpr (GEMMINI_DMA) { - // for GEMMINI_DMA, moving rows for the next 7 elements in the same column - // is the same as moving DIM elements forward in the memory because of the - // block-row-major layout + if constexpr (layout == MemLayout::block_row_major) { + // for the block-row-major layout, moving rows for the next 7 elements in + // the same column is the same as moving DIM elements forward in the memory + // because of the block-row-major layout asm volatile("flw f8, %0(%1)" :: "i"(DIM * 0 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f9, %0(%1)" :: "i"(DIM * 1 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f10, %0(%1)" :: "i"(DIM * 2 * sizeof(float)), "r"(smem_addr)); @@ -1064,8 +1069,12 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, } constexpr MemLayout layout_a = - TRANSPOSE_AT_CONSUME ? MemLayout::K_major : MemLayout::MN_major; - thread_block_gemm_single_tile( From 863e92a85e61ff74820f6ce1926c061ea4c76c76 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sat, 7 Sep 2024 17:40:21 -0700 Subject: [PATCH 05/50] generate_matrix.py: Default to range, fp32 --- tests/kernel/tensor/generate_matrix.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/kernel/tensor/generate_matrix.py b/tests/kernel/tensor/generate_matrix.py index 796a6ea9..d54ece46 100644 --- a/tests/kernel/tensor/generate_matrix.py +++ b/tests/kernel/tensor/generate_matrix.py @@ -46,7 +46,7 @@ def pack_fp16_by_row(array): if __name__ == "__main__": M, N, K = parse_mnk() - rand = True + rand = False if not rand: A_array = np.arange(M * K).reshape([M, K]) B_array = np.arange(K * N).reshape([K, N]) @@ -77,7 +77,7 @@ if __name__ == "__main__": np.savez("abc", A_array=A_array, B_array=B_array, C_array=C_array) - fp16 = True + fp16 = False if fp16: A_packed = pack_fp16_by_row(A_array) AT_packed = A_packed.transpose([1, 0, 2]) From 4d6cdeb00b6164a6a8afe134af4180bfdd40b6d1 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sat, 7 Sep 2024 17:40:49 -0700 Subject: [PATCH 06/50] Fallback to 4 cores for flash --- hw/VX_config.h | 2 +- kernel/include/vx_spawn.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/hw/VX_config.h b/hw/VX_config.h index e7a6b559..63be0e93 100644 --- a/hw/VX_config.h +++ b/hw/VX_config.h @@ -84,7 +84,7 @@ #endif #ifndef NUM_CORES -#define NUM_CORES 8 +#define NUM_CORES 4 #endif #ifndef NUM_WARPS diff --git a/kernel/include/vx_spawn.h b/kernel/include/vx_spawn.h index 83052f30..db77e683 100644 --- a/kernel/include/vx_spawn.h +++ b/kernel/include/vx_spawn.h @@ -18,7 +18,7 @@ #include #ifndef CORES_PER_CLUSTER -#define CORES_PER_CLUSTER 8 +#define CORES_PER_CLUSTER 4 #endif #ifdef __cplusplus From e02892ab7de11f22bccb26c84127814122b515be Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sat, 7 Sep 2024 17:49:37 -0700 Subject: [PATCH 07/50] flash: Fix DMA for up to GEMM II yeah --- tests/regression/flash_attention/kernel.cpp | 194 ++++++++++++-------- 1 file changed, 122 insertions(+), 72 deletions(-) diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 64e28939..8da963a6 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -168,16 +168,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( const uint32_t warps_per_threadblock_per_core = warps_in_threadblock / CORES_PER_CLUSTER; - // float ft[8]; - // asm volatile("fmv.s %0, f16" : "=f"(ft[0])); - // asm volatile("fmv.s %0, f17" : "=f"(ft[1])); - // asm volatile("fmv.s %0, f18" : "=f"(ft[2])); - // asm volatile("fmv.s %0, f19" : "=f"(ft[3])); - // asm volatile("fmv.s %0, f20" : "=f"(ft[4])); - // asm volatile("fmv.s %0, f21" : "=f"(ft[5])); - // asm volatile("fmv.s %0, f22" : "=f"(ft[6])); - // asm volatile("fmv.s %0, f23" : "=f"(ft[7])); - float *smem_rowmax_this = smem_rowmax + B_ROW; #pragma GCC unroll 1 @@ -541,6 +531,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { uint8_t *smem_per_threadblock = reinterpret_cast( DEV_SMEM_START_ADDR); float *smem_cursor = reinterpret_cast(smem_per_threadblock); + // float *smem_cursor = reinterpret_cast(DEV_FAKE_SMEM_START_ADDR); float *smem_Q0 = smem_cursor; smem_cursor += smem_Q_size; float *smem_Q1 = smem_cursor; @@ -587,31 +578,33 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { constexpr uint32_t smem_rowmax_size = B_ROW * ROWMAX_SETS; constexpr uint32_t smem_rowsum_size = B_ROW; constexpr uint32_t smem_O_row_scale_size = B_ROW; - smem_cursor = reinterpret_cast(SMEM_ADDR_END); + // smem_cursor = reinterpret_cast(DEV_FAKE_SMEM_START_ADDR + SMEM_SIZE); + smem_cursor = reinterpret_cast(0xff038000); - smem_cursor -= smem_rowmax_size; float *smem_rowmax_0 = smem_cursor; - smem_cursor -= smem_rowmax_size; + smem_cursor += smem_rowmax_size; float *smem_rowmax_1 = smem_cursor; - smem_cursor -= smem_rowsum_size; + smem_cursor += smem_rowmax_size; float *smem_rowsum_0 = smem_cursor; - smem_cursor -= smem_rowsum_size; + smem_cursor += smem_rowsum_size; float *smem_rowsum_1 = smem_cursor; - smem_cursor -= smem_O_row_scale_size; + smem_cursor += smem_rowsum_size; float *smem_O_row_scale_0 = smem_cursor; - smem_cursor -= smem_O_row_scale_size; + smem_cursor += smem_O_row_scale_size; float *smem_O_row_scale_1 = smem_cursor; + smem_cursor += smem_O_row_scale_size; // sharedmem "scratchpad" area to put temporary data, e.g. for tree reduction // in rowsum // NOTE: out-of bounds is not checked // TODO: reduce this from B_ROW to NUM_WARPS constexpr uint32_t smem_scratchpad_size = - threads_per_warpgroup * 2 /*arbitrary slack*/; - smem_cursor -= smem_scratchpad_size; + B_ROW * NUM_THREADS * 2 /*arbitrary slack*/; + // threads_per_warpgroup * 2 /*arbitrary slack*/; float *smem_scratchpad_0 = smem_cursor; - smem_cursor -= smem_scratchpad_size; + smem_cursor += smem_scratchpad_size; float *smem_scratchpad_1 = smem_cursor; + smem_cursor += smem_scratchpad_size; // select the correct buffer by warpgroup float *smem_Q = (warpgroup_id % 2) ? smem_Q1 : smem_Q0; @@ -628,19 +621,24 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { (warpgroup_id % 2) ? smem_scratchpad_1 : smem_scratchpad_0; // initialize rowmax/rowsum values in sharedmem - // thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O, - // smem_rowmax, smem_rowsum, smem_O_row_scale); + thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O, + smem_rowmax, smem_rowsum, smem_O_row_scale); constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary // delay warpgroup 0 by 1 iteration to do ping-pong scheduling - // if (warpgroup_id == 1) { - // threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); - // } + if (warpgroup_id == 1) { + threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); + } static_assert(!GEMMINI_DMA || Q_IS_K_MAJOR, "DMA code assumes Q matrix is stored K-major"); + // skip everything except DMA in the loop FSM + constexpr uint32_t skips = + loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/0, /*skip_ldd=*/1, + /*skip_ex=*/1, /*skip_stc=*/1); + if constexpr (GEMMINI_DMA) { if (tid_in_warpgroup == 0) { gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0); @@ -680,8 +678,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q), (uint64_t)(gmem_K), k_LOOP_WS_CONFIG_ADDRS_AB) // configure address strides for the DMA - // GEMMINI_CISC_CMD_R((B_COL << 16) | (HEADDIM << 8) | - // 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); GEMMINI_CISC_CMD_R((dim_seqlen << 16) | (HEADDIM << 8) | 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); gemmini_fence(); @@ -691,11 +687,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { GEMMINI_CISC_CMD_I(9); gemmini_fence(); #else - // skip everything except DMA in the loop FSM - constexpr uint32_t skips = - loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/0, /*skip_ldd=*/1, - /*skip_ex=*/1, /*skip_stc=*/1); - + // do DMA + // // among other things, this also configures CONFIG_BOUNDS so that the // DMA knows the full matrix dimensions sp_tiled_matmul_full_spad_ws( @@ -707,6 +700,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); gemmini_fence(); #endif + + // re-configure DMA for K and V load that will later happen in the loop + // GMEM addr stride for K + gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t), MVIN_SCALE_IDENTITY, + false, 0); + // GMEM addr stride for V + gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY, + false, 1); + gemmini_fence(); } asm volatile("dma_move_end_%=:" ::); @@ -767,7 +769,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<0>(); initialize_accum_regs<1>(); - if constexpr (Q_IS_K_MAJOR) { + if constexpr (GEMMINI_DMA) { + thread_block_gemm_single_tile< + float, MemLayout::block_row_major, MemLayout::block_row_major, + B_ROW, B_COL, HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0, + /*load_accum=*/false, + /*write_to_smem=*/true>( + smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else if constexpr (Q_IS_K_MAJOR) { thread_block_gemm_single_tile< float, MemLayout::K_major, MemLayout::MN_major, B_ROW, B_COL, HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0, @@ -803,6 +814,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<1>(); // split by rows into 2 chunks + // TODO: GEMMINI_DMA if constexpr (Q_IS_K_MAJOR) { thread_block_gemm_single_tile< float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL, @@ -826,6 +838,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<0>(); initialize_accum_regs<1>(); + // TODO: GEMMINI_DMA if constexpr (Q_IS_K_MAJOR) { thread_block_gemm_single_tile< float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL, @@ -877,7 +890,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // inter-warpgroup barrier before online softmax threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); -#if 0 // Online softmax // thread_block_online_softmax(smem_S, smem_P, tid_in_warpgroup, @@ -885,22 +897,52 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { smem_scratchpad, smem_rowmax, smem_rowsum, smem_O_row_scale); + // FIXME: unnecessary? + threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + // data movement for K and V // // Q stays in SMEM for the entire loop - // - // load K for the next iteration - load_tile_to_smem( - dim_seqlen, tile_k + 1, 0 /* dim_k == headdim */, gmem_K, smem_K, - tid_in_warpgroup); + if constexpr (GEMMINI_DMA) { + if (tid_in_threadblock == 0) { + // configure GMEM addresses for K and V tiles + // load K for the next iteration + const float *gmem_K_tile = gmem_K + (B_COL * (tile_k + 1)); + // load V for the current iteration + const float *gmem_V_tile = gmem_V + (HEADDIM * B_COL * tile_k); + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile), + (uint64_t)(gmem_V_tile), + k_LOOP_WS_CONFIG_ADDRS_AB) + // configure address strides for the DMA + // FIXME: unnecessary? + GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 16) | (dim_seqlen /*KT*/ << 8) | + 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); + gemmini_fence(); - // load V for the current iteration - // V dimension is [seqlen, headdim], stored N(headdim)-major - load_tile_to_smem( - HEADDIM, 0 /* full N-dimension */, tile_k, gmem_V, smem_V, - tid_in_warpgroup); + // do DMA + sp_tiled_matmul_full_spad_ws( + spad_addr_K0, spad_addr_V0, + /*spad_D=*/0, /*spad_C=*/spad_addr_S0, + /*I=*/(HEADDIM / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); + gemmini_fence(); + } + } else { + // load K for the next iteration + load_tile_to_smem( + dim_seqlen, tile_k + 1, 0 /* dim_k == headdim */, gmem_K, smem_K, + tid_in_warpgroup); + + // load V for the current iteration + // V dimension is [seqlen, headdim], stored N(headdim)-major + load_tile_to_smem( + HEADDIM, 0 /* full N-dimension */, tile_k, gmem_V, smem_V, + tid_in_warpgroup); + } // protect write to SMEM threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); @@ -970,25 +1012,38 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<0>(); initialize_accum_regs<1>(); - thread_block_gemm_single_tile( - smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_warpgroup, - threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); - - // FIXME: wrong but fast - // thread_block_gemm_single_tile( - // smem_P, smem_V, smem_O /*load accum*/, smem_O, - // tid_in_warpgroup, threads_per_warpgroup, - // warpgroups_per_cluster, warpgroup_id_in_cluster); + if constexpr (GEMMINI_DMA) { + thread_block_gemm_single_tile( + smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_warpgroup, + threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile( + smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_warpgroup, + threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + // FIXME: wrong but fast + // thread_block_gemm_single_tile( + // smem_P, smem_V, smem_O /*load accum*/, smem_O, + // tid_in_warpgroup, threads_per_warpgroup, + // warpgroups_per_cluster, warpgroup_id_in_cluster); + } } else { // when warp-specialized, there's only enough warps to do 64x32 tile // size so we need to do 2 GEMM calls @@ -1006,6 +1061,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<1>(); // split by rows into 2 chunks + // TODO: GEMMINI_DMA thread_block_gemm_single_tile< float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0, @@ -1047,13 +1103,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { warps_per_warpgroup_per_core); } } - - tile_iter_end: - // synchronize progress of two warpgroups - // threadblock_barrier(threadblock_id_in_cluster, - // warps_per_threadblock_per_core); - // threadblock_barrier(3, // FIXME - // NUM_WARPS); +#if 0 #endif } From 33bc084c37a2d2e8b03ff4a85844bdf1f8936c34 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sat, 7 Sep 2024 19:50:04 -0700 Subject: [PATCH 08/50] flash: Fix DMA layout for GEMM II --- tests/regression/flash_attention/kernel.cpp | 20 +++++++++++--------- tests/regression/sgemm_tcore/sgemm_impl.hpp | 3 --- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 8da963a6..10d8f555 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -578,7 +578,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { constexpr uint32_t smem_rowmax_size = B_ROW * ROWMAX_SETS; constexpr uint32_t smem_rowsum_size = B_ROW; constexpr uint32_t smem_O_row_scale_size = B_ROW; - // smem_cursor = reinterpret_cast(DEV_FAKE_SMEM_START_ADDR + SMEM_SIZE); + // FIXME: dangerous smem_cursor = reinterpret_cast(0xff038000); float *smem_rowmax_0 = smem_cursor; @@ -599,8 +599,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // NOTE: out-of bounds is not checked // TODO: reduce this from B_ROW to NUM_WARPS constexpr uint32_t smem_scratchpad_size = - B_ROW * NUM_THREADS * 2 /*arbitrary slack*/; - // threads_per_warpgroup * 2 /*arbitrary slack*/; + threads_per_warpgroup * 2 /*arbitrary slack*/; float *smem_scratchpad_0 = smem_cursor; smem_cursor += smem_scratchpad_size; float *smem_scratchpad_1 = smem_cursor; @@ -1013,12 +1012,12 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<1>(); if constexpr (GEMMINI_DMA) { - thread_block_gemm_single_tile( + thread_block_gemm_single_tile< + float, MemLayout::K_major /* P matrix is row-major */, + MemLayout::block_row_major, B_ROW, HEADDIM, B_COL, + /*leading_dim_a=*/0, /*leading_dim_b=*/0, + /*load_accum=*/true, + /*write_to_smem=*/true>( smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, warpgroup_id_in_cluster); @@ -1045,6 +1044,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // warpgroups_per_cluster, warpgroup_id_in_cluster); } } else { + static_assert(!WARP_SPECIALIZED || !GEMMINI_DMA, + "warp specialization unimplemented for dma"); + // when warp-specialized, there's only enough warps to do 64x32 tile // size so we need to do 2 GEMM calls static_assert(B_ROW / 2 == 32, diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index e563f23c..1bb7b893 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -254,9 +254,6 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k, constexpr int packed_factor = (std::is_same_v ? 2 : 1); const int local_k_adjusted = local_k / packed_factor; - static_assert(!GEMMINI_DMA || (layout == MemLayout::block_row_major) || - GEMMINI_DMA_FLEXIBLE_LAYOUT, - "wrong memory layout selected for DMA"); static_assert((layout != MemLayout::K_major) || (FP_SIZE == 32), "fp16 is not really tested for K-major A layout"); From 8d32a03d09e630d62f57e08e82355f207bb2cb21 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sat, 7 Sep 2024 20:32:08 -0700 Subject: [PATCH 09/50] flash: Write DMA code for warp-specialized TODO: result unverified --- tests/regression/flash_attention/kernel.cpp | 133 +++++++++++++------- 1 file changed, 89 insertions(+), 44 deletions(-) diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 10d8f555..13d743ea 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -14,7 +14,7 @@ constexpr uint32_t ROWMAX_SETS = 3; constexpr bool DEBUG = true; -constexpr bool WARP_SPECIALIZED = false; +constexpr bool WARP_SPECIALIZED = true; constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000; @@ -492,11 +492,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { warpgroup_id % warpgroups_per_cluster; const uint32_t tid_in_warpgroup = tid_in_threadblock % threads_per_warpgroup; - // FIXME do proper software pipelining - // if (WARP_SPECIALIZED && warpgroup_id_in_cluster != 1) { - // return; - // } - const uint32_t dim_seqlen = arg->dim_seqlen; const uint32_t dim_headdim = arg->dim_headdim; @@ -597,7 +592,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // sharedmem "scratchpad" area to put temporary data, e.g. for tree reduction // in rowsum // NOTE: out-of bounds is not checked - // TODO: reduce this from B_ROW to NUM_WARPS constexpr uint32_t smem_scratchpad_size = threads_per_warpgroup * 2 /*arbitrary slack*/; float *smem_scratchpad_0 = smem_cursor; @@ -619,6 +613,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { float *smem_scratchpad = (warpgroup_id % 2) ? smem_scratchpad_1 : smem_scratchpad_0; + const auto spad_addr_Q = (warpgroup_id % 2) ? spad_addr_Q1 : spad_addr_Q0; + const auto spad_addr_K = (warpgroup_id % 2) ? spad_addr_K1 : spad_addr_K0; + const auto spad_addr_V = (warpgroup_id % 2) ? spad_addr_V1 : spad_addr_V0; + const auto spad_addr_S = (warpgroup_id % 2) ? spad_addr_S1 : spad_addr_S0; + // initialize rowmax/rowsum values in sharedmem thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O, smem_rowmax, smem_rowsum, smem_O_row_scale); @@ -626,7 +625,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary // delay warpgroup 0 by 1 iteration to do ping-pong scheduling - if (warpgroup_id == 1) { + if (WARP_SPECIALIZED && warpgroup_id == 1) { threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); } @@ -667,15 +666,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // static_assert(B_ROW == B_COL, "currently only supports square tiles"); - static_assert(warps_per_warpgroup_per_core == 8); // FIXME nocheckin - if constexpr (GEMMINI_DMA) { asm volatile("dma_move_start_%=:" ::); - if (tid_in_threadblock == 0) { + if (tid_in_warpgroup == 0) { + const float *gmem_Q_tile = gmem_Q + HEADDIM * B_ROW * warpgroup_id; + const float *gmem_K_tile = gmem_K; // configure the GMEM addresses for the DMA to read from - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q), - (uint64_t)(gmem_K), k_LOOP_WS_CONFIG_ADDRS_AB) + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q_tile), + (uint64_t)(gmem_K_tile), + k_LOOP_WS_CONFIG_ADDRS_AB) // configure address strides for the DMA GEMMINI_CISC_CMD_R((dim_seqlen << 16) | (HEADDIM << 8) | 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); @@ -691,8 +691,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // among other things, this also configures CONFIG_BOUNDS so that the // DMA knows the full matrix dimensions sp_tiled_matmul_full_spad_ws( - spad_addr_Q0, spad_addr_K0, - /*spad_D=*/0, /*spad_C=*/spad_addr_S0, + spad_addr_Q, spad_addr_K, + /*spad_D=*/0, /*spad_C=*/spad_addr_S, /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, @@ -803,8 +803,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { "tile size assumption for warp-specialization not met"); float *smem_Q_half0 = smem_Q; - float *smem_Q_half1 = Q_IS_K_MAJOR ? smem_Q + (B_ROW / 2) * HEADDIM - : smem_Q + (B_ROW / 2); + float *smem_Q_half1 = (Q_IS_K_MAJOR || GEMMINI_DMA) + ? smem_Q + (B_ROW / 2) * HEADDIM + : smem_Q + (B_ROW / 2); float *smem_S_half0 = smem_S; float *smem_S_half1 = smem_S + (B_ROW / 2) * B_COL; @@ -813,8 +814,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<1>(); // split by rows into 2 chunks - // TODO: GEMMINI_DMA - if constexpr (Q_IS_K_MAJOR) { + if constexpr (GEMMINI_DMA) { + thread_block_gemm_single_tile( + smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else if constexpr (Q_IS_K_MAJOR) { thread_block_gemm_single_tile< float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL, HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0, @@ -837,8 +847,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<0>(); initialize_accum_regs<1>(); - // TODO: GEMMINI_DMA - if constexpr (Q_IS_K_MAJOR) { + if constexpr (GEMMINI_DMA) { + thread_block_gemm_single_tile( + smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else if constexpr (Q_IS_K_MAJOR) { thread_block_gemm_single_tile< float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL, HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0, @@ -903,7 +922,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // // Q stays in SMEM for the entire loop if constexpr (GEMMINI_DMA) { - if (tid_in_threadblock == 0) { + // NOTE: Beware of race conditions; with warp specialization, we need to + // make sure below command code to DMA is not executed simultaneously + // from the two warpgroups (which will result in hardware fault). + // Currently the ping-pong scheduling scheme prevents that. + if (tid_in_warpgroup == 0) { // configure GMEM addresses for K and V tiles // load K for the next iteration const float *gmem_K_tile = gmem_K + (B_COL * (tile_k + 1)); @@ -920,8 +943,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // do DMA sp_tiled_matmul_full_spad_ws( - spad_addr_K0, spad_addr_V0, - /*spad_D=*/0, /*spad_C=*/spad_addr_S0, + spad_addr_K, spad_addr_V, + /*spad_D=*/0, /*spad_C=*/spad_addr_S, /*I=*/(HEADDIM / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, @@ -1044,9 +1067,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // warpgroups_per_cluster, warpgroup_id_in_cluster); } } else { - static_assert(!WARP_SPECIALIZED || !GEMMINI_DMA, - "warp specialization unimplemented for dma"); - // when warp-specialized, there's only enough warps to do 64x32 tile // size so we need to do 2 GEMM calls static_assert(B_ROW / 2 == 32, @@ -1063,27 +1083,52 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<1>(); // split by rows into 2 chunks - // TODO: GEMMINI_DMA - thread_block_gemm_single_tile< - float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, - B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0, - /*load_accum=*/true, - /*write_to_smem=*/true>( - smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); + if constexpr (GEMMINI_DMA) { + thread_block_gemm_single_tile< + float, MemLayout::K_major /* P matrix is row-major */, + MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL, + /*leading_dim_a=*/0, + /*leading_dim_b=*/0, + /*load_accum=*/true, + /*write_to_smem=*/true>( + smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile< + float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, + B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0, + /*load_accum=*/true, + /*write_to_smem=*/true>( + smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } initialize_accum_regs<0>(); initialize_accum_regs<1>(); - thread_block_gemm_single_tile< - float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, - B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0, - /*load_accum=*/true, - /*write_to_smem=*/true>( - smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); + if constexpr (GEMMINI_DMA) { + thread_block_gemm_single_tile< + float, MemLayout::K_major /* P matrix is row-major */, + MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL, + /*leading_dim_a=*/0, + /*leading_dim_b=*/0, + /*load_accum=*/true, + /*write_to_smem=*/true>( + smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile< + float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, + B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0, + /*load_accum=*/true, + /*write_to_smem=*/true>( + smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } } threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); From 03308f8033e2deb75ff686952d2741fbfde6e0cc Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sat, 7 Sep 2024 20:46:58 -0700 Subject: [PATCH 10/50] flash: Write fast config for DMA MAC utilization is 20-25% for the loop. --- tests/regression/flash_attention/kernel.cpp | 142 ++++++++++++++------ 1 file changed, 101 insertions(+), 41 deletions(-) diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 13d743ea..15538ba2 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -13,11 +13,12 @@ #define HEADDIM 64 constexpr uint32_t ROWMAX_SETS = 3; -constexpr bool DEBUG = true; +constexpr bool DEBUG = false; constexpr bool WARP_SPECIALIZED = true; constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000; +constexpr bool GEMMINI_DMA_FAST = true; constexpr bool Q_IS_K_MAJOR = true; // temporary safety stop for wrong configs @@ -763,6 +764,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // GEMM I: S = Q*K // // FIXME: deduplicate this between GEMM II + asm volatile("gemm_qk_start_%=:" ::); if constexpr (!WARP_SPECIALIZED) { // clear out accumulators before GEMM initialize_accum_regs<0>(); @@ -815,15 +817,27 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // split by rows into 2 chunks if constexpr (GEMMINI_DMA) { - thread_block_gemm_single_tile( - smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); + if constexpr (GEMMINI_DMA_FAST) { + thread_block_gemm_single_tile( + smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile( + smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } } else if constexpr (Q_IS_K_MAJOR) { thread_block_gemm_single_tile< float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL, @@ -848,15 +862,27 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<1>(); if constexpr (GEMMINI_DMA) { - thread_block_gemm_single_tile( - smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); + if constexpr (GEMMINI_DMA_FAST) { + thread_block_gemm_single_tile( + smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile( + smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } } else if constexpr (Q_IS_K_MAJOR) { thread_block_gemm_single_tile< float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL, @@ -888,6 +914,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // protect write to SMEM (smem_S) before softmax threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + asm volatile("gemm_qk_finish_%=:" ::); + if constexpr (DEBUG) { if (warpgroup_id == 0) { if (tile_k == 0) { @@ -921,6 +949,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // data movement for K and V // // Q stays in SMEM for the entire loop + asm volatile("move_k_v_start_%=:" ::); if constexpr (GEMMINI_DMA) { // NOTE: Beware of race conditions; with warp specialization, we need to // make sure below command code to DMA is not executed simultaneously @@ -965,6 +994,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { HEADDIM, 0 /* full N-dimension */, tile_k, gmem_V, smem_V, tid_in_warpgroup); } + asm volatile("move_k_v_finish_%=:" ::); // protect write to SMEM threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); @@ -995,8 +1025,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // inter-warpgroup barrier before GEMM II threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); - // GEMM II: O = O + P*V - // Oi rescale thread_block_O_rescale(smem_O, smem_O /*in-place*/, smem_O_row_scale, tid_in_warpgroup, @@ -1029,6 +1057,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } } + // GEMM II: O = O + P*V + + asm volatile("gemm_pv_start_%=:" ::); + if constexpr (!WARP_SPECIALIZED) { // clear out accumulators before GEMM initialize_accum_regs<0>(); @@ -1084,16 +1116,29 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // split by rows into 2 chunks if constexpr (GEMMINI_DMA) { - thread_block_gemm_single_tile< - float, MemLayout::K_major /* P matrix is row-major */, - MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL, - /*leading_dim_a=*/0, - /*leading_dim_b=*/0, - /*load_accum=*/true, - /*write_to_smem=*/true>( - smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); + if constexpr (GEMMINI_DMA_FAST) { + thread_block_gemm_single_tile( + smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile< + float, MemLayout::K_major /* P matrix is row-major */, + MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL, + /*leading_dim_a=*/0, + /*leading_dim_b=*/0, + /*load_accum=*/true, + /*write_to_smem=*/true>( + smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } } else { thread_block_gemm_single_tile< float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, @@ -1109,16 +1154,29 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<1>(); if constexpr (GEMMINI_DMA) { - thread_block_gemm_single_tile< - float, MemLayout::K_major /* P matrix is row-major */, - MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL, - /*leading_dim_a=*/0, - /*leading_dim_b=*/0, - /*load_accum=*/true, - /*write_to_smem=*/true>( - smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); + if constexpr (GEMMINI_DMA_FAST) { + thread_block_gemm_single_tile( + smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile< + float, MemLayout::K_major /* P matrix is row-major */, + MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL, + /*leading_dim_a=*/0, + /*leading_dim_b=*/0, + /*load_accum=*/true, + /*write_to_smem=*/true>( + smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } } else { thread_block_gemm_single_tile< float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, @@ -1133,6 +1191,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + asm volatile("gemm_pv_finish_%=:" ::); + if constexpr (DEBUG) { if (warpgroup_id == 0) { // O after PV From b3be271b8841c7ce7a3daa1af858d4e30206e5b8 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sat, 7 Sep 2024 21:16:35 -0700 Subject: [PATCH 11/50] flash: Split impl to header file --- .../regression/flash_attention/flash_impl.hpp | 446 ++++++++++++++++++ tests/regression/flash_attention/kernel.cpp | 441 +---------------- 2 files changed, 447 insertions(+), 440 deletions(-) create mode 100644 tests/regression/flash_attention/flash_impl.hpp diff --git a/tests/regression/flash_attention/flash_impl.hpp b/tests/regression/flash_attention/flash_impl.hpp new file mode 100644 index 00000000..0f53fd31 --- /dev/null +++ b/tests/regression/flash_attention/flash_impl.hpp @@ -0,0 +1,446 @@ +#ifndef _FLASH_IMPL_H_ +#define _FLASH_IMPL_H_ + +#include +#include + +#define B_ROW 64 +#define B_COL 64 +#define HEADDIM 64 + +constexpr uint32_t ROWMAX_SETS = 3; +constexpr bool DEBUG = true; +constexpr bool WARP_SPECIALIZED = true; + +constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000; + +constexpr bool GEMMINI_DMA_FAST = false; +constexpr bool Q_IS_K_MAJOR = true; + +// temporary safety stop for wrong configs +static_assert(NUM_CORES == 4); +static_assert(NUM_THREADS == 8); +static_assert(NUM_WARPS == 8); + +inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock, + const uint32_t threads_per_threadblock, + float *smem_O, float *smem_rowmax, + float *smem_rowsum, + float *smem_O_row_scale) { + asm volatile("threadblock_init_sharedmem_start_%=:" ::); + + const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; + const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; + const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS; + + static_assert((B_ROW % NUM_THREADS) == 0, + "B_ROW must be a multiple of NUM_THREADS"); + static_assert(B_ROW < (NUM_THREADS * CORES_PER_CLUSTER * + (NUM_WARPS / (WARP_SPECIALIZED ? 2 : 1))), + "not enough warps to initialize rowmax/rowsum"); + + // each thread initializes one element in rowmax/rowsum + // multiple warps participate for the whole vector + constexpr uint32_t needed_warps = B_ROW / NUM_THREADS; + if (warp_id < needed_warps /* more warps in HW than needed? */) { + uint32_t offset = NUM_THREADS * warp_id + tid_in_warp; +#pragma GCC unroll + for (int i = 0; i < ROWMAX_SETS; i++) { + smem_rowmax[offset + i * ROWMAX_SETS] = FLT_MIN; + } + smem_rowsum[offset] = 0.0f; + smem_O_row_scale[offset] = 0.0f; + } + + // each warp clears out a row of smem_O + // FIXME: dedup this pattern +#pragma GCC unroll 1 + for (int row_offset = 0; row_offset < B_COL; + row_offset += warps_in_threadblock) { + const uint32_t row = row_offset + warp_id; + uint32_t thread_offset = HEADDIM * row + tid_in_warp; + constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS; + const float one = 0.0f; +#pragma GCC unroll + for (int i = 0; i < per_row_iter; i++) { + smem_O[thread_offset] = 0.0f; + thread_offset += NUM_THREADS; + } + } + + asm volatile("threadblock_init_sharedmem_finish_%=:" ::); +} + +inline void thread_block_copy_rowmax(const float *src, float *dest, + const uint32_t tid_in_threadblock, + const uint32_t threads_per_threadblock, + const uint32_t threadblock_id_in_cluster) { + asm volatile("threadblock_copy_rowmax_start_%=:" ::); + + const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; + const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; + const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS; + const uint32_t warps_per_threadblock_per_core = + warps_in_threadblock / CORES_PER_CLUSTER; + + // each thread copies one element in rowmax + // multiple warps participate for the whole vector + constexpr uint32_t num_warps = B_ROW / NUM_THREADS; + if (warp_id < num_warps) { + uint32_t offset = NUM_THREADS * warp_id + tid_in_warp; + dest[offset] = src[offset]; + } + + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + + asm volatile("threadblock_copy_rowmax_finish_%=:" ::); +} + +template +inline void thread_block_copy_tile(const float *src, float *dest, + const uint32_t tid_in_threadblock, + const uint32_t threads_per_threadblock, + const uint32_t threadblock_id_in_cluster) { + asm volatile("threadblock_copy_tile_start_%=:" ::); + + const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; + const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; + const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS; + const uint32_t warps_per_threadblock_per_core = + warps_in_threadblock / CORES_PER_CLUSTER; + + // FIXME: dedup this pattern +#pragma GCC unroll 1 + for (int row_offset = 0; row_offset < dim_row; + row_offset += warps_in_threadblock) { + const uint32_t row = row_offset + warp_id; + const uint32_t first_thread_offset = dim_col * row; + + constexpr uint32_t per_row_iter = dim_col / NUM_THREADS; + uint32_t thread_offset = first_thread_offset + tid_in_warp; +#pragma GCC unroll + for (int i = 0; i < per_row_iter; i++) { + dest[thread_offset] = src[thread_offset]; + thread_offset += NUM_THREADS; + } + + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } + + asm volatile("threadblock_copy_tile_finish_%=:" ::); +} + +template +inline float exponential_taylor_term(const float x) { + asm volatile("exponential_taylor_term_start_%=:" ::); + + float res = 1.0f; + + if constexpr (order == 1) { + res = x; + } else if constexpr (order == 2) { + res = x * x; + res /= 2.0f; + } else if constexpr (order == 3) { + res = x * x * x; + res /= 6.0f; + } + + asm volatile("exponential_taylor_term_end_%=:" ::); + return res; +} + +__attribute__((always_inline)) inline void thread_block_online_softmax( + const float *smem_S, float *smem_P, const uint32_t tid_in_threadblock, + const uint32_t threads_per_threadblock, + const uint32_t threadblock_id_in_cluster, float *smem_scratchpad, + float *smem_rowmax, float *smem_rowsum, float *smem_O_row_scale) { + asm volatile("thread_block_online_softmax_start_%=:" ::); + + const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; + const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; + const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS; + const uint32_t warps_per_threadblock_per_core = + warps_in_threadblock / CORES_PER_CLUSTER; + + float *smem_rowmax_this = smem_rowmax + B_ROW; + +#pragma GCC unroll 1 + for (int row_offset = 0; row_offset < B_ROW; + row_offset += warps_in_threadblock) { + const uint32_t row = row_offset + warp_id; + const uint32_t first_thread_offset = B_COL * row; + + // rowmax + // + // two-level tree reduction: reduce each row into NUM_THREADS intermediate + // maxes, then reduce it down to one row max + // one warp handles one row in tile + + constexpr uint32_t per_row_iter = B_COL / NUM_THREADS; + uint32_t thread_offset = first_thread_offset + tid_in_warp; + // FIXME: threadblock_id needs to be in here too + float *warp_smem = smem_scratchpad + (warp_id * NUM_THREADS); + +// #define DUMB_ROWMAX +#ifdef DUMB_ROWMAX + // FIXME remove + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + + // no tree reduction; a single thread in a warp does serialized max across + // the entire row + if (tid_in_warp == 0) { + float rowmax = smem_S[first_thread_offset]; +#pragma GCC unroll 16 + for (int i = 0; i < B_COL; i++) { + asm volatile("fmax.s %0, %1, %2" + : "=f"(rowmax) + : "f"(rowmax), "f"(smem_S[first_thread_offset + i])); + } + smem_rowmax_this[row] = rowmax; + + // update previous rowmax + // i.e. mi_new = max(mi, mij) + float prev_rowmax = smem_rowmax[row]; + // stage prev rowmax in scratchpad for warp-wide broadcast + warp_smem[0] = prev_rowmax; + asm volatile("fmax.s %0, %1, %2" + : "=f"(rowmax) + : "f"(rowmax), "f"(prev_rowmax)); + smem_rowmax[row] = rowmax; + } + +#else + static_assert((B_COL % NUM_THREADS) == 0, + "B_COL must be a multiple of NUM_THREADS"); + float per_thread_max = FLT_MIN; +#pragma GCC unroll + for (int i = 0; i < per_row_iter; i++) { + const float next = smem_S[thread_offset]; + asm volatile("fmax.s %0, %1, %2" + : "=f"(per_thread_max) + : "f"(per_thread_max), "f"(next)); + thread_offset += NUM_THREADS; + } + // stage per-thread max value in smem + warp_smem[tid_in_warp] = per_thread_max; + + // sync writes to warp_smem + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + +// #define PARALLEL_ROWMAX +#ifndef PARALLEL_ROWMAX + // elect 0-th thread to reduce all other thread's values in the warp + if (tid_in_warp == 0) { + float rowmax = per_thread_max; + for (int i = 1; i < NUM_THREADS; i++) { + float other = warp_smem[i]; + asm volatile("fmax.s %0, %1, %2" + : "=f"(rowmax) + : "f"(rowmax), "f"(other)); + } + smem_rowmax_this[row] = rowmax; + + // update previous rowmax + // i.e. mi_new = max(mi, mij) + float prev_rowmax = smem_rowmax[row]; + // stage prev rowmax in scratchpad for warp-wide broadcast + warp_smem[0] = prev_rowmax; + asm volatile("fmax.s %0, %1, %2" + : "=f"(rowmax) + : "f"(rowmax), "f"(prev_rowmax)); + smem_rowmax[row] = rowmax; + } +#else + if (warp_id < warps_in_threadblock / NUM_THREADS) { + const uint32_t row = row_offset + NUM_THREADS * warp_id + tid_in_warp; + float *const thread_smem = smem_scratchpad + (tid_in_warp * NUM_THREADS); + float rowmax = FLT_MIN; +#pragma GCC unroll + for (int i = 0; i < NUM_THREADS; i++) { + const float f = thread_smem[i]; + asm volatile("fmax.s %0, %1, %2" : "=f"(rowmax) : "f"(rowmax), "f"(f)); + } + smem_rowmax_this[row] = rowmax; + + // update previous rowmax + // i.e. mi_new = max(mi, mij) + float prev_rowmax = smem_rowmax[row]; + // stage prev rowmax in scratchpad for warp-wide broadcast + thread_smem[0] = prev_rowmax; + asm volatile("fmax.s %0, %1, %2" + : "=f"(rowmax) + : "f"(rowmax), "f"(prev_rowmax)); + smem_rowmax[row] = rowmax; + } +#endif // PARALLEL_ROWMAX +#endif // DUMB_ROWMAX + + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + + // broadcast prev rowmax to all threads in the warp + // NOTE: memory consistency is a little sketchy here + const float rowmax_prev = warp_smem[0]; + const float rowmax_this = smem_rowmax_this[row]; + + // exponential + // + // B_ROW / (B_ROW * B_COL / (exp_elem * threads_per_threadblock)) + // const uint32_t row_stride = + // (exp_elem_per_thread * threads_per_threadblock) / B_COL; + + // broadcast updated rowmax to all threads in the warp + const float rowmax_new = smem_rowmax[row]; + + asm volatile("flashattn_exp_p_start_%=:" ::); + + thread_offset = first_thread_offset + tid_in_warp; +#pragma GCC unroll + for (int i = 0; i < per_row_iter; i++) { + float f0 = smem_S[thread_offset]; + + f0 -= rowmax_new; + + // 2nd-order Taylor approximation + float exp = 1.0f; + exp += exponential_taylor_term<1>(f0); + exp += exponential_taylor_term<2>(f0); + + // Store S transposed to the shared memory + + smem_P[thread_offset] = exp; + + thread_offset += NUM_THREADS; + } + + asm volatile("flashattn_exp_p_end_%=:" ::); + + + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + + // rowsum + // + // two-level tree reduction, similar to rowmax + + asm volatile("flashattn_rowsum_start_%=:" ::); + + float per_thread_sum = 0.0f; + + thread_offset = first_thread_offset + tid_in_warp; +#pragma GCC unroll + for (int i = 0; i < per_row_iter; i++) { + per_thread_sum += smem_P[thread_offset]; + thread_offset += NUM_THREADS; + } + // stage per-thread sum value in smem + // FIXME: threadblock_id needs to be in here too + warp_smem = smem_scratchpad + (warp_id * NUM_THREADS); + warp_smem[tid_in_warp] = per_thread_sum; + + // sync writes to warp_smem + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + + // 0-th thread collects all other thread's values in the warp + if (tid_in_warp == 0) { + float rowsum = per_thread_sum; + for (int iter = 1; iter < NUM_THREADS; iter++) { + float other = warp_smem[iter]; + rowsum += other; + } + + const float mi_prev = rowmax_prev; + const float mi_this = rowmax_this; + + const float x = mi_prev - mi_this; + // 2nd-order Taylor approximation + float exp = 1.0f; + exp += exponential_taylor_term<1>(x); + exp += exponential_taylor_term<2>(x); + + // update rowsum + const float rowsum_prev = smem_rowsum[row]; + float rowsum_new = exp * rowsum_prev + rowsum; + + smem_rowsum[row] = rowsum_new; + } + + asm volatile("flashattn_rowsum_end_%=:" ::); + + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + + // compute Oi rescale factor + // FIXME: parallelize this across threads + // + asm volatile("flashattn_rescale_factor_start_%=:" ::); + + thread_offset = first_thread_offset + tid_in_warp; +#pragma GCC unroll + for (int i = 0; i < per_row_iter; i++) { + const float mi_prev = rowmax_prev; + const float mi_new = rowmax_new; + + const float x = mi_prev - mi_new; + // 2nd-order Taylor approximation + float exp = 1.0f; + exp += exponential_taylor_term<1>(x); + exp += exponential_taylor_term<2>(x); + + // @perf: div vs. expansion on e(-x)? + smem_O_row_scale[row] = 1.0f / exp; + + thread_offset += NUM_THREADS; + } + + asm volatile("flashattn_rescale_factor_end_%=:" ::); + + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } + + asm volatile("thread_block_online_softmax_finish_%=:" ::); +} + +__attribute__((always_inline)) inline void thread_block_O_rescale( + const float *smem_O_in, float *smem_O_out, const float *smem_O_row_scale, + const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock, + const uint32_t threadblock_id_in_cluster) { + asm volatile("thread_block_O_rescale_start_%=:" ::); + + const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; + const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; + const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS; + const uint32_t warps_per_threadblock_per_core = + warps_in_threadblock / CORES_PER_CLUSTER; + +#pragma GCC unroll 1 + for (int row_offset = 0; row_offset < B_ROW; + row_offset += warps_in_threadblock) { + const uint32_t row = row_offset + warp_id; + const uint32_t first_thread_offset = B_COL * row; + constexpr uint32_t per_row_iter = B_COL / NUM_THREADS; + uint32_t thread_offset = first_thread_offset + tid_in_warp; + + // Oi rescale + // +#pragma GCC unroll + for (int i = 0; i < per_row_iter; i++) { + const float o = smem_O_in[thread_offset]; + const float scale = smem_O_row_scale[row]; + smem_O_out[thread_offset] = (o * scale); + + thread_offset += NUM_THREADS; + } + } + + asm volatile("thread_block_O_rescale_finish_%=:" ::); +} + +#endif diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 15538ba2..ba0537e3 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -2,450 +2,11 @@ #include #include #include -#include #include "common.h" #include "sgemm_impl.hpp" #include "include/gemmini.h" #include "gemmini_mmio.h" - -#define B_ROW 64 -#define B_COL 64 -#define HEADDIM 64 - -constexpr uint32_t ROWMAX_SETS = 3; -constexpr bool DEBUG = false; -constexpr bool WARP_SPECIALIZED = true; - -constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000; - -constexpr bool GEMMINI_DMA_FAST = true; -constexpr bool Q_IS_K_MAJOR = true; - -// temporary safety stop for wrong configs -static_assert(NUM_CORES == 4); -static_assert(NUM_THREADS == 8); -static_assert(NUM_WARPS == 8); - -inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock, - const uint32_t threads_per_threadblock, - float *smem_O, float *smem_rowmax, - float *smem_rowsum, - float *smem_O_row_scale) { - asm volatile("threadblock_init_sharedmem_start_%=:" ::); - - const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; - const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; - const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS; - - static_assert((B_ROW % NUM_THREADS) == 0, - "B_ROW must be a multiple of NUM_THREADS"); - static_assert(B_ROW < (NUM_THREADS * CORES_PER_CLUSTER * - (NUM_WARPS / (WARP_SPECIALIZED ? 2 : 1))), - "not enough warps to initialize rowmax/rowsum"); - - // each thread initializes one element in rowmax/rowsum - // multiple warps participate for the whole vector - constexpr uint32_t needed_warps = B_ROW / NUM_THREADS; - if (warp_id < needed_warps /* more warps in HW than needed? */) { - uint32_t offset = NUM_THREADS * warp_id + tid_in_warp; -#pragma GCC unroll - for (int i = 0; i < ROWMAX_SETS; i++) { - smem_rowmax[offset + i * ROWMAX_SETS] = FLT_MIN; - } - smem_rowsum[offset] = 0.0f; - smem_O_row_scale[offset] = 0.0f; - } - - // each warp clears out a row of smem_O - // FIXME: dedup this pattern -#pragma GCC unroll 1 - for (int row_offset = 0; row_offset < B_COL; - row_offset += warps_in_threadblock) { - const uint32_t row = row_offset + warp_id; - uint32_t thread_offset = HEADDIM * row + tid_in_warp; - constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS; - const float one = 0.0f; -#pragma GCC unroll - for (int i = 0; i < per_row_iter; i++) { - smem_O[thread_offset] = 0.0f; - thread_offset += NUM_THREADS; - } - } - - asm volatile("threadblock_init_sharedmem_finish_%=:" ::); -} - -inline void thread_block_copy_rowmax(const float *src, float *dest, - const uint32_t tid_in_threadblock, - const uint32_t threads_per_threadblock, - const uint32_t threadblock_id_in_cluster) { - asm volatile("threadblock_copy_rowmax_start_%=:" ::); - - const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; - const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; - const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS; - const uint32_t warps_per_threadblock_per_core = - warps_in_threadblock / CORES_PER_CLUSTER; - - // each thread copies one element in rowmax - // multiple warps participate for the whole vector - constexpr uint32_t num_warps = B_ROW / NUM_THREADS; - if (warp_id < num_warps) { - uint32_t offset = NUM_THREADS * warp_id + tid_in_warp; - dest[offset] = src[offset]; - } - - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); - - asm volatile("threadblock_copy_rowmax_finish_%=:" ::); -} - -template -inline void thread_block_copy_tile(const float *src, float *dest, - const uint32_t tid_in_threadblock, - const uint32_t threads_per_threadblock, - const uint32_t threadblock_id_in_cluster) { - asm volatile("threadblock_copy_tile_start_%=:" ::); - - const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; - const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; - const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS; - const uint32_t warps_per_threadblock_per_core = - warps_in_threadblock / CORES_PER_CLUSTER; - - // FIXME: dedup this pattern -#pragma GCC unroll 1 - for (int row_offset = 0; row_offset < dim_row; - row_offset += warps_in_threadblock) { - const uint32_t row = row_offset + warp_id; - const uint32_t first_thread_offset = dim_col * row; - - constexpr uint32_t per_row_iter = dim_col / NUM_THREADS; - uint32_t thread_offset = first_thread_offset + tid_in_warp; -#pragma GCC unroll - for (int i = 0; i < per_row_iter; i++) { - dest[thread_offset] = src[thread_offset]; - thread_offset += NUM_THREADS; - } - - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); - } - - asm volatile("threadblock_copy_tile_finish_%=:" ::); -} - -template -inline float exponential_taylor_term(const float x) { - asm volatile("exponential_taylor_term_start_%=:" ::); - - float res = 1.0f; - - if constexpr (order == 1) { - res = x; - } else if constexpr (order == 2) { - res = x * x; - res /= 2.0f; - } else if constexpr (order == 3) { - res = x * x * x; - res /= 6.0f; - } - - asm volatile("exponential_taylor_term_end_%=:" ::); - return res; -} - -__attribute__((always_inline)) inline void thread_block_online_softmax( - const float *smem_S, float *smem_P, const uint32_t tid_in_threadblock, - const uint32_t threads_per_threadblock, - const uint32_t threadblock_id_in_cluster, float *smem_scratchpad, - float *smem_rowmax, float *smem_rowsum, float *smem_O_row_scale) { - asm volatile("thread_block_online_softmax_start_%=:" ::); - - const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; - const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; - const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS; - const uint32_t warps_per_threadblock_per_core = - warps_in_threadblock / CORES_PER_CLUSTER; - - float *smem_rowmax_this = smem_rowmax + B_ROW; - -#pragma GCC unroll 1 - for (int row_offset = 0; row_offset < B_ROW; - row_offset += warps_in_threadblock) { - const uint32_t row = row_offset + warp_id; - const uint32_t first_thread_offset = B_COL * row; - - // rowmax - // - // two-level tree reduction: reduce each row into NUM_THREADS intermediate - // maxes, then reduce it down to one row max - // one warp handles one row in tile - - constexpr uint32_t per_row_iter = B_COL / NUM_THREADS; - uint32_t thread_offset = first_thread_offset + tid_in_warp; - // FIXME: threadblock_id needs to be in here too - float *warp_smem = smem_scratchpad + (warp_id * NUM_THREADS); - -// #define DUMB_ROWMAX -#ifdef DUMB_ROWMAX - // FIXME remove - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); - - // no tree reduction; a single thread in a warp does serialized max across - // the entire row - if (tid_in_warp == 0) { - float rowmax = smem_S[first_thread_offset]; -#pragma GCC unroll 16 - for (int i = 0; i < B_COL; i++) { - asm volatile("fmax.s %0, %1, %2" - : "=f"(rowmax) - : "f"(rowmax), "f"(smem_S[first_thread_offset + i])); - } - smem_rowmax_this[row] = rowmax; - - // update previous rowmax - // i.e. mi_new = max(mi, mij) - float prev_rowmax = smem_rowmax[row]; - // stage prev rowmax in scratchpad for warp-wide broadcast - warp_smem[0] = prev_rowmax; - asm volatile("fmax.s %0, %1, %2" - : "=f"(rowmax) - : "f"(rowmax), "f"(prev_rowmax)); - smem_rowmax[row] = rowmax; - } - -#else - static_assert((B_COL % NUM_THREADS) == 0, - "B_COL must be a multiple of NUM_THREADS"); - float per_thread_max = FLT_MIN; -#pragma GCC unroll - for (int i = 0; i < per_row_iter; i++) { - const float next = smem_S[thread_offset]; - asm volatile("fmax.s %0, %1, %2" - : "=f"(per_thread_max) - : "f"(per_thread_max), "f"(next)); - thread_offset += NUM_THREADS; - } - // stage per-thread max value in smem - warp_smem[tid_in_warp] = per_thread_max; - - // sync writes to warp_smem - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); - -// #define PARALLEL_ROWMAX -#ifndef PARALLEL_ROWMAX - // elect 0-th thread to reduce all other thread's values in the warp - if (tid_in_warp == 0) { - float rowmax = per_thread_max; - for (int i = 1; i < NUM_THREADS; i++) { - float other = warp_smem[i]; - asm volatile("fmax.s %0, %1, %2" - : "=f"(rowmax) - : "f"(rowmax), "f"(other)); - } - smem_rowmax_this[row] = rowmax; - - // update previous rowmax - // i.e. mi_new = max(mi, mij) - float prev_rowmax = smem_rowmax[row]; - // stage prev rowmax in scratchpad for warp-wide broadcast - warp_smem[0] = prev_rowmax; - asm volatile("fmax.s %0, %1, %2" - : "=f"(rowmax) - : "f"(rowmax), "f"(prev_rowmax)); - smem_rowmax[row] = rowmax; - } -#else - if (warp_id < warps_in_threadblock / NUM_THREADS) { - const uint32_t row = row_offset + NUM_THREADS * warp_id + tid_in_warp; - float *const thread_smem = smem_scratchpad + (tid_in_warp * NUM_THREADS); - float rowmax = FLT_MIN; -#pragma GCC unroll - for (int i = 0; i < NUM_THREADS; i++) { - const float f = thread_smem[i]; - asm volatile("fmax.s %0, %1, %2" : "=f"(rowmax) : "f"(rowmax), "f"(f)); - } - smem_rowmax_this[row] = rowmax; - - // update previous rowmax - // i.e. mi_new = max(mi, mij) - float prev_rowmax = smem_rowmax[row]; - // stage prev rowmax in scratchpad for warp-wide broadcast - thread_smem[0] = prev_rowmax; - asm volatile("fmax.s %0, %1, %2" - : "=f"(rowmax) - : "f"(rowmax), "f"(prev_rowmax)); - smem_rowmax[row] = rowmax; - } -#endif // PARALLEL_ROWMAX -#endif // DUMB_ROWMAX - - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); - - // broadcast prev rowmax to all threads in the warp - // NOTE: memory consistency is a little sketchy here - const float rowmax_prev = warp_smem[0]; - const float rowmax_this = smem_rowmax_this[row]; - - // exponential - // - // B_ROW / (B_ROW * B_COL / (exp_elem * threads_per_threadblock)) - // const uint32_t row_stride = - // (exp_elem_per_thread * threads_per_threadblock) / B_COL; - - // broadcast updated rowmax to all threads in the warp - const float rowmax_new = smem_rowmax[row]; - - asm volatile("flashattn_exp_p_start_%=:" ::); - - thread_offset = first_thread_offset + tid_in_warp; -#pragma GCC unroll - for (int i = 0; i < per_row_iter; i++) { - float f0 = smem_S[thread_offset]; - - f0 -= rowmax_new; - - // 2nd-order Taylor approximation - float exp = 1.0f; - exp += exponential_taylor_term<1>(f0); - exp += exponential_taylor_term<2>(f0); - - // Store S transposed to the shared memory - - smem_P[thread_offset] = exp; - - thread_offset += NUM_THREADS; - } - - asm volatile("flashattn_exp_p_end_%=:" ::); - - - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); - - // rowsum - // - // two-level tree reduction, similar to rowmax - - asm volatile("flashattn_rowsum_start_%=:" ::); - - float per_thread_sum = 0.0f; - - thread_offset = first_thread_offset + tid_in_warp; -#pragma GCC unroll - for (int i = 0; i < per_row_iter; i++) { - per_thread_sum += smem_P[thread_offset]; - thread_offset += NUM_THREADS; - } - // stage per-thread sum value in smem - // FIXME: threadblock_id needs to be in here too - warp_smem = smem_scratchpad + (warp_id * NUM_THREADS); - warp_smem[tid_in_warp] = per_thread_sum; - - // sync writes to warp_smem - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); - - // 0-th thread collects all other thread's values in the warp - if (tid_in_warp == 0) { - float rowsum = per_thread_sum; - for (int iter = 1; iter < NUM_THREADS; iter++) { - float other = warp_smem[iter]; - rowsum += other; - } - - const float mi_prev = rowmax_prev; - const float mi_this = rowmax_this; - - const float x = mi_prev - mi_this; - // 2nd-order Taylor approximation - float exp = 1.0f; - exp += exponential_taylor_term<1>(x); - exp += exponential_taylor_term<2>(x); - - // update rowsum - const float rowsum_prev = smem_rowsum[row]; - float rowsum_new = exp * rowsum_prev + rowsum; - - smem_rowsum[row] = rowsum_new; - } - - asm volatile("flashattn_rowsum_end_%=:" ::); - - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); - - // compute Oi rescale factor - // FIXME: parallelize this across threads - // - asm volatile("flashattn_rescale_factor_start_%=:" ::); - - thread_offset = first_thread_offset + tid_in_warp; -#pragma GCC unroll - for (int i = 0; i < per_row_iter; i++) { - const float mi_prev = rowmax_prev; - const float mi_new = rowmax_new; - - const float x = mi_prev - mi_new; - // 2nd-order Taylor approximation - float exp = 1.0f; - exp += exponential_taylor_term<1>(x); - exp += exponential_taylor_term<2>(x); - - // @perf: div vs. expansion on e(-x)? - smem_O_row_scale[row] = 1.0f / exp; - - thread_offset += NUM_THREADS; - } - - asm volatile("flashattn_rescale_factor_end_%=:" ::); - - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); - } - - asm volatile("thread_block_online_softmax_finish_%=:" ::); -} - -__attribute__((always_inline)) inline void thread_block_O_rescale( - const float *smem_O_in, float *smem_O_out, const float *smem_O_row_scale, - const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock, - const uint32_t threadblock_id_in_cluster) { - asm volatile("thread_block_O_rescale_start_%=:" ::); - - const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; - const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; - const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS; - const uint32_t warps_per_threadblock_per_core = - warps_in_threadblock / CORES_PER_CLUSTER; - -#pragma GCC unroll 1 - for (int row_offset = 0; row_offset < B_ROW; - row_offset += warps_in_threadblock) { - const uint32_t row = row_offset + warp_id; - const uint32_t first_thread_offset = B_COL * row; - constexpr uint32_t per_row_iter = B_COL / NUM_THREADS; - uint32_t thread_offset = first_thread_offset + tid_in_warp; - - // Oi rescale - // -#pragma GCC unroll - for (int i = 0; i < per_row_iter; i++) { - const float o = smem_O_in[thread_offset]; - const float scale = smem_O_row_scale[row]; - smem_O_out[thread_offset] = (o * scale); - - thread_offset += NUM_THREADS; - } - } - - asm volatile("thread_block_O_rescale_finish_%=:" ::); -} +#include "flash_impl.hpp" void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // @perf: All threads are running these compute whose result is mostly same From 2e1485877db7d252c804f8ae6b146f9501d2ac12 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sat, 7 Sep 2024 22:40:50 -0700 Subject: [PATCH 12/50] flash: Add Gemmini-accelerated kernel --- tests/regression/flash_attention/Makefile | 2 +- .../regression/flash_attention/flash_impl.hpp | 2 +- tests/regression/flash_attention/kernel.cpp | 4 +- .../flash_attention/kernel.gemmini.cpp | 684 ++++++++++++++++++ 4 files changed, 689 insertions(+), 3 deletions(-) create mode 100644 tests/regression/flash_attention/kernel.gemmini.cpp diff --git a/tests/regression/flash_attention/Makefile b/tests/regression/flash_attention/Makefile index 4f49f927..4d4fcad1 100644 --- a/tests/regression/flash_attention/Makefile +++ b/tests/regression/flash_attention/Makefile @@ -2,7 +2,7 @@ PROJECT = flash_attention SRCS = main.cpp common.h -VX_SRCS = kernel.cpp +VX_SRCS = kernel.gemmini.cpp VX_INCLUDES = ../sgemm_tcore/sgemm_impl.hpp OPTS ?= -n16 diff --git a/tests/regression/flash_attention/flash_impl.hpp b/tests/regression/flash_attention/flash_impl.hpp index 0f53fd31..423ebd69 100644 --- a/tests/regression/flash_attention/flash_impl.hpp +++ b/tests/regression/flash_attention/flash_impl.hpp @@ -10,7 +10,7 @@ constexpr uint32_t ROWMAX_SETS = 3; constexpr bool DEBUG = true; -constexpr bool WARP_SPECIALIZED = true; +constexpr bool WARP_SPECIALIZED = false; constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000; diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index ba0537e3..9eee2b60 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -320,12 +320,13 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // float *smem_O_row_scale_consume = // (tile_k % 2) ? smem_O_row_scale_1 : smem_O_row_scale_0; + asm volatile("gemm_qk_start_%=:" ::); + constexpr bool skip_gemm_qk = false; if constexpr (!skip_gemm_qk) { // GEMM I: S = Q*K // // FIXME: deduplicate this between GEMM II - asm volatile("gemm_qk_start_%=:" ::); if constexpr (!WARP_SPECIALIZED) { // clear out accumulators before GEMM initialize_accum_regs<0>(); @@ -587,6 +588,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); // Oi rescale + // TODO: move this back to after softmax for better load-balancing thread_block_O_rescale(smem_O, smem_O /*in-place*/, smem_O_row_scale, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp new file mode 100644 index 00000000..0df0cf87 --- /dev/null +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -0,0 +1,684 @@ +#include +#include +#include +#include +#include "common.h" +#include "sgemm_impl.hpp" +#include "include/gemmini.h" +#include "gemmini_mmio.h" +#include "flash_impl.hpp" + +static_assert(GEMMINI_DMA && !WARP_SPECIALIZED, + "GEMMINI_DMA should be set and WARP_SPECIALIZED unset"); + +void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { + // @perf: All threads are running these compute whose result is mostly same + // across the threadblock + +#ifdef RADIANCE + constexpr uint32_t cores_per_cluster = CORES_PER_CLUSTER; +#else + constexpr uint32_t cores_per_cluster = 1; +#endif + + // FIXME: headdim not considered + constexpr uint32_t threads_per_threadblock_theoretical = + (B_ROW * B_COL) / (ELEM_PER_THREAD); + constexpr uint32_t hw_threads_per_cluster = + CORES_PER_CLUSTER * NUM_THREADS * NUM_WARPS; + // cap maximum threadblock size to # of HW threads in cluster, to prevent + // multiple "wave" invocations which slows down the kernel + constexpr uint32_t threads_per_threadblock = + (threads_per_threadblock_theoretical > hw_threads_per_cluster) + ? hw_threads_per_cluster + : threads_per_threadblock_theoretical; + constexpr uint32_t threadblocks_per_cluster = + hw_threads_per_cluster / threads_per_threadblock; + constexpr uint32_t warps_per_threadblock_per_core = + NUM_WARPS / threadblocks_per_cluster; + + const uint32_t threadblock_id = task_id / threads_per_threadblock; + const uint32_t threadblock_id_in_cluster = + threadblock_id % threadblocks_per_cluster; + const uint32_t tid_in_threadblock = task_id % threads_per_threadblock; + const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; + constexpr uint32_t warps_in_threadblock = + threads_per_threadblock / NUM_THREADS; + + // warpgroup context + constexpr uint32_t threads_per_warpgroup = + threads_per_threadblock / (WARP_SPECIALIZED ? 2 : 1); + constexpr uint32_t warpgroups_per_cluster = + threadblocks_per_cluster * (WARP_SPECIALIZED ? 2 : 1); + const uint32_t warps_per_warpgroup_per_core = + NUM_WARPS / warpgroups_per_cluster; + const uint32_t warpgroup_id = task_id / threads_per_warpgroup; + const uint32_t warpgroup_id_in_cluster = + warpgroup_id % warpgroups_per_cluster; + const uint32_t tid_in_warpgroup = tid_in_threadblock % threads_per_warpgroup; + + const uint32_t dim_seqlen = arg->dim_seqlen; + const uint32_t dim_headdim = arg->dim_headdim; + + // get global memory addresses from kernel arguments + const float *gmem_Q = reinterpret_cast(arg->addr_q); + const float *gmem_K = reinterpret_cast(arg->addr_k); + const float *gmem_V = reinterpret_cast(arg->addr_v); + float *gmem_O = reinterpret_cast(arg->addr_o); + + float *gmem_tmp_d0 = reinterpret_cast(0xd0000000UL); + float *gmem_tmp_d1 = reinterpret_cast(0xd1000000UL); + float *gmem_tmp_d2 = reinterpret_cast(0xd2000000UL); + float *gmem_tmp_d3 = reinterpret_cast(0xd3000000UL); + float *gmem_tmp_d4 = reinterpret_cast(0xd4000000UL); + float *gmem_tmp_d5 = reinterpret_cast(0xd5000000UL); + float *gmem_tmp_d6 = reinterpret_cast(0xd6000000UL); + float *gmem_tmp_d7 = reinterpret_cast(0xd7000000UL); + float *gmem_tmp_e0 = reinterpret_cast(0xe0000000UL); + float *gmem_tmp_e1 = reinterpret_cast(0xe1000000UL); + float *gmem_tmp_e2 = reinterpret_cast(0xe2000000UL); + float *gmem_tmp_e3 = reinterpret_cast(0xe3000000UL); + + // static shared memory allocation + constexpr uint32_t smem_Q_size = B_ROW * HEADDIM; + constexpr uint32_t smem_K_size = B_COL * HEADDIM; + constexpr uint32_t smem_QK_size = B_ROW * B_COL; + constexpr uint32_t smem_V_size = B_COL * HEADDIM; + constexpr uint32_t smem_O_size = B_COL * HEADDIM; + static_assert( + threads_per_threadblock == NUM_WARPS * NUM_THREADS * CORES_PER_CLUSTER, + "flashattention kernel assumes 1 threadblock occupancy per cluster"); + uint8_t *smem_per_threadblock = reinterpret_cast( + DEV_SMEM_START_ADDR); + float *smem_cursor = reinterpret_cast(smem_per_threadblock); + // float *smem_cursor = reinterpret_cast(DEV_FAKE_SMEM_START_ADDR); + float *smem_Q0 = smem_cursor; + smem_cursor += smem_Q_size; + float *smem_Q1 = smem_cursor; + smem_cursor += smem_Q_size; + float *smem_K0 = smem_cursor; + smem_cursor += smem_K_size; + float *smem_K1 = smem_cursor; + smem_cursor += smem_K_size; + float *smem_V0 = smem_cursor; + smem_cursor += smem_V_size; + float *smem_V1 = smem_cursor; + smem_cursor += smem_V_size; + float *smem_S0 = smem_cursor; + smem_cursor += smem_QK_size; + float *smem_S1 = smem_cursor; + smem_cursor += smem_QK_size; + float *smem_P0 = smem_S0; // in-place update + float *smem_P1 = smem_S1; // in-place update + float *smem_O0 = smem_cursor; + smem_cursor += smem_O_size; + float *smem_O1 = smem_cursor; + smem_cursor += smem_O_size; + + // NOTE: this has to match with smem_* + static_assert(sizeof(elem_t) == sizeof(float)); + constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t); + constexpr uint32_t spad_addr_Q0 = 0; + constexpr uint32_t spad_addr_Q1 = + spad_addr_Q0 + (smem_Q_size * sizeof(float) / spad_addr_factor); + constexpr uint32_t spad_addr_K0 = + spad_addr_Q1 + (smem_Q_size * sizeof(float) / spad_addr_factor); + constexpr uint32_t spad_addr_K1 = + spad_addr_K0 + (smem_K_size * sizeof(float) / spad_addr_factor); + constexpr uint32_t spad_addr_V0 = + spad_addr_K1 + (smem_K_size * sizeof(float) / spad_addr_factor); + constexpr uint32_t spad_addr_V1 = + spad_addr_V0 + (smem_V_size * sizeof(float) / spad_addr_factor); + constexpr uint32_t spad_addr_S0 = + spad_addr_V1 + (smem_V_size * sizeof(float) / spad_addr_factor); + constexpr uint32_t spad_addr_S1 = + spad_addr_S0 + (smem_QK_size * sizeof(float) / spad_addr_factor); + + // allocate rowmax/rowsum storage at the end of the sharedmem address space + constexpr uint32_t smem_rowmax_size = B_ROW * ROWMAX_SETS; + constexpr uint32_t smem_rowsum_size = B_ROW; + constexpr uint32_t smem_O_row_scale_size = B_ROW; + // FIXME: dangerous + smem_cursor = reinterpret_cast(0xff038000); + + float *smem_rowmax_0 = smem_cursor; + smem_cursor += smem_rowmax_size; + float *smem_rowmax_1 = smem_cursor; + smem_cursor += smem_rowmax_size; + float *smem_rowsum_0 = smem_cursor; + smem_cursor += smem_rowsum_size; + float *smem_rowsum_1 = smem_cursor; + smem_cursor += smem_rowsum_size; + float *smem_O_row_scale_0 = smem_cursor; + smem_cursor += smem_O_row_scale_size; + float *smem_O_row_scale_1 = smem_cursor; + smem_cursor += smem_O_row_scale_size; + + // sharedmem "scratchpad" area to put temporary data, e.g. for tree reduction + // in rowsum + // NOTE: out-of bounds is not checked + constexpr uint32_t smem_scratchpad_size = + threads_per_warpgroup * 2 /*arbitrary slack*/; + float *smem_scratchpad_0 = smem_cursor; + smem_cursor += smem_scratchpad_size; + float *smem_scratchpad_1 = smem_cursor; + smem_cursor += smem_scratchpad_size; + + // initialize rowmax/rowsum values in sharedmem + thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O0, + smem_rowmax_0, smem_rowsum_0, smem_O_row_scale_0); + thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O1, + smem_rowmax_1, smem_rowsum_1, smem_O_row_scale_1); + + constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary + + // delay warpgroup 0 by 1 iteration to do ping-pong scheduling + if (WARP_SPECIALIZED && warpgroup_id == 1) { + threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); + } + + static_assert(!GEMMINI_DMA || Q_IS_K_MAJOR, + "DMA code assumes Q matrix is stored K-major"); + + // skip everything except DMA in the loop FSM + constexpr uint32_t skips = + loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/0, /*skip_ldd=*/1, + /*skip_ex=*/1, /*skip_stc=*/1); + constexpr uint32_t skips_mvout_spad = + loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1, + /*skip_ex=*/1, /*skip_stc=*/0); + + if constexpr (GEMMINI_DMA) { + if (tid_in_warpgroup == 0) { + gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0); + + // configure DMA for the full Q matrix + gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY, + false, 0); + // configure DMA for the full K matrix + gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t), MVIN_SCALE_IDENTITY, + false, 1); + // configure DMA for Q*K store + gemmini_extended_config_st(B_COL * sizeof(elem_t), 0, + MVIN_SCALE_IDENTITY); + gemmini_fence(); + } + } + + // NOTE about barriers: Placing barriers around thread-divergent branches may + // cause bugs, because the Vortex core doesn't check for tmask for barriers. + // The compiler might decide to duplicate vx_bar into both paths of a + // conditional branch, which will get evaluated twice because of the way + // branches are handled in SIMT; this might result in stalls especially when + // other warps behave differently on the branch condition. + // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + + // move Q and K into SMEM before the loop starts + // + static_assert(B_ROW == B_COL, "currently only supports square tiles"); + + asm volatile("dma_move_start_%=:" ::); + + if (tid_in_warpgroup == 0) { + // make sure to read from the correct row of Q + const float *gmem_Q_tile = gmem_Q + HEADDIM * B_ROW * warpgroup_id; + const float *gmem_K_tile = gmem_K; + // configure the GMEM addresses for the DMA to read from + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q_tile), + (uint64_t)(gmem_K_tile), k_LOOP_WS_CONFIG_ADDRS_AB) + // configure address strides for the DMA + GEMMINI_CISC_CMD_R((dim_seqlen << 16) | (HEADDIM << 8) | + 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); + gemmini_fence(); + +#define GEMMINI_DMA_CISC +#ifdef GEMMINI_DMA_CISC + GEMMINI_CISC_CMD_I(10); + gemmini_fence(); +#else + // do DMA + // + // among other things, this also configures CONFIG_BOUNDS so that the + // DMA knows the full matrix dimensions + sp_tiled_matmul_full_spad_ws( + spad_addr_Q, spad_addr_K, + /*spad_D=*/0, /*spad_C=*/spad_addr_S, + /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); + gemmini_fence(); +#endif + + // re-configure DMA for K and V load that will later happen in the loop + // GMEM addr stride for K + gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t), + MVIN_SCALE_IDENTITY, false, 0); + // GMEM addr stride for V + gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY, + false, 1); + gemmini_fence(); + } + + asm volatile("dma_move_end_%=:" ::); + + // protect write to SMEM + threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + + // if constexpr (DEBUG) { + // thread_block_copy_tile(smem_Q0, gmem_tmp_d0, tid_in_warpgroup, + // threads_per_warpgroup, warpgroup_id_in_cluster); + // thread_block_copy_tile(smem_K0, gmem_tmp_d1, tid_in_warpgroup, + // threads_per_warpgroup, warpgroup_id_in_cluster); + // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + // } + + asm volatile ("tile_loop_start_%=:" :: ); + + // "inner loop" along the columns of K^T + const uint32_t k_tiles = (dim_seqlen / B_COL); + for (uint32_t tile_k = 0; tile_k < k_tiles; tile_k++) { + // select the correct double buffer by tile iteration + float *smem_Q = (tile_k & 1) ? smem_Q1 : smem_Q0; // FIXME + float *smem_K = (tile_k & 1) ? smem_K1 : smem_K0; // FIXME + float *smem_V = (tile_k & 1) ? smem_V1 : smem_V0; // FIXME + float *smem_S = (tile_k & 1) ? smem_S1 : smem_S0; // FIXME + float *smem_O = (tile_k & 1) ? smem_O1 : smem_O0; // FIXME + float *smem_P = smem_S; // FIXME + float *smem_O_row_scale = + (tile_k & 1) ? smem_O_row_scale_1 : smem_O_row_scale_0; + float *smem_rowmax = (tile_k & 1) ? smem_rowmax_1 : smem_rowmax_0; + float *smem_rowsum = (tile_k & 1) ? smem_rowsum_1 : smem_rowsum_0; + float *smem_scratchpad = + (tile_k & 1) ? smem_scratchpad_1 : smem_scratchpad_0; + + const auto spad_addr_Q = (tile_k & 1) ? spad_addr_Q1 : spad_addr_Q0; + const auto spad_addr_K = (tile_k & 1) ? spad_addr_K1 : spad_addr_K0; + const auto spad_addr_V = (tile_k & 1) ? spad_addr_V1 : spad_addr_V0; + const auto spad_addr_S = (tile_k & 1) ? spad_addr_S1 : spad_addr_S0; + + // GEMM I: S = Q*K + // + asm volatile("gemm_qk_start_%=:" ::); + + if (tid_in_warpgroup == 0) { + if (tile_k == 0) { + gemmini_fence(); + GEMMINI_CISC_CMD_I(0); + } else if (tile_k & 1) { + gemmini_fence(); + GEMMINI_CISC_CMD_I(2); + } else { + gemmini_fence(); + GEMMINI_CISC_CMD_I(1); + } + + // mvout to SMEM + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + +#if 0 +// weight-stationary matmul loop +#define gemmini_loop_ws(I, J, K, pad_I, pad_J, pad_K, A, B, D, C, A_stride, B_stride, D_stride, C_stride, A_transpose, B_transpose, full_C, low_D, ex_accumulate, act, a_spad_id, b_spad_id, is_resadd) \ + { \ + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(pad_K) << 32) | ((uint64_t)(pad_J) << 16) | (uint64_t)(pad_I), ((uint64_t)(K) << 32) | ((uint64_t)(J) << 16) | (uint64_t)(I), k_LOOP_WS_CONFIG_BOUNDS) \ + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, A, B, k_LOOP_WS_CONFIG_ADDRS_AB) \ + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, D, C, k_LOOP_WS_CONFIG_ADDRS_DC) \ + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, A_stride, B_stride, k_LOOP_WS_CONFIG_STRIDES_AB) \ + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, D_stride, C_stride, k_LOOP_WS_CONFIG_STRIDES_DC) \ + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(a_spad_id) << 18) | ((uint64_t)(b_spad_id) << 16) | ((uint64_t)(act) << 8) | ((low_D) << 2) | ((full_C) << 1) | (ex_accumulate), ((is_resadd) << 2) | ((B_transpose) << 1) | (A_transpose), k_LOOP_WS) \ + } +#endif + + // GEMMINI_CISC_CMD_I(9); + sp_tiled_matmul_full_spad_ws( + /*spad_A=*/spad_addr_Q, /*spad_B=*/spad_addr_K, + /*spad_D=*/0, /*spad_C=*/spad_addr_S, + /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad); + gemmini_fence(); + + if constexpr (DEBUG) { + // for copy-out to GMEM + gemmini_fence(); + } + } + + // thread reconvergence + threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + + asm volatile("gemm_qk_finish_%=:" ::); + + if constexpr (DEBUG) { + if (warpgroup_id == 0) { + if (tile_k == 0) { + thread_block_copy_tile(smem_S, gmem_tmp_d0, + tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); + } else if (tile_k == 1) { + thread_block_copy_tile(smem_S, gmem_tmp_d1, + tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); + } + + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); + } + } + + // inter-warpgroup barrier before online softmax + threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); + +#if 0 + // Online softmax + // + thread_block_online_softmax(smem_S, smem_P, tid_in_warpgroup, + threads_per_warpgroup, warpgroup_id_in_cluster, + smem_scratchpad, smem_rowmax, smem_rowsum, + smem_O_row_scale); + + // FIXME: unnecessary? + threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + + // data movement for K and V + // + // Q stays in SMEM for the entire loop + asm volatile("move_k_v_start_%=:" ::); + if constexpr (GEMMINI_DMA) { + // NOTE: Beware of race conditions; with warp specialization, we need to + // make sure below command code to DMA is not executed simultaneously + // from the two warpgroups (which will result in hardware fault). + // Currently the ping-pong scheduling scheme prevents that. + if (tid_in_warpgroup == 0) { + // configure GMEM addresses for K and V tiles + // load K for the next iteration + const float *gmem_K_tile = gmem_K + (B_COL * (tile_k + 1)); + // load V for the current iteration + const float *gmem_V_tile = gmem_V + (HEADDIM * B_COL * tile_k); + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile), + (uint64_t)(gmem_V_tile), + k_LOOP_WS_CONFIG_ADDRS_AB) + // configure address strides for the DMA + // FIXME: unnecessary? + GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 16) | (dim_seqlen /*KT*/ << 8) | + 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); + gemmini_fence(); + + // do DMA + sp_tiled_matmul_full_spad_ws( + spad_addr_K, spad_addr_V, + /*spad_D=*/0, /*spad_C=*/spad_addr_S, + /*I=*/(HEADDIM / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); + gemmini_fence(); + } + } else { + // load K for the next iteration + load_tile_to_smem( + dim_seqlen, tile_k + 1, 0 /* dim_k == headdim */, gmem_K, smem_K, + tid_in_warpgroup); + + // load V for the current iteration + // V dimension is [seqlen, headdim], stored N(headdim)-major + load_tile_to_smem( + HEADDIM, 0 /* full N-dimension */, tile_k, gmem_V, smem_V, + tid_in_warpgroup); + } + asm volatile("move_k_v_finish_%=:" ::); + + // protect write to SMEM + threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + + if constexpr (DEBUG) { + if (warpgroup_id == 0) { + if (tile_k == 0) { + thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + } else if (tile_k == 1) { + thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + } + + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); + } + } + + // inter-warpgroup barrier before GEMM II + threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); + + // Oi rescale + thread_block_O_rescale(smem_O, smem_O /*in-place*/, + smem_O_row_scale, tid_in_warpgroup, + threads_per_warpgroup, warpgroup_id_in_cluster); + + // rescale-to-PV-GEMM barrier + threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + + if constexpr (DEBUG) { + if (warpgroup_id == 0) { + // O before PV + if (tile_k == 0) { + thread_block_copy_tile(smem_P, gmem_tmp_d2, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + thread_block_copy_tile(smem_O, gmem_tmp_d4, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + } else if (tile_k == 1) { + thread_block_copy_tile(smem_P, gmem_tmp_d3, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + thread_block_copy_tile(smem_O, gmem_tmp_d5, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + } + + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); + } + } + + // GEMM II: O = O + P*V + + asm volatile("gemm_pv_start_%=:" ::); + + if constexpr (!WARP_SPECIALIZED) { + // clear out accumulators before GEMM + initialize_accum_regs<0>(); + initialize_accum_regs<1>(); + + if constexpr (GEMMINI_DMA) { + thread_block_gemm_single_tile< + float, MemLayout::K_major /* P matrix is row-major */, + MemLayout::block_row_major, B_ROW, HEADDIM, B_COL, + /*leading_dim_a=*/0, /*leading_dim_b=*/0, + /*load_accum=*/true, + /*write_to_smem=*/true>( + smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_warpgroup, + threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile( + smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_warpgroup, + threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + // FIXME: wrong but fast + // thread_block_gemm_single_tile( + // smem_P, smem_V, smem_O /*load accum*/, smem_O, + // tid_in_warpgroup, threads_per_warpgroup, + // warpgroups_per_cluster, warpgroup_id_in_cluster); + } + } else { + // when warp-specialized, there's only enough warps to do 64x32 tile + // size so we need to do 2 GEMM calls + static_assert(B_ROW / 2 == 32, + "tile size assumption for warp-specialization not met"); + + // assumes smem_P is K-major + float *smem_P_half0 = smem_P; + float *smem_P_half1 = smem_P + (B_ROW / 2) * B_COL; + float *smem_O_half0 = smem_O; + float *smem_O_half1 = smem_O + (B_ROW / 2) * HEADDIM; + + // clear out accumulators before GEMM + initialize_accum_regs<0>(); + initialize_accum_regs<1>(); + + // split by rows into 2 chunks + if constexpr (GEMMINI_DMA) { + if constexpr (GEMMINI_DMA_FAST) { + thread_block_gemm_single_tile( + smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile< + float, MemLayout::K_major /* P matrix is row-major */, + MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL, + /*leading_dim_a=*/0, + /*leading_dim_b=*/0, + /*load_accum=*/true, + /*write_to_smem=*/true>( + smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } + } else { + thread_block_gemm_single_tile< + float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, + B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0, + /*load_accum=*/true, + /*write_to_smem=*/true>( + smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } + + initialize_accum_regs<0>(); + initialize_accum_regs<1>(); + + if constexpr (GEMMINI_DMA) { + if constexpr (GEMMINI_DMA_FAST) { + thread_block_gemm_single_tile( + smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile< + float, MemLayout::K_major /* P matrix is row-major */, + MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL, + /*leading_dim_a=*/0, + /*leading_dim_b=*/0, + /*load_accum=*/true, + /*write_to_smem=*/true>( + smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } + } else { + thread_block_gemm_single_tile< + float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, + B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0, + /*load_accum=*/true, + /*write_to_smem=*/true>( + smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } + } + + threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + + asm volatile("gemm_pv_finish_%=:" ::); + + if constexpr (DEBUG) { + if (warpgroup_id == 0) { + // O after PV + if (tile_k == 0) { + thread_block_copy_tile(smem_O, gmem_tmp_d6, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + } else if (tile_k == 1) { + thread_block_copy_tile(smem_O, gmem_tmp_d7, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + } + + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); + } + } +#endif + } + + asm volatile ("tile_loop_finish_%=:" :: ); + + // wait for warpgroup 1 to finish, which called the global barrier before + // entering the loop + if (warpgroup_id == 0) { + threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); + } +} + +int main() { + kernel_arg_t *arg = (kernel_arg_t *)KERNEL_ARG_DEV_MEM_ADDR; + + // FIXME:: use actuall seqlen/headdim + const uint32_t problem_size = (B_ROW * B_COL) / (ELEM_PER_THREAD); + const uint32_t hw_threads_per_cluster = + CORES_PER_CLUSTER * vx_num_threads() * vx_num_warps(); + // prevent launching more threads than the necessary problem size + // TODO: this does not take into account multiple clusters + const uint32_t grid_size = (problem_size > hw_threads_per_cluster) + ? hw_threads_per_cluster + : problem_size; + +#ifdef RADIANCE + vx_spawn_tasks_cluster(grid_size, (vx_spawn_tasks_cb)kernel_body, arg); +#else + // NOTE: This kernel assumes contiguous thread scheduling for efficient shared + // memory allocation, and therefore does not work with original vx_spawn_tasks + vx_spawn_tasks_contiguous(grid_size, (vx_spawn_tasks_cb)kernel_body, arg); +#endif + return 0; +} From c51dc4902d5060202a3bbbdfe3ba25de2572f138 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sat, 7 Sep 2024 23:21:28 -0700 Subject: [PATCH 13/50] flash: Fix online softmax for DMA layout --- tests/regression/flash_attention/Makefile | 2 +- .../regression/flash_attention/flash_impl.hpp | 37 ++++++----- .../flash_attention/kernel.gemmini.cpp | 63 +++++++++---------- 3 files changed, 54 insertions(+), 48 deletions(-) diff --git a/tests/regression/flash_attention/Makefile b/tests/regression/flash_attention/Makefile index 4d4fcad1..0456e983 100644 --- a/tests/regression/flash_attention/Makefile +++ b/tests/regression/flash_attention/Makefile @@ -3,7 +3,7 @@ PROJECT = flash_attention SRCS = main.cpp common.h VX_SRCS = kernel.gemmini.cpp -VX_INCLUDES = ../sgemm_tcore/sgemm_impl.hpp +VX_INCLUDES = flash_impl.hpp ../sgemm_tcore/sgemm_impl.hpp OPTS ?= -n16 diff --git a/tests/regression/flash_attention/flash_impl.hpp b/tests/regression/flash_attention/flash_impl.hpp index 423ebd69..48a0068f 100644 --- a/tests/regression/flash_attention/flash_impl.hpp +++ b/tests/regression/flash_attention/flash_impl.hpp @@ -152,6 +152,7 @@ inline float exponential_taylor_term(const float x) { return res; } +template __attribute__((always_inline)) inline void thread_block_online_softmax( const float *smem_S, float *smem_P, const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock, @@ -180,7 +181,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( // one warp handles one row in tile constexpr uint32_t per_row_iter = B_COL / NUM_THREADS; - uint32_t thread_offset = first_thread_offset + tid_in_warp; // FIXME: threadblock_id needs to be in here too float *warp_smem = smem_scratchpad + (warp_id * NUM_THREADS); @@ -219,11 +219,16 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( float per_thread_max = FLT_MIN; #pragma GCC unroll for (int i = 0; i < per_row_iter; i++) { - const float next = smem_S[thread_offset]; + const uint32_t col_offset = NUM_THREADS * i; + const uint32_t col = col_offset + tid_in_warp; + const auto [smem_row, smem_col] = + remap_to_gemmini_dma_layout(row, col); + const uint32_t offset = B_COL * smem_row + smem_col; + + const float next = smem_S[offset]; asm volatile("fmax.s %0, %1, %2" : "=f"(per_thread_max) : "f"(per_thread_max), "f"(next)); - thread_offset += NUM_THREADS; } // stage per-thread max value in smem warp_smem[tid_in_warp] = per_thread_max; @@ -299,10 +304,15 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( asm volatile("flashattn_exp_p_start_%=:" ::); - thread_offset = first_thread_offset + tid_in_warp; #pragma GCC unroll for (int i = 0; i < per_row_iter; i++) { - float f0 = smem_S[thread_offset]; + const uint32_t col_offset = NUM_THREADS * i; + const uint32_t col = col_offset + tid_in_warp; + const auto [smem_row, smem_col] = + remap_to_gemmini_dma_layout(row, col); + const uint32_t offset = B_COL * smem_row + smem_col; + + float f0 = smem_S[offset]; f0 -= rowmax_new; @@ -313,9 +323,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( // Store S transposed to the shared memory - smem_P[thread_offset] = exp; - - thread_offset += NUM_THREADS; + smem_P[offset] = exp; } asm volatile("flashattn_exp_p_end_%=:" ::); @@ -332,11 +340,15 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( float per_thread_sum = 0.0f; - thread_offset = first_thread_offset + tid_in_warp; #pragma GCC unroll for (int i = 0; i < per_row_iter; i++) { - per_thread_sum += smem_P[thread_offset]; - thread_offset += NUM_THREADS; + const uint32_t col_offset = NUM_THREADS * i; + const uint32_t col = col_offset + tid_in_warp; + const auto [smem_row, smem_col] = + remap_to_gemmini_dma_layout(row, col); + const uint32_t offset = B_COL * smem_row + smem_col; + + per_thread_sum += smem_P[offset]; } // stage per-thread sum value in smem // FIXME: threadblock_id needs to be in here too @@ -381,7 +393,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( // asm volatile("flashattn_rescale_factor_start_%=:" ::); - thread_offset = first_thread_offset + tid_in_warp; #pragma GCC unroll for (int i = 0; i < per_row_iter; i++) { const float mi_prev = rowmax_prev; @@ -395,8 +406,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( // @perf: div vs. expansion on e(-x)? smem_O_row_scale[row] = 1.0f / exp; - - thread_offset += NUM_THREADS; } asm volatile("flashattn_rescale_factor_end_%=:" ::); diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index 0df0cf87..4572c921 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -319,8 +319,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { gemmini_fence(); gemmini_fence(); -#if 0 -// weight-stationary matmul loop +#if 0 // TODO +// loop_ws variant that skips configuring strides #define gemmini_loop_ws(I, J, K, pad_I, pad_J, pad_K, A, B, D, C, A_stride, B_stride, D_stride, C_stride, A_transpose, B_transpose, full_C, low_D, ex_accumulate, act, a_spad_id, b_spad_id, is_resadd) \ { \ ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(pad_K) << 32) | ((uint64_t)(pad_J) << 16) | (uint64_t)(pad_I), ((uint64_t)(K) << 32) | ((uint64_t)(J) << 16) | (uint64_t)(I), k_LOOP_WS_CONFIG_BOUNDS) \ @@ -373,17 +373,40 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // inter-warpgroup barrier before online softmax threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); -#if 0 // Online softmax // - thread_block_online_softmax(smem_S, smem_P, tid_in_warpgroup, - threads_per_warpgroup, warpgroup_id_in_cluster, - smem_scratchpad, smem_rowmax, smem_rowsum, - smem_O_row_scale); + thread_block_online_softmax( + smem_S, smem_P, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster, smem_scratchpad, smem_rowmax, smem_rowsum, + smem_O_row_scale); // FIXME: unnecessary? threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + if constexpr (DEBUG) { + if (warpgroup_id == 0) { + if (tile_k == 0) { + thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + } else if (tile_k == 1) { + thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + } + + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); + } + } + +#if 0 // data movement for K and V // // Q stays in SMEM for the entire loop @@ -434,32 +457,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } asm volatile("move_k_v_finish_%=:" ::); - // protect write to SMEM - threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); - - if constexpr (DEBUG) { - if (warpgroup_id == 0) { - if (tile_k == 0) { - thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); - thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); - } else if (tile_k == 1) { - thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); - thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); - } - - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); - } - } - // inter-warpgroup barrier before GEMM II threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); From adcd0a9d497488e2a5ad2645c96991bebffb5a8b Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sun, 8 Sep 2024 02:23:51 -0700 Subject: [PATCH 14/50] sgemm_impl: Fix wrong smem address for fp16 Verified results for fp16 256x256. --- tests/regression/sgemm_tcore/kernel.cpp | 9 +++++---- tests/regression/sgemm_tcore/sgemm_impl.hpp | 10 +++++----- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 59fd7194..bc77ac2a 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -7,7 +7,7 @@ #include "include/gemmini.h" #include "gemmini_mmio.h" -constexpr bool DEBUG = true; +constexpr bool DEBUG = false; template inline void thread_block_copy_tile(const float *src, float *dest, @@ -91,8 +91,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*write_to_gmem=*/true, /*smem_a_offset=*/0, /*smem_a_dbuf_offset=*/0, - /*smem_b_offset=*/2 * BM * BK * sizeof(float), - /*smem_b_dbuf_offset=*/2 * BM * BK * sizeof(float)>( + /*smem_b_offset=*/2 * BM * BK * sizeof(float_type), + /*smem_b_dbuf_offset=*/2 * BM * BK * sizeof(float_type)>( (const float_type *)arg->addr_a, (const float_type *)arg->addr_b, (float *)arg->addr_c, arg->dim_m, arg->dim_n, arg->dim_k, tid_in_threadblock, threadblocks_per_cluster, threadblock_id_in_cluster, @@ -102,7 +102,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { float *gmem_tmp_d1 = reinterpret_cast(0xd1000000UL); const float *smem_A = reinterpret_cast(sharedmem_per_threadblock); - const float *smem_B = smem_A + 2 * BM * BK; + const float *smem_B = reinterpret_cast( + sharedmem_per_threadblock + 2 * BM * BK * sizeof(float_type)); if constexpr (DEBUG) { threadblock_barrier(threadblock_id_in_cluster, diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index 1bb7b893..0c6274a2 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -6,7 +6,7 @@ #include "include/gemmini.h" #include "gemmini_mmio.h" -#define FP_SIZE 32 +#define FP_SIZE 16 // "fake" fp16 type that only has the correct data width. using float16_t = uint16_t; @@ -29,7 +29,7 @@ using float_type = float16_t; // (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER // * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields // BM <= BK*TM*TN -#define BM 64 +#define BM 128 #define BN 64 #if (FP_SIZE == 32) #define BK 64 @@ -72,7 +72,7 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER == #define TRANSPOSE_AT_PRODUCE 0 #define TRANSPOSE_AT_CONSUME 0 -#define GEMMINI_DMA 1 +#define GEMMINI_DMA 0 #define GEMMINI_DMA_FLEXIBLE_LAYOUT 0 #if SMEM_SIZE == 0x4000 #define SMEM_ADDR_Q0 ((float * const) 0xff000000) @@ -847,9 +847,9 @@ template < uint32_t smem_a_dbuf_offset = 0, // byte offset of A // double-buffer tile in shared // memory - uint32_t smem_b_offset = sizeof(float) * BM * BK, // byte offset of B tile + uint32_t smem_b_offset = sizeof(T) * BM * BK, // byte offset of B tile // in shared memory - uint32_t smem_b_dbuf_offset = sizeof(float) * BM * + uint32_t smem_b_dbuf_offset = sizeof(T) * BM * BK // byte offset of B double-buffer // tile in shared memory > From 42913c00c410541bccada1321202b42a5d024232 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sun, 8 Sep 2024 14:28:27 -0700 Subject: [PATCH 15/50] sgemm_impl: Use 12-bit cmd interface, allow DIM=16 --- tests/regression/sgemm_tcore/sgemm_impl.hpp | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index 0c6274a2..d2e88ace 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -207,7 +207,7 @@ template inline constexpr std::pair remap_to_gemmini_dma_layout(const uint32_t logical_row, const uint32_t logical_col) { - static_assert(DIM == 8, + static_assert(GEMMINI_DMA_FLEXIBLE_LAYOUT || DIM == 8, "GEMMINI_DMA layout remapping code only written for DIM == 8"); if constexpr (use_dma) { @@ -905,6 +905,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, for (uint32_t block_m = block_m_start; block_m < block_m_end; block_m++) { #pragma GCC unroll 1 for (uint32_t block_n = 0; (block_n * BN) < dim_n; block_n++) { + asm volatile ("loop_mn_start_%=:" :: ); + // clear out accumulators initialize_accum_regs<0>(); initialize_accum_regs<1>(); @@ -920,7 +922,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, (uint64_t)(B + /*block_k:*/0 * BK * dim_n + block_n * BN), k_LOOP_WS_CONFIG_ADDRS_AB) // GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB - GEMMINI_CISC_CMD_R((dim_n << 16) | (dim_k << 8) | 8); + GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8); gemmini_fence(); GEMMINI_CISC_CMD_I(10); @@ -951,6 +953,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, #pragma GCC unroll 1 for (uint32_t block_k = 0; (block_k * BK) < dim_k; block_k++) { + asm volatile("loop_k_start_%=:" ::); // producer code: GMEM->SMEM memory movement // --------------------------------------------------------------------- @@ -967,8 +970,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, (uint64_t)(B + (block_k + 1/*runahead*/) * BK * dim_n + block_n * BN), k_LOOP_WS_CONFIG_ADDRS_AB) // GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB - GEMMINI_CISC_CMD_R((dim_n << 16) | (dim_k << 8) | 8); - // gemmini_fence(); + GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8); + gemmini_fence(); // block_k is even: opcode 11 (write to local_a_buf) // block_k is odd: opcode 10 (write to local_a) @@ -1043,6 +1046,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, // consumer code: SMEM->RF and compute // ---------------------------------------------------------------------- // @perf: this loop spills to stack a lot because of all the flws in + asm volatile("dbuf_sel_start_%=:" ::); const T *local_a_consume; const T *local_b_consume; if constexpr (GEMMINI_DMA) { @@ -1064,6 +1068,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, local_a_consume = local_a; local_b_consume = local_b; } + asm volatile("dbuf_sel_end_%=:" ::); constexpr MemLayout layout_a = GEMMINI_DMA ? MemLayout::block_row_major @@ -1092,6 +1097,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core); + + asm volatile("loop_k_end_%=:" ::); } if constexpr (write_to_gmem) { @@ -1106,6 +1113,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, } } } + asm volatile("loop_mn_end_%=:" ::); } } From 443a37be6ca93f22ddafade18349ccee8bdd617d Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sun, 8 Sep 2024 14:56:48 -0700 Subject: [PATCH 16/50] sgemm_impl: Add DMA_FAST option; fix dbuf offset for dma --- tests/regression/sgemm_tcore/kernel.cpp | 21 ++++++++++++----- tests/regression/sgemm_tcore/sgemm_impl.hpp | 26 +++++++++++++-------- 2 files changed, 31 insertions(+), 16 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index bc77ac2a..bb904baf 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -90,13 +90,22 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { thread_block_gemm( - (const float_type *)arg->addr_a, (const float_type *)arg->addr_b, - (float *)arg->addr_c, arg->dim_m, arg->dim_n, arg->dim_k, - tid_in_threadblock, threadblocks_per_cluster, threadblock_id_in_cluster, - sharedmem_per_threadblock); + /*smem_b_dbuf_offset=*/(2 * BM * BK + BK * BN) * sizeof(float_type) +#endif + >((const float_type *)arg->addr_a, + (const float_type *)arg->addr_b, (float *)arg->addr_c, + arg->dim_m, arg->dim_n, arg->dim_k, tid_in_threadblock, + threadblocks_per_cluster, threadblock_id_in_cluster, + sharedmem_per_threadblock); float *gmem_tmp_d0 = reinterpret_cast(0xd0000000UL); float *gmem_tmp_d1 = reinterpret_cast(0xd1000000UL); diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index d2e88ace..7ba19992 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -72,8 +72,9 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER == #define TRANSPOSE_AT_PRODUCE 0 #define TRANSPOSE_AT_CONSUME 0 -#define GEMMINI_DMA 0 -#define GEMMINI_DMA_FLEXIBLE_LAYOUT 0 +#define GEMMINI_DMA 1 +#define GEMMINI_DMA_FAST 1 +#define GEMMINI_DMA_FLEXIBLE_LAYOUT 1 #if SMEM_SIZE == 0x4000 #define SMEM_ADDR_Q0 ((float * const) 0xff000000) #define SMEM_ADDR_Q1 ((float * const) 0xff001000) @@ -207,7 +208,7 @@ template inline constexpr std::pair remap_to_gemmini_dma_layout(const uint32_t logical_row, const uint32_t logical_col) { - static_assert(GEMMINI_DMA_FLEXIBLE_LAYOUT || DIM == 8, + static_assert(!use_dma || DIM == 8, "GEMMINI_DMA layout remapping code only written for DIM == 8"); if constexpr (use_dma) { @@ -915,7 +916,6 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, // pipeline initiation if (tid_in_threadblock == 0) { // configure dma gmem address to load from - // FIXME: block_k is wrong ROCC_INSTRUCTION_RS1_RS2( XCUSTOM_ACC, (uint64_t)(A + block_m * BM * dim_k + /*block_k:*/0 * BK), @@ -963,7 +963,6 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, #if (GEMMINI_DMA == 1) if ((tid_in_threadblock == 0) && ((block_k * BK) != (dim_k - BK))) { // configure dma gmem address to load from - // FIXME: block_k is wrong ROCC_INSTRUCTION_RS1_RS2( XCUSTOM_ACC, (uint64_t)(A + block_m * BM * dim_k + (block_k + 1/*runahead*/) * BK), @@ -976,7 +975,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, // block_k is even: opcode 11 (write to local_a_buf) // block_k is odd: opcode 10 (write to local_a) const uint32_t opcode = 11 - (block_k & 1); - GEMMINI_CISC_CMD_R(opcode); + GEMMINI_CISC_CMD_I(opcode); // // TODO: branch is probably slow // if (block_k & 1) { // GEMMINI_CISC_CMD_I(12); @@ -1061,8 +1060,12 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, // local_b_consume = reinterpret_cast( // (mask_odd & reinterpret_cast(local_b_buf)) | // (mask_even & reinterpret_cast(local_b))); - local_a_consume = local_a + (block_k & 1) * (BM * BK); - local_b_consume = local_b + (block_k & 1) * (BK * BN); + local_a_consume = local_a + (block_k & 1) * + (smem_a_dbuf_offset - smem_a_offset) / + sizeof(T); + local_b_consume = local_b + (block_k & 1) * + (smem_b_dbuf_offset - smem_b_offset) / + sizeof(T); } else { // no double-buffering without DMA local_a_consume = local_a; @@ -1071,11 +1074,14 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, asm volatile("dbuf_sel_end_%=:" ::); constexpr MemLayout layout_a = - GEMMINI_DMA ? MemLayout::block_row_major + GEMMINI_DMA ? (GEMMINI_DMA_FAST ? MemLayout::MN_major + : MemLayout::block_row_major) : (TRANSPOSE_AT_CONSUME ? MemLayout::K_major : MemLayout::MN_major); constexpr MemLayout layout_b = - GEMMINI_DMA ? MemLayout::block_row_major : MemLayout::MN_major; + GEMMINI_DMA ? (GEMMINI_DMA_FAST ? MemLayout::MN_major + : MemLayout::block_row_major) + : MemLayout::MN_major; thread_block_gemm_single_tile Date: Sun, 8 Sep 2024 15:29:15 -0700 Subject: [PATCH 17/50] sgemm_impl: Parameterize BM on NUM_CORES --- tests/regression/sgemm_tcore/sgemm_impl.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index 7ba19992..d24c61d6 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -29,7 +29,7 @@ using float_type = float16_t; // (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER // * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields // BM <= BK*TM*TN -#define BM 128 +#define BM ((NUM_CORES == 8) ? 128 : 64) #define BN 64 #if (FP_SIZE == 32) #define BK 64 From 3f50ac57ee5d06778569d6dd093e749ef4e1971e Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sun, 8 Sep 2024 15:29:46 -0700 Subject: [PATCH 18/50] flash: use 12bit dma interface --- tests/regression/flash_attention/flash_impl.hpp | 2 -- tests/regression/flash_attention/kernel.gemmini.cpp | 8 +++++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/regression/flash_attention/flash_impl.hpp b/tests/regression/flash_attention/flash_impl.hpp index 48a0068f..fd027553 100644 --- a/tests/regression/flash_attention/flash_impl.hpp +++ b/tests/regression/flash_attention/flash_impl.hpp @@ -9,12 +9,10 @@ #define HEADDIM 64 constexpr uint32_t ROWMAX_SETS = 3; -constexpr bool DEBUG = true; constexpr bool WARP_SPECIALIZED = false; constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000; -constexpr bool GEMMINI_DMA_FAST = false; constexpr bool Q_IS_K_MAJOR = true; // temporary safety stop for wrong configs diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index 4572c921..4a2c3133 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -8,6 +8,8 @@ #include "gemmini_mmio.h" #include "flash_impl.hpp" +constexpr bool DEBUG = true; + static_assert(GEMMINI_DMA && !WARP_SPECIALIZED, "GEMMINI_DMA should be set and WARP_SPECIALIZED unset"); @@ -227,7 +229,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q_tile), (uint64_t)(gmem_K_tile), k_LOOP_WS_CONFIG_ADDRS_AB) // configure address strides for the DMA - GEMMINI_CISC_CMD_R((dim_seqlen << 16) | (HEADDIM << 8) | + GEMMINI_CISC_CMD_R((dim_seqlen << 20) | (HEADDIM << 8) | 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); gemmini_fence(); @@ -313,7 +315,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { GEMMINI_CISC_CMD_I(1); } - // mvout to SMEM gemmini_fence(); gemmini_fence(); gemmini_fence(); @@ -332,6 +333,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } #endif + // mvout to SMEM // GEMMINI_CISC_CMD_I(9); sp_tiled_matmul_full_spad_ws( /*spad_A=*/spad_addr_Q, /*spad_B=*/spad_addr_K, @@ -427,7 +429,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { k_LOOP_WS_CONFIG_ADDRS_AB) // configure address strides for the DMA // FIXME: unnecessary? - GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 16) | (dim_seqlen /*KT*/ << 8) | + GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) | 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); gemmini_fence(); From cdb8377b62af9b242ad9a55e371430c9cd283f04 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sun, 8 Sep 2024 16:09:06 -0700 Subject: [PATCH 19/50] flash: Do GEMM II in Gemmini; verify 1st iteration --- .../regression/flash_attention/flash_impl.hpp | 31 ++- .../flash_attention/kernel.gemmini.cpp | 252 +++++++----------- 2 files changed, 113 insertions(+), 170 deletions(-) diff --git a/tests/regression/flash_attention/flash_impl.hpp b/tests/regression/flash_attention/flash_impl.hpp index fd027553..46a62546 100644 --- a/tests/regression/flash_attention/flash_impl.hpp +++ b/tests/regression/flash_attention/flash_impl.hpp @@ -95,7 +95,7 @@ inline void thread_block_copy_rowmax(const float *src, float *dest, asm volatile("threadblock_copy_rowmax_finish_%=:" ::); } -template +template inline void thread_block_copy_tile(const float *src, float *dest, const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock, @@ -113,14 +113,18 @@ inline void thread_block_copy_tile(const float *src, float *dest, for (int row_offset = 0; row_offset < dim_row; row_offset += warps_in_threadblock) { const uint32_t row = row_offset + warp_id; - const uint32_t first_thread_offset = dim_col * row; constexpr uint32_t per_row_iter = dim_col / NUM_THREADS; - uint32_t thread_offset = first_thread_offset + tid_in_warp; #pragma GCC unroll for (int i = 0; i < per_row_iter; i++) { - dest[thread_offset] = src[thread_offset]; - thread_offset += NUM_THREADS; + const uint32_t col_offset = NUM_THREADS * i; + const uint32_t col = col_offset + tid_in_warp; + const auto [smem_row, smem_col] = + remap_to_gemmini_dma_layout(row, col); + const uint32_t smem_offset = B_COL * smem_row + smem_col; + const uint32_t gmem_offset = B_COL * row + col; + + dest[gmem_offset] = src[smem_offset]; } threadblock_barrier(threadblock_id_in_cluster, @@ -415,6 +419,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( asm volatile("thread_block_online_softmax_finish_%=:" ::); } +template __attribute__((always_inline)) inline void thread_block_O_rescale( const float *smem_O_in, float *smem_O_out, const float *smem_O_row_scale, const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock, @@ -431,19 +436,21 @@ __attribute__((always_inline)) inline void thread_block_O_rescale( for (int row_offset = 0; row_offset < B_ROW; row_offset += warps_in_threadblock) { const uint32_t row = row_offset + warp_id; - const uint32_t first_thread_offset = B_COL * row; - constexpr uint32_t per_row_iter = B_COL / NUM_THREADS; - uint32_t thread_offset = first_thread_offset + tid_in_warp; + constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS; // Oi rescale // #pragma GCC unroll for (int i = 0; i < per_row_iter; i++) { - const float o = smem_O_in[thread_offset]; - const float scale = smem_O_row_scale[row]; - smem_O_out[thread_offset] = (o * scale); + const uint32_t col_offset = NUM_THREADS * i; + const uint32_t col = col_offset + tid_in_warp; + const auto [smem_row, smem_col] = + remap_to_gemmini_dma_layout(row, col); - thread_offset += NUM_THREADS; + const uint32_t offset = HEADDIM * smem_row + smem_col; + const float o = smem_O_in[offset]; + const float scale = smem_O_row_scale[row]; + smem_O_out[offset] = (o * scale); } } diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index 4a2c3133..d5c553d8 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -110,8 +110,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { smem_cursor += smem_QK_size; float *smem_S1 = smem_cursor; smem_cursor += smem_QK_size; - float *smem_P0 = smem_S0; // in-place update - float *smem_P1 = smem_S1; // in-place update + float *smem_P0 = smem_cursor; + smem_cursor += smem_QK_size; + float *smem_P1 = smem_cursor; + smem_cursor += smem_QK_size; float *smem_O0 = smem_cursor; smem_cursor += smem_O_size; float *smem_O1 = smem_cursor; @@ -135,6 +137,14 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { spad_addr_V1 + (smem_V_size * sizeof(float) / spad_addr_factor); constexpr uint32_t spad_addr_S1 = spad_addr_S0 + (smem_QK_size * sizeof(float) / spad_addr_factor); + constexpr uint32_t spad_addr_P0 = + spad_addr_S1 + (smem_QK_size * sizeof(float) / spad_addr_factor); + constexpr uint32_t spad_addr_P1 = + spad_addr_P0 + (smem_QK_size * sizeof(float) / spad_addr_factor); + constexpr uint32_t spad_addr_O0 = + spad_addr_P1 + (smem_QK_size * sizeof(float) / spad_addr_factor); + constexpr uint32_t spad_addr_O1 = + spad_addr_O0 + (smem_O_size * sizeof(float) / spad_addr_factor); // allocate rowmax/rowsum storage at the end of the sharedmem address space constexpr uint32_t smem_rowmax_size = B_ROW * ROWMAX_SETS; @@ -189,6 +199,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { constexpr uint32_t skips_mvout_spad = loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1, /*skip_ex=*/1, /*skip_stc=*/0); + constexpr uint32_t skips_matmul = + loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1, + /*skip_ex=*/0, /*skip_stc=*/1); if constexpr (GEMMINI_DMA) { if (tid_in_warpgroup == 0) { @@ -281,12 +294,13 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const uint32_t k_tiles = (dim_seqlen / B_COL); for (uint32_t tile_k = 0; tile_k < k_tiles; tile_k++) { // select the correct double buffer by tile iteration - float *smem_Q = (tile_k & 1) ? smem_Q1 : smem_Q0; // FIXME - float *smem_K = (tile_k & 1) ? smem_K1 : smem_K0; // FIXME - float *smem_V = (tile_k & 1) ? smem_V1 : smem_V0; // FIXME - float *smem_S = (tile_k & 1) ? smem_S1 : smem_S0; // FIXME - float *smem_O = (tile_k & 1) ? smem_O1 : smem_O0; // FIXME - float *smem_P = smem_S; // FIXME + // FIXME do correct double buffering + float *smem_Q = (tile_k & 1) ? smem_Q1 : smem_Q0; + float *smem_K = (tile_k & 1) ? smem_K1 : smem_K0; + float *smem_V = (tile_k & 1) ? smem_V1 : smem_V0; + float *smem_S = (tile_k & 1) ? smem_S1 : smem_S0; + float *smem_P = (tile_k & 1) ? smem_P1 : smem_P0; + float *smem_O = (tile_k & 1) ? smem_O1 : smem_O0; float *smem_O_row_scale = (tile_k & 1) ? smem_O_row_scale_1 : smem_O_row_scale_0; float *smem_rowmax = (tile_k & 1) ? smem_rowmax_1 : smem_rowmax_0; @@ -298,6 +312,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const auto spad_addr_K = (tile_k & 1) ? spad_addr_K1 : spad_addr_K0; const auto spad_addr_V = (tile_k & 1) ? spad_addr_V1 : spad_addr_V0; const auto spad_addr_S = (tile_k & 1) ? spad_addr_S1 : spad_addr_S0; + const auto spad_addr_P = (tile_k & 1) ? spad_addr_P1 : spad_addr_P0; + const auto spad_addr_O = spad_addr_O0; // NOTE: there's only single O tile // GEMM I: S = Q*K // @@ -320,7 +336,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { gemmini_fence(); gemmini_fence(); -#if 0 // TODO +#if 0 // TODO: speed up mvout to SMEM // loop_ws variant that skips configuring strides #define gemmini_loop_ws(I, J, K, pad_I, pad_J, pad_K, A, B, D, C, A_stride, B_stride, D_stride, C_stride, A_transpose, B_transpose, full_C, low_D, ex_accumulate, act, a_spad_id, b_spad_id, is_resadd) \ { \ @@ -350,7 +366,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } } - // thread reconvergence + // reconverge from mmio divergence threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); asm volatile("gemm_qk_finish_%=:" ::); @@ -358,13 +374,13 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if constexpr (DEBUG) { if (warpgroup_id == 0) { if (tile_k == 0) { - thread_block_copy_tile(smem_S, gmem_tmp_d0, - tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); + thread_block_copy_tile( + smem_S, gmem_tmp_d0, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); } else if (tile_k == 1) { - thread_block_copy_tile(smem_S, gmem_tmp_d1, - tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); + thread_block_copy_tile( + smem_S, gmem_tmp_d1, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); } threadblock_barrier(warpgroup_id_in_cluster, @@ -408,7 +424,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } } -#if 0 // data movement for K and V // // Q stays in SMEM for the entire loop @@ -463,9 +478,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); // Oi rescale - thread_block_O_rescale(smem_O, smem_O /*in-place*/, - smem_O_row_scale, tid_in_warpgroup, - threads_per_warpgroup, warpgroup_id_in_cluster); + thread_block_O_rescale( + smem_O, smem_O /*in-place*/, smem_O_row_scale, tid_in_warpgroup, + threads_per_warpgroup, warpgroup_id_in_cluster); // rescale-to-PV-GEMM barrier threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); @@ -474,19 +489,19 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if (warpgroup_id == 0) { // O before PV if (tile_k == 0) { - thread_block_copy_tile(smem_P, gmem_tmp_d2, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); - thread_block_copy_tile(smem_O, gmem_tmp_d4, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); + thread_block_copy_tile( + smem_P, gmem_tmp_d2, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); + thread_block_copy_tile( + smem_O, gmem_tmp_d4, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); } else if (tile_k == 1) { - thread_block_copy_tile(smem_P, gmem_tmp_d3, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); - thread_block_copy_tile(smem_O, gmem_tmp_d5, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); + thread_block_copy_tile( + smem_P, gmem_tmp_d3, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); + thread_block_copy_tile( + smem_O, gmem_tmp_d5, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); } threadblock_barrier(warpgroup_id_in_cluster, @@ -498,134 +513,54 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { asm volatile("gemm_pv_start_%=:" ::); - if constexpr (!WARP_SPECIALIZED) { - // clear out accumulators before GEMM - initialize_accum_regs<0>(); - initialize_accum_regs<1>(); - - if constexpr (GEMMINI_DMA) { - thread_block_gemm_single_tile< - float, MemLayout::K_major /* P matrix is row-major */, - MemLayout::block_row_major, B_ROW, HEADDIM, B_COL, - /*leading_dim_a=*/0, /*leading_dim_b=*/0, - /*load_accum=*/true, - /*write_to_smem=*/true>( - smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_warpgroup, - threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); + if (tid_in_warpgroup == 0) { +#if 0 + if (tile_k == 0) { + gemmini_fence(); + GEMMINI_CISC_CMD_I(0); + } else if (tile_k & 1) { + gemmini_fence(); + GEMMINI_CISC_CMD_I(2); } else { - thread_block_gemm_single_tile( - smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_warpgroup, - threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); - // FIXME: wrong but fast - // thread_block_gemm_single_tile( - // smem_P, smem_V, smem_O /*load accum*/, smem_O, - // tid_in_warpgroup, threads_per_warpgroup, - // warpgroups_per_cluster, warpgroup_id_in_cluster); + gemmini_fence(); + GEMMINI_CISC_CMD_I(1); } - } else { - // when warp-specialized, there's only enough warps to do 64x32 tile - // size so we need to do 2 GEMM calls - static_assert(B_ROW / 2 == 32, - "tile size assumption for warp-specialization not met"); +#else + // do matmul + // among other things, this also configures CONFIG_BOUNDS so that the + // DMA knows the full matrix dimensions + sp_tiled_matmul_full_spad_ws( + spad_addr_P, spad_addr_V, + /*spad_D=*/0, /*spad_C=*/spad_addr_O, + /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul); +#endif - // assumes smem_P is K-major - float *smem_P_half0 = smem_P; - float *smem_P_half1 = smem_P + (B_ROW / 2) * B_COL; - float *smem_O_half0 = smem_O; - float *smem_O_half1 = smem_O + (B_ROW / 2) * HEADDIM; + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); - // clear out accumulators before GEMM - initialize_accum_regs<0>(); - initialize_accum_regs<1>(); + // mvout to SMEM + // GEMMINI_CISC_CMD_I(9); + sp_tiled_matmul_full_spad_ws( + /*spad_A=*/spad_addr_P, /*spad_B=*/spad_addr_V, + /*spad_D=*/0, /*spad_C=*/spad_addr_O, + /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL/ DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad); + gemmini_fence(); - // split by rows into 2 chunks - if constexpr (GEMMINI_DMA) { - if constexpr (GEMMINI_DMA_FAST) { - thread_block_gemm_single_tile( - smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); - } else { - thread_block_gemm_single_tile< - float, MemLayout::K_major /* P matrix is row-major */, - MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL, - /*leading_dim_a=*/0, - /*leading_dim_b=*/0, - /*load_accum=*/true, - /*write_to_smem=*/true>( - smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); - } - } else { - thread_block_gemm_single_tile< - float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, - B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0, - /*load_accum=*/true, - /*write_to_smem=*/true>( - smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); - } - - initialize_accum_regs<0>(); - initialize_accum_regs<1>(); - - if constexpr (GEMMINI_DMA) { - if constexpr (GEMMINI_DMA_FAST) { - thread_block_gemm_single_tile( - smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); - } else { - thread_block_gemm_single_tile< - float, MemLayout::K_major /* P matrix is row-major */, - MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL, - /*leading_dim_a=*/0, - /*leading_dim_b=*/0, - /*load_accum=*/true, - /*write_to_smem=*/true>( - smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); - } - } else { - thread_block_gemm_single_tile< - float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, - B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0, - /*load_accum=*/true, - /*write_to_smem=*/true>( - smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); + if constexpr (DEBUG) { + // for copy-out to GMEM + gemmini_fence(); } } + // reconverge from mmio divergence threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); asm volatile("gemm_pv_finish_%=:" ::); @@ -634,19 +569,20 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if (warpgroup_id == 0) { // O after PV if (tile_k == 0) { - thread_block_copy_tile(smem_O, gmem_tmp_d6, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); + thread_block_copy_tile( + smem_O, gmem_tmp_d6, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); } else if (tile_k == 1) { - thread_block_copy_tile(smem_O, gmem_tmp_d7, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); + thread_block_copy_tile( + smem_O, gmem_tmp_d7, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); } threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); } } +#if 0 #endif } From 8efa6868ea369c5300add8109b52e6389c213523 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sun, 8 Sep 2024 18:45:32 -0700 Subject: [PATCH 20/50] flash: Restructure for full software pipelining Verified up to P and O before PV; need to fix iteration for V load. --- .../flash_attention/kernel.gemmini.cpp | 443 ++++++++++-------- tests/regression/sgemm_tcore/sgemm_impl.hpp | 2 +- 2 files changed, 239 insertions(+), 206 deletions(-) diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index d5c553d8..9e36bf83 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -184,10 +184,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary - // delay warpgroup 0 by 1 iteration to do ping-pong scheduling - if (WARP_SPECIALIZED && warpgroup_id == 1) { - threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); - } + // // delay warpgroup 0 by 1 iteration to do ping-pong scheduling + // if (WARP_SPECIALIZED && warpgroup_id == 1) { + // threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); + // } static_assert(!GEMMINI_DMA || Q_IS_K_MAJOR, "DMA code assumes Q matrix is stored K-major"); @@ -196,6 +196,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { constexpr uint32_t skips = loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/0, /*skip_ldd=*/1, /*skip_ex=*/1, /*skip_stc=*/1); + constexpr uint32_t skips_only_a = + loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/1, /*skip_ldd=*/1, + /*skip_ex=*/1, /*skip_stc=*/1); constexpr uint32_t skips_mvout_spad = loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1, /*skip_ex=*/1, /*skip_stc=*/0); @@ -248,6 +251,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { #define GEMMINI_DMA_CISC #ifdef GEMMINI_DMA_CISC + // the target addresses of this should match with spad_addr_Q0 and + // spad_addr_K0 set in this kernel GEMMINI_CISC_CMD_I(10); gemmini_fence(); #else @@ -292,15 +297,30 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // "inner loop" along the columns of K^T const uint32_t k_tiles = (dim_seqlen / B_COL); - for (uint32_t tile_k = 0; tile_k < k_tiles; tile_k++) { + for (uint32_t tile_k = 0; tile_k < k_tiles + 2 /*pipeline latency*/; + tile_k++) { + if constexpr (DEBUG) { + // barrier for debugging + threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); + } + // select the correct double buffer by tile iteration - // FIXME do correct double buffering - float *smem_Q = (tile_k & 1) ? smem_Q1 : smem_Q0; - float *smem_K = (tile_k & 1) ? smem_K1 : smem_K0; - float *smem_V = (tile_k & 1) ? smem_V1 : smem_V0; - float *smem_S = (tile_k & 1) ? smem_S1 : smem_S0; - float *smem_P = (tile_k & 1) ? smem_P1 : smem_P0; - float *smem_O = (tile_k & 1) ? smem_O1 : smem_O0; + // all iterations work on the same Q row tile; no ping-pong necessary + asm volatile ("dbuf_sel_start_%=:" :: ); + // FIXME speedup by doing arithmetic + float *smem_Q = smem_Q0; + float *smem_K_consume = (tile_k & 1) ? smem_K1 : smem_K0; + float *smem_K_produce = (tile_k & 1) ? smem_K0 : smem_K1; + float *smem_V_consume = (tile_k & 1) ? smem_V1 : smem_V0; + float *smem_V_produce = (tile_k & 1) ? smem_V0 : smem_V1; + float *smem_S_consume = (tile_k & 1) ? smem_S1 : smem_S0; + float *smem_S_produce = (tile_k & 1) ? smem_S0 : smem_S1; + float *smem_P_consume = (tile_k & 1) ? smem_P1 : smem_P0; + float *smem_P_produce = (tile_k & 1) ? smem_P0 : smem_P1; + // O tile is sequentially updated at every iteration; no ping-pong + // necessary + float *smem_O = smem_O0; + // FIXME: O_row_scale/rowmax/rowsum/spad shouldn't really need ping-pong float *smem_O_row_scale = (tile_k & 1) ? smem_O_row_scale_1 : smem_O_row_scale_0; float *smem_rowmax = (tile_k & 1) ? smem_rowmax_1 : smem_rowmax_0; @@ -308,28 +328,111 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { float *smem_scratchpad = (tile_k & 1) ? smem_scratchpad_1 : smem_scratchpad_0; - const auto spad_addr_Q = (tile_k & 1) ? spad_addr_Q1 : spad_addr_Q0; - const auto spad_addr_K = (tile_k & 1) ? spad_addr_K1 : spad_addr_K0; - const auto spad_addr_V = (tile_k & 1) ? spad_addr_V1 : spad_addr_V0; - const auto spad_addr_S = (tile_k & 1) ? spad_addr_S1 : spad_addr_S0; - const auto spad_addr_P = (tile_k & 1) ? spad_addr_P1 : spad_addr_P0; + const auto spad_addr_Q = spad_addr_Q0; + const auto spad_addr_K_consume = (tile_k & 1) ? spad_addr_K1 : spad_addr_K0; + const auto spad_addr_K_produce = (tile_k & 1) ? spad_addr_K0 : spad_addr_K1; + const auto spad_addr_V_consume = (tile_k & 1) ? spad_addr_V1 : spad_addr_V0; + const auto spad_addr_V_produce = (tile_k & 1) ? spad_addr_V0 : spad_addr_V1; + const auto spad_addr_S_consume = (tile_k & 1) ? spad_addr_S1 : spad_addr_S0; + const auto spad_addr_S_produce = (tile_k & 1) ? spad_addr_S0 : spad_addr_S1; + const auto spad_addr_P_consume = (tile_k & 1) ? spad_addr_P1 : spad_addr_P0; + const auto spad_addr_P_produce = (tile_k & 1) ? spad_addr_P0 : spad_addr_P1; const auto spad_addr_O = spad_addr_O0; // NOTE: there's only single O tile + asm volatile ("dbuf_sel_end_%=:" :: ); - // GEMM I: S = Q*K - // - asm volatile("gemm_qk_start_%=:" ::); + // GEMM II: O = O + P*V + // -------------------- + // This is done *before* GEMM I in the software pipeline, working on the + // online softmax result tile from the previous iteration - if (tid_in_warpgroup == 0) { - if (tile_k == 0) { + if (tile_k >= 2) // delay by 2 iters for pipelining + { + const uint32_t tile_k_ = tile_k - 2; + + asm volatile("gemm_pv_start_%=:" ::); + + if (tid_in_warpgroup == 0) { +#if 0 + if (tile_k_ == 0) { gemmini_fence(); GEMMINI_CISC_CMD_I(0); - } else if (tile_k & 1) { + } else if (tile_k_ & 1) { gemmini_fence(); GEMMINI_CISC_CMD_I(2); } else { gemmini_fence(); GEMMINI_CISC_CMD_I(1); } +#else + // do matmul + // among other things, this also configures CONFIG_BOUNDS so that the + // DMA knows the full matrix dimensions + sp_tiled_matmul_full_spad_ws( + spad_addr_P_consume, spad_addr_V_consume, + /*spad_D=*/0, /*spad_C=*/spad_addr_O, + /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul); +#endif + + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + + // mvout to SMEM + // GEMMINI_CISC_CMD_I(9); + sp_tiled_matmul_full_spad_ws( + /*spad_A=*/spad_addr_P_consume, /*spad_B=*/spad_addr_V_consume, + /*spad_D=*/0, /*spad_C=*/spad_addr_O, + /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad); + gemmini_fence(); + + if constexpr (DEBUG) { + // for copy-out to GMEM + gemmini_fence(); + } + } + + // reconverge from mmio divergence + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); + + asm volatile("gemm_pv_finish_%=:" ::); + + if constexpr (DEBUG) { + if (warpgroup_id == 0) { + // O after PV + if (tile_k_ == 0) { + thread_block_copy_tile( + smem_O, gmem_tmp_d6, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); + } else if (tile_k_ == 1) { + thread_block_copy_tile( + smem_O, gmem_tmp_d7, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); + } + + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); + } + } + } + + // GEMM I: S = Q*K + // + asm volatile("gemm_qk_start_%=:" ::); + + if (tid_in_warpgroup == 0) { + // 0,2,.: opcode 0 (quartile 0/2, no accum) + // 1,3,.: opcode 3 (quartile 1/3, no accum) + const uint32_t opcode = 3 * (tile_k & 1); + gemmini_fence(); + GEMMINI_CISC_CMD_I(opcode); gemmini_fence(); gemmini_fence(); @@ -352,8 +455,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // mvout to SMEM // GEMMINI_CISC_CMD_I(9); sp_tiled_matmul_full_spad_ws( - /*spad_A=*/spad_addr_Q, /*spad_B=*/spad_addr_K, - /*spad_D=*/0, /*spad_C=*/spad_addr_S, + /*spad_A=*/spad_addr_Q, /*spad_B=*/spad_addr_K_consume, + /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce, /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, @@ -375,11 +478,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if (warpgroup_id == 0) { if (tile_k == 0) { thread_block_copy_tile( - smem_S, gmem_tmp_d0, tid_in_warpgroup, threads_per_warpgroup, + smem_S_produce, gmem_tmp_d0, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); } else if (tile_k == 1) { thread_block_copy_tile( - smem_S, gmem_tmp_d1, tid_in_warpgroup, threads_per_warpgroup, + smem_S_produce, gmem_tmp_d1, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); } @@ -388,39 +491,76 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } } - // inter-warpgroup barrier before online softmax - threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); + if (tile_k >= 1) // delay by 1 iters for pipelining + { + const uint32_t tile_k_ = tile_k - 1; - // Online softmax - // - thread_block_online_softmax( - smem_S, smem_P, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster, smem_scratchpad, smem_rowmax, smem_rowsum, - smem_O_row_scale); + // Online softmax + // + thread_block_online_softmax( + smem_S_consume, smem_P_produce, tid_in_warpgroup, + threads_per_warpgroup, warpgroup_id_in_cluster, smem_scratchpad, + smem_rowmax, smem_rowsum, smem_O_row_scale); - // FIXME: unnecessary? - threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); - if constexpr (DEBUG) { - if (warpgroup_id == 0) { - if (tile_k == 0) { - thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); - thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); - } else if (tile_k == 1) { - thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); - thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); + if constexpr (DEBUG) { + if (warpgroup_id == 0) { + if (tile_k_ == 0) { + thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + } else if (tile_k_ == 1) { + thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + } + + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); } + } - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); + // TODO: put a synchronization here with GEMM-II + + // Oi rescale + thread_block_O_rescale( + smem_O, smem_O /*in-place*/, smem_O_row_scale, tid_in_warpgroup, + threads_per_warpgroup, warpgroup_id_in_cluster); + + // rescale-to-PV-GEMM barrier + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); + + if constexpr (DEBUG) { + if (warpgroup_id == 0) { + // O before PV + if (tile_k_ == 0) { + thread_block_copy_tile( + smem_P_produce, gmem_tmp_d2, tid_in_warpgroup, + threads_per_warpgroup, warpgroup_id_in_cluster); + thread_block_copy_tile( + smem_O, gmem_tmp_d4, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); + } else if (tile_k_ == 1) { + thread_block_copy_tile( + smem_P_produce, gmem_tmp_d3, tid_in_warpgroup, + threads_per_warpgroup, warpgroup_id_in_cluster); + thread_block_copy_tile( + smem_O, gmem_tmp_d5, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); + } + + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); + } } } @@ -428,171 +568,64 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // // Q stays in SMEM for the entire loop asm volatile("move_k_v_start_%=:" ::); - if constexpr (GEMMINI_DMA) { - // NOTE: Beware of race conditions; with warp specialization, we need to - // make sure below command code to DMA is not executed simultaneously - // from the two warpgroups (which will result in hardware fault). - // Currently the ping-pong scheduling scheme prevents that. - if (tid_in_warpgroup == 0) { - // configure GMEM addresses for K and V tiles - // load K for the next iteration - const float *gmem_K_tile = gmem_K + (B_COL * (tile_k + 1)); - // load V for the current iteration - const float *gmem_V_tile = gmem_V + (HEADDIM * B_COL * tile_k); - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile), - (uint64_t)(gmem_V_tile), - k_LOOP_WS_CONFIG_ADDRS_AB) - // configure address strides for the DMA - // FIXME: unnecessary? - GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) | - 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); - gemmini_fence(); - // do DMA + // NOTE: Beware of race conditions; with warp specialization, we need to + // make sure below command code to DMA is not executed simultaneously + // from the two warpgroups (which will result in hardware fault). + // Currently the ping-pong scheduling scheme prevents that. + if (tid_in_warpgroup == 0) { + // configure GMEM addresses for K and V tiles + // load K for the next iteration + const float *gmem_K_tile = gmem_K + (B_COL * (tile_k + 1 /*runahead*/)); + // load V for the *previous* iteration; this will be consumed 2 + // iterations later + const float *gmem_V_tile = + gmem_V + (HEADDIM * B_COL * (tile_k - 1 /*dragbehind*/)); + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile), + (uint64_t)(gmem_V_tile), + k_LOOP_WS_CONFIG_ADDRS_AB) + // configure address strides for the DMA + // FIXME: unnecessary? + GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) | + 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); + gemmini_fence(); + + // do DMA + if (tile_k == 0) { + // we load (k-1)th tile for V; skip V for the 1st iteration, sp_tiled_matmul_full_spad_ws( - spad_addr_K, spad_addr_V, - /*spad_D=*/0, /*spad_C=*/spad_addr_S, + spad_addr_K_produce, spad_addr_V_produce, + /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce /*FIXME:bogus*/, + /*I=*/(HEADDIM / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a); + } else { + sp_tiled_matmul_full_spad_ws( + spad_addr_K_produce, spad_addr_V_produce, + /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce /*FIXME:bogus*/, /*I=*/(HEADDIM / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); - gemmini_fence(); } - } else { - // load K for the next iteration - load_tile_to_smem( - dim_seqlen, tile_k + 1, 0 /* dim_k == headdim */, gmem_K, smem_K, - tid_in_warpgroup); - - // load V for the current iteration - // V dimension is [seqlen, headdim], stored N(headdim)-major - load_tile_to_smem( - HEADDIM, 0 /* full N-dimension */, tile_k, gmem_V, smem_V, - tid_in_warpgroup); + gemmini_fence(); } + + threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + asm volatile("move_k_v_finish_%=:" ::); - - // inter-warpgroup barrier before GEMM II - threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); - - // Oi rescale - thread_block_O_rescale( - smem_O, smem_O /*in-place*/, smem_O_row_scale, tid_in_warpgroup, - threads_per_warpgroup, warpgroup_id_in_cluster); - - // rescale-to-PV-GEMM barrier - threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); - - if constexpr (DEBUG) { - if (warpgroup_id == 0) { - // O before PV - if (tile_k == 0) { - thread_block_copy_tile( - smem_P, gmem_tmp_d2, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); - thread_block_copy_tile( - smem_O, gmem_tmp_d4, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); - } else if (tile_k == 1) { - thread_block_copy_tile( - smem_P, gmem_tmp_d3, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); - thread_block_copy_tile( - smem_O, gmem_tmp_d5, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); - } - - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); - } - } - - // GEMM II: O = O + P*V - - asm volatile("gemm_pv_start_%=:" ::); - - if (tid_in_warpgroup == 0) { -#if 0 - if (tile_k == 0) { - gemmini_fence(); - GEMMINI_CISC_CMD_I(0); - } else if (tile_k & 1) { - gemmini_fence(); - GEMMINI_CISC_CMD_I(2); - } else { - gemmini_fence(); - GEMMINI_CISC_CMD_I(1); - } -#else - // do matmul - // among other things, this also configures CONFIG_BOUNDS so that the - // DMA knows the full matrix dimensions - sp_tiled_matmul_full_spad_ws( - spad_addr_P, spad_addr_V, - /*spad_D=*/0, /*spad_C=*/spad_addr_O, - /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), - /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, - /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, - /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul); -#endif - - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - - // mvout to SMEM - // GEMMINI_CISC_CMD_I(9); - sp_tiled_matmul_full_spad_ws( - /*spad_A=*/spad_addr_P, /*spad_B=*/spad_addr_V, - /*spad_D=*/0, /*spad_C=*/spad_addr_O, - /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL/ DIM), - /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, - /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, - /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad); - gemmini_fence(); - - if constexpr (DEBUG) { - // for copy-out to GMEM - gemmini_fence(); - } - } - - // reconverge from mmio divergence - threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); - - asm volatile("gemm_pv_finish_%=:" ::); - - if constexpr (DEBUG) { - if (warpgroup_id == 0) { - // O after PV - if (tile_k == 0) { - thread_block_copy_tile( - smem_O, gmem_tmp_d6, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); - } else if (tile_k == 1) { - thread_block_copy_tile( - smem_O, gmem_tmp_d7, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); - } - - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); - } - } #if 0 #endif } asm volatile ("tile_loop_finish_%=:" :: ); - // wait for warpgroup 1 to finish, which called the global barrier before - // entering the loop - if (warpgroup_id == 0) { - threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); - } + // // wait for warpgroup 1 to finish, which called the global barrier before + // // entering the loop + // if (warpgroup_id == 0) { + // threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); + // } } int main() { diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index d24c61d6..05692308 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -73,7 +73,7 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER == #define TRANSPOSE_AT_CONSUME 0 #define GEMMINI_DMA 1 -#define GEMMINI_DMA_FAST 1 +#define GEMMINI_DMA_FAST 0 #define GEMMINI_DMA_FLEXIBLE_LAYOUT 1 #if SMEM_SIZE == 0x4000 #define SMEM_ADDR_Q0 ((float * const) 0xff000000) From 6547e927577c4df7b0bfc44aac2643ed6c5b3f8a Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sun, 8 Sep 2024 19:47:55 -0700 Subject: [PATCH 21/50] flash: Load Q to both quartiles; preload O for acc --- .../flash_attention/kernel.gemmini.cpp | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index 9e36bf83..e934a1e1 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -202,8 +202,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { constexpr uint32_t skips_mvout_spad = loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1, /*skip_ex=*/1, /*skip_stc=*/0); - constexpr uint32_t skips_matmul = - loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1, + constexpr uint32_t skips_matmul_preload = + loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/0, /*skip_ex=*/0, /*skip_stc=*/1); if constexpr (GEMMINI_DMA) { @@ -255,6 +255,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // spad_addr_K0 set in this kernel GEMMINI_CISC_CMD_I(10); gemmini_fence(); + + // need to also move to Q1 for the next iteration + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q_tile), + (uint64_t)(gmem_K_tile), k_LOOP_WS_CONFIG_ADDRS_AB) + GEMMINI_CISC_CMD_R((dim_seqlen << 20) | (HEADDIM << 8) | + 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); + gemmini_fence(); + GEMMINI_CISC_CMD_I(11); + gemmini_fence(); #else // do DMA // @@ -369,11 +378,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // DMA knows the full matrix dimensions sp_tiled_matmul_full_spad_ws( spad_addr_P_consume, spad_addr_V_consume, - /*spad_D=*/0, /*spad_C=*/spad_addr_O, + /*spad_D=*/spad_addr_O, /*spad_C=*/spad_addr_O, /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, - /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul); + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul_preload); #endif gemmini_fence(); @@ -455,7 +464,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // mvout to SMEM // GEMMINI_CISC_CMD_I(9); sp_tiled_matmul_full_spad_ws( - /*spad_A=*/spad_addr_Q, /*spad_B=*/spad_addr_K_consume, + /*spad_A=*/spad_addr_Q /*bogus*/, + /*spad_B=*/spad_addr_K_consume /*bogus*/, /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce, /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, From a4dd45bc1b424e2681993bfefebcaeeb99e693b5 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sun, 8 Sep 2024 20:52:28 -0700 Subject: [PATCH 22/50] flash: Replace CISC with RISC spadQuartile in hw does not match spad addresses in kernel; match them later for optimization. --- .../flash_attention/kernel.gemmini.cpp | 72 +++++++++++++------ 1 file changed, 50 insertions(+), 22 deletions(-) diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index e934a1e1..ef16a6ee 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -199,9 +199,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { constexpr uint32_t skips_only_a = loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/1, /*skip_ldd=*/1, /*skip_ex=*/1, /*skip_stc=*/1); + constexpr uint32_t skips_only_b = + loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/0, /*skip_ldd=*/1, + /*skip_ex=*/1, /*skip_stc=*/1); constexpr uint32_t skips_mvout_spad = loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1, /*skip_ex=*/1, /*skip_stc=*/0); + constexpr uint32_t skips_matmul = + loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1, + /*skip_ex=*/0, /*skip_stc=*/1); constexpr uint32_t skips_matmul_preload = loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/0, /*skip_ex=*/0, /*skip_stc=*/1); @@ -231,12 +237,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // other warps behave differently on the branch condition. // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); - // move Q and K into SMEM before the loop starts - // static_assert(B_ROW == B_COL, "currently only supports square tiles"); + // move Q and K into SMEM before the loop starts + // asm volatile("dma_move_start_%=:" ::); - if (tid_in_warpgroup == 0) { // make sure to read from the correct row of Q const float *gmem_Q_tile = gmem_Q + HEADDIM * B_ROW * warpgroup_id; @@ -249,35 +254,48 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); gemmini_fence(); -#define GEMMINI_DMA_CISC +// #define GEMMINI_DMA_CISC #ifdef GEMMINI_DMA_CISC // the target addresses of this should match with spad_addr_Q0 and // spad_addr_K0 set in this kernel GEMMINI_CISC_CMD_I(10); - gemmini_fence(); - - // need to also move to Q1 for the next iteration - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q_tile), - (uint64_t)(gmem_K_tile), k_LOOP_WS_CONFIG_ADDRS_AB) - GEMMINI_CISC_CMD_R((dim_seqlen << 20) | (HEADDIM << 8) | - 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); - gemmini_fence(); - GEMMINI_CISC_CMD_I(11); - gemmini_fence(); #else // do DMA // // among other things, this also configures CONFIG_BOUNDS so that the // DMA knows the full matrix dimensions sp_tiled_matmul_full_spad_ws( - spad_addr_Q, spad_addr_K, - /*spad_D=*/0, /*spad_C=*/spad_addr_S, + spad_addr_Q0, spad_addr_K0, + /*spad_D=*/0, /*spad_C=*/spad_addr_S0/*bogus*/, /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); - gemmini_fence(); #endif + gemmini_fence(); + + // need to also move Q to spad_addr_Q1 for the next iteration + // FIXME: re-configure necessary? + gmem_K_tile = gmem_K + (B_COL * 1); + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q_tile), + (uint64_t)(gmem_K_tile), k_LOOP_WS_CONFIG_ADDRS_AB) + GEMMINI_CISC_CMD_R((dim_seqlen << 20) | (HEADDIM << 8) | + 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); + gemmini_fence(); +#ifdef GEMMINI_DMA_CISC + // GEMMINI_CISC_CMD_I(11); +#else + sp_tiled_matmul_full_spad_ws( + spad_addr_Q1, spad_addr_K1/*bogus*/, + /*spad_D=*/0, /*spad_C=*/spad_addr_S0/*bogus*/, + /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a); +#endif + + gemmini_fence(); + gemmini_fence(); // re-configure DMA for K and V load that will later happen in the loop // GMEM addr stride for K @@ -376,6 +394,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // do matmul // among other things, this also configures CONFIG_BOUNDS so that the // DMA knows the full matrix dimensions + gemmini_fence(); sp_tiled_matmul_full_spad_ws( spad_addr_P_consume, spad_addr_V_consume, /*spad_D=*/spad_addr_O, /*spad_C=*/spad_addr_O, @@ -437,11 +456,18 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { asm volatile("gemm_qk_start_%=:" ::); if (tid_in_warpgroup == 0) { + gemmini_fence(); // 0,2,.: opcode 0 (quartile 0/2, no accum) // 1,3,.: opcode 3 (quartile 1/3, no accum) const uint32_t opcode = 3 * (tile_k & 1); - gemmini_fence(); - GEMMINI_CISC_CMD_I(opcode); + //GEMMINI_CISC_CMD_I(opcode); + sp_tiled_matmul_full_spad_ws( + spad_addr_Q, spad_addr_K_consume, + /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce, + /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul); gemmini_fence(); gemmini_fence(); @@ -574,7 +600,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } } - // data movement for K and V + // data move for K and V // // Q stays in SMEM for the entire loop asm volatile("move_k_v_start_%=:" ::); @@ -606,7 +632,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { sp_tiled_matmul_full_spad_ws( spad_addr_K_produce, spad_addr_V_produce, /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce /*FIXME:bogus*/, - /*I=*/(HEADDIM / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a); @@ -614,12 +640,14 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { sp_tiled_matmul_full_spad_ws( spad_addr_K_produce, spad_addr_V_produce, /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce /*FIXME:bogus*/, - /*I=*/(HEADDIM / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); } gemmini_fence(); + gemmini_fence(); + gemmini_fence(); } threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); From 6911843a8259e1980072b1778042b4c08e19c78b Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sun, 8 Sep 2024 21:11:59 -0700 Subject: [PATCH 23/50] flash: Remove unnecessary dmem preload, fix rowmax/rowsum dependency --- .../flash_attention/kernel.gemmini.cpp | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index ef16a6ee..063bc468 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -216,10 +216,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if (tid_in_warpgroup == 0) { gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0); - // configure DMA for the full Q matrix + // configure DMA with GMEM address strides + // Q matrix gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY, false, 0); - // configure DMA for the full K matrix + // K matrix gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t), MVIN_SCALE_IDENTITY, false, 1); // configure DMA for Q*K store @@ -344,16 +345,13 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { float *smem_S_produce = (tile_k & 1) ? smem_S0 : smem_S1; float *smem_P_consume = (tile_k & 1) ? smem_P1 : smem_P0; float *smem_P_produce = (tile_k & 1) ? smem_P0 : smem_P1; - // O tile is sequentially updated at every iteration; no ping-pong - // necessary + // O, rowmax/rowsum etc. is sequentially updated at every iteration; no + // ping-pong necessary float *smem_O = smem_O0; - // FIXME: O_row_scale/rowmax/rowsum/spad shouldn't really need ping-pong - float *smem_O_row_scale = - (tile_k & 1) ? smem_O_row_scale_1 : smem_O_row_scale_0; - float *smem_rowmax = (tile_k & 1) ? smem_rowmax_1 : smem_rowmax_0; - float *smem_rowsum = (tile_k & 1) ? smem_rowsum_1 : smem_rowsum_0; - float *smem_scratchpad = - (tile_k & 1) ? smem_scratchpad_1 : smem_scratchpad_0; + float *smem_O_row_scale = smem_O_row_scale_0; + float *smem_rowmax = smem_rowmax_0; + float *smem_rowsum = smem_rowsum_0; + float *smem_scratchpad = smem_scratchpad_0; const auto spad_addr_Q = spad_addr_Q0; const auto spad_addr_K_consume = (tile_k & 1) ? spad_addr_K1 : spad_addr_K0; @@ -394,6 +392,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // do matmul // among other things, this also configures CONFIG_BOUNDS so that the // DMA knows the full matrix dimensions + // FIXME: perf: prevent GMEM->SMEM load for O tile gemmini_fence(); sp_tiled_matmul_full_spad_ws( spad_addr_P_consume, spad_addr_V_consume, @@ -401,7 +400,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, - /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul_preload); + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul); #endif gemmini_fence(); From 714b9f501e6ff8f1f17acf1d7ed85445370ef071 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sun, 8 Sep 2024 22:06:49 -0700 Subject: [PATCH 24/50] flash: Restructure to do delayed fence for better concurrency Verified up to O_before_PV of 2nd iteration; O_after_PV needs preload fix. FIXME: Stalls at barrier without DEBUG set. --- .../flash_attention/kernel.gemmini.cpp | 243 +++++++++--------- 1 file changed, 127 insertions(+), 116 deletions(-) diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index 063bc468..ae0aaf07 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -389,7 +389,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { GEMMINI_CISC_CMD_I(1); } #else - // do matmul + // kickoff matmul // among other things, this also configures CONFIG_BOUNDS so that the // DMA knows the full matrix dimensions // FIXME: perf: prevent GMEM->SMEM load for O tile @@ -402,27 +402,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul); #endif - - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - - // mvout to SMEM - // GEMMINI_CISC_CMD_I(9); - sp_tiled_matmul_full_spad_ws( - /*spad_A=*/spad_addr_P_consume, /*spad_B=*/spad_addr_V_consume, - /*spad_D=*/0, /*spad_C=*/spad_addr_O, - /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), - /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, - /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, - /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad); - gemmini_fence(); - - if constexpr (DEBUG) { - // for copy-out to GMEM - gemmini_fence(); - } } // reconverge from mmio divergence @@ -431,99 +410,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { asm volatile("gemm_pv_finish_%=:" ::); - if constexpr (DEBUG) { - if (warpgroup_id == 0) { - // O after PV - if (tile_k_ == 0) { - thread_block_copy_tile( - smem_O, gmem_tmp_d6, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); - } else if (tile_k_ == 1) { - thread_block_copy_tile( - smem_O, gmem_tmp_d7, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); - } - - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); - } - } - } - - // GEMM I: S = Q*K - // - asm volatile("gemm_qk_start_%=:" ::); - - if (tid_in_warpgroup == 0) { - gemmini_fence(); - // 0,2,.: opcode 0 (quartile 0/2, no accum) - // 1,3,.: opcode 3 (quartile 1/3, no accum) - const uint32_t opcode = 3 * (tile_k & 1); - //GEMMINI_CISC_CMD_I(opcode); - sp_tiled_matmul_full_spad_ws( - spad_addr_Q, spad_addr_K_consume, - /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce, - /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), - /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, - /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, - /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul); - - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - -#if 0 // TODO: speed up mvout to SMEM -// loop_ws variant that skips configuring strides -#define gemmini_loop_ws(I, J, K, pad_I, pad_J, pad_K, A, B, D, C, A_stride, B_stride, D_stride, C_stride, A_transpose, B_transpose, full_C, low_D, ex_accumulate, act, a_spad_id, b_spad_id, is_resadd) \ - { \ - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(pad_K) << 32) | ((uint64_t)(pad_J) << 16) | (uint64_t)(pad_I), ((uint64_t)(K) << 32) | ((uint64_t)(J) << 16) | (uint64_t)(I), k_LOOP_WS_CONFIG_BOUNDS) \ - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, A, B, k_LOOP_WS_CONFIG_ADDRS_AB) \ - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, D, C, k_LOOP_WS_CONFIG_ADDRS_DC) \ - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, A_stride, B_stride, k_LOOP_WS_CONFIG_STRIDES_AB) \ - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, D_stride, C_stride, k_LOOP_WS_CONFIG_STRIDES_DC) \ - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(a_spad_id) << 18) | ((uint64_t)(b_spad_id) << 16) | ((uint64_t)(act) << 8) | ((low_D) << 2) | ((full_C) << 1) | (ex_accumulate), ((is_resadd) << 2) | ((B_transpose) << 1) | (A_transpose), k_LOOP_WS) \ - } -#endif - - // mvout to SMEM - // GEMMINI_CISC_CMD_I(9); - sp_tiled_matmul_full_spad_ws( - /*spad_A=*/spad_addr_Q /*bogus*/, - /*spad_B=*/spad_addr_K_consume /*bogus*/, - /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce, - /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), - /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, - /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, - /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad); - gemmini_fence(); - - if constexpr (DEBUG) { - // for copy-out to GMEM - gemmini_fence(); - } - } - - // reconverge from mmio divergence - threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); - - asm volatile("gemm_qk_finish_%=:" ::); - - if constexpr (DEBUG) { - if (warpgroup_id == 0) { - if (tile_k == 0) { - thread_block_copy_tile( - smem_S_produce, gmem_tmp_d0, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); - } else if (tile_k == 1) { - thread_block_copy_tile( - smem_S_produce, gmem_tmp_d1, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); - } - - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); - } } if (tile_k >= 1) // delay by 1 iters for pipelining @@ -563,7 +449,89 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } } - // TODO: put a synchronization here with GEMM-II + if (tid_in_warpgroup == 0) { + // fence GEMM-II to make sure dependency on O tile is settled + gemmini_fence(); + gemmini_fence(); + + // mvout to SMEM + // GEMMINI_CISC_CMD_I(9); + sp_tiled_matmul_full_spad_ws( + /*spad_A=*/spad_addr_P_consume, /*spad_B=*/spad_addr_V_consume, + /*spad_D=*/0, /*spad_C=*/spad_addr_O, + /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad); + } + + // reconverge from mmio divergence + threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + + if constexpr (DEBUG) { + gemmini_fence(); + + if (warpgroup_id == 0) { + // O after PV + if (tile_k_ == 0) { + thread_block_copy_tile( + smem_O, gmem_tmp_d6, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); + } else if (tile_k_ == 1) { + thread_block_copy_tile( + smem_O, gmem_tmp_d7, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); + } + + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); + } + } + } + + { + // GEMM I: S = Q*K + // + // kick off asynchronously; fence later + asm volatile("gemm_qk_start_%=:" ::); + + if (tid_in_warpgroup == 0) { + gemmini_fence(); + // 0,2,.: opcode 0 (quartile 0/2, no accum) + // 1,3,.: opcode 3 (quartile 1/3, no accum) + const uint32_t opcode = 3 * (tile_k & 1); + //GEMMINI_CISC_CMD_I(opcode); + sp_tiled_matmul_full_spad_ws( + spad_addr_Q, spad_addr_K_consume, + /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce, + /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul); + +#if 0 // TODO: speed up mvout to SMEM + // loop_ws variant that skips configuring strides +#define gemmini_loop_ws(I, J, K, pad_I, pad_J, pad_K, A, B, D, C, A_stride, B_stride, D_stride, C_stride, A_transpose, B_transpose, full_C, low_D, ex_accumulate, act, a_spad_id, b_spad_id, is_resadd) \ + { \ + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(pad_K) << 32) | ((uint64_t)(pad_J) << 16) | (uint64_t)(pad_I), ((uint64_t)(K) << 32) | ((uint64_t)(J) << 16) | (uint64_t)(I), k_LOOP_WS_CONFIG_BOUNDS) \ + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, A, B, k_LOOP_WS_CONFIG_ADDRS_AB) \ + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, D, C, k_LOOP_WS_CONFIG_ADDRS_DC) \ + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, A_stride, B_stride, k_LOOP_WS_CONFIG_STRIDES_AB) \ + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, D_stride, C_stride, k_LOOP_WS_CONFIG_STRIDES_DC) \ + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(a_spad_id) << 18) | ((uint64_t)(b_spad_id) << 16) | ((uint64_t)(act) << 8) | ((low_D) << 2) | ((full_C) << 1) | (ex_accumulate), ((is_resadd) << 2) | ((B_transpose) << 1) | (A_transpose), k_LOOP_WS) \ + } +#endif + } + + // reconverge from mmio divergence + threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + + asm volatile("gemm_qk_finish_%=:" ::); + } + + if (tile_k >= 1) // delay by 1 iters for pipelining + { + const uint32_t tile_k_ = tile_k - 1; // Oi rescale thread_block_O_rescale( @@ -599,6 +567,46 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } } + // fence GEMM I after Oi rescale + if (tid_in_warpgroup == 0) { + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + + // mvout to SMEM + // GEMMINI_CISC_CMD_I(9); + sp_tiled_matmul_full_spad_ws( + /*spad_A=*/spad_addr_Q /*bogus*/, + /*spad_B=*/spad_addr_K_consume /*bogus*/, + /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce, + /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad); + + } + + // reconverge from mmio divergence + threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + + if constexpr (DEBUG) { + if (warpgroup_id == 0) { + if (tile_k == 0) { + thread_block_copy_tile( + smem_S_produce, gmem_tmp_d0, tid_in_warpgroup, + threads_per_warpgroup, warpgroup_id_in_cluster); + } else if (tile_k == 1) { + thread_block_copy_tile( + smem_S_produce, gmem_tmp_d1, tid_in_warpgroup, + threads_per_warpgroup, warpgroup_id_in_cluster); + } + + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); + } + } + // data move for K and V // // Q stays in SMEM for the entire loop @@ -616,6 +624,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // iterations later const float *gmem_V_tile = gmem_V + (HEADDIM * B_COL * (tile_k - 1 /*dragbehind*/)); + + // fence mvout S to SMEM + gemmini_fence(); ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile), (uint64_t)(gmem_V_tile), k_LOOP_WS_CONFIG_ADDRS_AB) From 1f51f7f9d44dd03048eccbdc6ecb265b549ee2c6 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sun, 8 Sep 2024 22:49:38 -0700 Subject: [PATCH 25/50] sgemm_impl: Mark threadblock_barrier convergent Thank you Chris Lattner --- tests/regression/sgemm_tcore/sgemm_impl.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index 05692308..bf4ca80d 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -536,7 +536,8 @@ wmma_store(const int thread_in_warp, const int warp_col, const int warp_row, asm volatile ("wmma_store_finish_%=:" :: ); } -inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count) { +__attribute__((convergent)) inline void +threadblock_barrier(const uint32_t barrier_id, const uint32_t count) { vx_fence(); vx_barrier(barrier_id, count); } From ecc800964abd209c4f1dcd274f5b1602da693f61 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Mon, 9 Sep 2024 13:47:18 -0700 Subject: [PATCH 26/50] flash: Change smem alloc for less bank conflicts; noskip stc --- .../flash_attention/kernel.gemmini.cpp | 138 +++++++++--------- 1 file changed, 65 insertions(+), 73 deletions(-) diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index ae0aaf07..82a83aa4 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -8,7 +8,7 @@ #include "gemmini_mmio.h" #include "flash_impl.hpp" -constexpr bool DEBUG = true; +constexpr bool DEBUG = false; static_assert(GEMMINI_DMA && !WARP_SPECIALIZED, "GEMMINI_DMA should be set and WARP_SPECIALIZED unset"); @@ -90,69 +90,48 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { static_assert( threads_per_threadblock == NUM_WARPS * NUM_THREADS * CORES_PER_CLUSTER, "flashattention kernel assumes 1 threadblock occupancy per cluster"); - uint8_t *smem_per_threadblock = reinterpret_cast( - DEV_SMEM_START_ADDR); - float *smem_cursor = reinterpret_cast(smem_per_threadblock); - // float *smem_cursor = reinterpret_cast(DEV_FAKE_SMEM_START_ADDR); - float *smem_Q0 = smem_cursor; - smem_cursor += smem_Q_size; - float *smem_Q1 = smem_cursor; - smem_cursor += smem_Q_size; - float *smem_K0 = smem_cursor; - smem_cursor += smem_K_size; - float *smem_K1 = smem_cursor; - smem_cursor += smem_K_size; - float *smem_V0 = smem_cursor; - smem_cursor += smem_V_size; - float *smem_V1 = smem_cursor; - smem_cursor += smem_V_size; - float *smem_S0 = smem_cursor; - smem_cursor += smem_QK_size; - float *smem_S1 = smem_cursor; - smem_cursor += smem_QK_size; - float *smem_P0 = smem_cursor; - smem_cursor += smem_QK_size; - float *smem_P1 = smem_cursor; - smem_cursor += smem_QK_size; - float *smem_O0 = smem_cursor; - smem_cursor += smem_O_size; - float *smem_O1 = smem_cursor; - smem_cursor += smem_O_size; + uint8_t *smem_per_threadblock = reinterpret_cast(DEV_SMEM_START_ADDR); + constexpr uint32_t smem_start = DEV_SMEM_START_ADDR; + constexpr uint32_t smem_quart0 = 0 * (SMEM_SIZE / 4); + constexpr uint32_t smem_quart1 = 1 * (SMEM_SIZE / 4); + constexpr uint32_t smem_quart2 = 2 * (SMEM_SIZE / 4); + constexpr uint32_t smem_quart3 = 3 * (SMEM_SIZE / 4); - // NOTE: this has to match with smem_* - static_assert(sizeof(elem_t) == sizeof(float)); - constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t); - constexpr uint32_t spad_addr_Q0 = 0; - constexpr uint32_t spad_addr_Q1 = - spad_addr_Q0 + (smem_Q_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_K0 = - spad_addr_Q1 + (smem_Q_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_K1 = - spad_addr_K0 + (smem_K_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_V0 = - spad_addr_K1 + (smem_K_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_V1 = - spad_addr_V0 + (smem_V_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_S0 = - spad_addr_V1 + (smem_V_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_S1 = - spad_addr_S0 + (smem_QK_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_P0 = - spad_addr_S1 + (smem_QK_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_P1 = - spad_addr_P0 + (smem_QK_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_O0 = - spad_addr_P1 + (smem_QK_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_O1 = - spad_addr_O0 + (smem_O_size * sizeof(float) / spad_addr_factor); + // Q/V/S in quart0/1, K/P/O in quart2/3 + constexpr uint32_t smem_Q0_offset = smem_quart0; + constexpr uint32_t smem_Q1_offset = smem_quart1; + constexpr uint32_t smem_K0_offset = smem_quart2; + constexpr uint32_t smem_K1_offset = smem_quart3; + constexpr uint32_t smem_V0_offset = smem_Q0_offset + smem_Q_size * sizeof(float); + constexpr uint32_t smem_V1_offset = smem_Q1_offset + smem_Q_size * sizeof(float); + constexpr uint32_t smem_S0_offset = smem_V0_offset + smem_V_size * sizeof(float); + constexpr uint32_t smem_S1_offset = smem_V1_offset + smem_V_size * sizeof(float); + constexpr uint32_t smem_P0_offset = smem_K0_offset + smem_K_size * sizeof(float); + constexpr uint32_t smem_P1_offset = smem_K1_offset + smem_K_size * sizeof(float); + constexpr uint32_t smem_O0_offset = smem_P1_offset + smem_QK_size * sizeof(float); + constexpr uint32_t smem_O1_offset = smem_P0_offset + smem_QK_size * sizeof(float); // unused + + float *smem_Q0 = reinterpret_cast(smem_start + smem_Q0_offset); + float *smem_Q1 = reinterpret_cast(smem_start + smem_Q1_offset); + float *smem_K0 = reinterpret_cast(smem_start + smem_K0_offset); + float *smem_K1 = reinterpret_cast(smem_start + smem_K1_offset); + float *smem_V0 = reinterpret_cast(smem_start + smem_V0_offset); + float *smem_V1 = reinterpret_cast(smem_start + smem_V1_offset); + float *smem_S0 = reinterpret_cast(smem_start + smem_S0_offset); + float *smem_S1 = reinterpret_cast(smem_start + smem_S1_offset); + float *smem_P0 = reinterpret_cast(smem_start + smem_P0_offset); + float *smem_P1 = reinterpret_cast(smem_start + smem_P1_offset); + float *smem_O0 = reinterpret_cast(smem_start + smem_O0_offset); + float *smem_O1 = reinterpret_cast(smem_start + smem_O1_offset); // allocate rowmax/rowsum storage at the end of the sharedmem address space constexpr uint32_t smem_rowmax_size = B_ROW * ROWMAX_SETS; constexpr uint32_t smem_rowsum_size = B_ROW; constexpr uint32_t smem_O_row_scale_size = B_ROW; - // FIXME: dangerous - smem_cursor = reinterpret_cast(0xff038000); + float *smem_cursor = smem_O1 + smem_O_size; + // // FIXME: dangerous + // smem_cursor = reinterpret_cast(0xff038000); float *smem_rowmax_0 = smem_cursor; smem_cursor += smem_rowmax_size; float *smem_rowmax_1 = smem_cursor; @@ -176,6 +155,21 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { float *smem_scratchpad_1 = smem_cursor; smem_cursor += smem_scratchpad_size; + static_assert(sizeof(elem_t) == sizeof(float)); + constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t); + constexpr uint32_t spad_addr_Q0 = smem_Q0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_Q1 = smem_Q1_offset / spad_addr_factor; + constexpr uint32_t spad_addr_K0 = smem_K0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_K1 = smem_K1_offset / spad_addr_factor; + constexpr uint32_t spad_addr_V0 = smem_V0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_V1 = smem_V1_offset / spad_addr_factor; + constexpr uint32_t spad_addr_S0 = smem_S0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_S1 = smem_S1_offset / spad_addr_factor; + constexpr uint32_t spad_addr_P0 = smem_P0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_P1 = smem_P1_offset / spad_addr_factor; + constexpr uint32_t spad_addr_O0 = smem_O0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_O1 = smem_O1_offset / spad_addr_factor; + // initialize rowmax/rowsum values in sharedmem thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O0, smem_rowmax_0, smem_rowsum_0, smem_O_row_scale_0); @@ -184,11 +178,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary - // // delay warpgroup 0 by 1 iteration to do ping-pong scheduling - // if (WARP_SPECIALIZED && warpgroup_id == 1) { - // threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); - // } - static_assert(!GEMMINI_DMA || Q_IS_K_MAJOR, "DMA code assumes Q matrix is stored K-major"); @@ -207,7 +196,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*skip_ex=*/1, /*skip_stc=*/0); constexpr uint32_t skips_matmul = loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1, - /*skip_ex=*/0, /*skip_stc=*/1); + /*skip_ex=*/0, /*skip_stc=*/0); constexpr uint32_t skips_matmul_preload = loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/0, /*skip_ex=*/0, /*skip_stc=*/1); @@ -327,9 +316,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const uint32_t k_tiles = (dim_seqlen / B_COL); for (uint32_t tile_k = 0; tile_k < k_tiles + 2 /*pipeline latency*/; tile_k++) { - if constexpr (DEBUG) { + if constexpr (DEBUG || true) { // barrier for debugging - threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); + // threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); } // select the correct double buffer by tile iteration @@ -394,6 +383,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // DMA knows the full matrix dimensions // FIXME: perf: prevent GMEM->SMEM load for O tile gemmini_fence(); + gemmini_fence(); sp_tiled_matmul_full_spad_ws( spad_addr_P_consume, spad_addr_V_consume, /*spad_D=*/spad_addr_O, /*spad_C=*/spad_addr_O, @@ -449,11 +439,14 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } } + // fence GEMM II to make sure dependency on O tile is settled if (tid_in_warpgroup == 0) { - // fence GEMM-II to make sure dependency on O tile is settled + gemmini_fence(); + gemmini_fence(); gemmini_fence(); gemmini_fence(); +#if 1 // mvout to SMEM // GEMMINI_CISC_CMD_I(9); sp_tiled_matmul_full_spad_ws( @@ -463,6 +456,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad); +#endif } // reconverge from mmio divergence @@ -497,6 +491,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if (tid_in_warpgroup == 0) { gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + // 0,2,.: opcode 0 (quartile 0/2, no accum) // 1,3,.: opcode 3 (quartile 1/3, no accum) const uint32_t opcode = 3 * (tile_k & 1); @@ -574,6 +571,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { gemmini_fence(); gemmini_fence(); +#if 1 // mvout to SMEM // GEMMINI_CISC_CMD_I(9); sp_tiled_matmul_full_spad_ws( @@ -584,7 +582,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad); - +#endif } // reconverge from mmio divergence @@ -668,12 +666,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } asm volatile ("tile_loop_finish_%=:" :: ); - - // // wait for warpgroup 1 to finish, which called the global barrier before - // // entering the loop - // if (warpgroup_id == 0) { - // threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); - // } } int main() { From 829af5d429ed60f1fb803427fe928a35c816df97 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Mon, 9 Sep 2024 15:21:49 -0700 Subject: [PATCH 27/50] flash: Comment out mvout to smem Verified up to O_before_PV; still stalls without DEBUG --- tests/regression/flash_attention/kernel.gemmini.cpp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index 82a83aa4..a8188dc4 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -82,6 +82,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { float *gmem_tmp_e3 = reinterpret_cast(0xe3000000UL); // static shared memory allocation + // these are in float elements, not bytes constexpr uint32_t smem_Q_size = B_ROW * HEADDIM; constexpr uint32_t smem_K_size = B_COL * HEADDIM; constexpr uint32_t smem_QK_size = B_ROW * B_COL; @@ -384,6 +385,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // FIXME: perf: prevent GMEM->SMEM load for O tile gemmini_fence(); gemmini_fence(); + gemmini_fence(); + gemmini_fence(); sp_tiled_matmul_full_spad_ws( spad_addr_P_consume, spad_addr_V_consume, /*spad_D=*/spad_addr_O, /*spad_C=*/spad_addr_O, @@ -446,7 +449,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { gemmini_fence(); gemmini_fence(); -#if 1 +#if 0 // mvout to SMEM // GEMMINI_CISC_CMD_I(9); sp_tiled_matmul_full_spad_ws( @@ -493,10 +496,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { gemmini_fence(); gemmini_fence(); gemmini_fence(); + gemmini_fence(); // 0,2,.: opcode 0 (quartile 0/2, no accum) // 1,3,.: opcode 3 (quartile 1/3, no accum) - const uint32_t opcode = 3 * (tile_k & 1); + // const uint32_t opcode = 3 * (tile_k & 1); //GEMMINI_CISC_CMD_I(opcode); sp_tiled_matmul_full_spad_ws( spad_addr_Q, spad_addr_K_consume, @@ -571,7 +575,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { gemmini_fence(); gemmini_fence(); -#if 1 +#if 0 // mvout to SMEM // GEMMINI_CISC_CMD_I(9); sp_tiled_matmul_full_spad_ws( @@ -656,6 +660,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { gemmini_fence(); gemmini_fence(); gemmini_fence(); + gemmini_fence(); } threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); From d31c8ffd7dfc72dda8ab9b13dffc25c8db3e3f8c Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Mon, 9 Sep 2024 15:43:31 -0700 Subject: [PATCH 28/50] flash: Fix grid size to hw cluster size Verified fast config, minus the barrier stall at the end. --- tests/regression/flash_attention/kernel.gemmini.cpp | 9 ++------- tests/regression/sgemm_tcore/sgemm_impl.hpp | 1 + 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index a8188dc4..b943ef0f 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -676,15 +676,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { int main() { kernel_arg_t *arg = (kernel_arg_t *)KERNEL_ARG_DEV_MEM_ADDR; - // FIXME:: use actuall seqlen/headdim - const uint32_t problem_size = (B_ROW * B_COL) / (ELEM_PER_THREAD); const uint32_t hw_threads_per_cluster = CORES_PER_CLUSTER * vx_num_threads() * vx_num_warps(); - // prevent launching more threads than the necessary problem size - // TODO: this does not take into account multiple clusters - const uint32_t grid_size = (problem_size > hw_threads_per_cluster) - ? hw_threads_per_cluster - : problem_size; + // fix to 1 threadblock per cluster + const uint32_t grid_size = hw_threads_per_cluster; #ifdef RADIANCE vx_spawn_tasks_cluster(grid_size, (vx_spawn_tasks_cb)kernel_body, arg); diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index bf4ca80d..7e95dbdc 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -538,6 +538,7 @@ wmma_store(const int thread_in_warp, const int warp_col, const int warp_row, __attribute__((convergent)) inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count) { + asm volatile("" ::: "memory"); vx_fence(); vx_barrier(barrier_id, count); } From b652e259451e7a51d73cc32a4a03278745ddaa27 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Mon, 9 Sep 2024 16:42:30 -0700 Subject: [PATCH 29/50] flash: Warp-specialize between warp 0 and 1-7 Finishes without stalls; No dependency check between O rescale and GEMM-II. --- .../regression/flash_attention/flash_impl.hpp | 30 +- .../flash_attention/kernel.gemmini.cpp | 473 +++++++++--------- 2 files changed, 261 insertions(+), 242 deletions(-) diff --git a/tests/regression/flash_attention/flash_impl.hpp b/tests/regression/flash_attention/flash_impl.hpp index 46a62546..8aac50ab 100644 --- a/tests/regression/flash_attention/flash_impl.hpp +++ b/tests/regression/flash_attention/flash_impl.hpp @@ -236,8 +236,9 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( warp_smem[tid_in_warp] = per_thread_max; // sync writes to warp_smem - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); + // threadblock_barrier(threadblock_id_in_cluster, + // warps_per_threadblock_per_core); + threadblock_barrier(1, 7); // #define PARALLEL_ROWMAX #ifndef PARALLEL_ROWMAX @@ -287,8 +288,9 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( #endif // PARALLEL_ROWMAX #endif // DUMB_ROWMAX - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); + // threadblock_barrier(threadblock_id_in_cluster, + // warps_per_threadblock_per_core); + threadblock_barrier(1, 7); // broadcast prev rowmax to all threads in the warp // NOTE: memory consistency is a little sketchy here @@ -331,8 +333,9 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( asm volatile("flashattn_exp_p_end_%=:" ::); - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); + // threadblock_barrier(threadblock_id_in_cluster, + // warps_per_threadblock_per_core); + threadblock_barrier(1, 7); // rowsum // @@ -358,8 +361,9 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( warp_smem[tid_in_warp] = per_thread_sum; // sync writes to warp_smem - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); + // threadblock_barrier(threadblock_id_in_cluster, + // warps_per_threadblock_per_core); + threadblock_barrier(1, 7); // 0-th thread collects all other thread's values in the warp if (tid_in_warp == 0) { @@ -387,8 +391,9 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( asm volatile("flashattn_rowsum_end_%=:" ::); - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); + // threadblock_barrier(threadblock_id_in_cluster, + // warps_per_threadblock_per_core); + threadblock_barrier(1, 7); // compute Oi rescale factor // FIXME: parallelize this across threads @@ -412,8 +417,9 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( asm volatile("flashattn_rescale_factor_end_%=:" ::); - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); + // threadblock_barrier(threadblock_id_in_cluster, + // warps_per_threadblock_per_core); + threadblock_barrier(1, 7); } asm volatile("thread_block_online_softmax_finish_%=:" ::); diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index b943ef0f..9d611ba2 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -58,6 +58,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const uint32_t warpgroup_id_in_cluster = warpgroup_id % warpgroups_per_cluster; const uint32_t tid_in_warpgroup = tid_in_threadblock % threads_per_warpgroup; + // // warpgroup 0: warp 0 + // // warpgroup 1: warp 1~7 + // const uint32_t warpgroup_id = (warp_id != 0); const uint32_t dim_seqlen = arg->dim_seqlen; const uint32_t dim_headdim = arg->dim_headdim; @@ -178,6 +181,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { smem_rowmax_1, smem_rowsum_1, smem_O_row_scale_1); constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary + static_assert(warps_per_threadblock_per_core == NUM_WARPS); static_assert(!GEMMINI_DMA || Q_IS_K_MAJOR, "DMA code assumes Q matrix is stored K-major"); @@ -301,7 +305,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { asm volatile("dma_move_end_%=:" ::); // protect write to SMEM - threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); // if constexpr (DEBUG) { // thread_block_copy_tile(smem_Q0, gmem_tmp_d0, tid_in_warpgroup, @@ -311,6 +316,18 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); // } + threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); + + constexpr uint32_t threads_per_warpgroup_simt = + threads_per_warpgroup - + CORES_PER_CLUSTER * NUM_THREADS /*warp 0, 4, 8, 12*/; + constexpr uint32_t warpgroup_id_simt = 1; + constexpr uint32_t barrier_id_simt = 1; + constexpr uint32_t barrier_count_simt = NUM_WARPS - 1; + const uint32_t tid_in_warpgroup_simt = + tid_in_warpgroup - (CORES_PER_CLUSTER * NUM_THREADS); + static_assert(barrier_id_simt == 1 && barrier_count_simt == 7); + asm volatile ("tile_loop_start_%=:" :: ); // "inner loop" along the columns of K^T @@ -318,8 +335,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { for (uint32_t tile_k = 0; tile_k < k_tiles + 2 /*pipeline latency*/; tile_k++) { if constexpr (DEBUG || true) { - // barrier for debugging - // threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); + threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); } // select the correct double buffer by tile iteration @@ -360,13 +376,14 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // This is done *before* GEMM I in the software pipeline, working on the // online softmax result tile from the previous iteration - if (tile_k >= 2) // delay by 2 iters for pipelining - { - const uint32_t tile_k_ = tile_k - 2; + if (vx_warp_id() == 0 /* warp 0 in every core */) { + if (tile_k >= 2) // delay by 2 iters for pipelining + { + const uint32_t tile_k_ = tile_k - 2; - asm volatile("gemm_pv_start_%=:" ::); + asm volatile("gemm_pv_start_%=:" ::); - if (tid_in_warpgroup == 0) { + if (tid_in_warpgroup == 0) { #if 0 if (tile_k_ == 0) { gemmini_fence(); @@ -379,114 +396,31 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { GEMMINI_CISC_CMD_I(1); } #else - // kickoff matmul - // among other things, this also configures CONFIG_BOUNDS so that the - // DMA knows the full matrix dimensions - // FIXME: perf: prevent GMEM->SMEM load for O tile - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - sp_tiled_matmul_full_spad_ws( - spad_addr_P_consume, spad_addr_V_consume, - /*spad_D=*/spad_addr_O, /*spad_C=*/spad_addr_O, - /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), - /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, - /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, - /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul); + // kickoff matmul + // among other things, this also configures CONFIG_BOUNDS so that the + // DMA knows the full matrix dimensions + // FIXME: perf: prevent GMEM->SMEM load for O tile + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + sp_tiled_matmul_full_spad_ws( + spad_addr_P_consume, spad_addr_V_consume, + /*spad_D=*/spad_addr_O, /*spad_C=*/spad_addr_O, + /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul); #endif - } - - // reconverge from mmio divergence - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); - - asm volatile("gemm_pv_finish_%=:" ::); - - } - - if (tile_k >= 1) // delay by 1 iters for pipelining - { - const uint32_t tile_k_ = tile_k - 1; - - // Online softmax - // - thread_block_online_softmax( - smem_S_consume, smem_P_produce, tid_in_warpgroup, - threads_per_warpgroup, warpgroup_id_in_cluster, smem_scratchpad, - smem_rowmax, smem_rowsum, smem_O_row_scale); - - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); - - if constexpr (DEBUG) { - if (warpgroup_id == 0) { - if (tile_k_ == 0) { - thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); - thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); - } else if (tile_k_ == 1) { - thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); - thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); - } - - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); } + + // // reconverge from mmio divergence + // threadblock_barrier(warpgroup_id_in_cluster, + // warps_per_warpgroup_per_core); + + asm volatile("gemm_pv_finish_%=:" ::); } - // fence GEMM II to make sure dependency on O tile is settled - if (tid_in_warpgroup == 0) { - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - -#if 0 - // mvout to SMEM - // GEMMINI_CISC_CMD_I(9); - sp_tiled_matmul_full_spad_ws( - /*spad_A=*/spad_addr_P_consume, /*spad_B=*/spad_addr_V_consume, - /*spad_D=*/0, /*spad_C=*/spad_addr_O, - /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), - /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, - /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, - /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad); -#endif - } - - // reconverge from mmio divergence - threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); - - if constexpr (DEBUG) { - gemmini_fence(); - - if (warpgroup_id == 0) { - // O after PV - if (tile_k_ == 0) { - thread_block_copy_tile( - smem_O, gmem_tmp_d6, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); - } else if (tile_k_ == 1) { - thread_block_copy_tile( - smem_O, gmem_tmp_d7, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); - } - - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); - } - } - } - - { // GEMM I: S = Q*K // // kick off asynchronously; fence later @@ -510,6 +444,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul); + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + #if 0 // TODO: speed up mvout to SMEM // loop_ws variant that skips configuring strides #define gemmini_loop_ws(I, J, K, pad_I, pad_J, pad_K, A, B, D, C, A_stride, B_stride, D_stride, C_stride, A_transpose, B_transpose, full_C, low_D, ex_accumulate, act, a_spad_id, b_spad_id, is_resadd) \ @@ -523,57 +462,186 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } #endif } - - // reconverge from mmio divergence - threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + // // reconverge after mmio + // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); asm volatile("gemm_qk_finish_%=:" ::); - } - if (tile_k >= 1) // delay by 1 iters for pipelining - { - const uint32_t tile_k_ = tile_k - 1; + // TODO: put synchronization here with online softmax - // Oi rescale - thread_block_O_rescale( - smem_O, smem_O /*in-place*/, smem_O_row_scale, tid_in_warpgroup, - threads_per_warpgroup, warpgroup_id_in_cluster); + // data move for K and V + // + // Q stays in SMEM for the entire loop + asm volatile("move_k_v_start_%=:" ::); - // rescale-to-PV-GEMM barrier - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); + // NOTE: Beware of race conditions; with warp specialization, we need to + // make sure below command code to DMA is not executed simultaneously + // from the two warpgroups (which will result in hardware fault). + // Currently the ping-pong scheduling scheme prevents that. + if (tid_in_warpgroup == 0) { + // configure GMEM addresses for K and V tiles + // load K for the next iteration + const float *gmem_K_tile = gmem_K + (B_COL * (tile_k + 1 /*runahead*/)); + // load V for the *previous* iteration; this will be consumed 2 + // iterations later + const float *gmem_V_tile = + gmem_V + (HEADDIM * B_COL * (tile_k - 1 /*dragbehind*/)); - if constexpr (DEBUG) { - if (warpgroup_id == 0) { - // O before PV - if (tile_k_ == 0) { - thread_block_copy_tile( - smem_P_produce, gmem_tmp_d2, tid_in_warpgroup, - threads_per_warpgroup, warpgroup_id_in_cluster); - thread_block_copy_tile( - smem_O, gmem_tmp_d4, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); - } else if (tile_k_ == 1) { - thread_block_copy_tile( - smem_P_produce, gmem_tmp_d3, tid_in_warpgroup, - threads_per_warpgroup, warpgroup_id_in_cluster); - thread_block_copy_tile( - smem_O, gmem_tmp_d5, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); + // fence mvout S to SMEM + gemmini_fence(); + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile), + (uint64_t)(gmem_V_tile), + k_LOOP_WS_CONFIG_ADDRS_AB) + // configure address strides for the DMA + // FIXME: unnecessary? + GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) | + 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); + gemmini_fence(); + + // do DMA + if (tile_k == 0) { + // we load (k-1)th tile for V; skip V for the 1st iteration, + sp_tiled_matmul_full_spad_ws( + spad_addr_K_produce, spad_addr_V_produce, + /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce /*FIXME:bogus*/, + /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a); + } else { + sp_tiled_matmul_full_spad_ws( + spad_addr_K_produce, spad_addr_V_produce, + /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce /*FIXME:bogus*/, + /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); + } + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + } + + // threadblock_barrier(warpgroup_id_in_cluster, + // warps_per_warpgroup_per_core); + + asm volatile("move_k_v_finish_%=:" ::); + + // // intra-warpgroup barrier + // // FIXME hardcoded + // threadblock_barrier(0, 1); + + } else /* warp_id != 0 */ { + + if (tile_k >= 1) // delay by 1 iters for pipelining + { + const uint32_t tile_k_ = tile_k - 1; + + // Online softmax + // + thread_block_online_softmax( + smem_S_consume, smem_P_produce, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt, smem_scratchpad, + smem_rowmax, smem_rowsum, smem_O_row_scale); + + threadblock_barrier(barrier_id_simt, barrier_count_simt); + + if constexpr (DEBUG) { + if (warpgroup_id == 0) { + if (tile_k_ == 0) { + thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, + tid_in_warpgroup_simt, threads_per_warpgroup, + warpgroup_id_in_cluster); + thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, + tid_in_warpgroup_simt, threads_per_warpgroup, + warpgroup_id_in_cluster); + } else if (tile_k_ == 1) { + thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, + tid_in_warpgroup_simt, threads_per_warpgroup, + warpgroup_id_in_cluster); + thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, + tid_in_warpgroup_simt, threads_per_warpgroup, + warpgroup_id_in_cluster); + } + + threadblock_barrier(barrier_id_simt, barrier_count_simt); } + } - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); +#if 0 + // fence GEMM II to make sure dependency on O tile is settled + if (tid_in_warpgroup == 0) { + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + } + + // reconverge from mmio divergence + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); +#endif + + if constexpr (DEBUG) { + // gemmini_fence(); + + if (warpgroup_id == 0) { + // O after PV + if (tile_k_ == 0) { + thread_block_copy_tile( + smem_O, gmem_tmp_d6, tid_in_warpgroup_simt, threads_per_warpgroup_simt, + warpgroup_id_simt); + } else if (tile_k_ == 1) { + thread_block_copy_tile( + smem_O, gmem_tmp_d7, tid_in_warpgroup_simt, threads_per_warpgroup_simt, + warpgroup_id_simt); + } + + threadblock_barrier(barrier_id_simt, barrier_count_simt); + } + } + + // Oi rescale + thread_block_O_rescale( + smem_O, smem_O /*in-place*/, smem_O_row_scale, + tid_in_warpgroup_simt, threads_per_warpgroup_simt, + warpgroup_id_simt); + + // rescale-to-PV-GEMM barrier + threadblock_barrier(barrier_id_simt, barrier_count_simt); + + if constexpr (DEBUG) { + if (warpgroup_id == 0) { + // O before PV + if (tile_k_ == 0) { + thread_block_copy_tile( + smem_P_produce, gmem_tmp_d2, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); + thread_block_copy_tile( + smem_O, gmem_tmp_d4, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); + } else if (tile_k_ == 1) { + thread_block_copy_tile( + smem_P_produce, gmem_tmp_d3, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); + thread_block_copy_tile( + smem_O, gmem_tmp_d5, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); + } + + threadblock_barrier(barrier_id_simt, barrier_count_simt); + } } } - } - // fence GEMM I after Oi rescale - if (tid_in_warpgroup == 0) { - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); +#if 0 + // fence GEMM I after Oi rescale + if (tid_in_warpgroup == 0) { + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); #if 0 // mvout to SMEM @@ -587,87 +655,32 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad); #endif - } - - // reconverge from mmio divergence - threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); - - if constexpr (DEBUG) { - if (warpgroup_id == 0) { - if (tile_k == 0) { - thread_block_copy_tile( - smem_S_produce, gmem_tmp_d0, tid_in_warpgroup, - threads_per_warpgroup, warpgroup_id_in_cluster); - } else if (tile_k == 1) { - thread_block_copy_tile( - smem_S_produce, gmem_tmp_d1, tid_in_warpgroup, - threads_per_warpgroup, warpgroup_id_in_cluster); - } - - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); } - } - // data move for K and V - // - // Q stays in SMEM for the entire loop - asm volatile("move_k_v_start_%=:" ::); - - // NOTE: Beware of race conditions; with warp specialization, we need to - // make sure below command code to DMA is not executed simultaneously - // from the two warpgroups (which will result in hardware fault). - // Currently the ping-pong scheduling scheme prevents that. - if (tid_in_warpgroup == 0) { - // configure GMEM addresses for K and V tiles - // load K for the next iteration - const float *gmem_K_tile = gmem_K + (B_COL * (tile_k + 1 /*runahead*/)); - // load V for the *previous* iteration; this will be consumed 2 - // iterations later - const float *gmem_V_tile = - gmem_V + (HEADDIM * B_COL * (tile_k - 1 /*dragbehind*/)); - - // fence mvout S to SMEM - gemmini_fence(); - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile), - (uint64_t)(gmem_V_tile), - k_LOOP_WS_CONFIG_ADDRS_AB) - // configure address strides for the DMA - // FIXME: unnecessary? - GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) | - 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); - gemmini_fence(); - - // do DMA - if (tile_k == 0) { - // we load (k-1)th tile for V; skip V for the 1st iteration, - sp_tiled_matmul_full_spad_ws( - spad_addr_K_produce, spad_addr_V_produce, - /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce /*FIXME:bogus*/, - /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), - /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, - /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, - /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a); - } else { - sp_tiled_matmul_full_spad_ws( - spad_addr_K_produce, spad_addr_V_produce, - /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce /*FIXME:bogus*/, - /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), - /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, - /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, - /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); - } - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - } - - threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); - - asm volatile("move_k_v_finish_%=:" ::); -#if 0 + // reconverge from mmio divergence + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); #endif + + if constexpr (DEBUG) { + if (warpgroup_id == 0) { + if (tile_k == 0) { + thread_block_copy_tile( + smem_S_produce, gmem_tmp_d0, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); + } else if (tile_k == 1) { + thread_block_copy_tile( + smem_S_produce, gmem_tmp_d1, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); + } + + threadblock_barrier(barrier_id_simt, barrier_count_simt); + } + } + + // intra-warpgroup barrier + threadblock_barrier(barrier_id_simt, barrier_count_simt); + } } asm volatile ("tile_loop_finish_%=:" :: ); From a17edac8759a131c5ababf9dd492be97b1fe14bf Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Mon, 9 Sep 2024 17:02:05 -0700 Subject: [PATCH 30/50] flash: Fix barrier stall with DEBUG Verified for up to P_expected on 2nd iter; O_before_PV is partially correct --- .../regression/flash_attention/flash_impl.hpp | 10 +-- .../flash_attention/kernel.gemmini.cpp | 63 ++++++++++--------- 2 files changed, 40 insertions(+), 33 deletions(-) diff --git a/tests/regression/flash_attention/flash_impl.hpp b/tests/regression/flash_attention/flash_impl.hpp index 8aac50ab..bd4aee9d 100644 --- a/tests/regression/flash_attention/flash_impl.hpp +++ b/tests/regression/flash_attention/flash_impl.hpp @@ -89,8 +89,9 @@ inline void thread_block_copy_rowmax(const float *src, float *dest, dest[offset] = src[offset]; } - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); + // threadblock_barrier(threadblock_id_in_cluster, + // warps_per_threadblock_per_core); + threadblock_barrier(1, 7); asm volatile("threadblock_copy_rowmax_finish_%=:" ::); } @@ -127,8 +128,9 @@ inline void thread_block_copy_tile(const float *src, float *dest, dest[gmem_offset] = src[smem_offset]; } - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); + // threadblock_barrier(threadblock_id_in_cluster, + // warps_per_threadblock_per_core); + threadblock_barrier(1, 7); } asm volatile("threadblock_copy_tile_finish_%=:" ::); diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index 9d611ba2..f85755e1 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -8,7 +8,7 @@ #include "gemmini_mmio.h" #include "flash_impl.hpp" -constexpr bool DEBUG = false; +constexpr bool DEBUG = true; static_assert(GEMMINI_DMA && !WARP_SPECIALIZED, "GEMMINI_DMA should be set and WARP_SPECIALIZED unset"); @@ -528,8 +528,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { asm volatile("move_k_v_finish_%=:" ::); - // // intra-warpgroup barrier - // // FIXME hardcoded + // NOTE: cannot put barrier here; thread 1-7 in warp 0 will skip the + // branch and call this barrier earlier than when thread 0 finishes. + // Since tmask is not considered, that will be a barrier resolve done too + // early // threadblock_barrier(0, 1); } else /* warp_id != 0 */ { @@ -538,6 +540,24 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { { const uint32_t tile_k_ = tile_k - 1; + if constexpr (DEBUG) { + // verify S = Q*K + + if (warpgroup_id == 0) { + if (tile_k == 0) { + thread_block_copy_tile( + smem_S_produce, gmem_tmp_d0, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); + } else if (tile_k == 1) { + thread_block_copy_tile( + smem_S_produce, gmem_tmp_d1, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); + } + + threadblock_barrier(barrier_id_simt, barrier_count_simt); + } + } + // Online softmax // thread_block_online_softmax( @@ -550,25 +570,26 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if constexpr (DEBUG) { if (warpgroup_id == 0) { if (tile_k_ == 0) { - thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, - tid_in_warpgroup_simt, threads_per_warpgroup, - warpgroup_id_in_cluster); - thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, - tid_in_warpgroup_simt, threads_per_warpgroup, - warpgroup_id_in_cluster); + thread_block_copy_rowmax( + smem_rowmax, gmem_tmp_e0, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); + thread_block_copy_rowmax( + smem_rowsum, gmem_tmp_e2, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); } else if (tile_k_ == 1) { thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, - tid_in_warpgroup_simt, threads_per_warpgroup, - warpgroup_id_in_cluster); + tid_in_warpgroup_simt, threads_per_warpgroup_simt, + warpgroup_id_simt); thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, - tid_in_warpgroup_simt, threads_per_warpgroup, - warpgroup_id_in_cluster); + tid_in_warpgroup_simt, threads_per_warpgroup_simt, + warpgroup_id_simt); } threadblock_barrier(barrier_id_simt, barrier_count_simt); } } + // FIXME: put synchronization with GEMM II here #if 0 // fence GEMM II to make sure dependency on O tile is settled if (tid_in_warpgroup == 0) { @@ -662,22 +683,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { warps_per_warpgroup_per_core); #endif - if constexpr (DEBUG) { - if (warpgroup_id == 0) { - if (tile_k == 0) { - thread_block_copy_tile( - smem_S_produce, gmem_tmp_d0, tid_in_warpgroup_simt, - threads_per_warpgroup_simt, warpgroup_id_simt); - } else if (tile_k == 1) { - thread_block_copy_tile( - smem_S_produce, gmem_tmp_d1, tid_in_warpgroup_simt, - threads_per_warpgroup_simt, warpgroup_id_simt); - } - - threadblock_barrier(barrier_id_simt, barrier_count_simt); - } - } - // intra-warpgroup barrier threadblock_barrier(barrier_id_simt, barrier_count_simt); } From 88760596cb5e8bfd73a8c37ef296408c45b21099 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Mon, 9 Sep 2024 17:18:59 -0700 Subject: [PATCH 31/50] flash: Remove bogus mvout to SMEM code --- .../flash_attention/kernel.gemmini.cpp | 27 ++----------------- 1 file changed, 2 insertions(+), 25 deletions(-) diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index f85755e1..884762d7 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -449,18 +449,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { gemmini_fence(); gemmini_fence(); -#if 0 // TODO: speed up mvout to SMEM - // loop_ws variant that skips configuring strides -#define gemmini_loop_ws(I, J, K, pad_I, pad_J, pad_K, A, B, D, C, A_stride, B_stride, D_stride, C_stride, A_transpose, B_transpose, full_C, low_D, ex_accumulate, act, a_spad_id, b_spad_id, is_resadd) \ - { \ - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(pad_K) << 32) | ((uint64_t)(pad_J) << 16) | (uint64_t)(pad_I), ((uint64_t)(K) << 32) | ((uint64_t)(J) << 16) | (uint64_t)(I), k_LOOP_WS_CONFIG_BOUNDS) \ - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, A, B, k_LOOP_WS_CONFIG_ADDRS_AB) \ - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, D, C, k_LOOP_WS_CONFIG_ADDRS_DC) \ - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, A_stride, B_stride, k_LOOP_WS_CONFIG_STRIDES_AB) \ - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, D_stride, C_stride, k_LOOP_WS_CONFIG_STRIDES_DC) \ - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(a_spad_id) << 18) | ((uint64_t)(b_spad_id) << 16) | ((uint64_t)(act) << 8) | ((low_D) << 2) | ((full_C) << 1) | (ex_accumulate), ((is_resadd) << 2) | ((B_transpose) << 1) | (A_transpose), k_LOOP_WS) \ - } -#endif } // // reconverge after mmio // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); @@ -487,6 +475,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const float *gmem_V_tile = gmem_V + (HEADDIM * B_COL * (tile_k - 1 /*dragbehind*/)); +#if 0 // fence mvout S to SMEM gemmini_fence(); ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile), @@ -496,6 +485,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // FIXME: unnecessary? GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) | 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); +#endif gemmini_fence(); // do DMA @@ -663,19 +653,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { gemmini_fence(); gemmini_fence(); gemmini_fence(); - -#if 0 - // mvout to SMEM - // GEMMINI_CISC_CMD_I(9); - sp_tiled_matmul_full_spad_ws( - /*spad_A=*/spad_addr_Q /*bogus*/, - /*spad_B=*/spad_addr_K_consume /*bogus*/, - /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce, - /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), - /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, - /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, - /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad); -#endif } // reconverge from mmio divergence From 90e03894fc69cbf6559f00597a424e84a3a3501d Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Tue, 10 Sep 2024 13:37:32 -0700 Subject: [PATCH 32/50] flash: Add flag in SMEM for dependency check on O TODO: results unverified. Stalls O rescale until GEMM II finishes. --- .../regression/flash_attention/flash_impl.hpp | 14 ++++++- .../flash_attention/kernel.gemmini.cpp | 39 +++++++++++++------ 2 files changed, 41 insertions(+), 12 deletions(-) diff --git a/tests/regression/flash_attention/flash_impl.hpp b/tests/regression/flash_attention/flash_impl.hpp index bd4aee9d..eb1a43bb 100644 --- a/tests/regression/flash_attention/flash_impl.hpp +++ b/tests/regression/flash_attention/flash_impl.hpp @@ -176,6 +176,19 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( for (int row_offset = 0; row_offset < B_ROW; row_offset += warps_in_threadblock) { const uint32_t row = row_offset + warp_id; + // if the number of warps doesn't exactly divide the number of rows, + // early-exit to prevent out-of-bounds access + // if (row >= B_ROW) { + // // WARNING: the number of barrier calls have to exactly match that in the + // // outside of the branch to prevent stalls!! FIXME better proof this. + // threadblock_barrier(1, 7); + // threadblock_barrier(1, 7); + // threadblock_barrier(1, 7); + // threadblock_barrier(1, 7); + // threadblock_barrier(1, 7); + // threadblock_barrier(1, 7); + // continue; + // } const uint32_t first_thread_offset = B_COL * row; // rowmax @@ -334,7 +347,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( asm volatile("flashattn_exp_p_end_%=:" ::); - // threadblock_barrier(threadblock_id_in_cluster, // warps_per_threadblock_per_core); threadblock_barrier(1, 7); diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index 884762d7..35a8cdf6 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -112,6 +112,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { constexpr uint32_t smem_S1_offset = smem_V1_offset + smem_V_size * sizeof(float); constexpr uint32_t smem_P0_offset = smem_K0_offset + smem_K_size * sizeof(float); constexpr uint32_t smem_P1_offset = smem_K1_offset + smem_K_size * sizeof(float); + // reversed! constexpr uint32_t smem_O0_offset = smem_P1_offset + smem_QK_size * sizeof(float); constexpr uint32_t smem_O1_offset = smem_P0_offset + smem_QK_size * sizeof(float); // unused @@ -158,6 +159,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { smem_cursor += smem_scratchpad_size; float *smem_scratchpad_1 = smem_cursor; smem_cursor += smem_scratchpad_size; + uint32_t *smem_O_flag = reinterpret_cast(smem_cursor); + smem_cursor += 1 /* 4Byte */; static_assert(sizeof(elem_t) == sizeof(float)); constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t); @@ -332,7 +335,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // "inner loop" along the columns of K^T const uint32_t k_tiles = (dim_seqlen / B_COL); - for (uint32_t tile_k = 0; tile_k < k_tiles + 2 /*pipeline latency*/; + for (uint32_t tile_k = 0; + tile_k < + (1 /*FIXME: for perf measurement*/ * k_tiles) + 2 /*pipeline latency*/; tile_k++) { if constexpr (DEBUG || true) { threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); @@ -371,16 +376,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const auto spad_addr_O = spad_addr_O0; // NOTE: there's only single O tile asm volatile ("dbuf_sel_end_%=:" :: ); - // GEMM II: O = O + P*V - // -------------------- - // This is done *before* GEMM I in the software pipeline, working on the - // online softmax result tile from the previous iteration - if (vx_warp_id() == 0 /* warp 0 in every core */) { if (tile_k >= 2) // delay by 2 iters for pipelining { const uint32_t tile_k_ = tile_k - 2; + // GEMM II: O = O + P*V + // -------------------- + // This is done *before* GEMM I in the software pipeline, working on the + // online softmax result tile from the previous iteration + asm volatile("gemm_pv_start_%=:" ::); if (tid_in_warpgroup == 0) { @@ -427,11 +432,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { asm volatile("gemm_qk_start_%=:" ::); if (tid_in_warpgroup == 0) { + // fence to GEMM II completion gemmini_fence(); gemmini_fence(); gemmini_fence(); gemmini_fence(); + // signal that GEMM II is finished to O rescale step + *smem_O_flag = 1; + vx_fence(); + // 0,2,.: opcode 0 (quartile 0/2, no accum) // 1,3,.: opcode 3 (quartile 1/3, no accum) // const uint32_t opcode = 3 * (tile_k & 1); @@ -448,7 +458,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { gemmini_fence(); gemmini_fence(); gemmini_fence(); - } // // reconverge after mmio // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); @@ -534,11 +543,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // verify S = Q*K if (warpgroup_id == 0) { - if (tile_k == 0) { + if (tile_k_ == 0) { thread_block_copy_tile( smem_S_produce, gmem_tmp_d0, tid_in_warpgroup_simt, threads_per_warpgroup_simt, warpgroup_id_simt); - } else if (tile_k == 1) { + } else if (tile_k_ == 1) { thread_block_copy_tile( smem_S_produce, gmem_tmp_d1, tid_in_warpgroup_simt, threads_per_warpgroup_simt, warpgroup_id_simt); @@ -579,9 +588,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } } - // FIXME: put synchronization with GEMM II here + // check flag to make sure GEMM II finished and read-after-write + // dependency on O tile is settled for rescale + if (tid_in_warpgroup_simt == 0) { + while ((*smem_O_flag) != 1) + ; + // set it back to 0 for the next tile iteration + *smem_O_flag = 0; + vx_fence(); + } + #if 0 - // fence GEMM II to make sure dependency on O tile is settled if (tid_in_warpgroup == 0) { gemmini_fence(); gemmini_fence(); From ccddd0bcc9c8e7f400605cf159ce61bc018af114 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Tue, 10 Sep 2024 15:54:17 -0700 Subject: [PATCH 33/50] sgemm_impl: Remove unused FLEXIBLE_LAYOUT --- tests/regression/sgemm_tcore/sgemm_impl.hpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index 7e95dbdc..d52f9b0a 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -73,8 +73,7 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER == #define TRANSPOSE_AT_CONSUME 0 #define GEMMINI_DMA 1 -#define GEMMINI_DMA_FAST 0 -#define GEMMINI_DMA_FLEXIBLE_LAYOUT 1 +#define GEMMINI_DMA_FAST 1 #if SMEM_SIZE == 0x4000 #define SMEM_ADDR_Q0 ((float * const) 0xff000000) #define SMEM_ADDR_Q1 ((float * const) 0xff001000) From 2152c80ffd06f725103973dfd03381c6dc607b0b Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Tue, 10 Sep 2024 18:05:01 -0700 Subject: [PATCH 34/50] sgemm_impl: Add missing reconvergence barrier after mmio --- tests/regression/sgemm_tcore/sgemm_impl.hpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index d52f9b0a..f0998873 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -1022,6 +1022,10 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips) #endif } + + // reconverge after mmio divergence + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); #else // move A if constexpr (!TRANSPOSE_AT_PRODUCE) { @@ -1038,9 +1042,6 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, load_tile_to_smem(dim_n, block_n, block_k, B, local_b, tid_in_threadblock); - - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); #endif // consumer code: SMEM->RF and compute From 28b2eaec8f6c5c454094fd48c8fc7598c747b903 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Tue, 10 Sep 2024 18:29:26 -0700 Subject: [PATCH 35/50] sgemm_gemmini_dma: Fix tile size to (128,64,128) --- tests/regression/sgemm_gemmini_dma/kernel.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/regression/sgemm_gemmini_dma/kernel.cpp b/tests/regression/sgemm_gemmini_dma/kernel.cpp index 6391e500..6b72c92c 100644 --- a/tests/regression/sgemm_gemmini_dma/kernel.cpp +++ b/tests/regression/sgemm_gemmini_dma/kernel.cpp @@ -8,9 +8,9 @@ // fp16 16x16 #define TILE_M 128 -#define TILE_N 128 +#define TILE_N 64 #define TILE_K 128 -#define BOUND_INST 0x800080008ULL +#define BOUND_INST 0x800040008ULL #define NUM_THREADS_IN_CLUSTER 512 // fp32 8x8 @@ -195,4 +195,4 @@ int main() { vx_spawn_tasks_contiguous(grid_size, (vx_spawn_tasks_cb)kernel_body, arg); #endif return 0; -} \ No newline at end of file +} From dc746272fb1a36e7508c47c5be572ce78e739f83 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Tue, 10 Sep 2024 22:53:35 -0700 Subject: [PATCH 36/50] flash: Conditionally enable GEMM II fence code, fix tile_k for DEBUG --- .../flash_attention/kernel.gemmini.cpp | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index 35a8cdf6..51993b21 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -8,7 +8,7 @@ #include "gemmini_mmio.h" #include "flash_impl.hpp" -constexpr bool DEBUG = true; +constexpr bool DEBUG = false; static_assert(GEMMINI_DMA && !WARP_SPECIALIZED, "GEMMINI_DMA should be set and WARP_SPECIALIZED unset"); @@ -438,9 +438,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { gemmini_fence(); gemmini_fence(); +#ifdef FENCE_GEMM_II // signal that GEMM II is finished to O rescale step *smem_O_flag = 1; vx_fence(); +#endif // 0,2,.: opcode 0 (quartile 0/2, no accum) // 1,3,.: opcode 3 (quartile 1/3, no accum) @@ -540,8 +542,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const uint32_t tile_k_ = tile_k - 1; if constexpr (DEBUG) { - // verify S = Q*K + gemmini_fence(); + gemmini_fence(); + // verify S = Q*K if (warpgroup_id == 0) { if (tile_k_ == 0) { thread_block_copy_tile( @@ -588,6 +592,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } } +#ifdef FENCE_GEMM_II // check flag to make sure GEMM II finished and read-after-write // dependency on O tile is settled for rescale if (tid_in_warpgroup_simt == 0) { @@ -597,6 +602,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { *smem_O_flag = 0; vx_fence(); } +#endif #if 0 if (tid_in_warpgroup == 0) { @@ -612,15 +618,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { #endif if constexpr (DEBUG) { - // gemmini_fence(); - if (warpgroup_id == 0) { + gemmini_fence(); + gemmini_fence(); + // O after PV - if (tile_k_ == 0) { + if (tile_k_ == 1 /*wait until GEMM II finshes */) { thread_block_copy_tile( smem_O, gmem_tmp_d6, tid_in_warpgroup_simt, threads_per_warpgroup_simt, warpgroup_id_simt); - } else if (tile_k_ == 1) { + } else if (tile_k_ == 2) { thread_block_copy_tile( smem_O, gmem_tmp_d7, tid_in_warpgroup_simt, threads_per_warpgroup_simt, warpgroup_id_simt); From ba66d2c2bd23b2a54e50d0c6fae7874ae14047b9 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 11 Sep 2024 00:01:56 -0700 Subject: [PATCH 37/50] sgemm_impl: barrier dumb dumb --- tests/regression/sgemm_tcore/sgemm_impl.hpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index f0998873..d1b9d76e 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -1042,6 +1042,9 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, load_tile_to_smem(dim_n, block_n, block_k, B, local_b, tid_in_threadblock); + + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); #endif // consumer code: SMEM->RF and compute From 068d48534efa7a352c00270806006cf8679e7071 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 11 Sep 2024 00:55:36 -0700 Subject: [PATCH 38/50] flash: Swap S1/S0 to avoid GEMM II - softmax bank conflict + remove spurrious fences to better overlap GEMM I and DMA --- .../flash_attention/kernel.gemmini.cpp | 76 ++++++++----------- 1 file changed, 33 insertions(+), 43 deletions(-) diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index 51993b21..a583feb7 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -108,8 +108,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { constexpr uint32_t smem_K1_offset = smem_quart3; constexpr uint32_t smem_V0_offset = smem_Q0_offset + smem_Q_size * sizeof(float); constexpr uint32_t smem_V1_offset = smem_Q1_offset + smem_Q_size * sizeof(float); - constexpr uint32_t smem_S0_offset = smem_V0_offset + smem_V_size * sizeof(float); - constexpr uint32_t smem_S1_offset = smem_V1_offset + smem_V_size * sizeof(float); + // put S1/S0 with V0/V1 so that softmax and GEMM-II doesn't cause bank + // conflicts + constexpr uint32_t smem_S0_offset = smem_V1_offset + smem_V_size * sizeof(float); + constexpr uint32_t smem_S1_offset = smem_V0_offset + smem_V_size * sizeof(float); constexpr uint32_t smem_P0_offset = smem_K0_offset + smem_K_size * sizeof(float); constexpr uint32_t smem_P1_offset = smem_K1_offset + smem_K_size * sizeof(float); // reversed! @@ -177,14 +179,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { constexpr uint32_t spad_addr_O0 = smem_O0_offset / spad_addr_factor; constexpr uint32_t spad_addr_O1 = smem_O1_offset / spad_addr_factor; + constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary + static_assert(warps_per_threadblock_per_core == NUM_WARPS); + // initialize rowmax/rowsum values in sharedmem thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O0, smem_rowmax_0, smem_rowsum_0, smem_O_row_scale_0); thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O1, smem_rowmax_1, smem_rowsum_1, smem_O_row_scale_1); - constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary - static_assert(warps_per_threadblock_per_core == NUM_WARPS); + threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); static_assert(!GEMMINI_DMA || Q_IS_K_MAJOR, "DMA code assumes Q matrix is stored K-major"); @@ -209,22 +213,19 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/0, /*skip_ex=*/0, /*skip_stc=*/1); - if constexpr (GEMMINI_DMA) { - if (tid_in_warpgroup == 0) { - gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0); + if (tid_in_warpgroup == 0) { + gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0); - // configure DMA with GMEM address strides - // Q matrix - gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY, - false, 0); - // K matrix - gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t), MVIN_SCALE_IDENTITY, - false, 1); - // configure DMA for Q*K store - gemmini_extended_config_st(B_COL * sizeof(elem_t), 0, - MVIN_SCALE_IDENTITY); - gemmini_fence(); - } + // configure DMA with GMEM address strides + // Q matrix + gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY, + false, 0); + // K matrix + gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t), + MVIN_SCALE_IDENTITY, false, 1); + // configure DMA for Q*K store + gemmini_extended_config_st(B_COL * sizeof(elem_t), 0, MVIN_SCALE_IDENTITY); + gemmini_fence(); } // NOTE about barriers: Placing barriers around thread-divergent branches may @@ -319,8 +320,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); // } - threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); - constexpr uint32_t threads_per_warpgroup_simt = threads_per_warpgroup - CORES_PER_CLUSTER * NUM_THREADS /*warp 0, 4, 8, 12*/; @@ -337,7 +336,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const uint32_t k_tiles = (dim_seqlen / B_COL); for (uint32_t tile_k = 0; tile_k < - (1 /*FIXME: for perf measurement*/ * k_tiles) + 2 /*pipeline latency*/; + (4 /*FIXME: for perf measurement*/ * k_tiles) + 2 /*pipeline latency*/; tile_k++) { if constexpr (DEBUG || true) { threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); @@ -456,28 +455,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul); - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - } - // // reconverge after mmio - // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + // gemmini_fence(); + // gemmini_fence(); + // gemmini_fence(); + // gemmini_fence(); + asm volatile("gemm_qk_finish_%=:" ::); - asm volatile("gemm_qk_finish_%=:" ::); + // data move for K and V + // + // Q stays in SMEM for the entire loop + asm volatile("move_k_v_start_%=:" ::); - // TODO: put synchronization here with online softmax - - // data move for K and V - // - // Q stays in SMEM for the entire loop - asm volatile("move_k_v_start_%=:" ::); - - // NOTE: Beware of race conditions; with warp specialization, we need to - // make sure below command code to DMA is not executed simultaneously - // from the two warpgroups (which will result in hardware fault). - // Currently the ping-pong scheduling scheme prevents that. - if (tid_in_warpgroup == 0) { // configure GMEM addresses for K and V tiles // load K for the next iteration const float *gmem_K_tile = gmem_K + (B_COL * (tile_k + 1 /*runahead*/)); @@ -497,7 +485,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) | 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); #endif - gemmini_fence(); + // gemmini_fence(); // do DMA if (tile_k == 0) { @@ -518,6 +506,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); } + + // fence everything before going to the next tile gemmini_fence(); gemmini_fence(); gemmini_fence(); From 18cf0e73cd6ed9a8b14031e4afe9faf0f48cd651 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 11 Sep 2024 00:56:09 -0700 Subject: [PATCH 39/50] flash: Add early return for warp-indivisible row iter --- .../regression/flash_attention/flash_impl.hpp | 53 +++++++++++++++---- 1 file changed, 42 insertions(+), 11 deletions(-) diff --git a/tests/regression/flash_attention/flash_impl.hpp b/tests/regression/flash_attention/flash_impl.hpp index eb1a43bb..410c5f4f 100644 --- a/tests/regression/flash_attention/flash_impl.hpp +++ b/tests/regression/flash_attention/flash_impl.hpp @@ -8,6 +8,8 @@ #define B_COL 64 #define HEADDIM 64 +#define ROW_REMAINDER_LOGIC + constexpr uint32_t ROWMAX_SETS = 3; constexpr bool WARP_SPECIALIZED = false; @@ -56,6 +58,14 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock, for (int row_offset = 0; row_offset < B_COL; row_offset += warps_in_threadblock) { const uint32_t row = row_offset + warp_id; +#ifdef ROW_REMAINDER_LOGIC + if (row >= B_ROW) { + // WARNING: the number of barrier calls have to exactly match that in the + // outside of the branch to prevent stalls!! FIXME better proof this. + continue; + } +#endif + uint32_t thread_offset = HEADDIM * row + tid_in_warp; constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS; const float one = 0.0f; @@ -114,6 +124,14 @@ inline void thread_block_copy_tile(const float *src, float *dest, for (int row_offset = 0; row_offset < dim_row; row_offset += warps_in_threadblock) { const uint32_t row = row_offset + warp_id; +#ifdef ROW_REMAINDER_LOGIC + if (row >= B_ROW) { + // WARNING: the number of barrier calls have to exactly match that in the + // outside of the branch to prevent stalls!! FIXME better proof this. + threadblock_barrier(1, 7); + continue; + } +#endif constexpr uint32_t per_row_iter = dim_col / NUM_THREADS; #pragma GCC unroll @@ -176,19 +194,21 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( for (int row_offset = 0; row_offset < B_ROW; row_offset += warps_in_threadblock) { const uint32_t row = row_offset + warp_id; +#ifdef ROW_REMAINDER_LOGIC // if the number of warps doesn't exactly divide the number of rows, // early-exit to prevent out-of-bounds access - // if (row >= B_ROW) { - // // WARNING: the number of barrier calls have to exactly match that in the - // // outside of the branch to prevent stalls!! FIXME better proof this. - // threadblock_barrier(1, 7); - // threadblock_barrier(1, 7); - // threadblock_barrier(1, 7); - // threadblock_barrier(1, 7); - // threadblock_barrier(1, 7); - // threadblock_barrier(1, 7); - // continue; - // } + if (row >= B_ROW) { + // WARNING: the number of barrier calls have to exactly match that in the + // outside of the branch to prevent stalls!! FIXME better proof this. + threadblock_barrier(1, 7); + threadblock_barrier(1, 7); + threadblock_barrier(1, 7); + threadblock_barrier(1, 7); + threadblock_barrier(1, 7); + threadblock_barrier(1, 7); + continue; + } +#endif const uint32_t first_thread_offset = B_COL * row; // rowmax @@ -456,6 +476,14 @@ __attribute__((always_inline)) inline void thread_block_O_rescale( for (int row_offset = 0; row_offset < B_ROW; row_offset += warps_in_threadblock) { const uint32_t row = row_offset + warp_id; +#ifdef ROW_REMAINDER_LOGIC + if (row >= B_ROW) { + // WARNING: the number of barrier calls have to exactly match that in the + // outside of the branch to prevent stalls!! FIXME better proof this. + continue; + } +#endif + constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS; // Oi rescale @@ -474,6 +502,9 @@ __attribute__((always_inline)) inline void thread_block_O_rescale( } } + // reconverge after warp divergence + threadblock_barrier(1, 7); + asm volatile("thread_block_O_rescale_finish_%=:" ::); } From d69707f686a04f9e12152beb7f5a6848a79f21f0 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 11 Sep 2024 19:24:06 -0700 Subject: [PATCH 40/50] flash: Enable GEMM II fence; Pull 1st KV move out of the loop --- .../flash_attention/kernel.gemmini.cpp | 35 ++++++++++--------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index a583feb7..63d3bd56 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -8,7 +8,9 @@ #include "gemmini_mmio.h" #include "flash_impl.hpp" -constexpr bool DEBUG = false; +#define FENCE_GEMM_II + +constexpr bool DEBUG = true; static_assert(GEMMINI_DMA && !WARP_SPECIALIZED, "GEMMINI_DMA should be set and WARP_SPECIALIZED unset"); @@ -290,11 +292,13 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, - /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a); + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); #endif gemmini_fence(); gemmini_fence(); + gemmini_fence(); + gemmini_fence(); // re-configure DMA for K and V load that will later happen in the loop // GMEM addr stride for K @@ -480,27 +484,27 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile), (uint64_t)(gmem_V_tile), k_LOOP_WS_CONFIG_ADDRS_AB) +#endif // configure address strides for the DMA // FIXME: unnecessary? GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) | 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); -#endif // gemmini_fence(); // do DMA if (tile_k == 0) { // we load (k-1)th tile for V; skip V for the 1st iteration, - sp_tiled_matmul_full_spad_ws( - spad_addr_K_produce, spad_addr_V_produce, - /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce /*FIXME:bogus*/, - /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), - /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, - /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, - /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a); + // sp_tiled_matmul_full_spad_ws( + // spad_addr_K_produce, spad_addr_V_produce, + // /*spad_D=*/0, /*spad_C=*/0, + // /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + // /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + // /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + // /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a); } else { sp_tiled_matmul_full_spad_ws( spad_addr_K_produce, spad_addr_V_produce, - /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce /*FIXME:bogus*/, + /*spad_D=*/0, /*spad_C=*/0, /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, @@ -532,18 +536,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const uint32_t tile_k_ = tile_k - 1; if constexpr (DEBUG) { - gemmini_fence(); - gemmini_fence(); - - // verify S = Q*K + // verify S = Q*K before softmax if (warpgroup_id == 0) { if (tile_k_ == 0) { thread_block_copy_tile( - smem_S_produce, gmem_tmp_d0, tid_in_warpgroup_simt, + smem_S_consume, gmem_tmp_d0, tid_in_warpgroup_simt, threads_per_warpgroup_simt, warpgroup_id_simt); } else if (tile_k_ == 1) { thread_block_copy_tile( - smem_S_produce, gmem_tmp_d1, tid_in_warpgroup_simt, + smem_S_consume, gmem_tmp_d1, tid_in_warpgroup_simt, threads_per_warpgroup_simt, warpgroup_id_simt); } From b5916f3f0718eb0232c5c104bd9985ffea1ad6d5 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 11 Sep 2024 22:08:06 -0700 Subject: [PATCH 41/50] flash: Fix hardcoded barrier for tcore; move tcore-specific flags --- .../regression/flash_attention/flash_impl.hpp | 122 ++++++++++++------ tests/regression/flash_attention/kernel.cpp | 7 +- 2 files changed, 91 insertions(+), 38 deletions(-) diff --git a/tests/regression/flash_attention/flash_impl.hpp b/tests/regression/flash_attention/flash_impl.hpp index 410c5f4f..47e21c70 100644 --- a/tests/regression/flash_attention/flash_impl.hpp +++ b/tests/regression/flash_attention/flash_impl.hpp @@ -11,11 +11,8 @@ #define ROW_REMAINDER_LOGIC constexpr uint32_t ROWMAX_SETS = 3; -constexpr bool WARP_SPECIALIZED = false; - -constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000; - -constexpr bool Q_IS_K_MAJOR = true; +constexpr bool WARP_SPECIALIZED = true; +constexpr bool TENSOR_CORE = true; // temporary safety stop for wrong configs static_assert(NUM_CORES == 4); @@ -99,9 +96,12 @@ inline void thread_block_copy_rowmax(const float *src, float *dest, dest[offset] = src[offset]; } - // threadblock_barrier(threadblock_id_in_cluster, - // warps_per_threadblock_per_core); - threadblock_barrier(1, 7); + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } asm volatile("threadblock_copy_rowmax_finish_%=:" ::); } @@ -128,7 +128,12 @@ inline void thread_block_copy_tile(const float *src, float *dest, if (row >= B_ROW) { // WARNING: the number of barrier calls have to exactly match that in the // outside of the branch to prevent stalls!! FIXME better proof this. - threadblock_barrier(1, 7); + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } continue; } #endif @@ -146,9 +151,12 @@ inline void thread_block_copy_tile(const float *src, float *dest, dest[gmem_offset] = src[smem_offset]; } - // threadblock_barrier(threadblock_id_in_cluster, - // warps_per_threadblock_per_core); - threadblock_barrier(1, 7); + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } } asm volatile("threadblock_copy_tile_finish_%=:" ::); @@ -200,12 +208,28 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( if (row >= B_ROW) { // WARNING: the number of barrier calls have to exactly match that in the // outside of the branch to prevent stalls!! FIXME better proof this. - threadblock_barrier(1, 7); - threadblock_barrier(1, 7); - threadblock_barrier(1, 7); - threadblock_barrier(1, 7); - threadblock_barrier(1, 7); - threadblock_barrier(1, 7); + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + threadblock_barrier(1, 7); + threadblock_barrier(1, 7); + threadblock_barrier(1, 7); + threadblock_barrier(1, 7); + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } + continue; } #endif @@ -271,9 +295,12 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( warp_smem[tid_in_warp] = per_thread_max; // sync writes to warp_smem - // threadblock_barrier(threadblock_id_in_cluster, - // warps_per_threadblock_per_core); - threadblock_barrier(1, 7); + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } // #define PARALLEL_ROWMAX #ifndef PARALLEL_ROWMAX @@ -323,9 +350,13 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( #endif // PARALLEL_ROWMAX #endif // DUMB_ROWMAX - // threadblock_barrier(threadblock_id_in_cluster, - // warps_per_threadblock_per_core); - threadblock_barrier(1, 7); + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } + // broadcast prev rowmax to all threads in the warp // NOTE: memory consistency is a little sketchy here @@ -367,9 +398,12 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( asm volatile("flashattn_exp_p_end_%=:" ::); - // threadblock_barrier(threadblock_id_in_cluster, - // warps_per_threadblock_per_core); - threadblock_barrier(1, 7); + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } // rowsum // @@ -395,9 +429,12 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( warp_smem[tid_in_warp] = per_thread_sum; // sync writes to warp_smem - // threadblock_barrier(threadblock_id_in_cluster, - // warps_per_threadblock_per_core); - threadblock_barrier(1, 7); + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } // 0-th thread collects all other thread's values in the warp if (tid_in_warp == 0) { @@ -425,9 +462,12 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( asm volatile("flashattn_rowsum_end_%=:" ::); - // threadblock_barrier(threadblock_id_in_cluster, - // warps_per_threadblock_per_core); - threadblock_barrier(1, 7); + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } // compute Oi rescale factor // FIXME: parallelize this across threads @@ -451,9 +491,12 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( asm volatile("flashattn_rescale_factor_end_%=:" ::); - // threadblock_barrier(threadblock_id_in_cluster, - // warps_per_threadblock_per_core); - threadblock_barrier(1, 7); + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } } asm volatile("thread_block_online_softmax_finish_%=:" ::); @@ -503,7 +546,12 @@ __attribute__((always_inline)) inline void thread_block_O_rescale( } // reconverge after warp divergence - threadblock_barrier(1, 7); + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } asm volatile("thread_block_O_rescale_finish_%=:" ::); } diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 9eee2b60..1c9b015d 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -8,6 +8,9 @@ #include "gemmini_mmio.h" #include "flash_impl.hpp" +constexpr bool DEBUG = false; +constexpr bool Q_IS_K_MAJOR = true; + void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // @perf: All threads are running these compute whose result is mostly same // across the threadblock @@ -88,6 +91,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { uint8_t *smem_per_threadblock = reinterpret_cast( DEV_SMEM_START_ADDR); float *smem_cursor = reinterpret_cast(smem_per_threadblock); + // constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000; // float *smem_cursor = reinterpret_cast(DEV_FAKE_SMEM_START_ADDR); float *smem_Q0 = smem_cursor; smem_cursor += smem_Q_size; @@ -310,7 +314,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // "inner loop" along the columns of K^T const uint32_t k_tiles = (dim_seqlen / B_COL); - for (uint32_t tile_k = 0; tile_k < k_tiles; tile_k++) { + for (uint32_t tile_k = 0; tile_k < (4 /* for perf measurement */ * k_tiles); + tile_k++) { // float *smem_P_produce = (tile_k % 2) ? smem_P0 : smem_P1; // float *smem_P_consume = (tile_k % 2) ? smem_P1 : smem_P0; // float *smem_V_produce = (tile_k % 2) ? smem_V0 : smem_V1; From be15cffbf39451099937519a8078ab7b7db5f233 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Thu, 12 Sep 2024 14:25:33 -0700 Subject: [PATCH 42/50] flash: Revert to gemmini config, remove DEBUG and unnecessary checks --- tests/regression/flash_attention/flash_impl.hpp | 4 ++-- tests/regression/flash_attention/kernel.gemmini.cpp | 8 ++------ 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/regression/flash_attention/flash_impl.hpp b/tests/regression/flash_attention/flash_impl.hpp index 47e21c70..93dc3cc9 100644 --- a/tests/regression/flash_attention/flash_impl.hpp +++ b/tests/regression/flash_attention/flash_impl.hpp @@ -11,8 +11,8 @@ #define ROW_REMAINDER_LOGIC constexpr uint32_t ROWMAX_SETS = 3; -constexpr bool WARP_SPECIALIZED = true; -constexpr bool TENSOR_CORE = true; +constexpr bool WARP_SPECIALIZED = false; +constexpr bool TENSOR_CORE = false; // temporary safety stop for wrong configs static_assert(NUM_CORES == 4); diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index 63d3bd56..ac3788d4 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -10,7 +10,7 @@ #define FENCE_GEMM_II -constexpr bool DEBUG = true; +constexpr bool DEBUG = false; static_assert(GEMMINI_DMA && !WARP_SPECIALIZED, "GEMMINI_DMA should be set and WARP_SPECIALIZED unset"); @@ -192,9 +192,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); - static_assert(!GEMMINI_DMA || Q_IS_K_MAJOR, - "DMA code assumes Q matrix is stored K-major"); - // skip everything except DMA in the loop FSM constexpr uint32_t skips = loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/0, /*skip_ldd=*/1, @@ -339,8 +336,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // "inner loop" along the columns of K^T const uint32_t k_tiles = (dim_seqlen / B_COL); for (uint32_t tile_k = 0; - tile_k < - (4 /*FIXME: for perf measurement*/ * k_tiles) + 2 /*pipeline latency*/; + tile_k < (4 /*for perf measurement*/ * k_tiles) + 2 /*pipeline latency*/; tile_k++) { if constexpr (DEBUG || true) { threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); From b9cafd63727c7272c8a73baa81ddb5d960ec1dd5 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 18 Sep 2024 18:10:29 -0700 Subject: [PATCH 43/50] idle: unused const --- tests/regression/idle/kernel.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/regression/idle/kernel.cpp b/tests/regression/idle/kernel.cpp index ccd9bcc5..12ca862d 100644 --- a/tests/regression/idle/kernel.cpp +++ b/tests/regression/idle/kernel.cpp @@ -7,7 +7,6 @@ #include "gemmini_mmio.h" #define NUM_CLUSTERS 1 -#define NUM_THREADS_IN_CLUSTER 256 #define HW_TID() ({uint32_t gtid; asm volatile ("csrr %0, mhartid" : "=r" (gtid)); gtid;}) From d0ef06cec1429c8d9c970c8296cdb1f04796783d Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Thu, 19 Sep 2024 20:36:03 -0700 Subject: [PATCH 44/50] flash: Complete Q_IS_K_MAJOR code for GEMM II --- tests/regression/flash_attention/kernel.cpp | 27 ++++++++++++++++++--- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 1c9b015d..1d88b4de 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -672,9 +672,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { static_assert(B_ROW / 2 == 32, "tile size assumption for warp-specialization not met"); - // assumes smem_P is K-major float *smem_P_half0 = smem_P; - float *smem_P_half1 = smem_P + (B_ROW / 2) * B_COL; + float *smem_P_half1 = (Q_IS_K_MAJOR || GEMMINI_DMA) + ? smem_P + (B_ROW / 2) * B_COL + : smem_P + (B_ROW / 2); float *smem_O_half0 = smem_O; float *smem_O_half1 = smem_O + (B_ROW / 2) * HEADDIM; @@ -707,7 +708,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, warpgroup_id_in_cluster); } - } else { + } else if constexpr (Q_IS_K_MAJOR) { thread_block_gemm_single_tile< float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0, @@ -716,6 +717,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0, tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile< + float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, + B_COL, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0, + /*load_accum=*/true, + /*write_to_smem=*/true>( + smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); } initialize_accum_regs<0>(); @@ -745,7 +755,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, warpgroup_id_in_cluster); } - } else { + } else if constexpr (Q_IS_K_MAJOR) { thread_block_gemm_single_tile< float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0, @@ -754,6 +764,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1, tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile< + float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, + B_COL, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0, + /*load_accum=*/true, + /*write_to_smem=*/true>( + smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); } } From 221d5f75c2b889b02e8cf9e286cf5662025f7035 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Thu, 19 Sep 2024 21:31:11 -0700 Subject: [PATCH 45/50] flash: Optimize smem alloc for tcore for 8banks Divide into first half & last half for warpgroup 0 & 1, and allocate Q/K and P/V in different banks for parallel acccess. --- .../regression/flash_attention/flash_impl.hpp | 4 +- tests/regression/flash_attention/kernel.cpp | 138 ++++++++++-------- 2 files changed, 79 insertions(+), 63 deletions(-) diff --git a/tests/regression/flash_attention/flash_impl.hpp b/tests/regression/flash_attention/flash_impl.hpp index 93dc3cc9..47e21c70 100644 --- a/tests/regression/flash_attention/flash_impl.hpp +++ b/tests/regression/flash_attention/flash_impl.hpp @@ -11,8 +11,8 @@ #define ROW_REMAINDER_LOGIC constexpr uint32_t ROWMAX_SETS = 3; -constexpr bool WARP_SPECIALIZED = false; -constexpr bool TENSOR_CORE = false; +constexpr bool WARP_SPECIALIZED = true; +constexpr bool TENSOR_CORE = true; // temporary safety stop for wrong configs static_assert(NUM_CORES == 4); diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 1d88b4de..3c2d463c 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -11,6 +11,9 @@ constexpr bool DEBUG = false; constexpr bool Q_IS_K_MAJOR = true; +// temporary safety stop +static_assert(TENSOR_CORE && WARP_SPECIALIZED); + void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // @perf: All threads are running these compute whose result is mostly same // across the threadblock @@ -90,80 +93,78 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { "flashattention kernel assumes 1 threadblock occupancy per cluster"); uint8_t *smem_per_threadblock = reinterpret_cast( DEV_SMEM_START_ADDR); - float *smem_cursor = reinterpret_cast(smem_per_threadblock); - // constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000; - // float *smem_cursor = reinterpret_cast(DEV_FAKE_SMEM_START_ADDR); - float *smem_Q0 = smem_cursor; - smem_cursor += smem_Q_size; - float *smem_Q1 = smem_cursor; - smem_cursor += smem_Q_size; - float *smem_K0 = smem_cursor; - smem_cursor += smem_K_size; - float *smem_K1 = smem_cursor; - smem_cursor += smem_K_size; - float *smem_V0 = smem_cursor; - smem_cursor += smem_V_size; - float *smem_V1 = smem_cursor; - smem_cursor += smem_V_size; - float *smem_S0 = smem_cursor; - smem_cursor += smem_QK_size; - float *smem_S1 = smem_cursor; - smem_cursor += smem_QK_size; - float *smem_P0 = smem_S0; // in-place update - float *smem_P1 = smem_S1; // in-place update - float *smem_O0 = smem_cursor; - smem_cursor += smem_O_size; - float *smem_O1 = smem_cursor; - smem_cursor += smem_O_size; + constexpr uint32_t smem_start = DEV_SMEM_START_ADDR; + constexpr uint32_t smem_octet0 = 0 * (SMEM_SIZE / 8); + constexpr uint32_t smem_octet1 = 1 * (SMEM_SIZE / 8); + constexpr uint32_t smem_octet2 = 2 * (SMEM_SIZE / 8); + constexpr uint32_t smem_octet3 = 3 * (SMEM_SIZE / 8); + constexpr uint32_t smem_octet4 = 4 * (SMEM_SIZE / 8); + constexpr uint32_t smem_octet5 = 5 * (SMEM_SIZE / 8); + constexpr uint32_t smem_octet6 = 6 * (SMEM_SIZE / 8); + constexpr uint32_t smem_octet7 = 7 * (SMEM_SIZE / 8); - // NOTE: this has to match with smem_* - static_assert(sizeof(elem_t) == sizeof(float)); - constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t); - constexpr uint32_t spad_addr_Q0 = 0; - constexpr uint32_t spad_addr_Q1 = - spad_addr_Q0 + (smem_Q_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_K0 = - spad_addr_Q1 + (smem_Q_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_K1 = - spad_addr_K0 + (smem_K_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_V0 = - spad_addr_K1 + (smem_K_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_V1 = - spad_addr_V0 + (smem_V_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_S0 = - spad_addr_V1 + (smem_V_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_S1 = - spad_addr_S0 + (smem_QK_size * sizeof(float) / spad_addr_factor); + // allocation strategy: since the two warpgroups only access *0 and *1 + // buffers each, allocate *0 in the first half of SMEM, and *1 in the latter + // half + // at the same time, make sure Q and K are in different banks so that they + // can be accessed in parallel for GEMM; same for P and V + constexpr uint32_t smem_Q0_offset = smem_octet0; + constexpr uint32_t smem_Q1_offset = smem_octet4; + constexpr uint32_t smem_K0_offset = smem_octet1; + constexpr uint32_t smem_K1_offset = smem_octet5; + constexpr uint32_t smem_V0_offset = smem_K0_offset + smem_K_size * sizeof(float); + constexpr uint32_t smem_V1_offset = smem_K1_offset + smem_K_size * sizeof(float); + constexpr uint32_t smem_S0_offset = smem_octet2; + constexpr uint32_t smem_S1_offset = smem_octet6; + constexpr uint32_t smem_P0_offset = smem_Q0_offset + smem_Q_size * sizeof(float); + constexpr uint32_t smem_P1_offset = smem_Q1_offset + smem_Q_size * sizeof(float); + constexpr uint32_t smem_O0_offset = smem_octet3; + constexpr uint32_t smem_O1_offset = smem_octet7; + + float *smem_Q0 = reinterpret_cast(smem_start + smem_Q0_offset); + float *smem_Q1 = reinterpret_cast(smem_start + smem_Q1_offset); + float *smem_K0 = reinterpret_cast(smem_start + smem_K0_offset); + float *smem_K1 = reinterpret_cast(smem_start + smem_K1_offset); + float *smem_V0 = reinterpret_cast(smem_start + smem_V0_offset); + float *smem_V1 = reinterpret_cast(smem_start + smem_V1_offset); + float *smem_S0 = reinterpret_cast(smem_start + smem_S0_offset); + float *smem_S1 = reinterpret_cast(smem_start + smem_S1_offset); + float *smem_P0 = reinterpret_cast(smem_start + smem_P0_offset); + float *smem_P1 = reinterpret_cast(smem_start + smem_P1_offset); + float *smem_O0 = reinterpret_cast(smem_start + smem_O0_offset); + float *smem_O1 = reinterpret_cast(smem_start + smem_O1_offset); // allocate rowmax/rowsum storage at the end of the sharedmem address space constexpr uint32_t smem_rowmax_size = B_ROW * ROWMAX_SETS; constexpr uint32_t smem_rowsum_size = B_ROW; constexpr uint32_t smem_O_row_scale_size = B_ROW; - // FIXME: dangerous - smem_cursor = reinterpret_cast(0xff038000); - float *smem_rowmax_0 = smem_cursor; - smem_cursor += smem_rowmax_size; - float *smem_rowmax_1 = smem_cursor; - smem_cursor += smem_rowmax_size; - float *smem_rowsum_0 = smem_cursor; - smem_cursor += smem_rowsum_size; - float *smem_rowsum_1 = smem_cursor; - smem_cursor += smem_rowsum_size; - float *smem_O_row_scale_0 = smem_cursor; - smem_cursor += smem_O_row_scale_size; - float *smem_O_row_scale_1 = smem_cursor; - smem_cursor += smem_O_row_scale_size; + float *smem_cursor_0 = smem_O0 + smem_O_size; + float *smem_cursor_1 = smem_O1 + smem_O_size; + // // FIXME: dangerous + // smem_cursor = reinterpret_cast(0xff038000); + float *smem_rowmax_0 = smem_cursor_0; + smem_cursor_0 += smem_rowmax_size; + float *smem_rowmax_1 = smem_cursor_1; + smem_cursor_1 += smem_rowmax_size; + float *smem_rowsum_0 = smem_cursor_0; + smem_cursor_0 += smem_rowsum_size; + float *smem_rowsum_1 = smem_cursor_1; + smem_cursor_1 += smem_rowsum_size; + float *smem_O_row_scale_0 = smem_cursor_0; + smem_cursor_0 += smem_O_row_scale_size; + float *smem_O_row_scale_1 = smem_cursor_1; + smem_cursor_1 += smem_O_row_scale_size; // sharedmem "scratchpad" area to put temporary data, e.g. for tree reduction // in rowsum // NOTE: out-of bounds is not checked constexpr uint32_t smem_scratchpad_size = threads_per_warpgroup * 2 /*arbitrary slack*/; - float *smem_scratchpad_0 = smem_cursor; - smem_cursor += smem_scratchpad_size; - float *smem_scratchpad_1 = smem_cursor; - smem_cursor += smem_scratchpad_size; + float *smem_scratchpad_0 = smem_cursor_0; + smem_cursor_0 += smem_scratchpad_size; + float *smem_scratchpad_1 = smem_cursor_1; + smem_cursor_1 += smem_scratchpad_size; // select the correct buffer by warpgroup float *smem_Q = (warpgroup_id % 2) ? smem_Q1 : smem_Q0; @@ -179,6 +180,21 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { float *smem_scratchpad = (warpgroup_id % 2) ? smem_scratchpad_1 : smem_scratchpad_0; + static_assert(sizeof(elem_t) == sizeof(float)); + constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t); + constexpr uint32_t spad_addr_Q0 = smem_Q0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_Q1 = smem_Q1_offset / spad_addr_factor; + constexpr uint32_t spad_addr_K0 = smem_K0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_K1 = smem_K1_offset / spad_addr_factor; + constexpr uint32_t spad_addr_V0 = smem_V0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_V1 = smem_V1_offset / spad_addr_factor; + constexpr uint32_t spad_addr_S0 = smem_S0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_S1 = smem_S1_offset / spad_addr_factor; + constexpr uint32_t spad_addr_P0 = smem_P0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_P1 = smem_P1_offset / spad_addr_factor; + constexpr uint32_t spad_addr_O0 = smem_O0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_O1 = smem_O1_offset / spad_addr_factor; + const auto spad_addr_Q = (warpgroup_id % 2) ? spad_addr_Q1 : spad_addr_Q0; const auto spad_addr_K = (warpgroup_id % 2) ? spad_addr_K1 : spad_addr_K0; const auto spad_addr_V = (warpgroup_id % 2) ? spad_addr_V1 : spad_addr_V0; From 6f6ee5616f056feaa8f07c10cfdc23174e4bcab3 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 2 Oct 2024 10:57:45 -0700 Subject: [PATCH 46/50] Add convergent attribute to vx_barrier Note this attribute is only supported by Clang, so this will only be applied to the kernel binary but not runtime. --- kernel/include/gemmini_mmio.h | 2 ++ kernel/include/vx_intrinsics.h | 1 + 2 files changed, 3 insertions(+) diff --git a/kernel/include/gemmini_mmio.h b/kernel/include/gemmini_mmio.h index ebd3a5ba..ed55236c 100644 --- a/kernel/include/gemmini_mmio.h +++ b/kernel/include/gemmini_mmio.h @@ -9,6 +9,8 @@ // #define SMEM_SIZE 0x4000 // 64KB // #define SMEM_SIZE 0x10000 +// 128KB +// #define SMEM_SIZE 0x20000 // 256KB #define SMEM_SIZE 0x40000 diff --git a/kernel/include/vx_intrinsics.h b/kernel/include/vx_intrinsics.h index f6cfbf58..f51601f7 100644 --- a/kernel/include/vx_intrinsics.h +++ b/kernel/include/vx_intrinsics.h @@ -149,6 +149,7 @@ inline void vx_join(unsigned stack_ptr) { } // Warp Barrier +__attribute__((convergent)) inline void vx_barrier(unsigned barried_id, unsigned num_warps) { asm volatile (".insn r %0, 4, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(barried_id), "r"(num_warps)); } From db2789bf23be0c607def642f66913e9ce41e31ea Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 2 Oct 2024 10:59:14 -0700 Subject: [PATCH 47/50] Add asm label for cisc compute --- tests/regression/sgemm_gemmini_dma/kernel.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/regression/sgemm_gemmini_dma/kernel.cpp b/tests/regression/sgemm_gemmini_dma/kernel.cpp index 6b72c92c..6c8b3249 100644 --- a/tests/regression/sgemm_gemmini_dma/kernel.cpp +++ b/tests/regression/sgemm_gemmini_dma/kernel.cpp @@ -118,8 +118,10 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg, } if (tile_k == 0) { + asm volatile("cisc_start_%=:" ::); gemmini_fence(); GEMMINI_CISC_CMD_I(0); + asm volatile("cisc_end_%=:" ::); } else if (tile_k & 1) { gemmini_fence(); GEMMINI_CISC_CMD_I(2); From 34902946262989b453d696b6502d9442842a14bd Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 2 Oct 2024 11:01:23 -0700 Subject: [PATCH 48/50] generate_matrix.py: switch to fp16 rand, generate row-major A --- tests/kernel/tensor/generate_matrix.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/kernel/tensor/generate_matrix.py b/tests/kernel/tensor/generate_matrix.py index d54ece46..c9255465 100644 --- a/tests/kernel/tensor/generate_matrix.py +++ b/tests/kernel/tensor/generate_matrix.py @@ -46,7 +46,7 @@ def pack_fp16_by_row(array): if __name__ == "__main__": M, N, K = parse_mnk() - rand = False + rand = True if not rand: A_array = np.arange(M * K).reshape([M, K]) B_array = np.arange(K * N).reshape([K, N]) @@ -77,12 +77,16 @@ if __name__ == "__main__": np.savez("abc", A_array=A_array, B_array=B_array, C_array=C_array) - fp16 = False + fp16 = True if fp16: A_packed = pack_fp16_by_row(A_array) + A_swizzled = A_packed.reshape([-1, M * 2]) + A_swizzled.astype('float16').tofile("input.a.row.bin") AT_packed = A_packed.transpose([1, 0, 2]) AT_swizzled = AT_packed.reshape([-1, M * 2]) AT_swizzled.astype('float16').tofile("input.a.col.bin") + print('A:') + print(A_swizzled) print('AT:') print(AT_swizzled) B_packed = pack_fp16_by_column(B_array) From 34d0956cd59d783698a4a873d66f2004238dc255 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 2 Oct 2024 15:14:55 -0700 Subject: [PATCH 49/50] tensor: Attempt row-major mapping for C store (WIP) Doesn't work because 1x2 jagged mapping is required to achieve throughput for storing the bigger C matrix (2x4, vs. 2x2 in A). --- tests/kernel/tensor/main.cpp | 57 +++++++++++++++++++++++++++++------- 1 file changed, 47 insertions(+), 10 deletions(-) diff --git a/tests/kernel/tensor/main.cpp b/tests/kernel/tensor/main.cpp index c373507a..05a80454 100644 --- a/tests/kernel/tensor/main.cpp +++ b/tests/kernel/tensor/main.cpp @@ -93,6 +93,23 @@ inline constexpr void map_c_8lanes(const int tid, int &row, int &col) { col += ((tid % 4) / 2) * 2; } +inline constexpr void map_c_rowmajor_8lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // A (row major) + // row 0~ 3: threadgroup 0 + // row 4~ 7: threadgroup 1 + row = tid % 4; + row += tg * 4; + + // B (column major) + // col 0~ 3: threadgroup 0 + // col 4~ 7: threadgroup 1 + col = tid % 4; + col += tg * 4; +} + + void vx_wmma_load() { int tid = vx_thread_id(); int tg = tid / 4; @@ -174,11 +191,31 @@ void store_wmma_result() { int row = 0; int col = 0; - map_c_8lanes(tid, row, col); + // map_c_8lanes(tid, row, col); + map_c_rowmajor_8lanes(tid, row, col); // store C float *const results_wid = results + (DIM_M * DIM_N * wid); - // uncomment to have two accum buffers in rf + + // asm volatile("fsw f16, %0" ::"m"(results_wid[DIM_N * 0 + col])); + // asm volatile("fsw f17, %0" ::"m"(results_wid[DIM_N * 1 + col])); + // asm volatile("fsw f18, %0" ::"m"(results_wid[DIM_N * 2 + col])); + // asm volatile("fsw f19, %0" ::"m"(results_wid[DIM_N * 3 + col])); + // asm volatile("fsw f20, %0" ::"m"(results_wid[DIM_N * 4 + col])); + // asm volatile("fsw f21, %0" ::"m"(results_wid[DIM_N * 5 + col])); + // asm volatile("fsw f22, %0" ::"m"(results_wid[DIM_N * 6 + col])); + // asm volatile("fsw f23, %0" ::"m"(results_wid[DIM_N * 7 + col])); + asm volatile("fsw f24, %0" ::"m"(results_wid[DIM_N * 0 + col])); + asm volatile("fsw f25, %0" ::"m"(results_wid[DIM_N * 1 + col])); + asm volatile("fsw f26, %0" ::"m"(results_wid[DIM_N * 2 + col])); + asm volatile("fsw f27, %0" ::"m"(results_wid[DIM_N * 3 + col])); + asm volatile("fsw f28, %0" ::"m"(results_wid[DIM_N * 4 + col])); + asm volatile("fsw f29, %0" ::"m"(results_wid[DIM_N * 5 + col])); + asm volatile("fsw f30, %0" ::"m"(results_wid[DIM_N * 6 + col])); + asm volatile("fsw f31, %0" ::"m"(results_wid[DIM_N * 7 + col])); + + + // 1x2 jagged mapping // asm volatile("fsw f16, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 0)])); // asm volatile("fsw f17, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 1)])); // asm volatile("fsw f18, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 0)])); @@ -187,14 +224,14 @@ void store_wmma_result() { // asm volatile("fsw f21, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 5)])); // asm volatile("fsw f22, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 4)])); // asm volatile("fsw f23, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 5)])); - asm volatile("fsw f24, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 0)])); - asm volatile("fsw f25, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 1)])); - asm volatile("fsw f26, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 0)])); - asm volatile("fsw f27, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 1)])); - asm volatile("fsw f28, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 4)])); - asm volatile("fsw f29, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 5)])); - asm volatile("fsw f30, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 4)])); - asm volatile("fsw f31, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 5)])); + // asm volatile("fsw f24, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 0)])); + // asm volatile("fsw f25, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 1)])); + // asm volatile("fsw f26, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 0)])); + // asm volatile("fsw f27, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 1)])); + // asm volatile("fsw f28, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 4)])); + // asm volatile("fsw f29, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 5)])); + // asm volatile("fsw f30, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 4)])); + // asm volatile("fsw f31, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 5)])); } void print_wmma_result() { From 68cd6455fe3fa9835909f70d1b3d9a2d544f0b81 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 2 Oct 2024 15:17:44 -0700 Subject: [PATCH 50/50] sgemm_impl: Add mmio reconverge barrier to avoid slip-off; switch to FP32 --- tests/regression/sgemm_tcore/sgemm_impl.hpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index d1b9d76e..0134d6e5 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -6,7 +6,7 @@ #include "include/gemmini.h" #include "gemmini_mmio.h" -#define FP_SIZE 16 +#define FP_SIZE 32 // "fake" fp16 type that only has the correct data width. using float16_t = uint16_t; @@ -822,6 +822,10 @@ __attribute__((always_inline)) inline void thread_block_gemm_single_tile( if (tid_in_threadblock == 0) { gemmini_fence(); } + + // reconverge after mmio + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); } if constexpr (write_to_mem) {