diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index 063bc468..ae0aaf07 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -389,7 +389,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { GEMMINI_CISC_CMD_I(1); } #else - // do matmul + // kickoff matmul // among other things, this also configures CONFIG_BOUNDS so that the // DMA knows the full matrix dimensions // FIXME: perf: prevent GMEM->SMEM load for O tile @@ -402,27 +402,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul); #endif - - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - - // mvout to SMEM - // GEMMINI_CISC_CMD_I(9); - sp_tiled_matmul_full_spad_ws( - /*spad_A=*/spad_addr_P_consume, /*spad_B=*/spad_addr_V_consume, - /*spad_D=*/0, /*spad_C=*/spad_addr_O, - /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), - /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, - /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, - /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad); - gemmini_fence(); - - if constexpr (DEBUG) { - // for copy-out to GMEM - gemmini_fence(); - } } // reconverge from mmio divergence @@ -431,99 +410,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { asm volatile("gemm_pv_finish_%=:" ::); - if constexpr (DEBUG) { - if (warpgroup_id == 0) { - // O after PV - if (tile_k_ == 0) { - thread_block_copy_tile( - smem_O, gmem_tmp_d6, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); - } else if (tile_k_ == 1) { - thread_block_copy_tile( - smem_O, gmem_tmp_d7, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); - } - - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); - } - } - } - - // GEMM I: S = Q*K - // - asm volatile("gemm_qk_start_%=:" ::); - - if (tid_in_warpgroup == 0) { - gemmini_fence(); - // 0,2,.: opcode 0 (quartile 0/2, no accum) - // 1,3,.: opcode 3 (quartile 1/3, no accum) - const uint32_t opcode = 3 * (tile_k & 1); - //GEMMINI_CISC_CMD_I(opcode); - sp_tiled_matmul_full_spad_ws( - spad_addr_Q, spad_addr_K_consume, - /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce, - /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), - /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, - /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, - /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul); - - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - -#if 0 // TODO: speed up mvout to SMEM -// loop_ws variant that skips configuring strides -#define gemmini_loop_ws(I, J, K, pad_I, pad_J, pad_K, A, B, D, C, A_stride, B_stride, D_stride, C_stride, A_transpose, B_transpose, full_C, low_D, ex_accumulate, act, a_spad_id, b_spad_id, is_resadd) \ - { \ - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(pad_K) << 32) | ((uint64_t)(pad_J) << 16) | (uint64_t)(pad_I), ((uint64_t)(K) << 32) | ((uint64_t)(J) << 16) | (uint64_t)(I), k_LOOP_WS_CONFIG_BOUNDS) \ - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, A, B, k_LOOP_WS_CONFIG_ADDRS_AB) \ - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, D, C, k_LOOP_WS_CONFIG_ADDRS_DC) \ - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, A_stride, B_stride, k_LOOP_WS_CONFIG_STRIDES_AB) \ - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, D_stride, C_stride, k_LOOP_WS_CONFIG_STRIDES_DC) \ - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(a_spad_id) << 18) | ((uint64_t)(b_spad_id) << 16) | ((uint64_t)(act) << 8) | ((low_D) << 2) | ((full_C) << 1) | (ex_accumulate), ((is_resadd) << 2) | ((B_transpose) << 1) | (A_transpose), k_LOOP_WS) \ - } -#endif - - // mvout to SMEM - // GEMMINI_CISC_CMD_I(9); - sp_tiled_matmul_full_spad_ws( - /*spad_A=*/spad_addr_Q /*bogus*/, - /*spad_B=*/spad_addr_K_consume /*bogus*/, - /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce, - /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), - /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, - /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, - /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad); - gemmini_fence(); - - if constexpr (DEBUG) { - // for copy-out to GMEM - gemmini_fence(); - } - } - - // reconverge from mmio divergence - threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); - - asm volatile("gemm_qk_finish_%=:" ::); - - if constexpr (DEBUG) { - if (warpgroup_id == 0) { - if (tile_k == 0) { - thread_block_copy_tile( - smem_S_produce, gmem_tmp_d0, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); - } else if (tile_k == 1) { - thread_block_copy_tile( - smem_S_produce, gmem_tmp_d1, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); - } - - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); - } } if (tile_k >= 1) // delay by 1 iters for pipelining @@ -563,7 +449,89 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } } - // TODO: put a synchronization here with GEMM-II + if (tid_in_warpgroup == 0) { + // fence GEMM-II to make sure dependency on O tile is settled + gemmini_fence(); + gemmini_fence(); + + // mvout to SMEM + // GEMMINI_CISC_CMD_I(9); + sp_tiled_matmul_full_spad_ws( + /*spad_A=*/spad_addr_P_consume, /*spad_B=*/spad_addr_V_consume, + /*spad_D=*/0, /*spad_C=*/spad_addr_O, + /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad); + } + + // reconverge from mmio divergence + threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + + if constexpr (DEBUG) { + gemmini_fence(); + + if (warpgroup_id == 0) { + // O after PV + if (tile_k_ == 0) { + thread_block_copy_tile( + smem_O, gmem_tmp_d6, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); + } else if (tile_k_ == 1) { + thread_block_copy_tile( + smem_O, gmem_tmp_d7, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); + } + + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); + } + } + } + + { + // GEMM I: S = Q*K + // + // kick off asynchronously; fence later + asm volatile("gemm_qk_start_%=:" ::); + + if (tid_in_warpgroup == 0) { + gemmini_fence(); + // 0,2,.: opcode 0 (quartile 0/2, no accum) + // 1,3,.: opcode 3 (quartile 1/3, no accum) + const uint32_t opcode = 3 * (tile_k & 1); + //GEMMINI_CISC_CMD_I(opcode); + sp_tiled_matmul_full_spad_ws( + spad_addr_Q, spad_addr_K_consume, + /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce, + /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul); + +#if 0 // TODO: speed up mvout to SMEM + // loop_ws variant that skips configuring strides +#define gemmini_loop_ws(I, J, K, pad_I, pad_J, pad_K, A, B, D, C, A_stride, B_stride, D_stride, C_stride, A_transpose, B_transpose, full_C, low_D, ex_accumulate, act, a_spad_id, b_spad_id, is_resadd) \ + { \ + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(pad_K) << 32) | ((uint64_t)(pad_J) << 16) | (uint64_t)(pad_I), ((uint64_t)(K) << 32) | ((uint64_t)(J) << 16) | (uint64_t)(I), k_LOOP_WS_CONFIG_BOUNDS) \ + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, A, B, k_LOOP_WS_CONFIG_ADDRS_AB) \ + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, D, C, k_LOOP_WS_CONFIG_ADDRS_DC) \ + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, A_stride, B_stride, k_LOOP_WS_CONFIG_STRIDES_AB) \ + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, D_stride, C_stride, k_LOOP_WS_CONFIG_STRIDES_DC) \ + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(a_spad_id) << 18) | ((uint64_t)(b_spad_id) << 16) | ((uint64_t)(act) << 8) | ((low_D) << 2) | ((full_C) << 1) | (ex_accumulate), ((is_resadd) << 2) | ((B_transpose) << 1) | (A_transpose), k_LOOP_WS) \ + } +#endif + } + + // reconverge from mmio divergence + threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + + asm volatile("gemm_qk_finish_%=:" ::); + } + + if (tile_k >= 1) // delay by 1 iters for pipelining + { + const uint32_t tile_k_ = tile_k - 1; // Oi rescale thread_block_O_rescale( @@ -599,6 +567,46 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } } + // fence GEMM I after Oi rescale + if (tid_in_warpgroup == 0) { + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + + // mvout to SMEM + // GEMMINI_CISC_CMD_I(9); + sp_tiled_matmul_full_spad_ws( + /*spad_A=*/spad_addr_Q /*bogus*/, + /*spad_B=*/spad_addr_K_consume /*bogus*/, + /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce, + /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad); + + } + + // reconverge from mmio divergence + threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + + if constexpr (DEBUG) { + if (warpgroup_id == 0) { + if (tile_k == 0) { + thread_block_copy_tile( + smem_S_produce, gmem_tmp_d0, tid_in_warpgroup, + threads_per_warpgroup, warpgroup_id_in_cluster); + } else if (tile_k == 1) { + thread_block_copy_tile( + smem_S_produce, gmem_tmp_d1, tid_in_warpgroup, + threads_per_warpgroup, warpgroup_id_in_cluster); + } + + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); + } + } + // data move for K and V // // Q stays in SMEM for the entire loop @@ -616,6 +624,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // iterations later const float *gmem_V_tile = gmem_V + (HEADDIM * B_COL * (tile_k - 1 /*dragbehind*/)); + + // fence mvout S to SMEM + gemmini_fence(); ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile), (uint64_t)(gmem_V_tile), k_LOOP_WS_CONFIG_ADDRS_AB)