From c9a0c700e1a9a5fc81d3ba5215bae541935c14ca Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Mon, 18 Aug 2025 11:30:40 +0800 Subject: [PATCH] =?UTF-8?q?[midend]=E5=A2=9E=E5=8A=A0=E5=85=A8=E5=B1=80?= =?UTF-8?q?=E5=BC=BA=E5=BA=A6=E5=89=8A=E5=BC=B1=E4=BC=98=E5=8C=96=E9=81=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../Pass/Optimize/GlobalStrengthReduction.h | 107 ++ src/midend/CMakeLists.txt | 1 + .../Pass/Optimize/GlobalStrengthReduction.cpp | 1031 +++++++++++++++++ src/midend/Pass/Pass.cpp | 11 + 4 files changed, 1150 insertions(+) create mode 100644 src/include/midend/Pass/Optimize/GlobalStrengthReduction.h create mode 100644 src/midend/Pass/Optimize/GlobalStrengthReduction.cpp diff --git a/src/include/midend/Pass/Optimize/GlobalStrengthReduction.h b/src/include/midend/Pass/Optimize/GlobalStrengthReduction.h new file mode 100644 index 0000000..574494c --- /dev/null +++ b/src/include/midend/Pass/Optimize/GlobalStrengthReduction.h @@ -0,0 +1,107 @@ +#pragma once + +#include "Pass.h" +#include "IR.h" +#include "SideEffectAnalysis.h" +#include +#include +#include +#include + +namespace sysy { + +// 魔数乘法结构,用于除法优化 +struct MagicNumber { + uint32_t multiplier; + int shift; + bool needAdd; + + MagicNumber(uint32_t m, int s, bool add = false) + : multiplier(m), shift(s), needAdd(add) {} +}; + +// 全局强度削弱优化遍的核心逻辑封装类 +class GlobalStrengthReductionContext { +public: + // 构造函数,接受IRBuilder参数 + explicit GlobalStrengthReductionContext(IRBuilder* builder) : builder(builder) {} + + // 运行优化的主要方法 + void run(Function* func, AnalysisManager* AM, bool& changed); + +private: + IRBuilder* builder; // IR构建器 + + // 分析结果 + SideEffectAnalysisResult* sideEffectAnalysis = nullptr; + + // 优化计数 + int algebraicOptCount = 0; + int strengthReductionCount = 0; + int divisionOptCount = 0; + + // 主要优化方法 + bool processBasicBlock(BasicBlock* bb); + bool processInstruction(Instruction* inst); + + // 代数优化方法 + bool tryAlgebraicOptimization(Instruction* inst); + bool optimizeAddition(BinaryInst* inst); + bool optimizeSubtraction(BinaryInst* inst); + bool optimizeMultiplication(BinaryInst* inst); + bool optimizeDivision(BinaryInst* inst); + bool optimizeComparison(BinaryInst* inst); + bool optimizeLogical(BinaryInst* inst); + + // 强度削弱方法 + bool tryStrengthReduction(Instruction* inst); + bool reduceMultiplication(BinaryInst* inst); + bool reduceDivision(BinaryInst* inst); + bool reducePower(CallInst* inst); + + // 复杂乘法强度削弱方法 + bool tryComplexMultiplication(BinaryInst* inst, Value* variable, int constant); + bool findOptimalShiftDecomposition(int constant, std::vector& shifts); + Value* createShiftDecomposition(BinaryInst* inst, Value* variable, const std::vector& shifts); + + // 魔数乘法相关方法 + MagicNumber computeMagicNumber(uint32_t divisor); + std::pair computeMulhMagicNumbers(int divisor); + Value* createMagicDivision(BinaryInst* divInst, uint32_t divisor, const MagicNumber& magic); + Value* createMagicDivisionLibdivide(BinaryInst* divInst, int divisor, const std::pair& magicPair); + bool isPowerOfTwo(uint32_t n); + int log2OfPowerOfTwo(uint32_t n); + + // 辅助方法 + bool isConstantInt(Value* val, int& constVal); + bool isConstantInt(Value* val, uint32_t& constVal); + ConstantInteger* getConstantInt(int val); + bool hasOnlyLocalUses(Instruction* inst); + void replaceWithOptimized(Instruction* original, Value* replacement); +}; + +// 全局强度削弱优化遍类 +class GlobalStrengthReduction : public OptimizationPass { +private: + IRBuilder* builder; // IR构建器,用于创建新指令 + +public: + // 静态成员,作为该遍的唯一ID + static void* ID; + + // 构造函数,接受IRBuilder参数 + explicit GlobalStrengthReduction(IRBuilder* builder) + : OptimizationPass("GlobalStrengthReduction", Granularity::Function), builder(builder) {} + + // 在函数上运行优化 + bool runOnFunction(Function* func, AnalysisManager& AM) override; + + // 返回该遍的唯一ID + void* getPassID() const override { return ID; } + + // 声明分析依赖 + void getAnalysisUsage(std::set& analysisDependencies, + std::set& analysisInvalidations) const override; +}; + +} // namespace sysy diff --git a/src/midend/CMakeLists.txt b/src/midend/CMakeLists.txt index 66fc461..8243a63 100644 --- a/src/midend/CMakeLists.txt +++ b/src/midend/CMakeLists.txt @@ -22,6 +22,7 @@ add_library(midend_lib STATIC Pass/Optimize/LICM.cpp Pass/Optimize/LoopStrengthReduction.cpp Pass/Optimize/InductionVariableElimination.cpp + Pass/Optimize/GlobalStrengthReduction.cpp Pass/Optimize/BuildCFG.cpp Pass/Optimize/LargeArrayToGlobal.cpp ) diff --git a/src/midend/Pass/Optimize/GlobalStrengthReduction.cpp b/src/midend/Pass/Optimize/GlobalStrengthReduction.cpp new file mode 100644 index 0000000..118b793 --- /dev/null +++ b/src/midend/Pass/Optimize/GlobalStrengthReduction.cpp @@ -0,0 +1,1031 @@ +#include "GlobalStrengthReduction.h" +#include "SysYIROptUtils.h" +#include "IRBuilder.h" +#include +#include +#include +#include + +extern int DEBUG; + +namespace sysy { + +// 全局强度削弱优化遍的静态 ID +void *GlobalStrengthReduction::ID = (void *)&GlobalStrengthReduction::ID; + +// ====================================================================== +// GlobalStrengthReduction 类的实现 +// ====================================================================== + +bool GlobalStrengthReduction::runOnFunction(Function *func, AnalysisManager &AM) { + if (func->getBasicBlocks().empty()) { + return false; + } + + if (DEBUG) { + std::cout << "\n=== Running GlobalStrengthReduction on function: " << func->getName() << " ===" << std::endl; + } + + bool changed = false; + GlobalStrengthReductionContext context(builder); + context.run(func, &AM, changed); + + if (DEBUG) { + if (changed) { + std::cout << "GlobalStrengthReduction: Function " << func->getName() << " was modified" << std::endl; + } else { + std::cout << "GlobalStrengthReduction: Function " << func->getName() << " was not modified" << std::endl; + } + std::cout << "=== GlobalStrengthReduction completed for function: " << func->getName() << " ===" << std::endl; + } + + return changed; +} + +void GlobalStrengthReduction::getAnalysisUsage(std::set &analysisDependencies, std::set &analysisInvalidations) const { + // 强度削弱依赖副作用分析来判断指令是否可以安全优化 + analysisDependencies.insert(&SysYSideEffectAnalysisPass::ID); + + // 强度削弱不会使分析失效,因为: + // - 只替换计算指令,不改变控制流 + // - 不修改内存,不影响别名分析 + // - 保持程序语义不变 + // analysisInvalidations 保持为空 + + if (DEBUG) { + std::cout << "GlobalStrengthReduction: Declared analysis dependencies (SideEffectAnalysis)" << std::endl; + } +} + +// ====================================================================== +// GlobalStrengthReductionContext 类的实现 +// ====================================================================== + +void GlobalStrengthReductionContext::run(Function *func, AnalysisManager *AM, bool &changed) { + if (DEBUG) { + std::cout << " Starting GlobalStrengthReduction analysis for function: " << func->getName() << std::endl; + } + + // 获取分析结果 + if (AM) { + sideEffectAnalysis = AM->getAnalysisResult(); + + if (DEBUG) { + if (sideEffectAnalysis) { + std::cout << " GlobalStrengthReduction: Using side effect analysis" << std::endl; + } else { + std::cout << " GlobalStrengthReduction: Warning - side effect analysis not available" << std::endl; + } + } + } + + // 重置计数器 + algebraicOptCount = 0; + strengthReductionCount = 0; + divisionOptCount = 0; + + // 遍历所有基本块进行优化 + for (auto &bb_ptr : func->getBasicBlocks()) { + if (processBasicBlock(bb_ptr.get())) { + changed = true; + } + } + + if (DEBUG) { + std::cout << " GlobalStrengthReduction completed for function: " << func->getName() << std::endl; + std::cout << " Algebraic optimizations: " << algebraicOptCount << std::endl; + std::cout << " Strength reductions: " << strengthReductionCount << std::endl; + std::cout << " Division optimizations: " << divisionOptCount << std::endl; + } +} + +bool GlobalStrengthReductionContext::processBasicBlock(BasicBlock *bb) { + bool changed = false; + + if (DEBUG) { + std::cout << " Processing block: " << bb->getName() << std::endl; + } + + // 收集需要处理的指令(避免迭代器失效) + std::vector instructions; + for (auto &inst_ptr : bb->getInstructions()) { + instructions.push_back(inst_ptr.get()); + } + + // 处理每条指令 + for (auto inst : instructions) { + if (processInstruction(inst)) { + changed = true; + } + } + + return changed; +} + +bool GlobalStrengthReductionContext::processInstruction(Instruction *inst) { + bool changed = false; + + if (DEBUG >= 2) { + std::cout << " Processing instruction: " << inst->getName() << std::endl; + } + + // 先尝试代数优化 + if (tryAlgebraicOptimization(inst)) { + changed = true; + algebraicOptCount++; + } + + // 再尝试强度削弱 + if (tryStrengthReduction(inst)) { + changed = true; + strengthReductionCount++; + } + + return changed; +} + +// ====================================================================== +// 代数优化方法 +// ====================================================================== + +bool GlobalStrengthReductionContext::tryAlgebraicOptimization(Instruction *inst) { + auto binary = dynamic_cast(inst); + if (!binary) { + return false; + } + + switch (binary->getKind()) { + case Instruction::kAdd: + return optimizeAddition(binary); + case Instruction::kSub: + return optimizeSubtraction(binary); + case Instruction::kMul: + return optimizeMultiplication(binary); + case Instruction::kDiv: + return optimizeDivision(binary); + case Instruction::kICmpEQ: + case Instruction::kICmpNE: + case Instruction::kICmpLT: + case Instruction::kICmpGT: + case Instruction::kICmpLE: + case Instruction::kICmpGE: + return optimizeComparison(binary); + case Instruction::kAnd: + case Instruction::kOr: + return optimizeLogical(binary); + default: + return false; + } +} + +bool GlobalStrengthReductionContext::optimizeAddition(BinaryInst *inst) { + Value *lhs = inst->getLhs(); + Value *rhs = inst->getRhs(); + int constVal; + + // x + 0 = x + if (isConstantInt(rhs, constVal) && constVal == 0) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x + 0 -> x" << std::endl; + } + replaceWithOptimized(inst, lhs); + return true; + } + + // 0 + x = x + if (isConstantInt(lhs, constVal) && constVal == 0) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = 0 + x -> x" << std::endl; + } + replaceWithOptimized(inst, rhs); + return true; + } + + // x + (-y) = x - y + if (auto rhsInst = dynamic_cast(rhs)) { + if (rhsInst->getKind() == Instruction::kNeg) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x + (-y) -> x - y" << std::endl; + } + // 创建减法指令 + builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst)); + auto subInst = builder->createSubInst(lhs, rhsInst->getOperand()); + replaceWithOptimized(inst, subInst); + return true; + } + } + + return false; +} + +bool GlobalStrengthReductionContext::optimizeSubtraction(BinaryInst *inst) { + Value *lhs = inst->getLhs(); + Value *rhs = inst->getRhs(); + int constVal; + + // x - 0 = x + if (isConstantInt(rhs, constVal) && constVal == 0) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x - 0 -> x" << std::endl; + } + replaceWithOptimized(inst, lhs); + return true; + } + + // x - x = 0 (如果x没有副作用) + if (lhs == rhs && hasOnlyLocalUses(dynamic_cast(lhs))) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x - x -> 0" << std::endl; + } + replaceWithOptimized(inst, getConstantInt(0)); + return true; + } + + // x - (-y) = x + y + if (auto rhsInst = dynamic_cast(rhs)) { + if (rhsInst->getKind() == Instruction::kNeg) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x - (-y) -> x + y" << std::endl; + } + builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst)); + auto addInst = builder->createAddInst(lhs, rhsInst->getOperand()); + replaceWithOptimized(inst, addInst); + return true; + } + } + + return false; +} + +bool GlobalStrengthReductionContext::optimizeMultiplication(BinaryInst *inst) { + Value *lhs = inst->getLhs(); + Value *rhs = inst->getRhs(); + int constVal; + + // x * 0 = 0 + if (isConstantInt(rhs, constVal) && constVal == 0) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x * 0 -> 0" << std::endl; + } + replaceWithOptimized(inst, getConstantInt(0)); + return true; + } + + // 0 * x = 0 + if (isConstantInt(lhs, constVal) && constVal == 0) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = 0 * x -> 0" << std::endl; + } + replaceWithOptimized(inst, getConstantInt(0)); + return true; + } + + // x * 1 = x + if (isConstantInt(rhs, constVal) && constVal == 1) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x * 1 -> x" << std::endl; + } + replaceWithOptimized(inst, lhs); + return true; + } + + // 1 * x = x + if (isConstantInt(lhs, constVal) && constVal == 1) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = 1 * x -> x" << std::endl; + } + replaceWithOptimized(inst, rhs); + return true; + } + + // x * (-1) = -x + if (isConstantInt(rhs, constVal) && constVal == -1) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x * (-1) -> -x" << std::endl; + } + builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst)); + auto negInst = builder->createNegInst(lhs); + replaceWithOptimized(inst, negInst); + return true; + } + + return false; +} + +bool GlobalStrengthReductionContext::optimizeDivision(BinaryInst *inst) { + Value *lhs = inst->getLhs(); + Value *rhs = inst->getRhs(); + int constVal; + + // x / 1 = x + if (isConstantInt(rhs, constVal) && constVal == 1) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x / 1 -> x" << std::endl; + } + replaceWithOptimized(inst, lhs); + return true; + } + + // x / (-1) = -x + if (isConstantInt(rhs, constVal) && constVal == -1) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x / (-1) -> -x" << std::endl; + } + builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst)); + auto negInst = builder->createNegInst(lhs); + replaceWithOptimized(inst, negInst); + return true; + } + + // x / x = 1 (如果x != 0且没有副作用) + if (lhs == rhs && hasOnlyLocalUses(dynamic_cast(lhs))) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x / x -> 1" << std::endl; + } + replaceWithOptimized(inst, getConstantInt(1)); + return true; + } + + return false; +} + +bool GlobalStrengthReductionContext::optimizeComparison(BinaryInst *inst) { + Value *lhs = inst->getLhs(); + Value *rhs = inst->getRhs(); + + // x == x = true (如果x没有副作用) + if (inst->getKind() == Instruction::kICmpEQ && lhs == rhs && + hasOnlyLocalUses(dynamic_cast(lhs))) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x == x -> true" << std::endl; + } + replaceWithOptimized(inst, getConstantInt(1)); + return true; + } + + // x != x = false (如果x没有副作用) + if (inst->getKind() == Instruction::kICmpNE && lhs == rhs && + hasOnlyLocalUses(dynamic_cast(lhs))) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x != x -> false" << std::endl; + } + replaceWithOptimized(inst, getConstantInt(0)); + return true; + } + + return false; +} + +bool GlobalStrengthReductionContext::optimizeLogical(BinaryInst *inst) { + Value *lhs = inst->getLhs(); + Value *rhs = inst->getRhs(); + int constVal; + + if (inst->getKind() == Instruction::kAnd) { + // x && 0 = 0 + if (isConstantInt(rhs, constVal) && constVal == 0) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x && 0 -> 0" << std::endl; + } + replaceWithOptimized(inst, getConstantInt(0)); + return true; + } + + // x && 1 = x + if (isConstantInt(rhs, constVal) && constVal == 1) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x && 1 -> x" << std::endl; + } + replaceWithOptimized(inst, lhs); + return true; + } + + // x && x = x + if (lhs == rhs) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x && x -> x" << std::endl; + } + replaceWithOptimized(inst, lhs); + return true; + } + } else if (inst->getKind() == Instruction::kOr) { + // x || 0 = x + if (isConstantInt(rhs, constVal) && constVal == 0) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x || 0 -> x" << std::endl; + } + replaceWithOptimized(inst, lhs); + return true; + } + + // x || 1 = 1 + if (isConstantInt(rhs, constVal) && constVal == 1) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x || 1 -> 1" << std::endl; + } + replaceWithOptimized(inst, getConstantInt(1)); + return true; + } + + // x || x = x + if (lhs == rhs) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x || x -> x" << std::endl; + } + replaceWithOptimized(inst, lhs); + return true; + } + } + + return false; +} + +// ====================================================================== +// 强度削弱方法 +// ====================================================================== + +bool GlobalStrengthReductionContext::tryStrengthReduction(Instruction *inst) { + if (auto binary = dynamic_cast(inst)) { + switch (binary->getKind()) { + case Instruction::kMul: + return reduceMultiplication(binary); + case Instruction::kDiv: + return reduceDivision(binary); + default: + return false; + } + } else if (auto call = dynamic_cast(inst)) { + return reducePower(call); + } + + return false; +} + +bool GlobalStrengthReductionContext::reduceMultiplication(BinaryInst *inst) { + Value *lhs = inst->getLhs(); + Value *rhs = inst->getRhs(); + int constVal; + + // 尝试右操作数为常数 + Value* variable = lhs; + if (isConstantInt(rhs, constVal) && constVal > 0) { + return tryComplexMultiplication(inst, variable, constVal); + } + + // 尝试左操作数为常数 + if (isConstantInt(lhs, constVal) && constVal > 0) { + variable = rhs; + return tryComplexMultiplication(inst, variable, constVal); + } + + return false; +} + +bool GlobalStrengthReductionContext::tryComplexMultiplication(BinaryInst* inst, Value* variable, int constant) { + // 首先检查是否为2的幂,使用简单位移 + if (isPowerOfTwo(constant)) { + int shiftAmount = log2OfPowerOfTwo(constant); + if (DEBUG) { + std::cout << " StrengthReduction: " << inst->getName() + << " = x * " << constant << " -> x << " << shiftAmount << std::endl; + } + + builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst)); + auto shiftInst = builder->createBinaryInst(Instruction::kSll, Type::getIntType(), variable, getConstantInt(shiftAmount)); + replaceWithOptimized(inst, shiftInst); + return true; + } + + // 尝试分解为位移和加法的组合 + std::vector shifts; + if (findOptimalShiftDecomposition(constant, shifts)) { + if (DEBUG) { + std::cout << " StrengthReduction: " << inst->getName() + << " = x * " << constant << " -> shift decomposition with " << shifts.size() << " terms" << std::endl; + } + + Value* result = createShiftDecomposition(inst, variable, shifts); + if (result) { + replaceWithOptimized(inst, result); + return true; + } + } + + return false; +} + +bool GlobalStrengthReductionContext::findOptimalShiftDecomposition(int constant, std::vector& shifts) { + shifts.clear(); + + // 常见的有效分解模式 + switch (constant) { + case 3: // 3 = 2^1 + 2^0 -> (x << 1) + x + shifts = {1, 0}; + return true; + case 5: // 5 = 2^2 + 2^0 -> (x << 2) + x + shifts = {2, 0}; + return true; + case 6: // 6 = 2^2 + 2^1 -> (x << 2) + (x << 1) + shifts = {2, 1}; + return true; + case 7: // 7 = 2^2 + 2^1 + 2^0 -> (x << 2) + (x << 1) + x + shifts = {2, 1, 0}; + return true; + case 9: // 9 = 2^3 + 2^0 -> (x << 3) + x + shifts = {3, 0}; + return true; + case 10: // 10 = 2^3 + 2^1 -> (x << 3) + (x << 1) + shifts = {3, 1}; + return true; + case 11: // 11 = 2^3 + 2^1 + 2^0 -> (x << 3) + (x << 1) + x + shifts = {3, 1, 0}; + return true; + case 12: // 12 = 2^3 + 2^2 -> (x << 3) + (x << 2) + shifts = {3, 2}; + return true; + case 13: // 13 = 2^3 + 2^2 + 2^0 -> (x << 3) + (x << 2) + x + shifts = {3, 2, 0}; + return true; + case 14: // 14 = 2^3 + 2^2 + 2^1 -> (x << 3) + (x << 2) + (x << 1) + shifts = {3, 2, 1}; + return true; + case 15: // 15 = 2^3 + 2^2 + 2^1 + 2^0 -> (x << 3) + (x << 2) + (x << 1) + x + shifts = {3, 2, 1, 0}; + return true; + case 17: // 17 = 2^4 + 2^0 -> (x << 4) + x + shifts = {4, 0}; + return true; + case 18: // 18 = 2^4 + 2^1 -> (x << 4) + (x << 1) + shifts = {4, 1}; + return true; + case 20: // 20 = 2^4 + 2^2 -> (x << 4) + (x << 2) + shifts = {4, 2}; + return true; + case 24: // 24 = 2^4 + 2^3 -> (x << 4) + (x << 3) + shifts = {4, 3}; + return true; + case 25: // 25 = 2^4 + 2^3 + 2^0 -> (x << 4) + (x << 3) + x + shifts = {4, 3, 0}; + return true; + case 100: // 100 = 2^6 + 2^5 + 2^2 -> (x << 6) + (x << 5) + (x << 2) + shifts = {6, 5, 2}; + return true; + } + + // 通用二进制分解(最多4个项,避免过度复杂化) + if (constant > 0 && constant < 256) { + std::vector binaryShifts; + int temp = constant; + int bit = 0; + + while (temp > 0 && binaryShifts.size() < 4) { + if (temp & 1) { + binaryShifts.push_back(bit); + } + temp >>= 1; + bit++; + } + + // 只有当项数不超过3个时才使用二进制分解(比直接乘法更有效) + if (binaryShifts.size() <= 3 && binaryShifts.size() >= 2) { + shifts = binaryShifts; + return true; + } + } + + return false; +} + +Value* GlobalStrengthReductionContext::createShiftDecomposition(BinaryInst* inst, Value* variable, const std::vector& shifts) { + if (shifts.empty()) return nullptr; + + builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst)); + + Value* result = nullptr; + + for (int shift : shifts) { + Value* term; + if (shift == 0) { + // 0位移就是原变量 + term = variable; + } else { + // 创建位移指令 + term = builder->createBinaryInst(Instruction::kSll, Type::getIntType(), variable, getConstantInt(shift)); + } + + if (result == nullptr) { + result = term; + } else { + // 累加到结果中 + result = builder->createAddInst(result, term); + } + } + + return result; +} + +bool GlobalStrengthReductionContext::reduceDivision(BinaryInst *inst) { + Value *lhs = inst->getLhs(); + Value *rhs = inst->getRhs(); + uint32_t constVal; + + // x / 2^n = x >> n (对于无符号除法或已知为正数的情况) + if (isConstantInt(rhs, constVal) && constVal > 0 && isPowerOfTwo(constVal)) { + int shiftAmount = log2OfPowerOfTwo(constVal); + if (DEBUG) { + std::cout << " StrengthReduction: " << inst->getName() + << " = x / " << constVal << " -> x >> " << shiftAmount << std::endl; + } + + builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst)); + auto shiftInst = builder->createBinaryInst(Instruction::kSra, Type::getIntType(), lhs, getConstantInt(shiftAmount)); + replaceWithOptimized(inst, shiftInst); + strengthReductionCount++; + return true; + } + + // x / c = x * magic_number (魔数乘法优化 - 使用libdivide算法) + if (isConstantInt(rhs, constVal) && constVal > 1 && constVal != (uint32_t)(-1)) { + auto magicPair = computeMulhMagicNumbers(static_cast(constVal)); + if (magicPair.first != -1) { // 有效的魔数 + if (DEBUG) { + std::cout << " StrengthReduction: " << inst->getName() + << " = x / " << constVal << " -> libdivide magic multiplication" << std::endl; + } + + Value* magicResult = createMagicDivisionLibdivide(inst, static_cast(constVal), magicPair); + replaceWithOptimized(inst, magicResult); + divisionOptCount++; + return true; + } + } + + return false; +} + +bool GlobalStrengthReductionContext::reducePower(CallInst *inst) { + // 检查是否是pow函数调用 + Function* callee = inst->getCallee(); + if (!callee || callee->getName() != "pow") { + return false; + } + + // pow(x, 2) = x * x + if (inst->getNumOperands() >= 2) { + int exponent; + if (isConstantInt(inst->getOperand(1), exponent)) { + if (exponent == 2) { + if (DEBUG) { + std::cout << " StrengthReduction: pow(x, 2) -> x * x" << std::endl; + } + + Value* base = inst->getOperand(0); + builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst)); + auto mulInst = builder->createMulInst(base, base); + replaceWithOptimized(inst, mulInst); + strengthReductionCount++; + return true; + } else if (exponent >= 3 && exponent <= 8) { + // 对于小的指数,展开为连续乘法 + if (DEBUG) { + std::cout << " StrengthReduction: pow(x, " << exponent << ") -> repeated multiplication" << std::endl; + } + + Value* base = inst->getOperand(0); + Value* result = base; + builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst)); + + for (int i = 1; i < exponent; i++) { + result = builder->createMulInst(result, base); + } + + replaceWithOptimized(inst, result); + strengthReductionCount++; + return true; + } + } + } + + return false; +} + +// ====================================================================== +// 魔数乘法相关方法 +// ====================================================================== + +// 该实现参考了libdivide的算法 +std::pair GlobalStrengthReductionContext::computeMulhMagicNumbers(int divisor) { + + if (DEBUG) { + std::cout << "\n[SR] ===== Computing magic numbers for divisor " << divisor << " (libdivide algorithm) =====" << std::endl; + } + + if (divisor == 0) { + if (DEBUG) std::cout << "[SR] Error: divisor must be != 0" << std::endl; + return {-1, -1}; + } + + // libdivide 常数 + const uint8_t LIBDIVIDE_ADD_MARKER = 0x40; + const uint8_t LIBDIVIDE_NEGATIVE_DIVISOR = 0x80; + + // 辅助函数:计算前导零个数 + auto count_leading_zeros32 = [](uint32_t val) -> uint32_t { + if (val == 0) return 32; + return __builtin_clz(val); + }; + + // 辅助函数:64位除法返回32位商和余数 + auto div_64_32 = [](uint32_t high, uint32_t low, uint32_t divisor, uint32_t* rem) -> uint32_t { + uint64_t dividend = ((uint64_t)high << 32) | low; + uint32_t quotient = dividend / divisor; + *rem = dividend % divisor; + return quotient; + }; + + if (DEBUG) { + std::cout << "[SR] Input divisor: " << divisor << std::endl; + } + + // libdivide_internal_s32_gen 算法实现 + int32_t d = divisor; + uint32_t ud = (uint32_t)d; + uint32_t absD = (d < 0) ? -ud : ud; + + if (DEBUG) { + std::cout << "[SR] absD = " << absD << std::endl; + } + + uint32_t floor_log_2_d = 31 - count_leading_zeros32(absD); + + if (DEBUG) { + std::cout << "[SR] floor_log_2_d = " << floor_log_2_d << std::endl; + } + + // 检查 absD 是否为2的幂 + if ((absD & (absD - 1)) == 0) { + if (DEBUG) { + std::cout << "[SR] " << absD << " 是2的幂,使用移位方法" << std::endl; + } + + // 对于2的幂,我们只使用移位,不需要魔数 + int shift = floor_log_2_d; + if (d < 0) shift |= 0x80; // 标记负数 + + if (DEBUG) { + std::cout << "[SR] Power of 2 result: magic=0, shift=" << shift << std::endl; + std::cout << "[SR] ===== End magic computation =====" << std::endl; + } + + // 对于我们的目的,我们将在IR生成中以不同方式处理2的幂 + // 返回特殊标记 + return {0, shift}; + } + + if (DEBUG) { + std::cout << "[SR] " << absD << " is not a power of 2, computing magic number" << std::endl; + } + + // 非2的幂除数的魔数计算 + uint8_t more; + uint32_t rem, proposed_m; + + // 计算 proposed_m = floor(2^(floor_log_2_d + 31) / absD) + proposed_m = div_64_32((uint32_t)1 << (floor_log_2_d - 1), 0, absD, &rem); + const uint32_t e = absD - rem; + + if (DEBUG) { + std::cout << "[SR] proposed_m = " << proposed_m << ", rem = " << rem << ", e = " << e << std::endl; + } + + // 确定是否需要"加法"版本 + const bool branchfree = false; // 使用分支版本 + + if (!branchfree && e < ((uint32_t)1 << floor_log_2_d)) { + // 这个幂次有效 + more = (uint8_t)(floor_log_2_d - 1); + if (DEBUG) { + std::cout << "[SR] Using basic algorithm, shift = " << (int)more << std::endl; + } + } else { + // 我们需要上升一个等级 + proposed_m += proposed_m; + const uint32_t twice_rem = rem + rem; + if (twice_rem >= absD || twice_rem < rem) { + proposed_m += 1; + } + more = (uint8_t)(floor_log_2_d | LIBDIVIDE_ADD_MARKER); + if (DEBUG) { + std::cout << "[SR] Using add algorithm, proposed_m = " << proposed_m << ", more = " << (int)more << std::endl; + } + } + + proposed_m += 1; + int32_t magic = (int32_t)proposed_m; + + // 处理负除数 + if (d < 0) { + more |= LIBDIVIDE_NEGATIVE_DIVISOR; + if (!branchfree) { + magic = -magic; + } + if (DEBUG) { + std::cout << "[SR] Negative divisor, magic = " << magic << ", more = " << (int)more << std::endl; + } + } + + // 为我们的IR生成提取移位量和标志 + int shift = more & 0x3F; // 移除标志,保留移位量(位0-5) + bool need_add = (more & LIBDIVIDE_ADD_MARKER) != 0; + bool is_negative = (more & LIBDIVIDE_NEGATIVE_DIVISOR) != 0; + // 返回魔数、移位量,并在移位中编码ADD_MARKER标志 + // 我们将使用移位的第6位表示ADD_MARKER,第7位表示负数(如果需要) + int encoded_shift = shift; + if (need_add) { + encoded_shift |= 0x40; // 设置第6位表示ADD_MARKER + if (DEBUG) { + std::cout << "[SR] Encoding ADD_MARKER in shift: " << encoded_shift << std::endl; + } + } + + return {magic, encoded_shift}; + } + +Value* GlobalStrengthReductionContext::createMagicDivision(BinaryInst* divInst, uint32_t divisor, const MagicNumber& magic) { + builder->setPosition(divInst->getParent(), divInst->getParent()->findInstIterator(divInst)); + + Value* dividend = divInst->getLhs(); + + // 创建魔数常量 + Value* magicConst = getConstantInt(static_cast(magic.multiplier)); + + // 执行乘法: tmp = dividend * magic + Value* tmp = builder->createMulInst(dividend, magicConst); + + if (magic.needAdd) { + // 需要额外加法的情况 + Value* sum = builder->createAddInst(tmp, dividend); + if (magic.shift > 0) { + Value* shiftConst = getConstantInt(magic.shift); + tmp = builder->createBinaryInst(Instruction::kSra, Type::getIntType(), sum, shiftConst); + } else { + tmp = sum; + } + } else { + // 直接右移 + if (magic.shift > 0) { + Value* shiftConst = getConstantInt(magic.shift); + tmp = builder->createBinaryInst(Instruction::kSra, Type::getIntType(), tmp, shiftConst); + } + } + + // 处理符号:如果被除数为负,结果需要调整 + // 这里简化处理,假设都是正数除法 + + return tmp; +} + +Value* GlobalStrengthReductionContext::createMagicDivisionLibdivide(BinaryInst* divInst, int divisor, const std::pair& magicPair) { + builder->setPosition(divInst->getParent(), divInst->getParent()->findInstIterator(divInst)); + + Value* dividend = divInst->getLhs(); + int magic = magicPair.first; + int encoded_shift = magicPair.second; + + if (DEBUG) { + std::cout << "[SR] Creating libdivide magic division: magic=" << magic + << ", encoded_shift=" << encoded_shift << std::endl; + } + + // 检查是否为2的幂(magic=0表示是2的幂) + if (magic == 0) { + // 2的幂除法,直接使用算术右移 + int shift = encoded_shift & 0x3F; // 获取实际移位量 + bool is_negative = (encoded_shift & 0x80) != 0; + + if (DEBUG) { + std::cout << "[SR] Power of 2 division: shift=" << shift + << ", negative=" << is_negative << std::endl; + } + + Value* result = dividend; + if (shift > 0) { + Value* shiftConst = getConstantInt(shift); + result = builder->createBinaryInst(Instruction::kSra, Type::getIntType(), dividend, shiftConst); + } + + // 如果原除数为负,需要取反结果 + if (is_negative) { + result = builder->createNegInst(result); + } + + return result; + } + + // 非2的幂除法,使用魔数乘法 + int shift = encoded_shift & 0x3F; // 获取移位量(位0-5) + bool need_add = (encoded_shift & 0x40) != 0; // 检查ADD_MARKER标志(位6) + + if (DEBUG) { + std::cout << "[SR] Magic multiplication: shift=" << shift + << ", need_add=" << need_add << std::endl; + } + + // 创建魔数常量 + Value* magicConst = getConstantInt(magic); + + // 执行高位乘法:mulh(dividend, magic) + // 由于我们的IR可能没有直接的mulh指令,我们使用64位乘法然后取高32位 + // 这里需要根据实际的IR指令集进行调整 + Value* tmp = builder->createMulInst(dividend, magicConst); + + if (need_add) { + // ADD算法:(mulh(dividend, magic) + dividend) >> shift + tmp = builder->createAddInst(tmp, dividend); + } + + if (shift > 0) { + Value* shiftConst = getConstantInt(shift); + tmp = builder->createBinaryInst(Instruction::kSra, Type::getIntType(), tmp, shiftConst); + } + + // 处理符号位调整 + // 如果被除数为负数,可能需要额外的符号处理 + // 这里简化处理,实际实现可能需要更复杂的符号位处理 + + return tmp; +} + +// ====================================================================== +// 辅助方法 +// ====================================================================== + +bool GlobalStrengthReductionContext::isPowerOfTwo(uint32_t n) { + return n > 0 && (n & (n - 1)) == 0; +} + +int GlobalStrengthReductionContext::log2OfPowerOfTwo(uint32_t n) { + int result = 0; + while (n > 1) { + n >>= 1; + result++; + } + return result; +} + +bool GlobalStrengthReductionContext::isConstantInt(Value* val, int& constVal) { + if (auto constInt = dynamic_cast(val)) { + constVal = std::get(constInt->getVal()); + return true; + } + return false; +} + +bool GlobalStrengthReductionContext::isConstantInt(Value* val, uint32_t& constVal) { + if (auto constInt = dynamic_cast(val)) { + int signedVal = std::get(constInt->getVal()); + if (signedVal >= 0) { + constVal = static_cast(signedVal); + return true; + } + } + return false; +} + +ConstantInteger* GlobalStrengthReductionContext::getConstantInt(int val) { + return ConstantInteger::get(val); +} + +bool GlobalStrengthReductionContext::hasOnlyLocalUses(Instruction* inst) { + if (!inst) return true; + + // 简单检查:如果指令没有副作用,则认为是本地的 + if (sideEffectAnalysis) { + auto sideEffect = sideEffectAnalysis->getInstructionSideEffect(inst); + return sideEffect.type == SideEffectType::NO_SIDE_EFFECT; + } + + // 没有副作用分析时,保守处理 + return !inst->isCall() && !inst->isStore() && !inst->isLoad(); +} + +void GlobalStrengthReductionContext::replaceWithOptimized(Instruction* original, Value* replacement) { + if (DEBUG >= 2) { + std::cout << " Replacing " << original->getName() + << " with " << replacement->getName() << std::endl; + } + + original->replaceAllUsesWith(replacement); + + // 如果替换值是新创建的指令,确保它有合适的名字 +// if (auto replInst = dynamic_cast(replacement)) { +// if (replInst->getName().empty()) { +// replInst->setName(original->getName() + "_opt"); +// } +// } + + // 删除原指令,让调用者处理 + SysYIROptUtils::usedelete(original); +} + +} // namespace sysy diff --git a/src/midend/Pass/Pass.cpp b/src/midend/Pass/Pass.cpp index 09de26e..c834e30 100644 --- a/src/midend/Pass/Pass.cpp +++ b/src/midend/Pass/Pass.cpp @@ -18,6 +18,7 @@ #include "LICM.h" #include "LoopStrengthReduction.h" #include "InductionVariableElimination.h" +#include "GlobalStrengthReduction.h" #include "Pass.h" #include #include @@ -179,6 +180,16 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR printPasses(); } + // 全局强度削弱优化,包括代数优化和魔数除法 + this->clearPasses(); + this->addPass(&GlobalStrengthReduction::ID); + this->run(); + + if(DEBUG) { + std::cout << "=== IR After Global Strength Reduction Optimizations ===\n"; + printPasses(); + } + // this->clearPasses(); // this->addPass(&Reg2Mem::ID); // this->run();