From d2f086344db2f1e72e821e0126690ef2f1f4d522 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sat, 7 Sep 2024 15:48:37 -0700 Subject: [PATCH] flash: Fix DMA addr stride, stop at S=Q*K --- tests/regression/flash_attention/kernel.cpp | 192 ++++++++++++-------- tests/regression/sgemm_tcore/sgemm_impl.hpp | 14 +- 2 files changed, 127 insertions(+), 79 deletions(-) diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 78fc8969..64e28939 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -8,10 +8,9 @@ #include "include/gemmini.h" #include "gemmini_mmio.h" -#define B_ROW BM -#define B_COL BN -// FIXME -#define HEADDIM B_COL +#define B_ROW 64 +#define B_COL 64 +#define HEADDIM 64 constexpr uint32_t ROWMAX_SETS = 3; constexpr bool DEBUG = true; @@ -19,6 +18,8 @@ constexpr bool WARP_SPECIALIZED = false; constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000; +constexpr bool Q_IS_K_MAJOR = true; + // temporary safety stop for wrong configs static_assert(NUM_CORES == 4); static_assert(NUM_THREADS == 8); @@ -99,6 +100,7 @@ inline void thread_block_copy_rowmax(const float *src, float *dest, asm volatile("threadblock_copy_rowmax_finish_%=:" ::); } +template inline void thread_block_copy_tile(const float *src, float *dest, const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock, @@ -113,12 +115,12 @@ inline void thread_block_copy_tile(const float *src, float *dest, // FIXME: dedup this pattern #pragma GCC unroll 1 - for (int row_offset = 0; row_offset < B_ROW; + for (int row_offset = 0; row_offset < dim_row; row_offset += warps_in_threadblock) { const uint32_t row = row_offset + warp_id; - const uint32_t first_thread_offset = B_COL * row; + const uint32_t first_thread_offset = dim_col * row; - constexpr uint32_t per_row_iter = B_COL / NUM_THREADS; + constexpr uint32_t per_row_iter = dim_col / NUM_THREADS; uint32_t thread_offset = first_thread_offset + tid_in_warp; #pragma GCC unroll for (int i = 0; i < per_row_iter; i++) { @@ -533,12 +535,12 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { constexpr uint32_t smem_QK_size = B_ROW * B_COL; constexpr uint32_t smem_V_size = B_COL * HEADDIM; constexpr uint32_t smem_O_size = B_COL * HEADDIM; + static_assert( + threads_per_threadblock == NUM_WARPS * NUM_THREADS * CORES_PER_CLUSTER, + "flashattention kernel assumes 1 threadblock occupancy per cluster"); uint8_t *smem_per_threadblock = reinterpret_cast( - DEV_SMEM_START_ADDR + - sizeof(float_type) * - (smem_QK_size + smem_V_size + smem_O_size) * - threadblock_id_in_cluster); - float *smem_cursor = reinterpret_cast(DEV_FAKE_SMEM_START_ADDR); + DEV_SMEM_START_ADDR); + float *smem_cursor = reinterpret_cast(smem_per_threadblock); float *smem_Q0 = smem_cursor; smem_cursor += smem_Q_size; float *smem_Q1 = smem_cursor; @@ -563,6 +565,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { smem_cursor += smem_O_size; // NOTE: this has to match with smem_* + static_assert(sizeof(elem_t) == sizeof(float)); constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t); constexpr uint32_t spad_addr_Q0 = 0; constexpr uint32_t spad_addr_Q1 = @@ -635,15 +638,18 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); // } + static_assert(!GEMMINI_DMA || Q_IS_K_MAJOR, + "DMA code assumes Q matrix is stored K-major"); + if constexpr (GEMMINI_DMA) { if (tid_in_warpgroup == 0) { gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0); - // configure DMA for Q tile + // configure DMA for the full Q matrix gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY, false, 0); - // configure DMA for K tile - gemmini_extended3_config_ld(B_COL * sizeof(elem_t), MVIN_SCALE_IDENTITY, + // configure DMA for the full K matrix + gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t), MVIN_SCALE_IDENTITY, false, 1); // configure DMA for Q*K store gemmini_extended_config_st(B_COL * sizeof(elem_t), 0, @@ -652,12 +658,12 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } } - // NOTE about barriers: placing barriers around thread-divergent branches may - // cause bugs, since the core doesn't check tmask for barriers. The compiler - // might decide to replicate vx_bar into both paths of a conditional branch, - // which will get evaluated twice along the split/join process and result in - // a different number of calls w.r.t other non-divergent warps and therefore - // stalls. + // NOTE about barriers: Placing barriers around thread-divergent branches may + // cause bugs, because the Vortex core doesn't check for tmask for barriers. + // The compiler might decide to duplicate vx_bar into both paths of a + // conditional branch, which will get evaluated twice because of the way + // branches are handled in SIMT; this might result in stalls especially when + // other warps behave differently on the branch condition. // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); // move Q and K into SMEM before the loop starts @@ -674,13 +680,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q), (uint64_t)(gmem_K), k_LOOP_WS_CONFIG_ADDRS_AB) // configure address strides for the DMA - GEMMINI_CISC_CMD_R((B_COL << 16) | (HEADDIM << 8) | + // GEMMINI_CISC_CMD_R((B_COL << 16) | (HEADDIM << 8) | + // 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); + GEMMINI_CISC_CMD_R((dim_seqlen << 16) | (HEADDIM << 8) | 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); gemmini_fence(); -#define GEMMINI_DMA_CISC +// #define GEMMINI_DMA_CISC #ifdef GEMMINI_DMA_CISC - GEMMINI_CISC_CMD_I(10); + GEMMINI_CISC_CMD_I(9); gemmini_fence(); #else // skip everything except DMA in the loop FSM @@ -693,7 +701,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { sp_tiled_matmul_full_spad_ws( spad_addr_Q0, spad_addr_K0, /*spad_D=*/0, /*spad_C=*/spad_addr_S0, - /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); @@ -704,10 +712,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { asm volatile("dma_move_end_%=:" ::); } else { // load Q; this stays in SMEM for the entire loop - load_tile_to_smem( - dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q, - tid_in_warpgroup); + if constexpr (Q_IS_K_MAJOR) { + load_tile_to_smem( + HEADDIM, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q, + tid_in_warpgroup); + } else { + load_tile_to_smem( + dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q, + tid_in_warpgroup); + } // load K load_tile_to_smem(smem_Q0, gmem_tmp_d0, tid_in_warpgroup, + // threads_per_warpgroup, warpgroup_id_in_cluster); + // thread_block_copy_tile(smem_K0, gmem_tmp_d1, tid_in_warpgroup, + // threads_per_warpgroup, warpgroup_id_in_cluster); - // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); - } + // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + // } -#if 0 asm volatile ("tile_loop_start_%=:" :: ); // "inner loop" along the columns of K^T @@ -751,25 +767,34 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<0>(); initialize_accum_regs<1>(); - thread_block_gemm_single_tile< - float, MemLayout::MN_major, MemLayout::MN_major, B_ROW, B_COL, - HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0, - /*load_accum=*/false, - /*write_to_smem=*/true>( - smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_warpgroup, - threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); + if constexpr (Q_IS_K_MAJOR) { + thread_block_gemm_single_tile< + float, MemLayout::K_major, MemLayout::MN_major, B_ROW, B_COL, + HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0, + /*load_accum=*/false, + /*write_to_smem=*/true>( + smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile< + float, MemLayout::MN_major, MemLayout::MN_major, B_ROW, B_COL, + HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0, + /*load_accum=*/false, + /*write_to_smem=*/true>( + smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } } else { // when warp-specialized, there's only enough warps to do 64x32 tile // size so we need to do 2 GEMM calls static_assert(B_ROW / 2 == 32, "tile size assumption for warp-specialization not met"); - // assumes smem_Q is K-major - // FIXME: fix this to MN-major float *smem_Q_half0 = smem_Q; - float *smem_Q_half1 = smem_Q + (B_ROW / 2); // MN-major - // float *smem_Q_half1 = smem_Q + (B_ROW / 2) * HEADDIM; // K-major + float *smem_Q_half1 = Q_IS_K_MAJOR ? smem_Q + (B_ROW / 2) * HEADDIM + : smem_Q + (B_ROW / 2); float *smem_S_half0 = smem_S; float *smem_S_half1 = smem_S + (B_ROW / 2) * B_COL; @@ -778,26 +803,48 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<1>(); // split by rows into 2 chunks - thread_block_gemm_single_tile< - float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2, - B_COL, HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0, - /*load_accum=*/false, - /*write_to_smem=*/true>( - smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); + if constexpr (Q_IS_K_MAJOR) { + thread_block_gemm_single_tile< + float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL, + HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0, + /*load_accum=*/false, + /*write_to_smem=*/true>( + smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile< + float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2, B_COL, + HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0, + /*load_accum=*/false, + /*write_to_smem=*/true>( + smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } initialize_accum_regs<0>(); initialize_accum_regs<1>(); - thread_block_gemm_single_tile< - float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2, - B_COL, HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0, - /*load_accum=*/false, - /*write_to_smem=*/true>( - smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); + if constexpr (Q_IS_K_MAJOR) { + thread_block_gemm_single_tile< + float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL, + HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0, + /*load_accum=*/false, + /*write_to_smem=*/true>( + smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile< + float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2, B_COL, + HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0, + /*load_accum=*/false, + /*write_to_smem=*/true>( + smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } } } else { // load Q*K @@ -813,11 +860,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if constexpr (DEBUG) { if (warpgroup_id == 0) { if (tile_k == 0) { - thread_block_copy_tile(smem_S, gmem_tmp_d0, + thread_block_copy_tile(smem_S, gmem_tmp_d0, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); } else if (tile_k == 1) { - thread_block_copy_tile(smem_S, gmem_tmp_d1, + thread_block_copy_tile(smem_S, gmem_tmp_d1, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); } @@ -830,6 +877,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // inter-warpgroup barrier before online softmax threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); +#if 0 // Online softmax // thread_block_online_softmax(smem_S, smem_P, tid_in_warpgroup, @@ -897,17 +945,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if (warpgroup_id == 0) { // O before PV if (tile_k == 0) { - thread_block_copy_tile(smem_P, gmem_tmp_d2, tid_in_warpgroup, + thread_block_copy_tile(smem_P, gmem_tmp_d2, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); - thread_block_copy_tile(smem_O, gmem_tmp_d4, tid_in_warpgroup, + thread_block_copy_tile(smem_O, gmem_tmp_d4, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); } else if (tile_k == 1) { - thread_block_copy_tile(smem_P, gmem_tmp_d3, tid_in_warpgroup, + thread_block_copy_tile(smem_P, gmem_tmp_d3, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); - thread_block_copy_tile(smem_O, gmem_tmp_d5, tid_in_warpgroup, + thread_block_copy_tile(smem_O, gmem_tmp_d5, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); } @@ -986,11 +1034,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if (warpgroup_id == 0) { // O after PV if (tile_k == 0) { - thread_block_copy_tile(smem_O, gmem_tmp_d6, tid_in_warpgroup, + thread_block_copy_tile(smem_O, gmem_tmp_d6, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); } else if (tile_k == 1) { - thread_block_copy_tile(smem_O, gmem_tmp_d7, tid_in_warpgroup, + thread_block_copy_tile(smem_O, gmem_tmp_d7, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); } @@ -1006,6 +1054,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // warps_per_threadblock_per_core); // threadblock_barrier(3, // FIXME // NUM_WARPS); +#endif } asm volatile ("tile_loop_finish_%=:" :: ); @@ -1015,7 +1064,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if (warpgroup_id == 0) { threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); } -#endif } int main() { diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index ac07a666..db8df789 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -6,7 +6,7 @@ #include "include/gemmini.h" #include "gemmini_mmio.h" -#define FP_SIZE 16 +#define FP_SIZE 32 // "fake" fp16 type that only has the correct data width. using float16_t = uint16_t; @@ -29,7 +29,7 @@ using float_type = float16_t; // (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER // * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields // BM <= BK*TM*TN -#define BM 128 +#define BM 64 #define BN 64 #if (FP_SIZE == 32) #define BK 64 @@ -62,18 +62,18 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER == #define BK_LOOP 1 // Whether to transpose smem A tile at GMEM->SMEM (produce), or SMEM->RF // (consume). This is because the tensor core expects the A tile to be stored -// in column-major order in SMEM, whereas it will be ultimately stored in -// row-major in the RF. +// in column-major order in SMEM, so a transpose is necessary if A was stored +// row-major in GMEM. // // For correctness, only one of either should be 1. E.g., PRODUCE 1 CONSUME 0 // generates the NN kernel where both A and B are stored row-major in GMEM. // To model the case where the A matrix is already stored column-major in GMEM, // set both to 0. #define TRANSPOSE_AT_PRODUCE 0 -#define TRANSPOSE_AT_CONSUME 0 +#define TRANSPOSE_AT_CONSUME 1 -#define GEMMINI_DMA 0 -#define GEMMINI_DMA_MN_MAJOR 1 +#define GEMMINI_DMA 1 +#define GEMMINI_DMA_MN_MAJOR 0 #if SMEM_SIZE == 0x4000 #define SMEM_ADDR_Q0 ((float * const) 0xff000000) #define SMEM_ADDR_Q1 ((float * const) 0xff001000)