diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index e934a1e1..ef16a6ee 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -199,9 +199,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { constexpr uint32_t skips_only_a = loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/1, /*skip_ldd=*/1, /*skip_ex=*/1, /*skip_stc=*/1); + constexpr uint32_t skips_only_b = + loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/0, /*skip_ldd=*/1, + /*skip_ex=*/1, /*skip_stc=*/1); constexpr uint32_t skips_mvout_spad = loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1, /*skip_ex=*/1, /*skip_stc=*/0); + constexpr uint32_t skips_matmul = + loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1, + /*skip_ex=*/0, /*skip_stc=*/1); constexpr uint32_t skips_matmul_preload = loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/0, /*skip_ex=*/0, /*skip_stc=*/1); @@ -231,12 +237,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // other warps behave differently on the branch condition. // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); - // move Q and K into SMEM before the loop starts - // static_assert(B_ROW == B_COL, "currently only supports square tiles"); + // move Q and K into SMEM before the loop starts + // asm volatile("dma_move_start_%=:" ::); - if (tid_in_warpgroup == 0) { // make sure to read from the correct row of Q const float *gmem_Q_tile = gmem_Q + HEADDIM * B_ROW * warpgroup_id; @@ -249,35 +254,48 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); gemmini_fence(); -#define GEMMINI_DMA_CISC +// #define GEMMINI_DMA_CISC #ifdef GEMMINI_DMA_CISC // the target addresses of this should match with spad_addr_Q0 and // spad_addr_K0 set in this kernel GEMMINI_CISC_CMD_I(10); - gemmini_fence(); - - // need to also move to Q1 for the next iteration - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q_tile), - (uint64_t)(gmem_K_tile), k_LOOP_WS_CONFIG_ADDRS_AB) - GEMMINI_CISC_CMD_R((dim_seqlen << 20) | (HEADDIM << 8) | - 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); - gemmini_fence(); - GEMMINI_CISC_CMD_I(11); - gemmini_fence(); #else // do DMA // // among other things, this also configures CONFIG_BOUNDS so that the // DMA knows the full matrix dimensions sp_tiled_matmul_full_spad_ws( - spad_addr_Q, spad_addr_K, - /*spad_D=*/0, /*spad_C=*/spad_addr_S, + spad_addr_Q0, spad_addr_K0, + /*spad_D=*/0, /*spad_C=*/spad_addr_S0/*bogus*/, /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); - gemmini_fence(); #endif + gemmini_fence(); + + // need to also move Q to spad_addr_Q1 for the next iteration + // FIXME: re-configure necessary? + gmem_K_tile = gmem_K + (B_COL * 1); + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q_tile), + (uint64_t)(gmem_K_tile), k_LOOP_WS_CONFIG_ADDRS_AB) + GEMMINI_CISC_CMD_R((dim_seqlen << 20) | (HEADDIM << 8) | + 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); + gemmini_fence(); +#ifdef GEMMINI_DMA_CISC + // GEMMINI_CISC_CMD_I(11); +#else + sp_tiled_matmul_full_spad_ws( + spad_addr_Q1, spad_addr_K1/*bogus*/, + /*spad_D=*/0, /*spad_C=*/spad_addr_S0/*bogus*/, + /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a); +#endif + + gemmini_fence(); + gemmini_fence(); // re-configure DMA for K and V load that will later happen in the loop // GMEM addr stride for K @@ -376,6 +394,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // do matmul // among other things, this also configures CONFIG_BOUNDS so that the // DMA knows the full matrix dimensions + gemmini_fence(); sp_tiled_matmul_full_spad_ws( spad_addr_P_consume, spad_addr_V_consume, /*spad_D=*/spad_addr_O, /*spad_C=*/spad_addr_O, @@ -437,11 +456,18 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { asm volatile("gemm_qk_start_%=:" ::); if (tid_in_warpgroup == 0) { + gemmini_fence(); // 0,2,.: opcode 0 (quartile 0/2, no accum) // 1,3,.: opcode 3 (quartile 1/3, no accum) const uint32_t opcode = 3 * (tile_k & 1); - gemmini_fence(); - GEMMINI_CISC_CMD_I(opcode); + //GEMMINI_CISC_CMD_I(opcode); + sp_tiled_matmul_full_spad_ws( + spad_addr_Q, spad_addr_K_consume, + /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce, + /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul); gemmini_fence(); gemmini_fence(); @@ -574,7 +600,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } } - // data movement for K and V + // data move for K and V // // Q stays in SMEM for the entire loop asm volatile("move_k_v_start_%=:" ::); @@ -606,7 +632,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { sp_tiled_matmul_full_spad_ws( spad_addr_K_produce, spad_addr_V_produce, /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce /*FIXME:bogus*/, - /*I=*/(HEADDIM / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a); @@ -614,12 +640,14 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { sp_tiled_matmul_full_spad_ws( spad_addr_K_produce, spad_addr_V_produce, /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce /*FIXME:bogus*/, - /*I=*/(HEADDIM / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); } gemmini_fence(); + gemmini_fence(); + gemmini_fence(); } threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);