diff --git a/src/backend/RISCv64/Optimize/Peephole.cpp b/src/backend/RISCv64/Optimize/Peephole.cpp index 936af92..e7ac927 100644 --- a/src/backend/RISCv64/Optimize/Peephole.cpp +++ b/src/backend/RISCv64/Optimize/Peephole.cpp @@ -4,6 +4,7 @@ namespace sysy { char PeepholeOptimizer::ID = 0; +bool PeepholeOptimizer::fusedMulAddEnabled = true; // 默认启用浮点乘加融合优化 bool PeepholeOptimizer::runOnFunction(Function *F, AnalysisManager& AM) { // This pass works on MachineFunction level, not IR level @@ -634,6 +635,99 @@ void PeepholeOptimizer::runOnMachineFunction(MachineFunction *mfunc) { } } } + // 8. 浮点乘加融合优化 + // 8.1 fmul.s t1, t2, t3; fadd.s t4, t1, t5 -> fmadd.s t4, t2, t3, t5 + else if (isFusedMulAddEnabled() && + mi1->getOpcode() == RVOpcodes::FMUL_S && + mi2->getOpcode() == RVOpcodes::FADD_S) { + if (mi1->getOperands().size() == 3 && mi2->getOperands().size() == 3) { + auto *fmul_dst = static_cast(mi1->getOperands()[0].get()); + auto *fmul_src1 = static_cast(mi1->getOperands()[1].get()); + auto *fmul_src2 = static_cast(mi1->getOperands()[2].get()); + + auto *fadd_dst = static_cast(mi2->getOperands()[0].get()); + auto *fadd_src1 = static_cast(mi2->getOperands()[1].get()); + auto *fadd_src2 = static_cast(mi2->getOperands()[2].get()); + + // 检查fmul的目标是否是fadd的第一个源操作数 + if (areRegsEqual(fmul_dst, fadd_src1)) { + // 检查中间寄存器是否在后续还会被使用 + bool canOptimize = true; + for (size_t j = i + 2; j < instrs.size(); ++j) { + auto *later_instr = instrs[j].get(); + + // 如果中间寄存器被重新定义,则可以优化 + if (isRegRedefinedAt(later_instr, fmul_dst, areRegsEqual)) { + break; + } + + // 如果中间寄存器被使用,则不能优化 + if (isRegUsedLater(instrs, fmul_dst, j)) { + canOptimize = false; + break; + } + } + + if (canOptimize) { + // 创建新的FMADD_S指令: fmadd.s t4, t2, t3, t5 + auto newInstr = std::make_unique(RVOpcodes::FMADD_S); + newInstr->addOperand(std::make_unique(*fadd_dst)); + newInstr->addOperand(std::make_unique(*fmul_src1)); + newInstr->addOperand(std::make_unique(*fmul_src2)); + newInstr->addOperand(std::make_unique(*fadd_src2)); + instrs[i + 1] = std::move(newInstr); + instrs.erase(instrs.begin() + i); + changed = true; + } + } + } + } + // 8.2 fmul.s t1, t2, t3; fadd.s t4, t5, t1 -> fmadd.s t4, t2, t3, t5 + else if (isFusedMulAddEnabled() && + mi1->getOpcode() == RVOpcodes::FMUL_S && + mi2->getOpcode() == RVOpcodes::FADD_S) { + if (mi1->getOperands().size() == 3 && mi2->getOperands().size() == 3) { + auto *fmul_dst = static_cast(mi1->getOperands()[0].get()); + auto *fmul_src1 = static_cast(mi1->getOperands()[1].get()); + auto *fmul_src2 = static_cast(mi1->getOperands()[2].get()); + + auto *fadd_dst = static_cast(mi2->getOperands()[0].get()); + auto *fadd_src1 = static_cast(mi2->getOperands()[1].get()); + auto *fadd_src2 = static_cast(mi2->getOperands()[2].get()); + + // 检查fmul的目标是否是fadd的第二个源操作数 + if (areRegsEqual(fmul_dst, fadd_src2)) { + // 检查中间寄存器是否在后续还会被使用 + bool canOptimize = true; + for (size_t j = i + 2; j < instrs.size(); ++j) { + auto *later_instr = instrs[j].get(); + + // 如果中间寄存器被重新定义,则可以优化 + if (isRegRedefinedAt(later_instr, fmul_dst, areRegsEqual)) { + break; + } + + // 如果中间寄存器被使用,则不能优化 + if (isRegUsedLater(instrs, fmul_dst, j)) { + canOptimize = false; + break; + } + } + + if (canOptimize) { + // 创建新的FMADD_S指令: fmadd.s t4, t2, t3, t5 + auto newInstr = std::make_unique(RVOpcodes::FMADD_S); + newInstr->addOperand(std::make_unique(*fadd_dst)); + newInstr->addOperand(std::make_unique(*fmul_src1)); + newInstr->addOperand(std::make_unique(*fmul_src2)); + newInstr->addOperand(std::make_unique(*fadd_src1)); + instrs[i + 1] = std::move(newInstr); + instrs.erase(instrs.begin() + i); + changed = true; + } + } + } + } // 根据是否发生变化调整遍历索引 if (!changed) { diff --git a/src/backend/RISCv64/RISCv64AsmPrinter.cpp b/src/backend/RISCv64/RISCv64AsmPrinter.cpp index a9b5146..fd51d4e 100644 --- a/src/backend/RISCv64/RISCv64AsmPrinter.cpp +++ b/src/backend/RISCv64/RISCv64AsmPrinter.cpp @@ -79,6 +79,7 @@ void RISCv64AsmPrinter::printInstruction(MachineInstr* instr, bool debug) { case RVOpcodes::FSUB_S: *OS << "fsub.s "; break; case RVOpcodes::FMUL_S: *OS << "fmul.s "; break; case RVOpcodes::FDIV_S: *OS << "fdiv.s "; break; + case RVOpcodes::FMADD_S: *OS << "fmadd.s "; break; case RVOpcodes::FNEG_S: *OS << "fneg.s "; break; case RVOpcodes::FEQ_S: *OS << "feq.s "; break; case RVOpcodes::FLT_S: *OS << "flt.s "; break; diff --git a/src/include/backend/RISCv64/Optimize/Peephole.h b/src/include/backend/RISCv64/Optimize/Peephole.h index 8b98fd5..d34a875 100644 --- a/src/include/backend/RISCv64/Optimize/Peephole.h +++ b/src/include/backend/RISCv64/Optimize/Peephole.h @@ -23,6 +23,21 @@ public: bool runOnFunction(Function *F, AnalysisManager& AM) override; void runOnMachineFunction(MachineFunction* mfunc); + + /** + * @brief 设置是否启用浮点乘加融合优化 + * @param enabled 是否启用 + */ + static void setFusedMulAddEnabled(bool enabled) { fusedMulAddEnabled = enabled; } + + /** + * @brief 检查是否启用了浮点乘加融合优化 + * @return 是否启用 + */ + static bool isFusedMulAddEnabled() { return fusedMulAddEnabled; } + +private: + static bool fusedMulAddEnabled; // 浮点乘加融合优化开关 }; } // namespace sysy diff --git a/src/include/backend/RISCv64/RISCv64Info.h b/src/include/backend/RISCv64/RISCv64Info.h index ae8114d..603341d 100644 --- a/src/include/backend/RISCv64/RISCv64Info.h +++ b/src/include/backend/RISCv64/RISCv64Info.h @@ -85,6 +85,7 @@ static const std::map, std::vector>> // --- 浮点指令 --- {RVOpcodes::FADD_S, {{0}, {1, 2}}}, {RVOpcodes::FSUB_S, {{0}, {1, 2}}}, {RVOpcodes::FMUL_S, {{0}, {1, 2}}}, {RVOpcodes::FDIV_S, {{0}, {1, 2}}}, + {RVOpcodes::FMADD_S, {{0}, {1, 2, 3}}}, {RVOpcodes::FEQ_S, {{0}, {1, 2}}}, {RVOpcodes::FLT_S, {{0}, {1, 2}}}, {RVOpcodes::FLE_S, {{0}, {1, 2}}}, {RVOpcodes::FCVT_S_W, {{0}, {1}}}, {RVOpcodes::FCVT_W_S, {{0}, {1}}}, {RVOpcodes::FCVT_W_S_RTZ, {{0}, {1}}}, diff --git a/src/include/backend/RISCv64/RISCv64LLIR.h b/src/include/backend/RISCv64/RISCv64LLIR.h index 068fcac..b021f04 100644 --- a/src/include/backend/RISCv64/RISCv64LLIR.h +++ b/src/include/backend/RISCv64/RISCv64LLIR.h @@ -79,6 +79,7 @@ enum class RVOpcodes { FSUB_S, // fsub.s rd, rs1, rs2 FMUL_S, // fmul.s rd, rs1, rs2 FDIV_S, // fdiv.s rd, rs1, rs2 + FMADD_S, // fmadd.s rd, rs1, rs2, rs3 // 浮点比较 (单精度) FEQ_S, // feq.s rd, rs1, rs2 (结果写入整数寄存器rd)