From f312792fe9aec24ff3daf567012eac1e2109ee01 Mon Sep 17 00:00:00 2001 From: CGH0S7 <776459475@qq.com> Date: Sun, 3 Aug 2025 13:46:42 +0800 Subject: [PATCH 1/3] =?UTF-8?q?[optimze]=E6=B7=BB=E5=8A=A0=E5=9F=BA?= =?UTF-8?q?=E7=A1=80=E7=9A=84=E9=99=A4=E6=B3=95=E6=8C=87=E4=BB=A4=E4=BC=98?= =?UTF-8?q?=E5=8C=96=EF=BC=8C=E7=9B=AE=E5=89=8D=E5=8F=AA=E5=AF=B9=E9=99=A4?= =?UTF-8?q?=E4=BB=A52=E7=9A=84=E5=B9=82=E6=95=B0=E7=94=9F=E6=95=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/RISCv64/CMakeLists.txt | 1 + .../RISCv64/Optimize/DivStrengthReduction.cpp | 329 ++++++++++++++++++ src/backend/RISCv64/RISCv64Backend.cpp | 6 +- src/backend/RISCv64/RISCv64ISel.cpp | 11 +- .../RISCv64/Optimize/DivStrengthReduction.h | 30 ++ src/include/backend/RISCv64/RISCv64Passes.h | 2 + src/include/midend/IR.h | 18 +- src/include/midend/IRBuilder.h | 3 + src/midend/SysYIRGenerator.cpp | 21 +- src/midend/SysYIRPrinter.cpp | 4 + 10 files changed, 419 insertions(+), 6 deletions(-) create mode 100644 src/backend/RISCv64/Optimize/DivStrengthReduction.cpp create mode 100644 src/include/backend/RISCv64/Optimize/DivStrengthReduction.h diff --git a/src/backend/RISCv64/CMakeLists.txt b/src/backend/RISCv64/CMakeLists.txt index eb9f37f..f1e8f55 100644 --- a/src/backend/RISCv64/CMakeLists.txt +++ b/src/backend/RISCv64/CMakeLists.txt @@ -11,6 +11,7 @@ add_library(riscv64_backend_lib STATIC Optimize/Peephole.cpp Optimize/PostRA_Scheduler.cpp Optimize/PreRA_Scheduler.cpp + Optimize/DivStrengthReduction.cpp ) # 包含后端模块所需的头文件路径 diff --git a/src/backend/RISCv64/Optimize/DivStrengthReduction.cpp b/src/backend/RISCv64/Optimize/DivStrengthReduction.cpp new file mode 100644 index 0000000..c052c5d --- /dev/null +++ b/src/backend/RISCv64/Optimize/DivStrengthReduction.cpp @@ -0,0 +1,329 @@ +#include "DivStrengthReduction.h" + +namespace sysy { + +char DivStrengthReduction::ID = 0; + +bool DivStrengthReduction::runOnFunction(Function *F, AnalysisManager& AM) { + // This pass works on MachineFunction level, not IR level + return false; +} + +void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) { + if (!mfunc) + return; + + bool debug = false; // Set to true for debugging + if (debug) + std::cout << "Running DivStrengthReduction optimization..." << std::endl; + + // 虚拟寄存器分配器 + int next_temp_reg = 1000; + auto createTempReg = [&]() -> int { + return next_temp_reg++; + }; + + // Magic number 信息结构 + struct MagicInfo { + int64_t magic; + int shift; + bool add_indicator; // 是否需要额外的加法修正 + }; + + // 针对缺少MULH指令的简化magic number计算 + auto computeMagicNumber = [](int64_t divisor, bool is_32bit) -> MagicInfo { + if (divisor == 0) return {0, 0, false}; + if (divisor == 1) return {1, 0, false}; + if (divisor == -1) return {-1, 0, false}; + + // 对于没有MULH的情况,我们使用更简单但有效的算法 + // 基于 2^n / divisor 的近似 + + bool neg = divisor < 0; + int64_t d = neg ? -divisor : divisor; + + int word_size = is_32bit ? 32 : 64; + + // 计算合适的移位量 + int shift = word_size; + int64_t magic = ((1LL << shift) + d - 1) / d; + + // 调整magic number以适应MUL指令 + if (is_32bit) { + // 32位情况:调整magic使其适合符号扩展后的乘法 + shift = 32; + magic = ((1LL << shift) + d - 1) / d; + } else { + // 64位情况:使用更保守的算法 + shift = 32; // 使用32位作为基础移位 + magic = ((1LL << shift) + d - 1) / d; + } + + bool add_indicator = false; + + // 检查是否需要加法修正 + if (magic >= (1LL << (word_size - 1))) { + add_indicator = true; + magic -= (1LL << word_size); + } + + if (neg) { + magic = -magic; + } + + return {magic, shift, add_indicator}; + }; + + // 检查是否为2的幂次 + auto isPowerOfTwo = [](int64_t n) -> bool { + return n > 0 && (n & (n - 1)) == 0; + }; + + // 获取2的幂次的指数 + auto getPowerOfTwoExponent = [](int64_t n) -> int { + if (n <= 0 || (n & (n - 1)) != 0) return -1; + int shift = 0; + while (n > 1) { + n >>= 1; + shift++; + } + return shift; + }; + + // 收集需要替换的指令 + struct InstructionReplacement { + size_t index; + std::vector> newInstrs; + }; + + for (auto &mbb_uptr : mfunc->getBlocks()) { + auto &mbb = *mbb_uptr; + auto &instrs = mbb.getInstructions(); + std::vector replacements; + + for (size_t i = 0; i < instrs.size(); ++i) { + auto *instr = instrs[i].get(); + + bool is_32bit = (instr->getOpcode() == RVOpcodes::DIVW); + + // 只处理 DIV 和 DIVW 指令 + if (instr->getOpcode() != RVOpcodes::DIV && !is_32bit) { + continue; + } + + if (instr->getOperands().size() != 3) { + continue; + } + + auto *dst_op = instr->getOperands()[0].get(); + auto *src1_op = instr->getOperands()[1].get(); + auto *src2_op = instr->getOperands()[2].get(); + + // 检查操作数类型 + if (dst_op->getKind() != MachineOperand::KIND_REG || + src1_op->getKind() != MachineOperand::KIND_REG || + src2_op->getKind() != MachineOperand::KIND_IMM) { + continue; + } + + auto *dst_reg = static_cast(dst_op); + auto *src1_reg = static_cast(src1_op); + auto *src2_imm = static_cast(src2_op); + + int64_t divisor = src2_imm->getValue(); + + // 跳过除数为0的情况 + if (divisor == 0) continue; + + std::vector> newInstrs; + + // 情况1: 除数为1 + if (divisor == 1) { + // dst = src1 (直接复制) + auto moveInstr = std::make_unique(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD); + moveInstr->addOperand(std::make_unique(*dst_reg)); + moveInstr->addOperand(std::make_unique(*src1_reg)); + moveInstr->addOperand(std::make_unique(PhysicalReg::ZERO)); + newInstrs.push_back(std::move(moveInstr)); + } + // 情况2: 除数为-1 + else if (divisor == -1) { + // dst = -src1 + auto negInstr = std::make_unique(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB); + negInstr->addOperand(std::make_unique(*dst_reg)); + negInstr->addOperand(std::make_unique(PhysicalReg::ZERO)); + negInstr->addOperand(std::make_unique(*src1_reg)); + newInstrs.push_back(std::move(negInstr)); + } + // 情况3: 正的2的幂次除法 + else if (isPowerOfTwo(divisor)) { + int shift = getPowerOfTwoExponent(divisor); + int temp_reg = createTempReg(); + + // 对于有符号除法,需要处理负数的舍入 + // if (src1 < 0) src1 += (divisor - 1) + + // 获取符号位:temp = src1 >> (word_size - 1) + auto sraSignInstr = std::make_unique(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI); + sraSignInstr->addOperand(std::make_unique(temp_reg)); + sraSignInstr->addOperand(std::make_unique(*src1_reg)); + sraSignInstr->addOperand(std::make_unique(is_32bit ? 31 : 63)); + newInstrs.push_back(std::move(sraSignInstr)); + + // 计算偏移:temp = temp >> (word_size - shift) + if (shift < (is_32bit ? 32 : 64)) { + auto srlInstr = std::make_unique(is_32bit ? RVOpcodes::SRLIW : RVOpcodes::SRLI); + srlInstr->addOperand(std::make_unique(temp_reg)); + srlInstr->addOperand(std::make_unique(temp_reg)); + srlInstr->addOperand(std::make_unique((is_32bit ? 32 : 64) - shift)); + newInstrs.push_back(std::move(srlInstr)); + } + + // 加上偏移:temp = src1 + temp + auto addInstr = std::make_unique(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD); + addInstr->addOperand(std::make_unique(temp_reg)); + addInstr->addOperand(std::make_unique(*src1_reg)); + addInstr->addOperand(std::make_unique(temp_reg)); + newInstrs.push_back(std::move(addInstr)); + + // 最终右移:dst = temp >> shift + auto sraInstr = std::make_unique(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI); + sraInstr->addOperand(std::make_unique(*dst_reg)); + sraInstr->addOperand(std::make_unique(temp_reg)); + sraInstr->addOperand(std::make_unique(shift)); + newInstrs.push_back(std::move(sraInstr)); + } + // 情况4: 负的2的幂次除法 + else if (divisor < 0 && isPowerOfTwo(-divisor)) { + int shift = getPowerOfTwoExponent(-divisor); + int temp_reg = createTempReg(); + + // 先按正数处理 + auto sraSignInstr = std::make_unique(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI); + sraSignInstr->addOperand(std::make_unique(temp_reg)); + sraSignInstr->addOperand(std::make_unique(*src1_reg)); + sraSignInstr->addOperand(std::make_unique(is_32bit ? 31 : 63)); + newInstrs.push_back(std::move(sraSignInstr)); + + if (shift < (is_32bit ? 32 : 64)) { + auto srlInstr = std::make_unique(is_32bit ? RVOpcodes::SRLIW : RVOpcodes::SRLI); + srlInstr->addOperand(std::make_unique(temp_reg)); + srlInstr->addOperand(std::make_unique(temp_reg)); + srlInstr->addOperand(std::make_unique((is_32bit ? 32 : 64) - shift)); + newInstrs.push_back(std::move(srlInstr)); + } + + auto addInstr = std::make_unique(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD); + addInstr->addOperand(std::make_unique(temp_reg)); + addInstr->addOperand(std::make_unique(*src1_reg)); + addInstr->addOperand(std::make_unique(temp_reg)); + newInstrs.push_back(std::move(addInstr)); + + auto sraInstr = std::make_unique(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI); + sraInstr->addOperand(std::make_unique(temp_reg)); + sraInstr->addOperand(std::make_unique(temp_reg)); + sraInstr->addOperand(std::make_unique(shift)); + newInstrs.push_back(std::move(sraInstr)); + + // 然后取反 + auto negInstr = std::make_unique(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB); + negInstr->addOperand(std::make_unique(*dst_reg)); + negInstr->addOperand(std::make_unique(PhysicalReg::ZERO)); + negInstr->addOperand(std::make_unique(temp_reg)); + newInstrs.push_back(std::move(negInstr)); + } + // 情况5: 通用magic number算法(针对没有MULH的情况进行了简化) + else { + // 对于一般除法,在没有MULH的情况下,我们采用更保守的策略 + // 只处理一些简单的常数除法,复杂的情况保持原始除法指令 + + // 检查是否为小的常数(可以用简单乘法处理) + if (std::abs(divisor) <= 1024) { // 限制在较小的除数范围内 + auto magic_info = computeMagicNumber(divisor, is_32bit); + + if (magic_info.magic == 0) continue; + + int magic_reg = createTempReg(); + int temp_reg = createTempReg(); + + // 加载magic number到寄存器 + auto loadInstr = std::make_unique(RVOpcodes::LI); + loadInstr->addOperand(std::make_unique(magic_reg)); + loadInstr->addOperand(std::make_unique(magic_info.magic)); + newInstrs.push_back(std::move(loadInstr)); + + // 使用普通乘法模拟高位乘法 + if (is_32bit) { + // 32位:使用MULW + auto mulInstr = std::make_unique(RVOpcodes::MULW); + mulInstr->addOperand(std::make_unique(temp_reg)); + mulInstr->addOperand(std::make_unique(*src1_reg)); + mulInstr->addOperand(std::make_unique(magic_reg)); + newInstrs.push_back(std::move(mulInstr)); + + // 右移得到近似结果 + auto sraInstr = std::make_unique(RVOpcodes::SRAIW); + sraInstr->addOperand(std::make_unique(temp_reg)); + sraInstr->addOperand(std::make_unique(temp_reg)); + sraInstr->addOperand(std::make_unique(magic_info.shift)); + newInstrs.push_back(std::move(sraInstr)); + } else { + // 64位:使用MUL + auto mulInstr = std::make_unique(RVOpcodes::MUL); + mulInstr->addOperand(std::make_unique(temp_reg)); + mulInstr->addOperand(std::make_unique(*src1_reg)); + mulInstr->addOperand(std::make_unique(magic_reg)); + newInstrs.push_back(std::move(mulInstr)); + + // 右移得到近似结果 + auto sraInstr = std::make_unique(RVOpcodes::SRAI); + sraInstr->addOperand(std::make_unique(temp_reg)); + sraInstr->addOperand(std::make_unique(temp_reg)); + sraInstr->addOperand(std::make_unique(magic_info.shift)); + newInstrs.push_back(std::move(sraInstr)); + } + + // 符号修正:处理负数被除数 + int sign_reg = createTempReg(); + + // 获取被除数的符号位 + auto sraSignInstr = std::make_unique(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI); + sraSignInstr->addOperand(std::make_unique(sign_reg)); + sraSignInstr->addOperand(std::make_unique(*src1_reg)); + sraSignInstr->addOperand(std::make_unique(is_32bit ? 31 : 63)); + newInstrs.push_back(std::move(sraSignInstr)); + + // 最终结果:dst = temp - sign(对于正除数)或 dst = temp + sign(对于负除数) + if (divisor > 0) { + auto finalSubInstr = std::make_unique(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB); + finalSubInstr->addOperand(std::make_unique(*dst_reg)); + finalSubInstr->addOperand(std::make_unique(temp_reg)); + finalSubInstr->addOperand(std::make_unique(sign_reg)); + newInstrs.push_back(std::move(finalSubInstr)); + } else { + auto finalAddInstr = std::make_unique(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD); + finalAddInstr->addOperand(std::make_unique(*dst_reg)); + finalAddInstr->addOperand(std::make_unique(temp_reg)); + finalAddInstr->addOperand(std::make_unique(sign_reg)); + newInstrs.push_back(std::move(finalAddInstr)); + } + } + // 对于大的除数或复杂情况,保持原始除法指令不变 + } + + if (!newInstrs.empty()) { + replacements.push_back({i, std::move(newInstrs)}); + } + } + + // 批量应用替换(从后往前处理避免索引问题) + for (auto it = replacements.rbegin(); it != replacements.rend(); ++it) { + instrs.erase(instrs.begin() + it->index); + instrs.insert(instrs.begin() + it->index, + std::make_move_iterator(it->newInstrs.begin()), + std::make_move_iterator(it->newInstrs.end())); + } + } +} + +} // namespace sysy \ No newline at end of file diff --git a/src/backend/RISCv64/RISCv64Backend.cpp b/src/backend/RISCv64/RISCv64Backend.cpp index 2797eb7..ae45d65 100644 --- a/src/backend/RISCv64/RISCv64Backend.cpp +++ b/src/backend/RISCv64/RISCv64Backend.cpp @@ -119,7 +119,11 @@ std::string RISCv64CodeGen::function_gen(Function* func) { RISCv64AsmPrinter printer1(mfunc.get()); printer1.run(ss1, true); - // 阶段 2: 指令调度 (Instruction Scheduling) + // 阶段 2: 除法强度削弱优化 (Division Strength Reduction) + DivStrengthReduction div_strength_reduction; + div_strength_reduction.runOnMachineFunction(mfunc.get()); + + // 阶段 2.1: 指令调度 (Instruction Scheduling) PreRA_Scheduler scheduler; scheduler.runOnMachineFunction(mfunc.get()); diff --git a/src/backend/RISCv64/RISCv64ISel.cpp b/src/backend/RISCv64/RISCv64ISel.cpp index 871a3e7..e6b0929 100644 --- a/src/backend/RISCv64/RISCv64ISel.cpp +++ b/src/backend/RISCv64/RISCv64ISel.cpp @@ -539,6 +539,15 @@ void RISCv64ISel::selectNode(DAGNode* node) { CurMBB->addInstruction(std::move(instr)); break; } + case Instruction::kSRA: { + auto rhs_const = dynamic_cast(rhs); + auto instr = std::make_unique(RVOpcodes::SRAIW); + instr->addOperand(std::make_unique(dest_vreg)); + instr->addOperand(std::make_unique(lhs_vreg)); + instr->addOperand(std::make_unique(rhs_const->getInt())); + CurMBB->addInstruction(std::move(instr)); + break; + } case BinaryInst::kICmpEQ: { // 等于 (a == b) -> (subw; seqz) auto sub = std::make_unique(RVOpcodes::SUBW); sub->addOperand(std::make_unique(dest_vreg)); @@ -1473,7 +1482,7 @@ std::vector> RISCv64ISel::build_dag(BasicB } } } - if (bin->getKind() >= Instruction::kFAdd) { // 假设浮点指令枚举值更大 + if (bin->isFPBinary()) { // 假设浮点指令枚举值更大 auto fbin_node = create_node(DAGNode::FBINARY, bin, value_to_node, nodes_storage); fbin_node->operands.push_back(get_operand_node(bin->getLhs(), value_to_node, nodes_storage)); fbin_node->operands.push_back(get_operand_node(bin->getRhs(), value_to_node, nodes_storage)); diff --git a/src/include/backend/RISCv64/Optimize/DivStrengthReduction.h b/src/include/backend/RISCv64/Optimize/DivStrengthReduction.h new file mode 100644 index 0000000..685bb19 --- /dev/null +++ b/src/include/backend/RISCv64/Optimize/DivStrengthReduction.h @@ -0,0 +1,30 @@ +#ifndef RISCV64_DIV_STRENGTH_REDUCTION_H +#define RISCV64_DIV_STRENGTH_REDUCTION_H + +#include "RISCv64LLIR.h" +#include "Pass.h" + +namespace sysy { + +/** + * @class DivStrengthReduction + * @brief 除法强度削弱优化器 + * * 将除法运算转换为乘法运算,使用magic number算法 + * 适用于除数为常数的情况,可以显著提高性能 + */ +class DivStrengthReduction : public Pass { +public: + static char ID; + + DivStrengthReduction() : Pass("div-strength-reduction", Granularity::Function, PassKind::Optimization) {} + + void *getPassID() const override { return &ID; } + + bool runOnFunction(Function *F, AnalysisManager& AM) override; + + void runOnMachineFunction(MachineFunction* mfunc); +}; + +} // namespace sysy + +#endif // RISCV64_DIV_STRENGTH_REDUCTION_H \ No newline at end of file diff --git a/src/include/backend/RISCv64/RISCv64Passes.h b/src/include/backend/RISCv64/RISCv64Passes.h index de08882..b456994 100644 --- a/src/include/backend/RISCv64/RISCv64Passes.h +++ b/src/include/backend/RISCv64/RISCv64Passes.h @@ -9,6 +9,8 @@ #include "LegalizeImmediates.h" #include "PrologueEpilogueInsertion.h" #include "Pass.h" +#include "DivStrengthReduction.h" + namespace sysy { diff --git a/src/include/midend/IR.h b/src/include/midend/IR.h index 901a3b7..2e4d72b 100644 --- a/src/include/midend/IR.h +++ b/src/include/midend/IR.h @@ -708,6 +708,8 @@ class Instruction : public User { kPhi = 0x1UL << 39, kBitItoF = 0x1UL << 40, kBitFtoI = 0x1UL << 41, + kSRA = 0x1UL << 42, + kMulh = 0x1UL << 43, }; protected: @@ -804,6 +806,12 @@ public: return "Memset"; case kPhi: return "Phi"; + case kBitItoF: + return "BitItoF"; + case kBitFtoI: + return "BitFtoI"; + case kSRA: + return "SRA"; default: return "Unknown"; } @@ -815,11 +823,15 @@ public: bool isBinary() const { static constexpr uint64_t BinaryOpMask = - (kAdd | kSub | kMul | kDiv | kRem | kAnd | kOr) | - (kICmpEQ | kICmpNE | kICmpLT | kICmpGT | kICmpLE | kICmpGE) | + (kAdd | kSub | kMul | kDiv | kRem | kAnd | kOr | kSRA) | + (kICmpEQ | kICmpNE | kICmpLT | kICmpGT | kICmpLE | kICmpGE); + return kind & BinaryOpMask; + } + bool isFPBinary() const { + static constexpr uint64_t FPBinaryOpMask = (kFAdd | kFSub | kFMul | kFDiv) | (kFCmpEQ | kFCmpNE | kFCmpLT | kFCmpGT | kFCmpLE | kFCmpGE); - return kind & BinaryOpMask; + return kind & FPBinaryOpMask; } bool isUnary() const { static constexpr uint64_t UnaryOpMask = diff --git a/src/include/midend/IRBuilder.h b/src/include/midend/IRBuilder.h index 73e40e5..c232578 100644 --- a/src/include/midend/IRBuilder.h +++ b/src/include/midend/IRBuilder.h @@ -217,6 +217,9 @@ class IRBuilder { BinaryInst * createOrInst(Value *lhs, Value *rhs, const std::string &name = "") { return createBinaryInst(Instruction::kOr, Type::getIntType(), lhs, rhs, name); } ///< 创建按位或指令 + BinaryInst * createSRAInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kSRA, Type::getIntType(), lhs, rhs, name); + } ///< 创建算术右移指令 CallInst * createCallInst(Function *callee, const std::vector &args, const std::string &name = "") { std::string newName; if (name.empty() && callee->getReturnType() != Type::getVoidType()) { diff --git a/src/midend/SysYIRGenerator.cpp b/src/midend/SysYIRGenerator.cpp index 812bd39..0b2d751 100644 --- a/src/midend/SysYIRGenerator.cpp +++ b/src/midend/SysYIRGenerator.cpp @@ -249,7 +249,26 @@ void SysYIRGenerator::compute() { case BinaryOp::ADD: resultValue = builder.createAddInst(lhs, rhs); break; case BinaryOp::SUB: resultValue = builder.createSubInst(lhs, rhs); break; case BinaryOp::MUL: resultValue = builder.createMulInst(lhs, rhs); break; - case BinaryOp::DIV: resultValue = builder.createDivInst(lhs, rhs); break; + case BinaryOp::DIV: { + ConstantInteger *rhsConst = dynamic_cast(rhs); + if (rhsConst) { + int divisor = rhsConst->getInt(); + if (divisor > 0 && (divisor & (divisor - 1)) == 0) { + int shift = 0; + int temp = divisor; + while (temp > 1) { + temp >>= 1; + shift++; + } + resultValue = builder.createSRAInst(lhs, ConstantInteger::get(shift)); + } else { + resultValue = builder.createDivInst(lhs, rhs); + } + } else { + resultValue = builder.createDivInst(lhs, rhs); + } + break; + } case BinaryOp::MOD: resultValue = builder.createRemInst(lhs, rhs); break; } } else if (commonType == Type::getFloatType()) { diff --git a/src/midend/SysYIRPrinter.cpp b/src/midend/SysYIRPrinter.cpp index 0a024d2..7891bca 100644 --- a/src/midend/SysYIRPrinter.cpp +++ b/src/midend/SysYIRPrinter.cpp @@ -240,6 +240,8 @@ void SysYPrinter::printInst(Instruction *pInst) { case Kind::kMul: case Kind::kDiv: case Kind::kRem: + case Kind::kSRA: + case Kind::kMulh: case Kind::kFAdd: case Kind::kFSub: case Kind::kFMul: @@ -272,6 +274,8 @@ void SysYPrinter::printInst(Instruction *pInst) { case Kind::kMul: std::cout << "mul"; break; case Kind::kDiv: std::cout << "sdiv"; break; case Kind::kRem: std::cout << "srem"; break; + case Kind::kSRA: std::cout << "ashr"; break; + case Kind::kMulh: std::cout << "mulh"; break; case Kind::kFAdd: std::cout << "fadd"; break; case Kind::kFSub: std::cout << "fsub"; break; case Kind::kFMul: std::cout << "fmul"; break; From 0ce742a86eca84071c32c4215a3099586326772f Mon Sep 17 00:00:00 2001 From: CGH0S7 <776459475@qq.com> Date: Sun, 3 Aug 2025 14:37:33 +0800 Subject: [PATCH 2/3] =?UTF-8?q?[optimize]=E6=B7=BB=E5=8A=A0=E6=9B=B4?= =?UTF-8?q?=E4=B8=BA=E9=80=9A=E7=94=A8=E7=9A=84=E9=99=A4=E6=B3=95=E5=BC=BA?= =?UTF-8?q?=E5=BA=A6=E5=89=8A=E5=87=8FPass,=20=E4=B8=8D=E5=8F=97=E9=99=A4?= =?UTF-8?q?=E6=95=B0=E9=99=90=E5=88=B6=E6=9B=BF=E6=8D=A2div=E6=8C=87?= =?UTF-8?q?=E4=BB=A4=EF=BC=8C=E4=B8=8D=E5=BD=B1=E5=93=8D=E5=BD=93=E5=89=8D?= =?UTF-8?q?=E5=88=86=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../RISCv64/Optimize/DivStrengthReduction.cpp | 319 ++++++++---------- src/backend/RISCv64/RISCv64AsmPrinter.cpp | 2 +- src/include/backend/RISCv64/RISCv64LLIR.h | 2 +- src/include/midend/IR.h | 4 +- src/include/midend/IRBuilder.h | 3 + src/midend/SysYIRGenerator.cpp | 23 ++ test_div_optimization.sy | 9 + 7 files changed, 175 insertions(+), 187 deletions(-) create mode 100644 test_div_optimization.sy diff --git a/src/backend/RISCv64/Optimize/DivStrengthReduction.cpp b/src/backend/RISCv64/Optimize/DivStrengthReduction.cpp index c052c5d..bcd7bda 100644 --- a/src/backend/RISCv64/Optimize/DivStrengthReduction.cpp +++ b/src/backend/RISCv64/Optimize/DivStrengthReduction.cpp @@ -1,4 +1,6 @@ #include "DivStrengthReduction.h" +#include +#include namespace sysy { @@ -17,69 +19,49 @@ void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) { if (debug) std::cout << "Running DivStrengthReduction optimization..." << std::endl; - // 虚拟寄存器分配器 int next_temp_reg = 1000; auto createTempReg = [&]() -> int { return next_temp_reg++; }; - // Magic number 信息结构 struct MagicInfo { int64_t magic; int shift; - bool add_indicator; // 是否需要额外的加法修正 }; - // 针对缺少MULH指令的简化magic number计算 - auto computeMagicNumber = [](int64_t divisor, bool is_32bit) -> MagicInfo { - if (divisor == 0) return {0, 0, false}; - if (divisor == 1) return {1, 0, false}; - if (divisor == -1) return {-1, 0, false}; - - // 对于没有MULH的情况,我们使用更简单但有效的算法 - // 基于 2^n / divisor 的近似 - - bool neg = divisor < 0; - int64_t d = neg ? -divisor : divisor; - + auto computeMagic = [](int64_t d, bool is_32bit) -> MagicInfo { int word_size = is_32bit ? 32 : 64; + uint64_t ad = std::abs(d); - // 计算合适的移位量 - int shift = word_size; - int64_t magic = ((1LL << shift) + d - 1) / d; + if (ad == 0) return {0, 0}; - // 调整magic number以适应MUL指令 + int l = std::floor(std::log2(ad)); + if ((ad & (ad - 1)) == 0) { // power of 2 + l = 0; // special case for power of 2, shift will be calculated differently + } + + __int128_t one = 1; + __int128_t num; + int total_shift; + if (is_32bit) { - // 32位情况:调整magic使其适合符号扩展后的乘法 - shift = 32; - magic = ((1LL << shift) + d - 1) / d; + total_shift = 31 + l; + num = one << total_shift; } else { - // 64位情况:使用更保守的算法 - shift = 32; // 使用32位作为基础移位 - magic = ((1LL << shift) + d - 1) / d; + total_shift = 63 + l; + num = one << total_shift; } - bool add_indicator = false; + __int128_t den = ad; + int64_t magic = (num / den) + 1; - // 检查是否需要加法修正 - if (magic >= (1LL << (word_size - 1))) { - add_indicator = true; - magic -= (1LL << word_size); - } - - if (neg) { - magic = -magic; - } - - return {magic, shift, add_indicator}; + return {magic, total_shift}; }; - // 检查是否为2的幂次 auto isPowerOfTwo = [](int64_t n) -> bool { return n > 0 && (n & (n - 1)) == 0; }; - // 获取2的幂次的指数 auto getPowerOfTwoExponent = [](int64_t n) -> int { if (n <= 0 || (n & (n - 1)) != 0) return -1; int shift = 0; @@ -90,9 +72,9 @@ void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) { return shift; }; - // 收集需要替换的指令 struct InstructionReplacement { size_t index; + size_t count_to_erase; std::vector> newInstrs; }; @@ -106,7 +88,6 @@ void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) { bool is_32bit = (instr->getOpcode() == RVOpcodes::DIVW); - // 只处理 DIV 和 DIVW 指令 if (instr->getOpcode() != RVOpcodes::DIV && !is_32bit) { continue; } @@ -118,100 +99,74 @@ void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) { auto *dst_op = instr->getOperands()[0].get(); auto *src1_op = instr->getOperands()[1].get(); auto *src2_op = instr->getOperands()[2].get(); - - // 检查操作数类型 - if (dst_op->getKind() != MachineOperand::KIND_REG || - src1_op->getKind() != MachineOperand::KIND_REG || - src2_op->getKind() != MachineOperand::KIND_IMM) { + + int64_t divisor = 0; + bool const_divisor_found = false; + size_t instructions_to_replace = 1; + + if (src2_op->getKind() == MachineOperand::KIND_IMM) { + divisor = static_cast(src2_op)->getValue(); + const_divisor_found = true; + } else if (src2_op->getKind() == MachineOperand::KIND_REG) { + if (i > 0) { + auto *prev_instr = instrs[i - 1].get(); + if (prev_instr->getOpcode() == RVOpcodes::LI && prev_instr->getOperands().size() == 2) { + auto *li_dst_op = prev_instr->getOperands()[0].get(); + auto *li_imm_op = prev_instr->getOperands()[1].get(); + if (li_dst_op->getKind() == MachineOperand::KIND_REG && li_imm_op->getKind() == MachineOperand::KIND_IMM) { + auto *div_reg_op = static_cast(src2_op); + auto *li_dst_reg_op = static_cast(li_dst_op); + if (div_reg_op->isVirtual() && li_dst_reg_op->isVirtual() && + div_reg_op->getVRegNum() == li_dst_reg_op->getVRegNum()) { + divisor = static_cast(li_imm_op)->getValue(); + const_divisor_found = true; + instructions_to_replace = 2; + } + } + } + } + } + + if (!const_divisor_found) { continue; } auto *dst_reg = static_cast(dst_op); auto *src1_reg = static_cast(src1_op); - auto *src2_imm = static_cast(src2_op); - int64_t divisor = src2_imm->getValue(); - - // 跳过除数为0的情况 if (divisor == 0) continue; std::vector> newInstrs; - // 情况1: 除数为1 if (divisor == 1) { - // dst = src1 (直接复制) auto moveInstr = std::make_unique(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD); moveInstr->addOperand(std::make_unique(*dst_reg)); moveInstr->addOperand(std::make_unique(*src1_reg)); moveInstr->addOperand(std::make_unique(PhysicalReg::ZERO)); newInstrs.push_back(std::move(moveInstr)); } - // 情况2: 除数为-1 else if (divisor == -1) { - // dst = -src1 auto negInstr = std::make_unique(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB); negInstr->addOperand(std::make_unique(*dst_reg)); negInstr->addOperand(std::make_unique(PhysicalReg::ZERO)); negInstr->addOperand(std::make_unique(*src1_reg)); newInstrs.push_back(std::move(negInstr)); } - // 情况3: 正的2的幂次除法 - else if (isPowerOfTwo(divisor)) { - int shift = getPowerOfTwoExponent(divisor); + else if (isPowerOfTwo(std::abs(divisor))) { + int shift = getPowerOfTwoExponent(std::abs(divisor)); int temp_reg = createTempReg(); - // 对于有符号除法,需要处理负数的舍入 - // if (src1 < 0) src1 += (divisor - 1) - - // 获取符号位:temp = src1 >> (word_size - 1) auto sraSignInstr = std::make_unique(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI); sraSignInstr->addOperand(std::make_unique(temp_reg)); sraSignInstr->addOperand(std::make_unique(*src1_reg)); sraSignInstr->addOperand(std::make_unique(is_32bit ? 31 : 63)); newInstrs.push_back(std::move(sraSignInstr)); - // 计算偏移:temp = temp >> (word_size - shift) - if (shift < (is_32bit ? 32 : 64)) { - auto srlInstr = std::make_unique(is_32bit ? RVOpcodes::SRLIW : RVOpcodes::SRLI); - srlInstr->addOperand(std::make_unique(temp_reg)); - srlInstr->addOperand(std::make_unique(temp_reg)); - srlInstr->addOperand(std::make_unique((is_32bit ? 32 : 64) - shift)); - newInstrs.push_back(std::move(srlInstr)); - } - - // 加上偏移:temp = src1 + temp - auto addInstr = std::make_unique(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD); - addInstr->addOperand(std::make_unique(temp_reg)); - addInstr->addOperand(std::make_unique(*src1_reg)); - addInstr->addOperand(std::make_unique(temp_reg)); - newInstrs.push_back(std::move(addInstr)); - - // 最终右移:dst = temp >> shift - auto sraInstr = std::make_unique(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI); - sraInstr->addOperand(std::make_unique(*dst_reg)); - sraInstr->addOperand(std::make_unique(temp_reg)); - sraInstr->addOperand(std::make_unique(shift)); - newInstrs.push_back(std::move(sraInstr)); - } - // 情况4: 负的2的幂次除法 - else if (divisor < 0 && isPowerOfTwo(-divisor)) { - int shift = getPowerOfTwoExponent(-divisor); - int temp_reg = createTempReg(); - - // 先按正数处理 - auto sraSignInstr = std::make_unique(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI); - sraSignInstr->addOperand(std::make_unique(temp_reg)); - sraSignInstr->addOperand(std::make_unique(*src1_reg)); - sraSignInstr->addOperand(std::make_unique(is_32bit ? 31 : 63)); - newInstrs.push_back(std::move(sraSignInstr)); - - if (shift < (is_32bit ? 32 : 64)) { - auto srlInstr = std::make_unique(is_32bit ? RVOpcodes::SRLIW : RVOpcodes::SRLI); - srlInstr->addOperand(std::make_unique(temp_reg)); - srlInstr->addOperand(std::make_unique(temp_reg)); - srlInstr->addOperand(std::make_unique((is_32bit ? 32 : 64) - shift)); - newInstrs.push_back(std::move(srlInstr)); - } + auto srlInstr = std::make_unique(is_32bit ? RVOpcodes::SRLIW : RVOpcodes::SRLI); + srlInstr->addOperand(std::make_unique(temp_reg)); + srlInstr->addOperand(std::make_unique(temp_reg)); + srlInstr->addOperand(std::make_unique((is_32bit ? 32 : 64) - shift)); + newInstrs.push_back(std::move(srlInstr)); auto addInstr = std::make_unique(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD); addInstr->addOperand(std::make_unique(temp_reg)); @@ -224,101 +179,99 @@ void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) { sraInstr->addOperand(std::make_unique(temp_reg)); sraInstr->addOperand(std::make_unique(shift)); newInstrs.push_back(std::move(sraInstr)); - - // 然后取反 - auto negInstr = std::make_unique(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB); - negInstr->addOperand(std::make_unique(*dst_reg)); - negInstr->addOperand(std::make_unique(PhysicalReg::ZERO)); - negInstr->addOperand(std::make_unique(temp_reg)); - newInstrs.push_back(std::move(negInstr)); + + if (divisor < 0) { + auto negInstr = std::make_unique(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB); + negInstr->addOperand(std::make_unique(*dst_reg)); + negInstr->addOperand(std::make_unique(PhysicalReg::ZERO)); + negInstr->addOperand(std::make_unique(temp_reg)); + newInstrs.push_back(std::move(negInstr)); + } else { + auto moveInstr = std::make_unique(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD); + moveInstr->addOperand(std::make_unique(*dst_reg)); + moveInstr->addOperand(std::make_unique(temp_reg)); + moveInstr->addOperand(std::make_unique(PhysicalReg::ZERO)); + newInstrs.push_back(std::move(moveInstr)); + } } - // 情况5: 通用magic number算法(针对没有MULH的情况进行了简化) else { - // 对于一般除法,在没有MULH的情况下,我们采用更保守的策略 - // 只处理一些简单的常数除法,复杂的情况保持原始除法指令 - - // 检查是否为小的常数(可以用简单乘法处理) - if (std::abs(divisor) <= 1024) { // 限制在较小的除数范围内 - auto magic_info = computeMagicNumber(divisor, is_32bit); + auto magic_info = computeMagic(divisor, is_32bit); + int magic_reg = createTempReg(); + int temp_reg = createTempReg(); + + auto loadInstr = std::make_unique(RVOpcodes::LI); + loadInstr->addOperand(std::make_unique(magic_reg)); + loadInstr->addOperand(std::make_unique(magic_info.magic)); + newInstrs.push_back(std::move(loadInstr)); + + if (is_32bit) { + auto mulInstr = std::make_unique(RVOpcodes::MUL); + mulInstr->addOperand(std::make_unique(temp_reg)); + mulInstr->addOperand(std::make_unique(*src1_reg)); + mulInstr->addOperand(std::make_unique(magic_reg)); + newInstrs.push_back(std::move(mulInstr)); + + auto sraInstr = std::make_unique(RVOpcodes::SRAI); + sraInstr->addOperand(std::make_unique(temp_reg)); + sraInstr->addOperand(std::make_unique(temp_reg)); + sraInstr->addOperand(std::make_unique(magic_info.shift)); + newInstrs.push_back(std::move(sraInstr)); + } else { + auto mulhInstr = std::make_unique(RVOpcodes::MULH); + mulhInstr->addOperand(std::make_unique(temp_reg)); + mulhInstr->addOperand(std::make_unique(*src1_reg)); + mulhInstr->addOperand(std::make_unique(magic_reg)); + newInstrs.push_back(std::move(mulhInstr)); - if (magic_info.magic == 0) continue; - - int magic_reg = createTempReg(); - int temp_reg = createTempReg(); - - // 加载magic number到寄存器 - auto loadInstr = std::make_unique(RVOpcodes::LI); - loadInstr->addOperand(std::make_unique(magic_reg)); - loadInstr->addOperand(std::make_unique(magic_info.magic)); - newInstrs.push_back(std::move(loadInstr)); - - // 使用普通乘法模拟高位乘法 - if (is_32bit) { - // 32位:使用MULW - auto mulInstr = std::make_unique(RVOpcodes::MULW); - mulInstr->addOperand(std::make_unique(temp_reg)); - mulInstr->addOperand(std::make_unique(*src1_reg)); - mulInstr->addOperand(std::make_unique(magic_reg)); - newInstrs.push_back(std::move(mulInstr)); - - // 右移得到近似结果 - auto sraInstr = std::make_unique(RVOpcodes::SRAIW); - sraInstr->addOperand(std::make_unique(temp_reg)); - sraInstr->addOperand(std::make_unique(temp_reg)); - sraInstr->addOperand(std::make_unique(magic_info.shift)); - newInstrs.push_back(std::move(sraInstr)); - } else { - // 64位:使用MUL - auto mulInstr = std::make_unique(RVOpcodes::MUL); - mulInstr->addOperand(std::make_unique(temp_reg)); - mulInstr->addOperand(std::make_unique(*src1_reg)); - mulInstr->addOperand(std::make_unique(magic_reg)); - newInstrs.push_back(std::move(mulInstr)); - - // 右移得到近似结果 + int post_shift = magic_info.shift - 63; + if (post_shift > 0) { auto sraInstr = std::make_unique(RVOpcodes::SRAI); sraInstr->addOperand(std::make_unique(temp_reg)); sraInstr->addOperand(std::make_unique(temp_reg)); - sraInstr->addOperand(std::make_unique(magic_info.shift)); + sraInstr->addOperand(std::make_unique(post_shift)); newInstrs.push_back(std::move(sraInstr)); } - - // 符号修正:处理负数被除数 - int sign_reg = createTempReg(); - - // 获取被除数的符号位 - auto sraSignInstr = std::make_unique(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI); - sraSignInstr->addOperand(std::make_unique(sign_reg)); - sraSignInstr->addOperand(std::make_unique(*src1_reg)); - sraSignInstr->addOperand(std::make_unique(is_32bit ? 31 : 63)); - newInstrs.push_back(std::move(sraSignInstr)); - - // 最终结果:dst = temp - sign(对于正除数)或 dst = temp + sign(对于负除数) - if (divisor > 0) { - auto finalSubInstr = std::make_unique(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB); - finalSubInstr->addOperand(std::make_unique(*dst_reg)); - finalSubInstr->addOperand(std::make_unique(temp_reg)); - finalSubInstr->addOperand(std::make_unique(sign_reg)); - newInstrs.push_back(std::move(finalSubInstr)); - } else { - auto finalAddInstr = std::make_unique(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD); - finalAddInstr->addOperand(std::make_unique(*dst_reg)); - finalAddInstr->addOperand(std::make_unique(temp_reg)); - finalAddInstr->addOperand(std::make_unique(sign_reg)); - newInstrs.push_back(std::move(finalAddInstr)); - } } - // 对于大的除数或复杂情况,保持原始除法指令不变 + + int sign_reg = createTempReg(); + auto sraSignInstr = std::make_unique(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI); + sraSignInstr->addOperand(std::make_unique(sign_reg)); + sraSignInstr->addOperand(std::make_unique(*src1_reg)); + sraSignInstr->addOperand(std::make_unique(is_32bit ? 31 : 63)); + newInstrs.push_back(std::move(sraSignInstr)); + + auto subInstr = std::make_unique(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB); + subInstr->addOperand(std::make_unique(temp_reg)); + subInstr->addOperand(std::make_unique(temp_reg)); + subInstr->addOperand(std::make_unique(sign_reg)); + newInstrs.push_back(std::move(subInstr)); + + if (divisor < 0) { + auto negInstr = std::make_unique(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB); + negInstr->addOperand(std::make_unique(*dst_reg)); + negInstr->addOperand(std::make_unique(PhysicalReg::ZERO)); + negInstr->addOperand(std::make_unique(temp_reg)); + newInstrs.push_back(std::move(negInstr)); + } else { + auto moveInstr = std::make_unique(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD); + moveInstr->addOperand(std::make_unique(*dst_reg)); + moveInstr->addOperand(std::make_unique(temp_reg)); + moveInstr->addOperand(std::make_unique(PhysicalReg::ZERO)); + newInstrs.push_back(std::move(moveInstr)); + } } if (!newInstrs.empty()) { - replacements.push_back({i, std::move(newInstrs)}); + size_t start_index = i; + if (instructions_to_replace == 2) { + start_index = i - 1; + } + replacements.push_back({start_index, instructions_to_replace, std::move(newInstrs)}); } } - // 批量应用替换(从后往前处理避免索引问题) for (auto it = replacements.rbegin(); it != replacements.rend(); ++it) { - instrs.erase(instrs.begin() + it->index); + instrs.erase(instrs.begin() + it->index, instrs.begin() + it->index + it->count_to_erase); instrs.insert(instrs.begin() + it->index, std::make_move_iterator(it->newInstrs.begin()), std::make_move_iterator(it->newInstrs.end())); @@ -326,4 +279,4 @@ void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) { } } -} // namespace sysy \ No newline at end of file +} // namespace sysy diff --git a/src/backend/RISCv64/RISCv64AsmPrinter.cpp b/src/backend/RISCv64/RISCv64AsmPrinter.cpp index 183003c..5e73cd2 100644 --- a/src/backend/RISCv64/RISCv64AsmPrinter.cpp +++ b/src/backend/RISCv64/RISCv64AsmPrinter.cpp @@ -60,7 +60,7 @@ void RISCv64AsmPrinter::printInstruction(MachineInstr* instr, bool debug) { case RVOpcodes::ADD: *OS << "add "; break; case RVOpcodes::ADDI: *OS << "addi "; break; case RVOpcodes::ADDW: *OS << "addw "; break; case RVOpcodes::ADDIW: *OS << "addiw "; break; case RVOpcodes::SUB: *OS << "sub "; break; case RVOpcodes::SUBW: *OS << "subw "; break; - case RVOpcodes::MUL: *OS << "mul "; break; case RVOpcodes::MULW: *OS << "mulw "; break; + case RVOpcodes::MUL: *OS << "mul "; break; case RVOpcodes::MULW: *OS << "mulw "; break; case RVOpcodes::MULH: *OS << "mulh "; break; case RVOpcodes::DIV: *OS << "div "; break; case RVOpcodes::DIVW: *OS << "divw "; break; case RVOpcodes::REM: *OS << "rem "; break; case RVOpcodes::REMW: *OS << "remw "; break; case RVOpcodes::XOR: *OS << "xor "; break; case RVOpcodes::XORI: *OS << "xori "; break; diff --git a/src/include/backend/RISCv64/RISCv64LLIR.h b/src/include/backend/RISCv64/RISCv64LLIR.h index be11528..3c8710e 100644 --- a/src/include/backend/RISCv64/RISCv64LLIR.h +++ b/src/include/backend/RISCv64/RISCv64LLIR.h @@ -45,7 +45,7 @@ enum class PhysicalReg { // RISC-V 指令操作码枚举 enum class RVOpcodes { // 算术指令 - ADD, ADDI, ADDW, ADDIW, SUB, SUBW, MUL, MULW, DIV, DIVW, REM, REMW, + ADD, ADDI, ADDW, ADDIW, SUB, SUBW, MUL, MULW, MULH, DIV, DIVW, REM, REMW, // 逻辑指令 XOR, XORI, OR, ORI, AND, ANDI, // 移位指令 diff --git a/src/include/midend/IR.h b/src/include/midend/IR.h index 2e4d72b..9423a8c 100644 --- a/src/include/midend/IR.h +++ b/src/include/midend/IR.h @@ -709,7 +709,7 @@ class Instruction : public User { kBitItoF = 0x1UL << 40, kBitFtoI = 0x1UL << 41, kSRA = 0x1UL << 42, - kMulh = 0x1UL << 43, + kMulh = 0x1UL << 43 }; protected: @@ -823,7 +823,7 @@ public: bool isBinary() const { static constexpr uint64_t BinaryOpMask = - (kAdd | kSub | kMul | kDiv | kRem | kAnd | kOr | kSRA) | + (kAdd | kSub | kMul | kDiv | kRem | kAnd | kOr | kSRA | kMulh) | (kICmpEQ | kICmpNE | kICmpLT | kICmpGT | kICmpLE | kICmpGE); return kind & BinaryOpMask; } diff --git a/src/include/midend/IRBuilder.h b/src/include/midend/IRBuilder.h index c232578..ab50173 100644 --- a/src/include/midend/IRBuilder.h +++ b/src/include/midend/IRBuilder.h @@ -220,6 +220,9 @@ class IRBuilder { BinaryInst * createSRAInst(Value *lhs, Value *rhs, const std::string &name = "") { return createBinaryInst(Instruction::kSRA, Type::getIntType(), lhs, rhs, name); } ///< 创建算术右移指令 + BinaryInst * createMulhInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kMulh, Type::getIntType(), lhs, rhs, name); + } ///< 创建高位乘法指令 CallInst * createCallInst(Function *callee, const std::vector &args, const std::string &name = "") { std::string newName; if (name.empty() && callee->getReturnType() != Type::getVoidType()) { diff --git a/src/midend/SysYIRGenerator.cpp b/src/midend/SysYIRGenerator.cpp index 0b2d751..3228257 100644 --- a/src/midend/SysYIRGenerator.cpp +++ b/src/midend/SysYIRGenerator.cpp @@ -15,6 +15,29 @@ using namespace std; namespace sysy { +std::pair calculate_signed_magic(int d) { + if (d == 0) throw std::runtime_error("Division by zero"); + if (d == 1 || d == -1) return {0, 0}; // Not used by strength reduction + + int k = 0; + unsigned int ad = (d > 0) ? d : -d; + unsigned int temp = ad; + while (temp > 0) { + temp >>= 1; + k++; + } + if ((ad & (ad - 1)) == 0) { // if power of 2 + k--; + } + + unsigned __int128 m_val = 1; + m_val <<= (32 + k - 1); + unsigned __int128 m_prime = m_val / ad; + long long m = m_prime + 1; + + return {m, k}; +} + // std::vector BinaryValueStack; ///< 用于存储value的栈 // std::vector BinaryOpStack; ///< 用于存储二元表达式的操作符栈 diff --git a/test_div_optimization.sy b/test_div_optimization.sy new file mode 100644 index 0000000..30c0bb1 --- /dev/null +++ b/test_div_optimization.sy @@ -0,0 +1,9 @@ +int main() { + int a = 100; + int b = a / 4; + int c = a / 8; + int d = a / 16; + int e = a / 7; + int f = a / 3; + return b + c + d + e; +} From 9c5d9ea78c836133bd8a0e0266a3c18abe1bdc5c Mon Sep 17 00:00:00 2001 From: CGH0S7 <776459475@qq.com> Date: Sun, 3 Aug 2025 14:38:27 +0800 Subject: [PATCH 3/3] =?UTF-8?q?[optimize]=E5=88=A0=E9=99=A4=E5=A4=9A?= =?UTF-8?q?=E4=BD=99=E6=B5=8B=E8=AF=95=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test_div_optimization.sy | 9 --------- 1 file changed, 9 deletions(-) delete mode 100644 test_div_optimization.sy diff --git a/test_div_optimization.sy b/test_div_optimization.sy deleted file mode 100644 index 30c0bb1..0000000 --- a/test_div_optimization.sy +++ /dev/null @@ -1,9 +0,0 @@ -int main() { - int a = 100; - int b = a / 4; - int c = a / 8; - int d = a / 16; - int e = a / 7; - int f = a / 3; - return b + c + d + e; -}