diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index c3298a61..c5efbf3b 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -8,6 +8,8 @@ #include "gemmini_mmio.h" #include "flash_impl.hpp" +#define GEMMINI_NEW_CISC 1 + constexpr bool DEBUG = false; constexpr bool Q_IS_K_MAJOR = true; @@ -94,6 +96,14 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { uint8_t *smem_per_threadblock = reinterpret_cast( DEV_SMEM_START_ADDR); constexpr uint32_t smem_start = DEV_SMEM_START_ADDR; + constexpr uint32_t smem_hexadecile_size = (SMEM_SIZE / 16); + // currently assumes the Q/K/V tile sizes exactly match the hexadecile size + static_assert(smem_hexadecile_size == smem_Q_size * sizeof(float)); + static_assert(smem_hexadecile_size == smem_K_size * sizeof(float)); + static_assert(smem_hexadecile_size == smem_QK_size * sizeof(float)); + static_assert(smem_hexadecile_size == smem_V_size * sizeof(float)); + static_assert(smem_hexadecile_size == smem_O_size * sizeof(float)); + constexpr uint32_t smem_octet0 = 0 * (SMEM_SIZE / 8); constexpr uint32_t smem_octet1 = 1 * (SMEM_SIZE / 8); constexpr uint32_t smem_octet2 = 2 * (SMEM_SIZE / 8); @@ -108,31 +118,31 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // half // at the same time, make sure Q and K are in different banks so that they // can be accessed in parallel for GEMM; same for P and V - constexpr uint32_t smem_Q0_offset = smem_octet0; - constexpr uint32_t smem_Q1_offset = smem_octet4; - constexpr uint32_t smem_K0_offset = smem_octet1; - constexpr uint32_t smem_K1_offset = smem_octet5; - constexpr uint32_t smem_V0_offset = smem_K0_offset + smem_K_size * sizeof(float); - constexpr uint32_t smem_V1_offset = smem_K1_offset + smem_K_size * sizeof(float); - constexpr uint32_t smem_S0_offset = smem_octet2; - constexpr uint32_t smem_S1_offset = smem_octet6; - constexpr uint32_t smem_P0_offset = smem_Q0_offset + smem_Q_size * sizeof(float); - constexpr uint32_t smem_P1_offset = smem_Q1_offset + smem_Q_size * sizeof(float); - constexpr uint32_t smem_O0_offset = smem_octet3; - constexpr uint32_t smem_O1_offset = smem_octet7; + constexpr uint32_t smem_Q0_hexadecile = 2 * 0; // octet0 + constexpr uint32_t smem_Q1_hexadecile = 2 * 4; // octet4 + constexpr uint32_t smem_K0_hexadecile = 2 * 1; // octet1 + constexpr uint32_t smem_K1_hexadecile = 2 * 5; // octet5 + constexpr uint32_t smem_V0_hexadecile = smem_K0_hexadecile + 1; + constexpr uint32_t smem_V1_hexadecile = smem_K1_hexadecile + 1; + constexpr uint32_t smem_S0_hexadecile = 2 * 2; // octet2 + constexpr uint32_t smem_S1_hexadecile = 2 * 6; // octet6 + constexpr uint32_t smem_P0_hexadecile = smem_Q0_hexadecile + 1; + constexpr uint32_t smem_P1_hexadecile = smem_Q1_hexadecile + 1; + constexpr uint32_t smem_O0_hexadecile = 2 * 3; // octet3 + constexpr uint32_t smem_O1_hexadecile = 2 * 7; // octet7 - float *smem_Q0 = reinterpret_cast(smem_start + smem_Q0_offset); - float *smem_Q1 = reinterpret_cast(smem_start + smem_Q1_offset); - float *smem_K0 = reinterpret_cast(smem_start + smem_K0_offset); - float *smem_K1 = reinterpret_cast(smem_start + smem_K1_offset); - float *smem_V0 = reinterpret_cast(smem_start + smem_V0_offset); - float *smem_V1 = reinterpret_cast(smem_start + smem_V1_offset); - float *smem_S0 = reinterpret_cast(smem_start + smem_S0_offset); - float *smem_S1 = reinterpret_cast(smem_start + smem_S1_offset); - float *smem_P0 = reinterpret_cast(smem_start + smem_P0_offset); - float *smem_P1 = reinterpret_cast(smem_start + smem_P1_offset); - float *smem_O0 = reinterpret_cast(smem_start + smem_O0_offset); - float *smem_O1 = reinterpret_cast(smem_start + smem_O1_offset); + float *smem_Q0 = reinterpret_cast(smem_start + smem_Q0_hexadecile * smem_hexadecile_size); + float *smem_Q1 = reinterpret_cast(smem_start + smem_Q1_hexadecile * smem_hexadecile_size); + float *smem_K0 = reinterpret_cast(smem_start + smem_K0_hexadecile * smem_hexadecile_size); + float *smem_K1 = reinterpret_cast(smem_start + smem_K1_hexadecile * smem_hexadecile_size); + float *smem_V0 = reinterpret_cast(smem_start + smem_V0_hexadecile * smem_hexadecile_size); + float *smem_V1 = reinterpret_cast(smem_start + smem_V1_hexadecile * smem_hexadecile_size); + float *smem_S0 = reinterpret_cast(smem_start + smem_S0_hexadecile * smem_hexadecile_size); + float *smem_S1 = reinterpret_cast(smem_start + smem_S1_hexadecile * smem_hexadecile_size); + float *smem_P0 = reinterpret_cast(smem_start + smem_P0_hexadecile * smem_hexadecile_size); + float *smem_P1 = reinterpret_cast(smem_start + smem_P1_hexadecile * smem_hexadecile_size); + float *smem_O0 = reinterpret_cast(smem_start + smem_O0_hexadecile * smem_hexadecile_size); + float *smem_O1 = reinterpret_cast(smem_start + smem_O1_hexadecile * smem_hexadecile_size); // allocate rowmax/rowsum storage at the end of the sharedmem address space constexpr uint32_t smem_rowmax_size = B_ROW * ROWMAX_SETS; @@ -180,25 +190,32 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { float *smem_scratchpad = (warpgroup_id % 2) ? smem_scratchpad_1 : smem_scratchpad_0; +#ifdef GEMMINI_NEW_CISC + const auto spad_hex_Q = (warpgroup_id % 2) ? smem_Q1_hexadecile : smem_Q0_hexadecile; + const auto spad_hex_K = (warpgroup_id % 2) ? smem_K1_hexadecile : smem_K0_hexadecile; + const auto spad_hex_V = (warpgroup_id % 2) ? smem_V1_hexadecile : smem_V0_hexadecile; + const auto spad_hex_S = (warpgroup_id % 2) ? smem_S1_hexadecile : smem_S0_hexadecile; +#else static_assert(sizeof(elem_t) == sizeof(float)); constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t); - constexpr uint32_t spad_addr_Q0 = smem_Q0_offset / spad_addr_factor; - constexpr uint32_t spad_addr_Q1 = smem_Q1_offset / spad_addr_factor; - constexpr uint32_t spad_addr_K0 = smem_K0_offset / spad_addr_factor; - constexpr uint32_t spad_addr_K1 = smem_K1_offset / spad_addr_factor; - constexpr uint32_t spad_addr_V0 = smem_V0_offset / spad_addr_factor; - constexpr uint32_t spad_addr_V1 = smem_V1_offset / spad_addr_factor; - constexpr uint32_t spad_addr_S0 = smem_S0_offset / spad_addr_factor; - constexpr uint32_t spad_addr_S1 = smem_S1_offset / spad_addr_factor; - constexpr uint32_t spad_addr_P0 = smem_P0_offset / spad_addr_factor; - constexpr uint32_t spad_addr_P1 = smem_P1_offset / spad_addr_factor; - constexpr uint32_t spad_addr_O0 = smem_O0_offset / spad_addr_factor; - constexpr uint32_t spad_addr_O1 = smem_O1_offset / spad_addr_factor; + constexpr uint32_t spad_addr_Q0 = smem_Q0_hexadecile * smem_hexadecile_size / spad_addr_factor; + constexpr uint32_t spad_addr_Q1 = smem_Q1_hexadecile * smem_hexadecile_size / spad_addr_factor; + constexpr uint32_t spad_addr_K0 = smem_K0_hexadecile * smem_hexadecile_size / spad_addr_factor; + constexpr uint32_t spad_addr_K1 = smem_K1_hexadecile * smem_hexadecile_size / spad_addr_factor; + constexpr uint32_t spad_addr_V0 = smem_V0_hexadecile * smem_hexadecile_size / spad_addr_factor; + constexpr uint32_t spad_addr_V1 = smem_V1_hexadecile * smem_hexadecile_size / spad_addr_factor; + constexpr uint32_t spad_addr_S0 = smem_S0_hexadecile * smem_hexadecile_size / spad_addr_factor; + constexpr uint32_t spad_addr_S1 = smem_S1_hexadecile * smem_hexadecile_size / spad_addr_factor; + constexpr uint32_t spad_addr_P0 = smem_P0_hexadecile * smem_hexadecile_size / spad_addr_factor; + constexpr uint32_t spad_addr_P1 = smem_P1_hexadecile * smem_hexadecile_size / spad_addr_factor; + constexpr uint32_t spad_addr_O0 = smem_O0_hexadecile * smem_hexadecile_size / spad_addr_factor; + constexpr uint32_t spad_addr_O1 = smem_O1_hexadecile * smem_hexadecile_size / spad_addr_factor; const auto spad_addr_Q = (warpgroup_id % 2) ? spad_addr_Q1 : spad_addr_Q0; const auto spad_addr_K = (warpgroup_id % 2) ? spad_addr_K1 : spad_addr_K0; const auto spad_addr_V = (warpgroup_id % 2) ? spad_addr_V1 : spad_addr_V0; const auto spad_addr_S = (warpgroup_id % 2) ? spad_addr_S1 : spad_addr_S0; +#endif // initialize rowmax/rowsum values in sharedmem thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O, @@ -246,8 +263,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // other warps behave differently on the branch condition. // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); - // move Q and K into SMEM before the loop starts - // static_assert(B_ROW == B_COL, "currently only supports square tiles"); if constexpr (GEMMINI_DMA) { @@ -256,6 +271,19 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if (tid_in_warpgroup == 0) { const float *gmem_Q_tile = gmem_Q + HEADDIM * B_ROW * warpgroup_id; const float *gmem_K_tile = gmem_K; + // do DMA + // + // move Q and K into SMEM before the loop starts. Note this will be done + // separately for the two warpgroups + // +#ifdef GEMMINI_NEW_CISC + // the target addresses of this should match with spad_addr_Q0 and + // spad_addr_K0 set in this kernel + gemmini_tile_load_ab(gmem_Q_tile, gmem_K_tile, spad_hex_Q, + spad_hex_K, 0 /*tile_idx_i*/, + 0 /*tile_idx_j*/, 0 /*tile_idx_k*/, dim_seqlen, + dim_seqlen, HEADDIM, B_ROW, B_COL, HEADDIM); +#else // configure the GMEM addresses for the DMA to read from ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q_tile), (uint64_t)(gmem_K_tile), @@ -265,13 +293,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); gemmini_fence(); -// #define GEMMINI_DMA_CISC -#ifdef GEMMINI_DMA_CISC - GEMMINI_CISC_CMD_I(9); - gemmini_fence(); -#else - // do DMA - // // among other things, this also configures CONFIG_BOUNDS so that the // DMA knows the full matrix dimensions sp_tiled_matmul_full_spad_ws( @@ -281,9 +302,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); - gemmini_fence(); #endif + // block until DMA complete + gemmini_fence(); + // re-configure DMA for K and V load that will later happen in the loop // GMEM addr stride for K gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t), MVIN_SCALE_IDENTITY, @@ -549,13 +572,20 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile), (uint64_t)(gmem_V_tile), k_LOOP_WS_CONFIG_ADDRS_AB) + // do DMA +#ifdef GEMMINI_NEW_CISC + gemmini_tile_load_ab(gmem_K_tile, gmem_V_tile, spad_hex_K, spad_hex_V, + 0 /*tile_idx_i*/, 0 /*tile_idx_j*/, + 0 /*tile_idx_k*/, HEADDIM /*dim_m of KT*/, + HEADDIM /*dim_n of V*/, dim_seqlen /*dim_k of KT*/, + B_ROW, HEADDIM, B_COL); +#else // configure address strides for the DMA // FIXME: unnecessary? GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) | 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); gemmini_fence(); - // do DMA sp_tiled_matmul_full_spad_ws( spad_addr_K, spad_addr_V, /*spad_D=*/0, /*spad_C=*/spad_addr_S, @@ -563,6 +593,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); +#endif + // FIXME: necessary? gemmini_fence(); } } else {