diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index a583feb7..63d3bd56 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -8,7 +8,9 @@ #include "gemmini_mmio.h" #include "flash_impl.hpp" -constexpr bool DEBUG = false; +#define FENCE_GEMM_II + +constexpr bool DEBUG = true; static_assert(GEMMINI_DMA && !WARP_SPECIALIZED, "GEMMINI_DMA should be set and WARP_SPECIALIZED unset"); @@ -290,11 +292,13 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, - /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a); + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); #endif gemmini_fence(); gemmini_fence(); + gemmini_fence(); + gemmini_fence(); // re-configure DMA for K and V load that will later happen in the loop // GMEM addr stride for K @@ -480,27 +484,27 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile), (uint64_t)(gmem_V_tile), k_LOOP_WS_CONFIG_ADDRS_AB) +#endif // configure address strides for the DMA // FIXME: unnecessary? GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) | 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); -#endif // gemmini_fence(); // do DMA if (tile_k == 0) { // we load (k-1)th tile for V; skip V for the 1st iteration, - sp_tiled_matmul_full_spad_ws( - spad_addr_K_produce, spad_addr_V_produce, - /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce /*FIXME:bogus*/, - /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), - /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, - /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, - /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a); + // sp_tiled_matmul_full_spad_ws( + // spad_addr_K_produce, spad_addr_V_produce, + // /*spad_D=*/0, /*spad_C=*/0, + // /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + // /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + // /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + // /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a); } else { sp_tiled_matmul_full_spad_ws( spad_addr_K_produce, spad_addr_V_produce, - /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce /*FIXME:bogus*/, + /*spad_D=*/0, /*spad_C=*/0, /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, @@ -532,18 +536,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const uint32_t tile_k_ = tile_k - 1; if constexpr (DEBUG) { - gemmini_fence(); - gemmini_fence(); - - // verify S = Q*K + // verify S = Q*K before softmax if (warpgroup_id == 0) { if (tile_k_ == 0) { thread_block_copy_tile( - smem_S_produce, gmem_tmp_d0, tid_in_warpgroup_simt, + smem_S_consume, gmem_tmp_d0, tid_in_warpgroup_simt, threads_per_warpgroup_simt, warpgroup_id_simt); } else if (tile_k_ == 1) { thread_block_copy_tile( - smem_S_produce, gmem_tmp_d1, tid_in_warpgroup_simt, + smem_S_consume, gmem_tmp_d1, tid_in_warpgroup_simt, threads_per_warpgroup_simt, warpgroup_id_simt); }