diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index c46b0bd1..e97ea635 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -10,7 +10,9 @@ #define FENCE_GEMM_II -#define GEMMINI_NEW_CISC +#define GEMMINI_NEW_CISC 1 +static_assert(GEMMINI_NEW_CISC, "NOTE: old non-CISC code is untested; look for " + "any misalignment of fields in ciscArgs."); constexpr bool DEBUG = false; @@ -282,6 +284,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); #endif + + // block until DMA complete gemmini_fence(); // also move Q to spad_addr_Q1 for the second iteration @@ -309,12 +313,12 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); #endif - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); + // block until DMA complete gemmini_fence(); // re-configure DMA for K and V load that will later happen in the loop + // FIXME: not sure necessary with new CISC + // // GMEM addr stride for K gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t), MVIN_SCALE_IDENTITY, false, 0); @@ -424,9 +428,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // FIXME: perf: prevent GMEM->SMEM load for O tile gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); #ifdef GEMMINI_NEW_CISC gemmini_tile_compute( spad_hex_P_consume, spad_hex_V_consume, spad_hex_O, @@ -458,16 +459,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if (tid_in_warpgroup == 0) { // fence to GEMM II completion gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); #ifdef FENCE_GEMM_II + asm volatile("rescale_fence_write_start_%=:" ::); // signal that GEMM II is finished to O rescale step *smem_O_flag = 1; vx_fence(); + asm volatile("rescale_fence_write_end_%=:" ::); #endif + // Kick off GEMM I + // // 0,2,.: opcode 0 (quartile 0/2, no accum) // 1,3,.: opcode 3 (quartile 1/3, no accum) // const uint32_t opcode = 3 * (tile_k & 1); @@ -485,10 +487,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul); #endif - // gemmini_fence(); - // gemmini_fence(); - // gemmini_fence(); - // gemmini_fence(); asm volatile("gemm_qk_finish_%=:" ::); // data move for K and V @@ -511,7 +509,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { (uint64_t)(gmem_V_tile), k_LOOP_WS_CONFIG_ADDRS_AB) #endif - // gemmini_fence(); // do DMA if (tile_k == 0) { @@ -554,9 +551,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // fence everything before going to the next tile gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); } // threadblock_barrier(warpgroup_id_in_cluster, @@ -625,6 +619,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } #ifdef FENCE_GEMM_II + asm volatile("rescale_fence_read_start_%=:" ::); // check flag to make sure GEMM II finished and read-after-write // dependency on O tile is settled for rescale if (tid_in_warpgroup_simt == 0) { @@ -634,6 +629,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { *smem_O_flag = 0; vx_fence(); } + asm volatile("rescale_fence_read_end_%=:" ::); #endif #if 0