From 467f2f6b242c704dda62439e7fd79b622a0a9cfb Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Sat, 16 Aug 2025 15:38:41 +0800 Subject: [PATCH 01/12] =?UTF-8?q?[midend-GVN]=E5=88=9D=E6=AD=A5=E6=9E=84?= =?UTF-8?q?=E5=BB=BAGVN=EF=BC=8C=E8=83=BD=E5=A4=9F=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E9=83=A8=E5=88=86CSE=E6=97=A0=E6=B3=95=E5=A4=84=E7=90=86?= =?UTF-8?q?=E7=9A=84=E5=AD=90=E8=A1=A8=E8=BE=BE=E5=BC=8F=E4=BD=86=E6=98=AF?= =?UTF-8?q?=E6=9C=89=E9=94=99=E8=AF=AF=E9=9C=80=E8=A6=81debug=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/midend/CMakeLists.txt | 1 + src/midend/IR.cpp | 4 +- src/midend/Pass/Optimize/GVN.cpp | 450 +++++++++++++++++++++++++++++++ src/midend/Pass/Pass.cpp | 13 + 4 files changed, 466 insertions(+), 2 deletions(-) create mode 100644 src/midend/Pass/Optimize/GVN.cpp diff --git a/src/midend/CMakeLists.txt b/src/midend/CMakeLists.txt index b3b86cc..66fc461 100644 --- a/src/midend/CMakeLists.txt +++ b/src/midend/CMakeLists.txt @@ -15,6 +15,7 @@ add_library(midend_lib STATIC Pass/Optimize/DCE.cpp Pass/Optimize/Mem2Reg.cpp Pass/Optimize/Reg2Mem.cpp + Pass/Optimize/GVN.cpp Pass/Optimize/SysYIRCFGOpt.cpp Pass/Optimize/SCCP.cpp Pass/Optimize/LoopNormalization.cpp diff --git a/src/midend/IR.cpp b/src/midend/IR.cpp index 39293f2..d35e16b 100644 --- a/src/midend/IR.cpp +++ b/src/midend/IR.cpp @@ -847,7 +847,7 @@ void CondBrInst::print(std::ostream &os) const { os << "%tmp_cond_" << condName << "_" << uniqueSuffix << " = icmp ne i32 "; printOperand(os, condition); - os << ", 0\n br i1 %tmp_cond_" << condName << "_" << uniqueSuffix; + os << ", 0\n br i1 %tmp_cond_" << condName << "_" << uniqueSuffix; os << ", label %"; printBlockName(os, getThenBlock()); @@ -886,7 +886,7 @@ void MemsetInst::print(std::ostream &os) const { // This is done at print time to avoid modifying the IR structure os << "%tmp_bitcast_" << ptr->getName() << " = bitcast " << *ptr->getType() << " "; printOperand(os, ptr); - os << " to i8*\n "; + os << " to i8*\n "; // Now call memset with the bitcast result os << "call void @llvm.memset.p0i8.i32(i8* %tmp_bitcast_" << ptr->getName() << ", i8 "; diff --git a/src/midend/Pass/Optimize/GVN.cpp b/src/midend/Pass/Optimize/GVN.cpp new file mode 100644 index 0000000..a06ec5f --- /dev/null +++ b/src/midend/Pass/Optimize/GVN.cpp @@ -0,0 +1,450 @@ +#include "GVN.h" +#include "Dom.h" +#include "SysYIROptUtils.h" +#include +#include +#include + +extern int DEBUG; + +namespace sysy { + +// GVN 遍的静态 ID +void *GVN::ID = (void *)&GVN::ID; + +// ====================================================================== +// GVN 类的实现 +// ====================================================================== + +bool GVN::runOnFunction(Function *func, AnalysisManager &AM) { + if (func->getBasicBlocks().empty()) { + return false; + } + + if (DEBUG) { + std::cout << "\n=== Running GVN on function: " << func->getName() << " ===" << std::endl; + } + + bool changed = false; + GVNContext context; + context.run(func, &AM, changed); + + if (DEBUG) { + if (changed) { + std::cout << "GVN: Function " << func->getName() << " was modified" << std::endl; + } else { + std::cout << "GVN: Function " << func->getName() << " was not modified" << std::endl; + } + std::cout << "=== GVN completed for function: " << func->getName() << " ===" << std::endl; + } + + return changed; +} + +void GVN::getAnalysisUsage(std::set &analysisDependencies, std::set &analysisInvalidations) const { + // GVN依赖以下分析: + // 1. 支配树分析 - 用于检查指令的支配关系,确保替换的安全性 + analysisDependencies.insert(&DominatorTreeAnalysisPass::ID); + + // 2. 副作用分析 - 用于判断函数调用是否可以进行GVN + analysisDependencies.insert(&SysYSideEffectAnalysisPass::ID); + + // GVN不会使任何分析失效,因为: + // - GVN只删除冗余计算,不改变CFG结构 + // - GVN不修改程序的语义,只是消除重复计算 + // - 支配关系保持不变 + // - 副作用分析结果保持不变 + // analysisInvalidations 保持为空 + + if (DEBUG) { + std::cout << "GVN: Declared analysis dependencies (DominatorTree, SideEffectAnalysis)" << std::endl; + } +} + +// ====================================================================== +// GVNContext 类的实现 +// ====================================================================== + +void GVNContext::run(Function *func, AnalysisManager *AM, bool &changed) { + if (DEBUG) { + std::cout << " Starting GVN analysis for function: " << func->getName() << std::endl; + } + + // 获取分析结果 + if (AM) { + domTree = AM->getAnalysisResult(func); + sideEffectAnalysis = AM->getAnalysisResult(); + + if (DEBUG) { + if (domTree) { + std::cout << " GVN: Using dominator tree analysis" << std::endl; + } else { + std::cout << " GVN: Warning - dominator tree analysis not available" << std::endl; + } + if (sideEffectAnalysis) { + std::cout << " GVN: Using side effect analysis" << std::endl; + } else { + std::cout << " GVN: Warning - side effect analysis not available" << std::endl; + } + } + } + + // 清空状态 + hashtable.clear(); + visited.clear(); + rpoBlocks.clear(); + needRemove.clear(); + + // 计算逆后序遍历 + computeRPO(func); + + if (DEBUG) { + std::cout << " Computed RPO with " << rpoBlocks.size() << " blocks" << std::endl; + } + + // 按逆后序遍历基本块进行GVN + int blockCount = 0; + for (auto bb : rpoBlocks) { + if (DEBUG) { + std::cout << " Processing block " << ++blockCount << "/" << rpoBlocks.size() + << ": " << bb->getName() << std::endl; + } + + int instCount = 0; + for (auto &instPtr : bb->getInstructions()) { + if (DEBUG) { + std::cout << " Processing instruction " << ++instCount + << ": " << instPtr->getName() << std::endl; + } + visitInstruction(instPtr.get()); + } + } + + if (DEBUG) { + std::cout << " Found " << needRemove.size() << " redundant instructions to remove" << std::endl; + } + + // 删除冗余指令 + int removeCount = 0; + for (auto inst : needRemove) { + auto bb = inst->getParent(); + if (DEBUG) { + std::cout << " Removing redundant instruction " << ++removeCount + << "/" << needRemove.size() << ": " << inst->getName() << std::endl; + } + // 删除指令前先断开所有使用关系 + inst->replaceAllUsesWith(nullptr); + // 使用基本块的删除方法 + // bb->removeInst(inst); + SysYIROptUtils::usedelete(inst); + changed = true; + } + + if (DEBUG) { + std::cout << " GVN analysis completed for function: " << func->getName() << std::endl; + std::cout << " Total instructions analyzed: " << hashtable.size() << std::endl; + std::cout << " Instructions eliminated: " << needRemove.size() << std::endl; + } +} + +void GVNContext::computeRPO(Function *func) { + rpoBlocks.clear(); + visited.clear(); + + auto entry = func->getEntryBlock(); + if (entry) { + dfs(entry); + std::reverse(rpoBlocks.begin(), rpoBlocks.end()); + } +} + +void GVNContext::dfs(BasicBlock *bb) { + if (!bb || visited.count(bb)) { + return; + } + + visited.insert(bb); + + // 访问所有后继基本块 + for (auto succ : bb->getSuccessors()) { + if (visited.find(succ) == visited.end()) { + dfs(succ); + } + } + + rpoBlocks.push_back(bb); +} + +Value *GVNContext::checkHashtable(Value *value) { + if (auto it = hashtable.find(value); it != hashtable.end()) { + return it->second; + } + + if (auto inst = dynamic_cast(value)) { + if (auto valueNumber = getValueNumber(inst)) { + hashtable[value] = valueNumber; + return valueNumber; + } + } + + hashtable[value] = value; + return value; +} + +Value *GVNContext::getValueNumber(Instruction *inst) { + if (auto binary = dynamic_cast(inst)) { + return getValueNumber(binary); + } else if (auto unary = dynamic_cast(inst)) { + return getValueNumber(unary); + } else if (auto gep = dynamic_cast(inst)) { + return getValueNumber(gep); + } else if (auto load = dynamic_cast(inst)) { + return getValueNumber(load); + } else if (auto call = dynamic_cast(inst)) { + return getValueNumber(call); + } + + return nullptr; +} + +Value *GVNContext::getValueNumber(BinaryInst *inst) { + auto lhs = checkHashtable(inst->getLhs()); + auto rhs = checkHashtable(inst->getRhs()); + + if (DEBUG) { + std::cout << " Checking binary instruction: " << inst->getName() + << " (kind: " << static_cast(inst->getKind()) << ")" << std::endl; + } + + for (auto [key, value] : hashtable) { + if (auto binary = dynamic_cast(key)) { + auto binLhs = checkHashtable(binary->getLhs()); + auto binRhs = checkHashtable(binary->getRhs()); + + if (binary->getKind() == inst->getKind()) { + // 检查操作数是否匹配 + if ((lhs == binLhs && rhs == binRhs) || (inst->isCommutative() && lhs == binRhs && rhs == binLhs)) { + if (DEBUG) { + std::cout << " Found equivalent binary instruction: " << binary->getName() << std::endl; + } + return value; + } + } + } + } + + if (DEBUG) { + std::cout << " No equivalent binary instruction found" << std::endl; + } + return inst; +} + +Value *GVNContext::getValueNumber(UnaryInst *inst) { + auto operand = checkHashtable(inst->getOperand()); + + for (auto [key, value] : hashtable) { + if (auto unary = dynamic_cast(key)) { + auto unOperand = checkHashtable(unary->getOperand()); + + if (unary->getKind() == inst->getKind() && operand == unOperand) { + return value; + } + } + } + + return inst; +} + +Value *GVNContext::getValueNumber(GetElementPtrInst *inst) { + auto ptr = checkHashtable(inst->getBasePointer()); + std::vector indices; + + // 使用正确的索引访问方法 + for (unsigned i = 0; i < inst->getNumIndices(); ++i) { + indices.push_back(checkHashtable(inst->getIndex(i))); + } + + for (auto [key, value] : hashtable) { + if (auto gep = dynamic_cast(key)) { + auto gepPtr = checkHashtable(gep->getBasePointer()); + + if (ptr == gepPtr && gep->getNumIndices() == inst->getNumIndices()) { + bool indicesMatch = true; + for (unsigned i = 0; i < inst->getNumIndices(); ++i) { + if (checkHashtable(gep->getIndex(i)) != indices[i]) { + indicesMatch = false; + break; + } + } + + if (indicesMatch && inst->getType() == gep->getType()) { + return value; + } + } + } + } + + return inst; +} + +Value *GVNContext::getValueNumber(LoadInst *inst) { + auto ptr = checkHashtable(inst->getPointer()); + + for (auto [key, value] : hashtable) { + if (auto load = dynamic_cast(key)) { + auto loadPtr = checkHashtable(load->getPointer()); + + if (ptr == loadPtr && inst->getType() == load->getType()) { + return value; + } + } + } + + return inst; +} + +Value *GVNContext::getValueNumber(CallInst *inst) { + // 只为无副作用的函数调用进行GVN + if (sideEffectAnalysis && !sideEffectAnalysis->isPureFunction(inst->getCallee())) { + return nullptr; + } + + for (auto [key, value] : hashtable) { + if (auto call = dynamic_cast(key)) { + if (call->getCallee() == inst->getCallee() && call->getNumOperands() == inst->getNumOperands()) { + + bool argsMatch = true; + // 跳过第一个操作数(函数指针),从参数开始比较 + for (size_t i = 1; i < inst->getNumOperands(); ++i) { + if (checkHashtable(inst->getOperand(i)) != checkHashtable(call->getOperand(i))) { + argsMatch = false; + break; + } + } + + if (argsMatch) { + return value; + } + } + } + } + + return inst; +} + +void GVNContext::visitInstruction(Instruction *inst) { + // 跳过分支指令 + if (inst->isBranch()) { + if (DEBUG) { + std::cout << " Skipping branch instruction: " << inst->getName() << std::endl; + } + return; + } + + if (DEBUG) { + std::cout << " Visiting instruction: " << inst->getName() + << " (kind: " << static_cast(inst->getKind()) << ")" << std::endl; + } + + auto value = checkHashtable(inst); + + if (inst != value) { + if (auto instValue = dynamic_cast(value)) { + if (canReplace(inst, instValue)) { + inst->replaceAllUsesWith(instValue); + needRemove.insert(inst); + + if (DEBUG) { + std::cout << " GVN: Replacing redundant instruction " << inst->getName() + << " with existing instruction " << instValue->getName() << std::endl; + } + } else { + if (DEBUG) { + std::cout << " Cannot replace instruction " << inst->getName() + << " with " << instValue->getName() << " (dominance check failed)" << std::endl; + } + } + } + } else { + if (DEBUG) { + std::cout << " Instruction " << inst->getName() << " is unique" << std::endl; + } + } +} + +bool GVNContext::canReplace(Instruction *original, Value *replacement) { + auto replInst = dynamic_cast(replacement); + if (!replInst) { + return true; // 替换为常量总是安全的 + } + + auto originalBB = original->getParent(); + auto replBB = replInst->getParent(); + + // 如果replacement是Call指令,需要特殊处理 + if (auto callInst = dynamic_cast(replInst)) { + if (sideEffectAnalysis && !sideEffectAnalysis->isPureFunction(callInst->getCallee())) { + // 对于有副作用的函数,只有在同一个基本块且相邻时才能替换 + if (originalBB != replBB) { + return false; + } + + // 检查指令顺序 + auto &insts = originalBB->getInstructions(); + auto origIt = + std::find_if(insts.begin(), insts.end(), [original](const auto &ptr) { return ptr.get() == original; }); + auto replIt = + std::find_if(insts.begin(), insts.end(), [replInst](const auto &ptr) { return ptr.get() == replInst; }); + + if (origIt == insts.end() || replIt == insts.end()) { + return false; + } + + return std::abs(std::distance(origIt, replIt)) == 1; + } + } + + // 简单的支配关系检查:如果在同一个基本块,检查指令顺序 + if (originalBB == replBB) { + auto &insts = originalBB->getInstructions(); + auto origIt = + std::find_if(insts.begin(), insts.end(), [original](const auto &ptr) { return ptr.get() == original; }); + auto replIt = + std::find_if(insts.begin(), insts.end(), [replInst](const auto &ptr) { return ptr.get() == replInst; }); + + // 替换指令必须在原指令之前 + return std::distance(insts.begin(), replIt) < std::distance(insts.begin(), origIt); + } + + // 使用支配关系检查(如果支配树分析可用) + if (domTree) { + auto dominators = domTree->getDominators(originalBB); + if (dominators && dominators->count(replBB)) { + return true; + } + } + + return false; +} + +std::string GVNContext::getCanonicalExpression(Instruction *inst) { + std::ostringstream oss; + + if (auto binary = dynamic_cast(inst)) { + oss << "binary_" << static_cast(binary->getKind()) << "_"; + oss << checkHashtable(binary->getLhs()) << "_"; + oss << checkHashtable(binary->getRhs()); + } else if (auto unary = dynamic_cast(inst)) { + oss << "unary_" << static_cast(unary->getKind()) << "_"; + oss << checkHashtable(unary->getOperand()); + } else if (auto gep = dynamic_cast(inst)) { + oss << "gep_" << checkHashtable(gep->getBasePointer()); + for (unsigned i = 0; i < gep->getNumIndices(); ++i) { + oss << "_" << checkHashtable(gep->getIndex(i)); + } + } + + return oss.str(); +} + +} // namespace sysy diff --git a/src/midend/Pass/Pass.cpp b/src/midend/Pass/Pass.cpp index 440ce0c..449508e 100644 --- a/src/midend/Pass/Pass.cpp +++ b/src/midend/Pass/Pass.cpp @@ -10,6 +10,7 @@ #include "DCE.h" #include "Mem2Reg.h" #include "Reg2Mem.h" +#include "GVN.h" #include "SCCP.h" #include "BuildCFG.h" #include "LargeArrayToGlobal.h" @@ -59,6 +60,8 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR // 注册优化遍 registerOptimizationPass(); registerOptimizationPass(); + + registerOptimizationPass(); registerOptimizationPass(); registerOptimizationPass(); @@ -129,6 +132,16 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR printPasses(); } + // 添加GVN优化遍 + this->clearPasses(); + this->addPass(&GVN::ID); + this->run(); + + if(DEBUG) { + std::cout << "=== IR After GVN Optimizations ===\n"; + printPasses(); + } + this->clearPasses(); this->addPass(&SCCP::ID); this->run(); From d038884ffba636ba6c55fe4e90a4853d223acf69 Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Sat, 16 Aug 2025 15:43:51 +0800 Subject: [PATCH 02/12] =?UTF-8?q?[midend-GVN]=20commit=E5=A4=B4=E6=96=87?= =?UTF-8?q?=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/include/midend/Pass/Optimize/GVN.h | 82 ++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 src/include/midend/Pass/Optimize/GVN.h diff --git a/src/include/midend/Pass/Optimize/GVN.h b/src/include/midend/Pass/Optimize/GVN.h new file mode 100644 index 0000000..3358e48 --- /dev/null +++ b/src/include/midend/Pass/Optimize/GVN.h @@ -0,0 +1,82 @@ +#pragma once + +#include "Pass.h" +#include "IR.h" +#include "Dom.h" +#include "SideEffectAnalysis.h" +#include +#include +#include +#include +#include + +namespace sysy { + +// GVN优化遍的核心逻辑封装类 +class GVNContext { +public: + // 运行GVN优化的主要方法 + void run(Function* func, AnalysisManager* AM, bool& changed); + +private: + // 值编号的哈希表:Value -> 代表值 + std::unordered_map hashtable; + + // 已访问的基本块集合 + std::unordered_set visited; + + // 逆后序遍历的基本块列表 + std::vector rpoBlocks; + + // 需要删除的指令集合 + std::unordered_set needRemove; + + // 分析结果 + DominatorTree* domTree = nullptr; + SideEffectAnalysisResult* sideEffectAnalysis = nullptr; + + // 计算逆后序遍历 + void computeRPO(Function* func); + void dfs(BasicBlock* bb); + + // 检查哈希表并获取值编号 + Value* checkHashtable(Value* value); + + // 为不同类型的指令获取值编号 + Value* getValueNumber(Instruction* inst); + Value* getValueNumber(BinaryInst* inst); + Value* getValueNumber(UnaryInst* inst); + Value* getValueNumber(GetElementPtrInst* inst); + Value* getValueNumber(LoadInst* inst); + Value* getValueNumber(CallInst* inst); + + // 访问指令并进行GVN优化 + void visitInstruction(Instruction* inst); + + // 检查是否可以安全地用一个值替换另一个值 + bool canReplace(Instruction* original, Value* replacement); + + // 生成表达式的标准化字符串 + std::string getCanonicalExpression(Instruction* inst); +}; + +// GVN优化遍类 +class GVN : public OptimizationPass { +public: + // 静态成员,作为该遍的唯一ID + static void* ID; + + GVN() : OptimizationPass("GVN", Granularity::Function) {} + + // 在函数上运行优化 + bool runOnFunction(Function* func, AnalysisManager& AM) override; + + // 返回该遍的唯一ID + void* getPassID() const override { return ID; } + + // 声明分析依赖 + void getAnalysisUsage(std::set& analysisDependencies, + std::set& analysisInvalidations) const override; +}; + +} // namespace sysy From c4eb1c39808730d2fd5034e0540a37bef0712de0 Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Sat, 16 Aug 2025 18:52:29 +0800 Subject: [PATCH 03/12] =?UTF-8?q?[midend-GVN&SideEffect]=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?GVN=E7=9A=84=E9=83=A8=E5=88=86=E9=97=AE=E9=A2=98=E5=92=8C?= =?UTF-8?q?=E5=89=AF=E4=BD=9C=E7=94=A8=E5=88=86=E6=9E=90=E7=9A=84=E7=BC=BA?= =?UTF-8?q?=E9=99=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/include/midend/Pass/Optimize/GVN.h | 3 + .../Pass/Analysis/SideEffectAnalysis.cpp | 11 +++ src/midend/Pass/Optimize/GVN.cpp | 75 +++++++++++++++++-- 3 files changed, 83 insertions(+), 6 deletions(-) diff --git a/src/include/midend/Pass/Optimize/GVN.h b/src/include/midend/Pass/Optimize/GVN.h index 3358e48..ce11769 100644 --- a/src/include/midend/Pass/Optimize/GVN.h +++ b/src/include/midend/Pass/Optimize/GVN.h @@ -56,6 +56,9 @@ private: // 检查是否可以安全地用一个值替换另一个值 bool canReplace(Instruction* original, Value* replacement); + // 检查两个load指令之间是否有store指令修改了相同的内存位置 + bool hasInterveningStore(LoadInst* earlierLoad, LoadInst* laterLoad, Value* ptr); + // 生成表达式的标准化字符串 std::string getCanonicalExpression(Instruction* inst); }; diff --git a/src/midend/Pass/Analysis/SideEffectAnalysis.cpp b/src/midend/Pass/Analysis/SideEffectAnalysis.cpp index 805f98b..3887add 100644 --- a/src/midend/Pass/Analysis/SideEffectAnalysis.cpp +++ b/src/midend/Pass/Analysis/SideEffectAnalysis.cpp @@ -26,10 +26,21 @@ const SideEffectInfo &SideEffectAnalysisResult::getInstructionSideEffect(Instruc } const SideEffectInfo &SideEffectAnalysisResult::getFunctionSideEffect(Function *func) const { + // 首先检查分析过的用户定义函数 auto it = functionSideEffects.find(func); if (it != functionSideEffects.end()) { return it->second; } + + // 如果没有找到,检查是否为已知的库函数 + if (func) { + std::string funcName = func->getName(); + const SideEffectInfo *knownInfo = getKnownFunctionSideEffect(funcName); + if (knownInfo) { + return *knownInfo; + } + } + // 返回默认的无副作用信息 static SideEffectInfo noEffect; return noEffect; diff --git a/src/midend/Pass/Optimize/GVN.cpp b/src/midend/Pass/Optimize/GVN.cpp index a06ec5f..76b0ffb 100644 --- a/src/midend/Pass/Optimize/GVN.cpp +++ b/src/midend/Pass/Optimize/GVN.cpp @@ -201,7 +201,11 @@ Value *GVNContext::getValueNumber(Instruction *inst) { } else if (auto load = dynamic_cast(inst)) { return getValueNumber(load); } else if (auto call = dynamic_cast(inst)) { - return getValueNumber(call); + // 只为无副作用的函数调用进行GVN + if (sideEffectAnalysis && sideEffectAnalysis->isPureFunction(call->getCallee())) { + return getValueNumber(call); + } + return nullptr; } return nullptr; @@ -295,6 +299,10 @@ Value *GVNContext::getValueNumber(LoadInst *inst) { auto loadPtr = checkHashtable(load->getPointer()); if (ptr == loadPtr && inst->getType() == load->getType()) { + // 检查两次load之间是否有store指令修改了内存 + if (hasInterveningStore(load, inst, ptr)) { + continue; // 如果有store指令,不能复用之前的load + } return value; } } @@ -304,11 +312,7 @@ Value *GVNContext::getValueNumber(LoadInst *inst) { } Value *GVNContext::getValueNumber(CallInst *inst) { - // 只为无副作用的函数调用进行GVN - if (sideEffectAnalysis && !sideEffectAnalysis->isPureFunction(inst->getCallee())) { - return nullptr; - } - + // 此时已经确认是无副作用的函数调用,可以安全进行GVN for (auto [key, value] : hashtable) { if (auto call = dynamic_cast(key)) { if (call->getCallee() == inst->getCallee() && call->getNumOperands() == inst->getNumOperands()) { @@ -427,6 +431,65 @@ bool GVNContext::canReplace(Instruction *original, Value *replacement) { return false; } +bool GVNContext::hasInterveningStore(LoadInst* earlierLoad, LoadInst* laterLoad, Value* ptr) { + // 如果两个load在不同的基本块,需要更复杂的分析 + auto earlierBB = earlierLoad->getParent(); + auto laterBB = laterLoad->getParent(); + + if (earlierBB != laterBB) { + // 跨基本块的情况:为了安全起见,暂时认为有intervening store + // 这是保守的做法,可能会错过一些优化机会,但确保正确性 + return true; + } + + // 同一基本块内的情况:检查指令序列 + auto &insts = earlierBB->getInstructions(); + + // 找到两个load指令的位置 + auto earlierIt = std::find_if(insts.begin(), insts.end(), + [earlierLoad](const auto &ptr) { return ptr.get() == earlierLoad; }); + auto laterIt = std::find_if(insts.begin(), insts.end(), + [laterLoad](const auto &ptr) { return ptr.get() == laterLoad; }); + + if (earlierIt == insts.end() || laterIt == insts.end()) { + return true; // 找不到指令,保守返回true + } + + // 检查两个load之间的所有指令 + for (auto it = std::next(earlierIt); it != laterIt; ++it) { + auto inst = it->get(); + + // 检查是否是store指令 + if (auto storeInst = dynamic_cast(inst)) { + auto storePtr = checkHashtable(storeInst->getPointer()); + + // 如果store的目标地址与load的地址相同,说明内存被修改了 + if (storePtr == ptr) { + if (DEBUG) { + std::cout << " Found intervening store to same address, cannot optimize load" << std::endl; + } + return true; + } + } + + // TODO: 还需要检查函数调用是否可能修改内存 + // 对于全局变量,任何函数调用都可能修改它 + if (auto callInst = dynamic_cast(inst)) { + if (sideEffectAnalysis && !sideEffectAnalysis->isPureFunction(callInst->getCallee())) { + // 如果是有副作用的函数调用,且load的是全局变量,则可能被修改 + if (auto globalPtr = dynamic_cast(ptr)) { + if (DEBUG) { + std::cout << " Found function call that may modify global variable, cannot optimize load" << std::endl; + } + return true; + } + } + } + } + + return false; // 没有找到会修改内存的指令 +} + std::string GVNContext::getCanonicalExpression(Instruction *inst) { std::ostringstream oss; From e32585fd251c934859a2dacb914913d4d17f1d36 Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Sun, 17 Aug 2025 00:14:47 +0800 Subject: [PATCH 04/12] =?UTF-8?q?[midend-GVN]=E4=BF=AE=E5=A4=8DGVN?= =?UTF-8?q?=E4=B8=AD=E9=83=A8=E5=88=86=E9=80=BB=E8=BE=91=E9=97=AE=E9=A2=98?= =?UTF-8?q?=EF=BC=8CLICM=E6=9C=89bug=E5=BE=85=E4=BF=AE=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/include/midend/Pass/Optimize/GVN.h | 3 ++ src/midend/Pass/Optimize/GVN.cpp | 41 ++++++++++++++++++++++++++ src/midend/Pass/Pass.cpp | 11 ++++--- 3 files changed, 49 insertions(+), 6 deletions(-) diff --git a/src/include/midend/Pass/Optimize/GVN.h b/src/include/midend/Pass/Optimize/GVN.h index ce11769..2aafd8d 100644 --- a/src/include/midend/Pass/Optimize/GVN.h +++ b/src/include/midend/Pass/Optimize/GVN.h @@ -59,6 +59,9 @@ private: // 检查两个load指令之间是否有store指令修改了相同的内存位置 bool hasInterveningStore(LoadInst* earlierLoad, LoadInst* laterLoad, Value* ptr); + // 使受store指令影响的load指令失效 + void invalidateLoadsAffectedByStore(StoreInst* storeInst); + // 生成表达式的标准化字符串 std::string getCanonicalExpression(Instruction* inst); }; diff --git a/src/midend/Pass/Optimize/GVN.cpp b/src/midend/Pass/Optimize/GVN.cpp index 76b0ffb..9f28609 100644 --- a/src/midend/Pass/Optimize/GVN.cpp +++ b/src/midend/Pass/Optimize/GVN.cpp @@ -301,8 +301,14 @@ Value *GVNContext::getValueNumber(LoadInst *inst) { if (ptr == loadPtr && inst->getType() == load->getType()) { // 检查两次load之间是否有store指令修改了内存 if (hasInterveningStore(load, inst, ptr)) { + if (DEBUG) { + std::cout << " Found intervening store, cannot reuse load value" << std::endl; + } continue; // 如果有store指令,不能复用之前的load } + if (DEBUG) { + std::cout << " No intervening store found, can reuse load value" << std::endl; + } return value; } } @@ -345,6 +351,11 @@ void GVNContext::visitInstruction(Instruction *inst) { return; } + // 如果是store指令,需要清理hashtable中可能被影响的load指令 + if (auto storeInst = dynamic_cast(inst)) { + invalidateLoadsAffectedByStore(storeInst); + } + if (DEBUG) { std::cout << " Visiting instruction: " << inst->getName() << " (kind: " << static_cast(inst->getKind()) << ")" << std::endl; @@ -490,6 +501,36 @@ bool GVNContext::hasInterveningStore(LoadInst* earlierLoad, LoadInst* laterLoad, return false; // 没有找到会修改内存的指令 } +void GVNContext::invalidateLoadsAffectedByStore(StoreInst* storeInst) { + auto storePtr = checkHashtable(storeInst->getPointer()); + + if (DEBUG) { + std::cout << " Invalidating loads affected by store to address" << std::endl; + } + + // 查找hashtable中所有可能被这个store影响的load指令 + std::vector toRemove; + + for (auto& [key, value] : hashtable) { + if (auto loadInst = dynamic_cast(key)) { + auto loadPtr = checkHashtable(loadInst->getPointer()); + + // 如果load的地址与store的地址相同,则需要从hashtable中移除 + if (loadPtr == storePtr) { + toRemove.push_back(key); + if (DEBUG) { + std::cout << " Invalidating load from same address: " << loadInst->getName() << std::endl; + } + } + } + } + + // 从hashtable中移除被影响的load指令 + for (auto key : toRemove) { + hashtable.erase(key); + } +} + std::string GVNContext::getCanonicalExpression(Instruction *inst) { std::ostringstream oss; diff --git a/src/midend/Pass/Pass.cpp b/src/midend/Pass/Pass.cpp index 449508e..e5e6aab 100644 --- a/src/midend/Pass/Pass.cpp +++ b/src/midend/Pass/Pass.cpp @@ -132,7 +132,6 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR printPasses(); } - // 添加GVN优化遍 this->clearPasses(); this->addPass(&GVN::ID); this->run(); @@ -154,14 +153,14 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR this->clearPasses(); this->addPass(&LoopNormalizationPass::ID); this->addPass(&InductionVariableElimination::ID); - this->addPass(&LICM::ID); + // this->addPass(&LICM::ID); this->addPass(&LoopStrengthReduction::ID); this->run(); - if(DEBUG) { - std::cout << "=== IR After Loop Normalization, LICM, and Strength Reduction Optimizations ===\n"; - printPasses(); - } + // if(DEBUG) { + // std::cout << "=== IR After Loop Normalization, LICM, and Strength Reduction Optimizations ===\n"; + // printPasses(); + // } // this->clearPasses(); // this->addPass(&Reg2Mem::ID); From d83dc7a2e7f4cb2f2e06f1fe8c18c796b83b48a1 Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Sun, 17 Aug 2025 01:19:44 +0800 Subject: [PATCH 05/12] =?UTF-8?q?[midend-LICM][fix]=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E5=BE=AA=E7=8E=AF=E4=B8=8D=E5=8F=98=E9=87=8F=E7=9A=84=E8=AF=86?= =?UTF-8?q?=E5=88=AB=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../Pass/Analysis/LoopCharacteristics.h | 6 +- .../Pass/Analysis/LoopCharacteristics.cpp | 296 ++++++++++++++++-- src/midend/Pass/Optimize/LICM.cpp | 5 +- src/midend/Pass/Pass.cpp | 27 +- 4 files changed, 300 insertions(+), 34 deletions(-) diff --git a/src/include/midend/Pass/Analysis/LoopCharacteristics.h b/src/include/midend/Pass/Analysis/LoopCharacteristics.h index e5ccafb..1f4a7ae 100644 --- a/src/include/midend/Pass/Analysis/LoopCharacteristics.h +++ b/src/include/midend/Pass/Analysis/LoopCharacteristics.h @@ -350,7 +350,11 @@ private: std::set& visited ); bool isBasicInductionVariable(Value* val, Loop* loop); - bool hasSimpleMemoryPattern(Loop* loop); // 简单的内存模式检查 + // ========== 循环不变量分析辅助方法 ========== + bool isInvariantOperands(Instruction* inst, Loop* loop, const std::unordered_set& invariants); + bool isMemoryLocationModifiedInLoop(Value* ptr, Loop* loop); + bool isMemoryLocationLoadedInLoop(Value* ptr, Loop* loop, Instruction* excludeInst = nullptr); + bool isPureFunction(Function* calledFunc); }; } // namespace sysy diff --git a/src/midend/Pass/Analysis/LoopCharacteristics.cpp b/src/midend/Pass/Analysis/LoopCharacteristics.cpp index daef567..eb58af1 100644 --- a/src/midend/Pass/Analysis/LoopCharacteristics.cpp +++ b/src/midend/Pass/Analysis/LoopCharacteristics.cpp @@ -776,38 +776,282 @@ void LoopCharacteristicsPass::findDerivedInductionVars( } } -// 递归/推进式判定 -bool LoopCharacteristicsPass::isClassicLoopInvariant(Value* val, Loop* loop, const std::unordered_set& invariants) { - // 1. 常量 - if (auto* constval = dynamic_cast(val)) return true; - - // 2. 参数(函数参数)通常不在任何BasicBlock内,直接判定为不变量 - if (auto* arg = dynamic_cast(val)) return true; - - // 3. 指令且定义在循环外 - if (auto* inst = dynamic_cast(val)) { - if (!loop->contains(inst->getParent())) - return true; - - // 4. 跳转 phi指令 副作用 不外提 - if (inst->isTerminator() || inst->isPhi() || sideEffectAnalysis->hasSideEffect(inst)) +// 检查操作数是否都是不变量 +bool LoopCharacteristicsPass::isInvariantOperands(Instruction* inst, Loop* loop, const std::unordered_set& invariants) { + for (size_t i = 0; i < inst->getNumOperands(); ++i) { + Value* op = inst->getOperand(i); + if (!isClassicLoopInvariant(op, loop, invariants) && !invariants.count(op)) { return false; - - // 5. 所有操作数都是不变量 - for (size_t i = 0; i < inst->getNumOperands(); ++i) { - Value* op = inst->getOperand(i); - if (!isClassicLoopInvariant(op, loop, invariants) && !invariants.count(op)) - return false; } - return true; } - // 其它情况 + return true; +} + +// 检查内存位置是否在循环中被修改 +bool LoopCharacteristicsPass::isMemoryLocationModifiedInLoop(Value* ptr, Loop* loop) { + // 遍历循环中的所有Store指令,检查是否有对该内存位置的写入 + for (BasicBlock* bb : loop->getBlocks()) { + for (auto& inst : bb->getInstructions()) { + if (auto* storeInst = dynamic_cast(inst.get())) { + Value* storeTar = storeInst->getPointer(); + + // 使用别名分析检查是否可能别名 + if (aliasAnalysis) { + auto aliasType = aliasAnalysis->queryAlias(ptr, storeTar); + if (aliasType != AliasType::NO_ALIAS) { + if (DEBUG) { + std::cout << " Memory location " << ptr->getName() + << " may be modified by store to " << storeTar->getName() << std::endl; + } + return true; + } + } else { + // 如果没有别名分析,保守处理 - 只检查精确匹配 + if (ptr == storeTar) { + return true; + } + } + } + } + } return false; } -bool LoopCharacteristicsPass::hasSimpleMemoryPattern(Loop* loop) { - // 检查是否有简单的内存访问模式 - return true; // 暂时简化处理 +// 检查内存位置是否在循环中被读取 +bool LoopCharacteristicsPass::isMemoryLocationLoadedInLoop(Value* ptr, Loop* loop, Instruction* excludeInst) { + // 遍历循环中的所有Load指令,检查是否有对该内存位置的读取 + for (BasicBlock* bb : loop->getBlocks()) { + for (auto& inst : bb->getInstructions()) { + if (inst.get() == excludeInst) continue; // 排除当前指令本身 + + if (auto* loadInst = dynamic_cast(inst.get())) { + Value* loadSrc = loadInst->getPointer(); + + // 使用别名分析检查是否可能别名 + if (aliasAnalysis) { + auto aliasType = aliasAnalysis->queryAlias(ptr, loadSrc); + if (aliasType != AliasType::NO_ALIAS) { + return true; + } + } else { + // 如果没有别名分析,保守处理 - 只检查精确匹配 + if (ptr == loadSrc) { + return true; + } + } + } + } + } + return false; } +// 检查函数调用是否为纯函数 +bool LoopCharacteristicsPass::isPureFunction(Function* calledFunc) { + if (!calledFunc) return false; + + // 使用副作用分析检查函数是否为纯函数 + if (sideEffectAnalysis && sideEffectAnalysis->isPureFunction(calledFunc)) { + return true; + } + + // 检查是否为内置纯函数(如数学函数) + std::string funcName = calledFunc->getName(); + static const std::set pureFunctions = { + "abs", "fabs", "sqrt", "sin", "cos", "tan", "exp", "log", "pow", + "floor", "ceil", "round", "min", "max" + }; + + return pureFunctions.count(funcName) > 0; +} + +// 递归/推进式判定 - 完善版本 +bool LoopCharacteristicsPass::isClassicLoopInvariant(Value* val, Loop* loop, const std::unordered_set& invariants) { + if (DEBUG >= 2) { + std::cout << " Checking loop invariant for: " << val->getName() << std::endl; + } + + // 1. 常量 + if (auto* constval = dynamic_cast(val)) { + if (DEBUG >= 2) std::cout << " -> Constant: YES" << std::endl; + return true; + } + + // 2. 参数(函数参数)通常不在任何BasicBlock内,直接判定为不变量 + // 在SSA形式下,参数不会被重新赋值 + if (auto* arg = dynamic_cast(val)) { + if (DEBUG >= 2) std::cout << " -> Function argument: YES" << std::endl; + return true; + } + + // 3. 指令且定义在循环外 + if (auto* inst = dynamic_cast(val)) { + if (!loop->contains(inst->getParent())) { + if (DEBUG >= 2) std::cout << " -> Defined outside loop: YES" << std::endl; + return true; + } + + // 4. 跳转指令、phi指令不能外提 + if (inst->isTerminator() || inst->isPhi()) { + if (DEBUG >= 2) std::cout << " -> Terminator or PHI: NO" << std::endl; + return false; + } + + // 5. 根据指令类型进行具体分析 + switch (inst->getKind()) { + case Instruction::Kind::kStore: { + // Store指令:检查循环内是否有对该内存的load + auto* storeInst = dynamic_cast(inst); + Value* storePtr = storeInst->getPointer(); + + // 首先检查操作数是否不变 + if (!isInvariantOperands(inst, loop, invariants)) { + if (DEBUG >= 2) std::cout << " -> Store: operands not invariant: NO" << std::endl; + return false; + } + + // 检查是否有对该内存位置的load + if (isMemoryLocationLoadedInLoop(storePtr, loop, inst)) { + if (DEBUG >= 2) std::cout << " -> Store: memory location loaded in loop: NO" << std::endl; + return false; + } + + if (DEBUG >= 2) std::cout << " -> Store: safe to hoist: YES" << std::endl; + return true; + } + + case Instruction::Kind::kLoad: { + // Load指令:检查循环内是否有对该内存的store + auto* loadInst = dynamic_cast(inst); + Value* loadPtr = loadInst->getPointer(); + + // 首先检查指针操作数是否不变 + if (!isInvariantOperands(inst, loop, invariants)) { + if (DEBUG >= 2) std::cout << " -> Load: pointer not invariant: NO" << std::endl; + return false; + } + + // 检查是否有对该内存位置的store + if (isMemoryLocationModifiedInLoop(loadPtr, loop)) { + if (DEBUG >= 2) std::cout << " -> Load: memory location modified in loop: NO" << std::endl; + return false; + } + + if (DEBUG >= 2) std::cout << " -> Load: safe to hoist: YES" << std::endl; + return true; + } + + case Instruction::Kind::kCall: { + // Call指令:检查是否为纯函数且参数不变 + auto* callInst = dynamic_cast(inst); + Function* calledFunc = callInst->getCallee(); + + // 检查是否为纯函数 + if (!isPureFunction(calledFunc)) { + if (DEBUG >= 2) std::cout << " -> Call: not pure function: NO" << std::endl; + return false; + } + + // 检查参数是否都不变 + if (!isInvariantOperands(inst, loop, invariants)) { + if (DEBUG >= 2) std::cout << " -> Call: arguments not invariant: NO" << std::endl; + return false; + } + + if (DEBUG >= 2) std::cout << " -> Call: pure function with invariant args: YES" << std::endl; + return true; + } + + case Instruction::Kind::kGetElementPtr: { + // GEP指令:检查基址和索引是否都不变 + if (!isInvariantOperands(inst, loop, invariants)) { + if (DEBUG >= 2) std::cout << " -> GEP: base or indices not invariant: NO" << std::endl; + return false; + } + + if (DEBUG >= 2) std::cout << " -> GEP: base and indices invariant: YES" << std::endl; + return true; + } + + // 一元运算指令 + case Instruction::Kind::kNeg: + case Instruction::Kind::kNot: + case Instruction::Kind::kFNeg: + case Instruction::Kind::kFNot: + case Instruction::Kind::kFtoI: + case Instruction::Kind::kItoF: + case Instruction::Kind::kBitItoF: + case Instruction::Kind::kBitFtoI: { + // 检查操作数是否不变 + if (!isInvariantOperands(inst, loop, invariants)) { + if (DEBUG >= 2) std::cout << " -> Unary op: operand not invariant: NO" << std::endl; + return false; + } + + if (DEBUG >= 2) std::cout << " -> Unary op: operand invariant: YES" << std::endl; + return true; + } + + // 二元运算指令 + case Instruction::Kind::kAdd: + case Instruction::Kind::kSub: + case Instruction::Kind::kMul: + case Instruction::Kind::kDiv: + case Instruction::Kind::kRem: + case Instruction::Kind::kSll: + case Instruction::Kind::kSrl: + case Instruction::Kind::kSra: + case Instruction::Kind::kAnd: + case Instruction::Kind::kOr: + case Instruction::Kind::kFAdd: + case Instruction::Kind::kFSub: + case Instruction::Kind::kFMul: + case Instruction::Kind::kFDiv: + case Instruction::Kind::kICmpEQ: + case Instruction::Kind::kICmpNE: + case Instruction::Kind::kICmpLT: + case Instruction::Kind::kICmpGT: + case Instruction::Kind::kICmpLE: + case Instruction::Kind::kICmpGE: + case Instruction::Kind::kFCmpEQ: + case Instruction::Kind::kFCmpNE: + case Instruction::Kind::kFCmpLT: + case Instruction::Kind::kFCmpGT: + case Instruction::Kind::kFCmpLE: + case Instruction::Kind::kFCmpGE: + case Instruction::Kind::kMulh: { + // 检查所有操作数是否不变 + if (!isInvariantOperands(inst, loop, invariants)) { + if (DEBUG >= 2) std::cout << " -> Binary op: operands not invariant: NO" << std::endl; + return false; + } + + if (DEBUG >= 2) std::cout << " -> Binary op: operands invariant: YES" << std::endl; + return true; + } + + default: { + // 其他指令:使用副作用分析 + if (sideEffectAnalysis && sideEffectAnalysis->hasSideEffect(inst)) { + if (DEBUG >= 2) std::cout << " -> Other inst: has side effect: NO" << std::endl; + return false; + } + + // 检查操作数是否都不变 + if (!isInvariantOperands(inst, loop, invariants)) { + if (DEBUG >= 2) std::cout << " -> Other inst: operands not invariant: NO" << std::endl; + return false; + } + + if (DEBUG >= 2) std::cout << " -> Other inst: no side effect, operands invariant: YES" << std::endl; + return true; + } + } + } + + // 其它情况 + if (DEBUG >= 2) std::cout << " -> Other value type: NO" << std::endl; + return false; +} + + } // namespace sysy diff --git a/src/midend/Pass/Optimize/LICM.cpp b/src/midend/Pass/Optimize/LICM.cpp index b583066..3193dd3 100644 --- a/src/midend/Pass/Optimize/LICM.cpp +++ b/src/midend/Pass/Optimize/LICM.cpp @@ -55,10 +55,11 @@ bool LICMContext::hoistInstructions() { } } - // 检查是否全部排序,若未全部排序,说明有环(理论上不会) + // 检查是否全部排序,若未全部排序,打印错误信息 + // 这可能是因为存在循环依赖或其他问题导致无法完成拓扑排序 if (sorted.size() != workSet.size()) { if (DEBUG) - std::cerr << "LICM: Topological sort failed, possible dependency cycle." << std::endl; + std::cout << "LICM: Topological sort failed, possible dependency cycle." << std::endl; return false; } diff --git a/src/midend/Pass/Pass.cpp b/src/midend/Pass/Pass.cpp index e5e6aab..09de26e 100644 --- a/src/midend/Pass/Pass.cpp +++ b/src/midend/Pass/Pass.cpp @@ -153,14 +153,31 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR this->clearPasses(); this->addPass(&LoopNormalizationPass::ID); this->addPass(&InductionVariableElimination::ID); - // this->addPass(&LICM::ID); + this->run(); + + if(DEBUG) { + std::cout << "=== IR After Loop Normalization, Induction Variable Elimination ===\n"; + printPasses(); + } + + + this->clearPasses(); + this->addPass(&LICM::ID); + this->run(); + + if(DEBUG) { + std::cout << "=== IR After LICM ===\n"; + printPasses(); + } + + this->clearPasses(); this->addPass(&LoopStrengthReduction::ID); this->run(); - // if(DEBUG) { - // std::cout << "=== IR After Loop Normalization, LICM, and Strength Reduction Optimizations ===\n"; - // printPasses(); - // } + if(DEBUG) { + std::cout << "=== IR After Loop Normalization, and Strength Reduction Optimizations ===\n"; + printPasses(); + } // this->clearPasses(); // this->addPass(&Reg2Mem::ID); From 8763c0a11a330485e81492e10d501a4aa4fb2c3a Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Sun, 17 Aug 2025 01:35:03 +0800 Subject: [PATCH 06/12] =?UTF-8?q?[midend-LICM][fix]=E4=BF=AE=E6=94=B9?= =?UTF-8?q?=E8=AE=A1=E7=AE=97=E5=BE=AA=E7=8E=AF=E4=B8=8D=E5=8F=98=E9=87=8F?= =?UTF-8?q?=E4=BE=9D=E8=B5=96=E5=85=B3=E7=B3=BB=E7=9A=84=E6=8E=92=E5=BA=8F?= =?UTF-8?q?=E9=94=99=E8=AF=AF=EF=BC=8C=E4=BD=86=E6=98=AF=E5=BC=95=E5=85=A5?= =?UTF-8?q?=E4=BA=86=E5=BE=88=E5=A4=9ASegmentation=20fault=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/midend/Pass/Optimize/LICM.cpp | 173 ++++++++++++++++++++++++++++-- 1 file changed, 162 insertions(+), 11 deletions(-) diff --git a/src/midend/Pass/Optimize/LICM.cpp b/src/midend/Pass/Optimize/LICM.cpp index 3193dd3..6aef11d 100644 --- a/src/midend/Pass/Optimize/LICM.cpp +++ b/src/midend/Pass/Optimize/LICM.cpp @@ -18,38 +18,100 @@ bool LICMContext::hoistInstructions() { // 1. 先收集所有可外提指令 std::unordered_set workSet(chars->invariantInsts.begin(), chars->invariantInsts.end()); + if (DEBUG) { + std::cout << "LICM: Found " << workSet.size() << " candidate invariant instructions to hoist:" << std::endl; + for (auto *inst : workSet) { + std::cout << " - " << inst->getName() << " (kind: " << static_cast(inst->getKind()) + << ", in BB: " << inst->getParent()->getName() << ")" << std::endl; + } + } + // 2. 计算每个指令被依赖的次数(入度) std::unordered_map indegree; + std::unordered_map> dependencies; // 记录依赖关系 + std::unordered_map> dependents; // 记录被依赖关系 + for (auto *inst : workSet) { indegree[inst] = 0; + dependencies[inst] = {}; + dependents[inst] = {}; } + + if (DEBUG) { + std::cout << "LICM: Analyzing dependencies between invariant instructions..." << std::endl; + } + for (auto *inst : workSet) { for (size_t i = 0; i < inst->getNumOperands(); ++i) { if (auto *dep = dynamic_cast(inst->getOperand(i))) { if (workSet.count(dep)) { indegree[inst]++; + dependencies[inst].push_back(dep); + dependents[dep].push_back(inst); + + if (DEBUG) { + std::cout << " Dependency: " << inst->getName() << " depends on " << dep->getName() << std::endl; + } } } } } + if (DEBUG) { + std::cout << "LICM: Initial indegree analysis:" << std::endl; + for (auto &[inst, deg] : indegree) { + std::cout << " " << inst->getName() << ": indegree=" << deg; + if (deg > 0) { + std::cout << ", depends on: "; + for (auto *dep : dependencies[inst]) { + std::cout << dep->getName() << " "; + } + } + std::cout << std::endl; + } + } + // 3. Kahn拓扑排序 std::vector sorted; std::queue q; - for (auto &[inst, deg] : indegree) { - if (deg == 0) - q.push(inst); + + if (DEBUG) { + std::cout << "LICM: Starting topological sort..." << std::endl; } + + for (auto &[inst, deg] : indegree) { + if (deg == 0) { + q.push(inst); + if (DEBUG) { + std::cout << " Initial zero-indegree instruction: " << inst->getName() << std::endl; + } + } + } + + int sortStep = 0; while (!q.empty()) { auto *inst = q.front(); q.pop(); sorted.push_back(inst); - for (size_t i = 0; i < inst->getNumOperands(); ++i) { - if (auto *dep = dynamic_cast(inst->getOperand(i))) { - if (workSet.count(dep)) { - indegree[dep]--; - if (indegree[dep] == 0) - q.push(dep); + + if (DEBUG) { + std::cout << " Step " << (++sortStep) << ": Processing " << inst->getName() << std::endl; + } + + if (DEBUG) { + std::cout << " Reducing indegree of dependents of " << inst->getName() << std::endl; + } + + // 正确的拓扑排序:当处理一个指令时,应该减少其所有使用者(dependents)的入度 + for (auto *dependent : dependents[inst]) { + indegree[dependent]--; + if (DEBUG) { + std::cout << " Reducing indegree of " << dependent->getName() << " to " << indegree[dependent] << std::endl; + } + if (indegree[dependent] == 0) { + q.push(dependent); + if (DEBUG) { + std::cout << " Adding " << dependent->getName() << " to queue (indegree=0)" << std::endl; } } } @@ -58,23 +120,112 @@ bool LICMContext::hoistInstructions() { // 检查是否全部排序,若未全部排序,打印错误信息 // 这可能是因为存在循环依赖或其他问题导致无法完成拓扑排序 if (sorted.size() != workSet.size()) { - if (DEBUG) - std::cout << "LICM: Topological sort failed, possible dependency cycle." << std::endl; + if (DEBUG) { + std::cout << "LICM: Topological sort failed! Sorted " << sorted.size() + << " instructions out of " << workSet.size() << " total." << std::endl; + + // 找出未被排序的指令(形成循环依赖的指令) + std::unordered_set remaining; + for (auto *inst : workSet) { + bool found = false; + for (auto *sortedInst : sorted) { + if (inst == sortedInst) { + found = true; + break; + } + } + if (!found) { + remaining.insert(inst); + } + } + + std::cout << "LICM: Instructions involved in dependency cycle:" << std::endl; + for (auto *inst : remaining) { + std::cout << " - " << inst->getName() << " (indegree=" << indegree[inst] << ")" << std::endl; + std::cout << " Dependencies within cycle: "; + for (auto *dep : dependencies[inst]) { + if (remaining.count(dep)) { + std::cout << dep->getName() << " "; + } + } + std::cout << std::endl; + std::cout << " Dependents within cycle: "; + for (auto *dependent : dependents[inst]) { + if (remaining.count(dependent)) { + std::cout << dependent->getName() << " "; + } + } + std::cout << std::endl; + } + + // 尝试找出一个具体的循环路径 + std::cout << "LICM: Attempting to trace a dependency cycle:" << std::endl; + if (!remaining.empty()) { + auto *start = *remaining.begin(); + std::unordered_set visited; + std::vector path; + + std::function findCycle = [&](Instruction *current) -> bool { + if (visited.count(current)) { + // 找到环 + auto it = std::find(path.begin(), path.end(), current); + if (it != path.end()) { + std::cout << " Cycle found: "; + for (auto cycleIt = it; cycleIt != path.end(); ++cycleIt) { + std::cout << (*cycleIt)->getName() << " -> "; + } + std::cout << current->getName() << std::endl; + return true; + } + return false; + } + + visited.insert(current); + path.push_back(current); + + for (auto *dep : dependencies[current]) { + if (remaining.count(dep)) { + if (findCycle(dep)) { + return true; + } + } + } + + path.pop_back(); + return false; + }; + + findCycle(start); + } + } return false; } // 4. 按拓扑序外提 + if (DEBUG) { + std::cout << "LICM: Successfully completed topological sort. Hoisting instructions in order:" << std::endl; + } + for (auto *inst : sorted) { if (!inst) continue; BasicBlock *parent = inst->getParent(); if (parent && loop->contains(parent)) { + if (DEBUG) { + std::cout << " Hoisting " << inst->getName() << " from " << parent->getName() + << " to preheader " << preheader->getName() << std::endl; + } auto sourcePos = parent->findInstIterator(inst); auto targetPos = preheader->terminator(); parent->moveInst(sourcePos, targetPos, preheader); changed = true; } } + + if (DEBUG && changed) { + std::cout << "LICM: Successfully hoisted " << sorted.size() << " invariant instructions" << std::endl; + } + return changed; } // ---- LICM Pass Implementation ---- From 969a78a08817501f72e95265c556b0590d520c49 Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Sun, 17 Aug 2025 14:37:27 +0800 Subject: [PATCH 07/12] =?UTF-8?q?[midend-GVN]segmentation=20fault=E6=98=AF?= =?UTF-8?q?GVN=E5=BC=95=E5=85=A5=E7=9A=84=E5=B7=B2=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=EF=BC=8CLICM=E4=BB=8D=E7=84=B6=E6=9C=89=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/midend/Pass/Optimize/GVN.cpp | 253 ++++++++++++++++++++++++++++--- src/midend/Pass/Pass.cpp | 14 +- 2 files changed, 243 insertions(+), 24 deletions(-) diff --git a/src/midend/Pass/Optimize/GVN.cpp b/src/midend/Pass/Optimize/GVN.cpp index 9f28609..a2f1c57 100644 --- a/src/midend/Pass/Optimize/GVN.cpp +++ b/src/midend/Pass/Optimize/GVN.cpp @@ -176,18 +176,35 @@ void GVNContext::dfs(BasicBlock *bb) { } Value *GVNContext::checkHashtable(Value *value) { + // 避免无限递归:如果已经在哈希表中,直接返回映射的值 if (auto it = hashtable.find(value); it != hashtable.end()) { + if (DEBUG >= 2) { + std::cout << " Found " << value->getName() << " in hashtable, mapped to " + << it->second->getName() << std::endl; + } return it->second; } + // 如果是指令,尝试获取其值编号 if (auto inst = dynamic_cast(value)) { if (auto valueNumber = getValueNumber(inst)) { - hashtable[value] = valueNumber; - return valueNumber; + // 如果找到了等价的值,建立映射关系 + if (valueNumber != inst) { + hashtable[value] = valueNumber; + if (DEBUG >= 2) { + std::cout << " Mapping " << value->getName() << " to equivalent value " + << valueNumber->getName() << std::endl; + } + return valueNumber; + } } } + // 没有找到等价值,将自己映射到自己 hashtable[value] = value; + if (DEBUG >= 2) { + std::cout << " Mapping " << value->getName() << " to itself (unique)" << std::endl; + } return value; } @@ -227,11 +244,73 @@ Value *GVNContext::getValueNumber(BinaryInst *inst) { if (binary->getKind() == inst->getKind()) { // 检查操作数是否匹配 - if ((lhs == binLhs && rhs == binRhs) || (inst->isCommutative() && lhs == binRhs && rhs == binLhs)) { - if (DEBUG) { - std::cout << " Found equivalent binary instruction: " << binary->getName() << std::endl; + bool operandsMatch = false; + if (lhs == binLhs && rhs == binRhs) { + operandsMatch = true; + } else if (inst->isCommutative() && lhs == binRhs && rhs == binLhs) { + operandsMatch = true; + } + + if (operandsMatch) { + // 检查支配关系,确保替换是安全的 + if (canReplace(inst, binary)) { + // 对于涉及load指令的情况,需要特别检查 + bool hasLoadOperands = (dynamic_cast(lhs) != nullptr) || + (dynamic_cast(rhs) != nullptr); + + if (hasLoadOperands) { + // 检查是否有任何load操作数之间有intervening store + bool hasIntervening = false; + + auto loadLhs = dynamic_cast(lhs); + auto loadRhs = dynamic_cast(rhs); + auto binLoadLhs = dynamic_cast(binLhs); + auto binLoadRhs = dynamic_cast(binRhs); + + if (loadLhs && binLoadLhs) { + if (hasInterveningStore(binLoadLhs, loadLhs, checkHashtable(loadLhs->getPointer()))) { + hasIntervening = true; + } + } + + if (!hasIntervening && loadRhs && binLoadRhs) { + if (hasInterveningStore(binLoadRhs, loadRhs, checkHashtable(loadRhs->getPointer()))) { + hasIntervening = true; + } + } + + // 对于交换操作数的情况,也需要检查 + if (!hasIntervening && inst->isCommutative()) { + if (loadLhs && binLoadRhs) { + if (hasInterveningStore(binLoadRhs, loadLhs, checkHashtable(loadLhs->getPointer()))) { + hasIntervening = true; + } + } + + if (!hasIntervening && loadRhs && binLoadLhs) { + if (hasInterveningStore(binLoadLhs, loadRhs, checkHashtable(loadRhs->getPointer()))) { + hasIntervening = true; + } + } + } + + if (hasIntervening) { + if (DEBUG) { + std::cout << " Found equivalent binary but load operands have intervening store, skipping" << std::endl; + } + continue; + } + } + + if (DEBUG) { + std::cout << " Found equivalent binary instruction: " << binary->getName() << std::endl; + } + return value; + } else { + if (DEBUG) { + std::cout << " Found equivalent binary but dominance check failed: " << binary->getName() << std::endl; + } } - return value; } } } @@ -294,26 +373,47 @@ Value *GVNContext::getValueNumber(GetElementPtrInst *inst) { Value *GVNContext::getValueNumber(LoadInst *inst) { auto ptr = checkHashtable(inst->getPointer()); + if (DEBUG) { + std::cout << " Checking load instruction: " << inst->getName() + << " from address: " << ptr->getName() << std::endl; + } + for (auto [key, value] : hashtable) { if (auto load = dynamic_cast(key)) { auto loadPtr = checkHashtable(load->getPointer()); if (ptr == loadPtr && inst->getType() == load->getType()) { - // 检查两次load之间是否有store指令修改了内存 + if (DEBUG) { + std::cout << " Found potential equivalent load: " << load->getName() << std::endl; + } + + // 检查支配关系:load 必须支配 inst + if (!canReplace(inst, load)) { + if (DEBUG) { + std::cout << " Equivalent load does not dominate current load, skipping" << std::endl; + } + continue; + } + + // 检查是否有中间的store指令影响 if (hasInterveningStore(load, inst, ptr)) { if (DEBUG) { std::cout << " Found intervening store, cannot reuse load value" << std::endl; } continue; // 如果有store指令,不能复用之前的load } + if (DEBUG) { - std::cout << " No intervening store found, can reuse load value" << std::endl; + std::cout << " Can safely reuse load value from: " << load->getName() << std::endl; } return value; } } } + if (DEBUG) { + std::cout << " No equivalent load found" << std::endl; + } return inst; } @@ -427,8 +527,21 @@ bool GVNContext::canReplace(Instruction *original, Value *replacement) { auto replIt = std::find_if(insts.begin(), insts.end(), [replInst](const auto &ptr) { return ptr.get() == replInst; }); - // 替换指令必须在原指令之前 - return std::distance(insts.begin(), replIt) < std::distance(insts.begin(), origIt); + if (origIt == insts.end() || replIt == insts.end()) { + if (DEBUG) { + std::cout << " Cannot find instructions in basic block for dominance check" << std::endl; + } + return false; + } + + // 替换指令必须在原指令之前(支配原指令) + bool canRepl = std::distance(insts.begin(), replIt) < std::distance(insts.begin(), origIt); + if (DEBUG) { + std::cout << " Same block dominance check: " << (canRepl ? "PASS" : "FAIL") + << " (repl at " << std::distance(insts.begin(), replIt) + << ", orig at " << std::distance(insts.begin(), origIt) << ")" << std::endl; + } + return canRepl; } // 使用支配关系检查(如果支配树分析可用) @@ -450,6 +563,9 @@ bool GVNContext::hasInterveningStore(LoadInst* earlierLoad, LoadInst* laterLoad, if (earlierBB != laterBB) { // 跨基本块的情况:为了安全起见,暂时认为有intervening store // 这是保守的做法,可能会错过一些优化机会,但确保正确性 + if (DEBUG) { + std::cout << " Cross-block load optimization: conservatively assuming intervening store" << std::endl; + } return true; } @@ -463,11 +579,28 @@ bool GVNContext::hasInterveningStore(LoadInst* earlierLoad, LoadInst* laterLoad, [laterLoad](const auto &ptr) { return ptr.get() == laterLoad; }); if (earlierIt == insts.end() || laterIt == insts.end()) { + if (DEBUG) { + std::cout << " Could not find load instructions in basic block" << std::endl; + } return true; // 找不到指令,保守返回true } + // 确定实际的执行顺序(哪个load在前,哪个在后) + auto firstIt = earlierIt; + auto secondIt = laterIt; + + if (std::distance(insts.begin(), earlierIt) > std::distance(insts.begin(), laterIt)) { + // 如果"earlier"实际上在"later"之后,交换它们 + firstIt = laterIt; + secondIt = earlierIt; + if (DEBUG) { + std::cout << " Swapped load order: " << laterLoad->getName() + << " actually comes before " << earlierLoad->getName() << std::endl; + } + } + // 检查两个load之间的所有指令 - for (auto it = std::next(earlierIt); it != laterIt; ++it) { + for (auto it = std::next(firstIt); it != secondIt; ++it) { auto inst = it->get(); // 检查是否是store指令 @@ -477,27 +610,34 @@ bool GVNContext::hasInterveningStore(LoadInst* earlierLoad, LoadInst* laterLoad, // 如果store的目标地址与load的地址相同,说明内存被修改了 if (storePtr == ptr) { if (DEBUG) { - std::cout << " Found intervening store to same address, cannot optimize load" << std::endl; + std::cout << " Found intervening store to same address: " << storeInst->getName() << std::endl; } return true; } + + // TODO: 这里还应该检查别名分析,看store是否可能影响load的地址 + // 为了简化,现在只检查精确匹配 } - // TODO: 还需要检查函数调用是否可能修改内存 - // 对于全局变量,任何函数调用都可能修改它 + // 检查函数调用是否可能修改内存 if (auto callInst = dynamic_cast(inst)) { if (sideEffectAnalysis && !sideEffectAnalysis->isPureFunction(callInst->getCallee())) { // 如果是有副作用的函数调用,且load的是全局变量,则可能被修改 if (auto globalPtr = dynamic_cast(ptr)) { if (DEBUG) { - std::cout << " Found function call that may modify global variable, cannot optimize load" << std::endl; + std::cout << " Found function call that may modify global variable: " << callInst->getName() << std::endl; } return true; } + // TODO: 这里还应该检查函数是否可能修改通过指针参数传递的内存 } } } + if (DEBUG) { + std::cout << " No intervening store found between loads" << std::endl; + } + return false; // 没有找到会修改内存的指令 } @@ -508,9 +648,11 @@ void GVNContext::invalidateLoadsAffectedByStore(StoreInst* storeInst) { std::cout << " Invalidating loads affected by store to address" << std::endl; } - // 查找hashtable中所有可能被这个store影响的load指令 + // 查找hashtable中所有可能被这个store影响的指令 std::vector toRemove; + std::set invalidatedLoads; + // 第一步:找到所有被直接影响的load指令 for (auto& [key, value] : hashtable) { if (auto loadInst = dynamic_cast(key)) { auto loadPtr = checkHashtable(loadInst->getPointer()); @@ -518,6 +660,7 @@ void GVNContext::invalidateLoadsAffectedByStore(StoreInst* storeInst) { // 如果load的地址与store的地址相同,则需要从hashtable中移除 if (loadPtr == storePtr) { toRemove.push_back(key); + invalidatedLoads.insert(loadInst); if (DEBUG) { std::cout << " Invalidating load from same address: " << loadInst->getName() << std::endl; } @@ -525,10 +668,86 @@ void GVNContext::invalidateLoadsAffectedByStore(StoreInst* storeInst) { } } - // 从hashtable中移除被影响的load指令 + // 第二步:找到所有依赖被失效load的指令(如binary指令) + bool foundMore = true; + while (foundMore) { + foundMore = false; + std::vector additionalToRemove; + + for (auto& [key, value] : hashtable) { + // 跳过已经标记要删除的指令 + if (std::find(toRemove.begin(), toRemove.end(), key) != toRemove.end()) { + continue; + } + + bool shouldInvalidate = false; + + // 检查binary指令的操作数 + if (auto binaryInst = dynamic_cast(key)) { + auto lhs = checkHashtable(binaryInst->getLhs()); + auto rhs = checkHashtable(binaryInst->getRhs()); + + if (invalidatedLoads.count(lhs) || invalidatedLoads.count(rhs)) { + shouldInvalidate = true; + if (DEBUG) { + std::cout << " Invalidating binary instruction due to invalidated operand: " + << binaryInst->getName() << std::endl; + } + } + } + // 检查unary指令的操作数 + else if (auto unaryInst = dynamic_cast(key)) { + auto operand = checkHashtable(unaryInst->getOperand()); + if (invalidatedLoads.count(operand)) { + shouldInvalidate = true; + if (DEBUG) { + std::cout << " Invalidating unary instruction due to invalidated operand: " + << unaryInst->getName() << std::endl; + } + } + } + // 检查GEP指令的操作数 + else if (auto gepInst = dynamic_cast(key)) { + auto basePtr = checkHashtable(gepInst->getBasePointer()); + if (invalidatedLoads.count(basePtr)) { + shouldInvalidate = true; + } else { + // 检查索引操作数 + for (unsigned i = 0; i < gepInst->getNumIndices(); ++i) { + if (invalidatedLoads.count(checkHashtable(gepInst->getIndex(i)))) { + shouldInvalidate = true; + break; + } + } + } + if (shouldInvalidate && DEBUG) { + std::cout << " Invalidating GEP instruction due to invalidated operand: " + << gepInst->getName() << std::endl; + } + } + + if (shouldInvalidate) { + additionalToRemove.push_back(key); + if (auto inst = dynamic_cast(key)) { + invalidatedLoads.insert(inst); + } + foundMore = true; + } + } + + // 将新找到的失效指令加入移除列表 + toRemove.insert(toRemove.end(), additionalToRemove.begin(), additionalToRemove.end()); + } + + // 从hashtable中移除所有被影响的指令 for (auto key : toRemove) { hashtable.erase(key); } + + if (DEBUG && toRemove.size() > invalidatedLoads.size()) { + std::cout << " Total invalidated instructions: " << toRemove.size() + << " (including " << (toRemove.size() - invalidatedLoads.size()) << " dependent instructions)" << std::endl; + } } std::string GVNContext::getCanonicalExpression(Instruction *inst) { diff --git a/src/midend/Pass/Pass.cpp b/src/midend/Pass/Pass.cpp index 09de26e..69f9790 100644 --- a/src/midend/Pass/Pass.cpp +++ b/src/midend/Pass/Pass.cpp @@ -161,14 +161,14 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR } - this->clearPasses(); - this->addPass(&LICM::ID); - this->run(); + // this->clearPasses(); + // this->addPass(&LICM::ID); + // this->run(); - if(DEBUG) { - std::cout << "=== IR After LICM ===\n"; - printPasses(); - } + // if(DEBUG) { + // std::cout << "=== IR After LICM ===\n"; + // printPasses(); + // } this->clearPasses(); this->addPass(&LoopStrengthReduction::ID); From 8ca64610ebea02af370959aa58328daf8ba36a48 Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Sun, 17 Aug 2025 16:33:15 +0800 Subject: [PATCH 08/12] =?UTF-8?q?[midend-GVN]=E9=87=8D=E6=9E=84GVN?= =?UTF-8?q?=E7=9A=84=E5=80=BC=E7=BC=96=E5=8F=B7=E7=B3=BB=E7=BB=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/include/midend/Pass/Optimize/GVN.h | 43 +- src/midend/Pass/Optimize/GVN.cpp | 821 ++++++++----------------- 2 files changed, 291 insertions(+), 573 deletions(-) diff --git a/src/include/midend/Pass/Optimize/GVN.h b/src/include/midend/Pass/Optimize/GVN.h index 2aafd8d..552b82e 100644 --- a/src/include/midend/Pass/Optimize/GVN.h +++ b/src/include/midend/Pass/Optimize/GVN.h @@ -19,8 +19,11 @@ public: void run(Function* func, AnalysisManager* AM, bool& changed); private: - // 值编号的哈希表:Value -> 代表值 - std::unordered_map hashtable; + // 新的值编号系统 + std::unordered_map valueToNumber; // Value -> 值编号 + std::unordered_map numberToValue; // 值编号 -> 代表值 + std::unordered_map expressionToNumber; // 表达式 -> 值编号 + unsigned nextValueNumber = 1; // 已访问的基本块集合 std::unordered_set visited; @@ -39,31 +42,27 @@ private: void computeRPO(Function* func); void dfs(BasicBlock* bb); - // 检查哈希表并获取值编号 - Value* checkHashtable(Value* value); + // 新的值编号方法 + unsigned getValueNumber(Value* value); + unsigned assignValueNumber(Value* value); - // 为不同类型的指令获取值编号 - Value* getValueNumber(Instruction* inst); - Value* getValueNumber(BinaryInst* inst); - Value* getValueNumber(UnaryInst* inst); - Value* getValueNumber(GetElementPtrInst* inst); - Value* getValueNumber(LoadInst* inst); - Value* getValueNumber(CallInst* inst); + // 基本块处理 + void processBasicBlock(BasicBlock* bb, bool& changed); - // 访问指令并进行GVN优化 - void visitInstruction(Instruction* inst); + // 指令处理 + bool processInstruction(Instruction* inst); - // 检查是否可以安全地用一个值替换另一个值 - bool canReplace(Instruction* original, Value* replacement); + // 表达式构建和查找 + std::string buildExpressionKey(Instruction* inst); + Value* findExistingValue(const std::string& exprKey, Instruction* inst); - // 检查两个load指令之间是否有store指令修改了相同的内存位置 - bool hasInterveningStore(LoadInst* earlierLoad, LoadInst* laterLoad, Value* ptr); + // 支配关系和安全性检查 + bool dominates(Instruction* a, Instruction* b); + bool isMemorySafe(LoadInst* earlierLoad, LoadInst* laterLoad); - // 使受store指令影响的load指令失效 - void invalidateLoadsAffectedByStore(StoreInst* storeInst); - - // 生成表达式的标准化字符串 - std::string getCanonicalExpression(Instruction* inst); + // 清理方法 + void eliminateRedundantInstructions(bool& changed); + void invalidateMemoryValues(StoreInst* store); }; // GVN优化遍类 diff --git a/src/midend/Pass/Optimize/GVN.cpp b/src/midend/Pass/Optimize/GVN.cpp index a2f1c57..09b67a1 100644 --- a/src/midend/Pass/Optimize/GVN.cpp +++ b/src/midend/Pass/Optimize/GVN.cpp @@ -4,6 +4,8 @@ #include #include #include +#include +#include extern int DEBUG; @@ -62,9 +64,33 @@ void GVN::getAnalysisUsage(std::set &analysisDependencies, std::set operands; + Type* resultType; + + bool operator==(const ExpressionKey& other) const { + return type == other.type && opcode == other.opcode && + operands == other.operands && resultType == other.resultType; + } +}; + +struct ExpressionKeyHash { + size_t operator()(const ExpressionKey& key) const { + size_t hash = std::hash()(static_cast(key.type)) ^ + std::hash()(key.opcode); + for (auto op : key.operands) { + hash ^= std::hash()(op) + 0x9e3779b9 + (hash << 6) + (hash >> 2); + } + return hash; + } +}; + void GVNContext::run(Function *func, AnalysisManager *AM, bool &changed) { if (DEBUG) { std::cout << " Starting GVN analysis for function: " << func->getName() << std::endl; @@ -90,7 +116,10 @@ void GVNContext::run(Function *func, AnalysisManager *AM, bool &changed) { } // 清空状态 - hashtable.clear(); + valueToNumber.clear(); + numberToValue.clear(); + expressionToNumber.clear(); + nextValueNumber = 1; visited.clear(); rpoBlocks.clear(); needRemove.clear(); @@ -110,14 +139,7 @@ void GVNContext::run(Function *func, AnalysisManager *AM, bool &changed) { << ": " << bb->getName() << std::endl; } - int instCount = 0; - for (auto &instPtr : bb->getInstructions()) { - if (DEBUG) { - std::cout << " Processing instruction " << ++instCount - << ": " << instPtr->getName() << std::endl; - } - visitInstruction(instPtr.get()); - } + processBasicBlock(bb, changed); } if (DEBUG) { @@ -125,24 +147,11 @@ void GVNContext::run(Function *func, AnalysisManager *AM, bool &changed) { } // 删除冗余指令 - int removeCount = 0; - for (auto inst : needRemove) { - auto bb = inst->getParent(); - if (DEBUG) { - std::cout << " Removing redundant instruction " << ++removeCount - << "/" << needRemove.size() << ": " << inst->getName() << std::endl; - } - // 删除指令前先断开所有使用关系 - inst->replaceAllUsesWith(nullptr); - // 使用基本块的删除方法 - // bb->removeInst(inst); - SysYIROptUtils::usedelete(inst); - changed = true; - } + eliminateRedundantInstructions(changed); if (DEBUG) { std::cout << " GVN analysis completed for function: " << func->getName() << std::endl; - std::cout << " Total instructions analyzed: " << hashtable.size() << std::endl; + std::cout << " Total values numbered: " << valueToNumber.size() << std::endl; std::cout << " Instructions eliminated: " << needRemove.size() << std::endl; } } @@ -175,599 +184,309 @@ void GVNContext::dfs(BasicBlock *bb) { rpoBlocks.push_back(bb); } -Value *GVNContext::checkHashtable(Value *value) { - // 避免无限递归:如果已经在哈希表中,直接返回映射的值 - if (auto it = hashtable.find(value); it != hashtable.end()) { - if (DEBUG >= 2) { - std::cout << " Found " << value->getName() << " in hashtable, mapped to " - << it->second->getName() << std::endl; - } +unsigned GVNContext::getValueNumber(Value* value) { + // 如果已经有值编号,直接返回 + auto it = valueToNumber.find(value); + if (it != valueToNumber.end()) { return it->second; } + + // 为新值分配编号 + return assignValueNumber(value); +} - // 如果是指令,尝试获取其值编号 - if (auto inst = dynamic_cast(value)) { - if (auto valueNumber = getValueNumber(inst)) { - // 如果找到了等价的值,建立映射关系 - if (valueNumber != inst) { - hashtable[value] = valueNumber; - if (DEBUG >= 2) { - std::cout << " Mapping " << value->getName() << " to equivalent value " - << valueNumber->getName() << std::endl; - } - return valueNumber; - } - } - } - - // 没有找到等价值,将自己映射到自己 - hashtable[value] = value; +unsigned GVNContext::assignValueNumber(Value* value) { + unsigned number = nextValueNumber++; + valueToNumber[value] = number; + numberToValue[number] = value; + if (DEBUG >= 2) { - std::cout << " Mapping " << value->getName() << " to itself (unique)" << std::endl; + std::cout << " Assigned value number " << number + << " to " << value->getName() << std::endl; } - return value; + + return number; } -Value *GVNContext::getValueNumber(Instruction *inst) { - if (auto binary = dynamic_cast(inst)) { - return getValueNumber(binary); - } else if (auto unary = dynamic_cast(inst)) { - return getValueNumber(unary); - } else if (auto gep = dynamic_cast(inst)) { - return getValueNumber(gep); - } else if (auto load = dynamic_cast(inst)) { - return getValueNumber(load); - } else if (auto call = dynamic_cast(inst)) { - // 只为无副作用的函数调用进行GVN - if (sideEffectAnalysis && sideEffectAnalysis->isPureFunction(call->getCallee())) { - return getValueNumber(call); - } - return nullptr; - } - - return nullptr; -} - -Value *GVNContext::getValueNumber(BinaryInst *inst) { - auto lhs = checkHashtable(inst->getLhs()); - auto rhs = checkHashtable(inst->getRhs()); - - if (DEBUG) { - std::cout << " Checking binary instruction: " << inst->getName() - << " (kind: " << static_cast(inst->getKind()) << ")" << std::endl; - } - - for (auto [key, value] : hashtable) { - if (auto binary = dynamic_cast(key)) { - auto binLhs = checkHashtable(binary->getLhs()); - auto binRhs = checkHashtable(binary->getRhs()); - - if (binary->getKind() == inst->getKind()) { - // 检查操作数是否匹配 - bool operandsMatch = false; - if (lhs == binLhs && rhs == binRhs) { - operandsMatch = true; - } else if (inst->isCommutative() && lhs == binRhs && rhs == binLhs) { - operandsMatch = true; - } - - if (operandsMatch) { - // 检查支配关系,确保替换是安全的 - if (canReplace(inst, binary)) { - // 对于涉及load指令的情况,需要特别检查 - bool hasLoadOperands = (dynamic_cast(lhs) != nullptr) || - (dynamic_cast(rhs) != nullptr); - - if (hasLoadOperands) { - // 检查是否有任何load操作数之间有intervening store - bool hasIntervening = false; - - auto loadLhs = dynamic_cast(lhs); - auto loadRhs = dynamic_cast(rhs); - auto binLoadLhs = dynamic_cast(binLhs); - auto binLoadRhs = dynamic_cast(binRhs); - - if (loadLhs && binLoadLhs) { - if (hasInterveningStore(binLoadLhs, loadLhs, checkHashtable(loadLhs->getPointer()))) { - hasIntervening = true; - } - } - - if (!hasIntervening && loadRhs && binLoadRhs) { - if (hasInterveningStore(binLoadRhs, loadRhs, checkHashtable(loadRhs->getPointer()))) { - hasIntervening = true; - } - } - - // 对于交换操作数的情况,也需要检查 - if (!hasIntervening && inst->isCommutative()) { - if (loadLhs && binLoadRhs) { - if (hasInterveningStore(binLoadRhs, loadLhs, checkHashtable(loadLhs->getPointer()))) { - hasIntervening = true; - } - } - - if (!hasIntervening && loadRhs && binLoadLhs) { - if (hasInterveningStore(binLoadLhs, loadRhs, checkHashtable(loadRhs->getPointer()))) { - hasIntervening = true; - } - } - } - - if (hasIntervening) { - if (DEBUG) { - std::cout << " Found equivalent binary but load operands have intervening store, skipping" << std::endl; - } - continue; - } - } - - if (DEBUG) { - std::cout << " Found equivalent binary instruction: " << binary->getName() << std::endl; - } - return value; - } else { - if (DEBUG) { - std::cout << " Found equivalent binary but dominance check failed: " << binary->getName() << std::endl; - } - } - } - } - } - } - - if (DEBUG) { - std::cout << " No equivalent binary instruction found" << std::endl; - } - return inst; -} - -Value *GVNContext::getValueNumber(UnaryInst *inst) { - auto operand = checkHashtable(inst->getOperand()); - - for (auto [key, value] : hashtable) { - if (auto unary = dynamic_cast(key)) { - auto unOperand = checkHashtable(unary->getOperand()); - - if (unary->getKind() == inst->getKind() && operand == unOperand) { - return value; - } - } - } - - return inst; -} - -Value *GVNContext::getValueNumber(GetElementPtrInst *inst) { - auto ptr = checkHashtable(inst->getBasePointer()); - std::vector indices; - - // 使用正确的索引访问方法 - for (unsigned i = 0; i < inst->getNumIndices(); ++i) { - indices.push_back(checkHashtable(inst->getIndex(i))); - } - - for (auto [key, value] : hashtable) { - if (auto gep = dynamic_cast(key)) { - auto gepPtr = checkHashtable(gep->getBasePointer()); - - if (ptr == gepPtr && gep->getNumIndices() == inst->getNumIndices()) { - bool indicesMatch = true; - for (unsigned i = 0; i < inst->getNumIndices(); ++i) { - if (checkHashtable(gep->getIndex(i)) != indices[i]) { - indicesMatch = false; - break; - } - } - - if (indicesMatch && inst->getType() == gep->getType()) { - return value; - } - } - } - } - - return inst; -} - -Value *GVNContext::getValueNumber(LoadInst *inst) { - auto ptr = checkHashtable(inst->getPointer()); - - if (DEBUG) { - std::cout << " Checking load instruction: " << inst->getName() - << " from address: " << ptr->getName() << std::endl; - } - - for (auto [key, value] : hashtable) { - if (auto load = dynamic_cast(key)) { - auto loadPtr = checkHashtable(load->getPointer()); - - if (ptr == loadPtr && inst->getType() == load->getType()) { - if (DEBUG) { - std::cout << " Found potential equivalent load: " << load->getName() << std::endl; - } - - // 检查支配关系:load 必须支配 inst - if (!canReplace(inst, load)) { - if (DEBUG) { - std::cout << " Equivalent load does not dominate current load, skipping" << std::endl; - } - continue; - } - - // 检查是否有中间的store指令影响 - if (hasInterveningStore(load, inst, ptr)) { - if (DEBUG) { - std::cout << " Found intervening store, cannot reuse load value" << std::endl; - } - continue; // 如果有store指令,不能复用之前的load - } - - if (DEBUG) { - std::cout << " Can safely reuse load value from: " << load->getName() << std::endl; - } - return value; - } - } - } - - if (DEBUG) { - std::cout << " No equivalent load found" << std::endl; - } - return inst; -} - -Value *GVNContext::getValueNumber(CallInst *inst) { - // 此时已经确认是无副作用的函数调用,可以安全进行GVN - for (auto [key, value] : hashtable) { - if (auto call = dynamic_cast(key)) { - if (call->getCallee() == inst->getCallee() && call->getNumOperands() == inst->getNumOperands()) { - - bool argsMatch = true; - // 跳过第一个操作数(函数指针),从参数开始比较 - for (size_t i = 1; i < inst->getNumOperands(); ++i) { - if (checkHashtable(inst->getOperand(i)) != checkHashtable(call->getOperand(i))) { - argsMatch = false; - break; - } - } - - if (argsMatch) { - return value; - } - } - } - } - - return inst; -} - -void GVNContext::visitInstruction(Instruction *inst) { - // 跳过分支指令 - if (inst->isBranch()) { +void GVNContext::processBasicBlock(BasicBlock* bb, bool& changed) { + int instCount = 0; + for (auto &instPtr : bb->getInstructions()) { if (DEBUG) { - std::cout << " Skipping branch instruction: " << inst->getName() << std::endl; + std::cout << " Processing instruction " << ++instCount + << ": " << instPtr->getName() << std::endl; + } + + if (processInstruction(instPtr.get())) { + changed = true; } - return; } +} - // 如果是store指令,需要清理hashtable中可能被影响的load指令 - if (auto storeInst = dynamic_cast(inst)) { - invalidateLoadsAffectedByStore(storeInst); +bool GVNContext::processInstruction(Instruction* inst) { + // 跳过分支指令和其他不可优化的指令 + if (inst->isBranch() || dynamic_cast(inst) || + dynamic_cast(inst) || dynamic_cast(inst)) { + + // 如果是store指令,需要使相关的内存值失效 + if (auto store = dynamic_cast(inst)) { + invalidateMemoryValues(store); + } + + // 为这些指令分配值编号但不尝试优化 + getValueNumber(inst); + return false; } - + if (DEBUG) { - std::cout << " Visiting instruction: " << inst->getName() + std::cout << " Processing optimizable instruction: " << inst->getName() << " (kind: " << static_cast(inst->getKind()) << ")" << std::endl; } - - auto value = checkHashtable(inst); - - if (inst != value) { - if (auto instValue = dynamic_cast(value)) { - if (canReplace(inst, instValue)) { - inst->replaceAllUsesWith(instValue); - needRemove.insert(inst); - + + // 构建表达式键 + std::string exprKey = buildExpressionKey(inst); + if (exprKey.empty()) { + // 不可优化的指令,只分配值编号 + getValueNumber(inst); + return false; + } + + if (DEBUG >= 2) { + std::cout << " Expression key: " << exprKey << std::endl; + } + + // 查找已存在的等价值 + Value* existing = findExistingValue(exprKey, inst); + if (existing && existing != inst) { + // 检查支配关系 + if (auto existingInst = dynamic_cast(existing)) { + if (dominates(existingInst, inst)) { if (DEBUG) { - std::cout << " GVN: Replacing redundant instruction " << inst->getName() - << " with existing instruction " << instValue->getName() << std::endl; + std::cout << " GVN: Replacing " << inst->getName() + << " with existing " << existing->getName() << std::endl; } + + // 用已存在的值替换当前指令 + inst->replaceAllUsesWith(existing); + needRemove.insert(inst); + + // 将当前指令的值编号指向已存在的值 + unsigned existingNumber = getValueNumber(existing); + valueToNumber[inst] = existingNumber; + + return true; } else { if (DEBUG) { - std::cout << " Cannot replace instruction " << inst->getName() - << " with " << instValue->getName() << " (dominance check failed)" << std::endl; + std::cout << " Found equivalent but dominance check failed" << std::endl; } } } - } else { - if (DEBUG) { - std::cout << " Instruction " << inst->getName() << " is unique" << std::endl; - } } -} - -bool GVNContext::canReplace(Instruction *original, Value *replacement) { - auto replInst = dynamic_cast(replacement); - if (!replInst) { - return true; // 替换为常量总是安全的 + + // 没有找到等价值,为这个表达式分配新的值编号 + unsigned number = assignValueNumber(inst); + expressionToNumber[exprKey] = number; + + if (DEBUG) { + std::cout << " Instruction " << inst->getName() << " is unique" << std::endl; } - - auto originalBB = original->getParent(); - auto replBB = replInst->getParent(); - - // 如果replacement是Call指令,需要特殊处理 - if (auto callInst = dynamic_cast(replInst)) { - if (sideEffectAnalysis && !sideEffectAnalysis->isPureFunction(callInst->getCallee())) { - // 对于有副作用的函数,只有在同一个基本块且相邻时才能替换 - if (originalBB != replBB) { - return false; - } - - // 检查指令顺序 - auto &insts = originalBB->getInstructions(); - auto origIt = - std::find_if(insts.begin(), insts.end(), [original](const auto &ptr) { return ptr.get() == original; }); - auto replIt = - std::find_if(insts.begin(), insts.end(), [replInst](const auto &ptr) { return ptr.get() == replInst; }); - - if (origIt == insts.end() || replIt == insts.end()) { - return false; - } - - return std::abs(std::distance(origIt, replIt)) == 1; - } - } - - // 简单的支配关系检查:如果在同一个基本块,检查指令顺序 - if (originalBB == replBB) { - auto &insts = originalBB->getInstructions(); - auto origIt = - std::find_if(insts.begin(), insts.end(), [original](const auto &ptr) { return ptr.get() == original; }); - auto replIt = - std::find_if(insts.begin(), insts.end(), [replInst](const auto &ptr) { return ptr.get() == replInst; }); - - if (origIt == insts.end() || replIt == insts.end()) { - if (DEBUG) { - std::cout << " Cannot find instructions in basic block for dominance check" << std::endl; - } - return false; - } - - // 替换指令必须在原指令之前(支配原指令) - bool canRepl = std::distance(insts.begin(), replIt) < std::distance(insts.begin(), origIt); - if (DEBUG) { - std::cout << " Same block dominance check: " << (canRepl ? "PASS" : "FAIL") - << " (repl at " << std::distance(insts.begin(), replIt) - << ", orig at " << std::distance(insts.begin(), origIt) << ")" << std::endl; - } - return canRepl; - } - - // 使用支配关系检查(如果支配树分析可用) - if (domTree) { - auto dominators = domTree->getDominators(originalBB); - if (dominators && dominators->count(replBB)) { - return true; - } - } - + return false; } -bool GVNContext::hasInterveningStore(LoadInst* earlierLoad, LoadInst* laterLoad, Value* ptr) { - // 如果两个load在不同的基本块,需要更复杂的分析 +std::string GVNContext::buildExpressionKey(Instruction* inst) { + std::ostringstream oss; + + if (auto binary = dynamic_cast(inst)) { + oss << "binary_" << static_cast(binary->getKind()) << "_"; + oss << getValueNumber(binary->getLhs()) << "_" << getValueNumber(binary->getRhs()); + + // 对于可交换操作,确保操作数顺序一致 + if (binary->isCommutative()) { + unsigned lhsNum = getValueNumber(binary->getLhs()); + unsigned rhsNum = getValueNumber(binary->getRhs()); + if (lhsNum > rhsNum) { + oss.str(""); + oss << "binary_" << static_cast(binary->getKind()) << "_"; + oss << rhsNum << "_" << lhsNum; + } + } + } else if (auto unary = dynamic_cast(inst)) { + oss << "unary_" << static_cast(unary->getKind()) << "_"; + oss << getValueNumber(unary->getOperand()); + } else if (auto gep = dynamic_cast(inst)) { + oss << "gep_" << getValueNumber(gep->getBasePointer()); + for (unsigned i = 0; i < gep->getNumIndices(); ++i) { + oss << "_" << getValueNumber(gep->getIndex(i)); + } + } else if (auto load = dynamic_cast(inst)) { + oss << "load_" << getValueNumber(load->getPointer()); + oss << "_" << reinterpret_cast(load->getType()); // 类型区分 + } else if (auto call = dynamic_cast(inst)) { + // 只为无副作用的函数调用建立表达式 + if (sideEffectAnalysis && sideEffectAnalysis->isPureFunction(call->getCallee())) { + oss << "call_" << call->getCallee()->getName(); + for (size_t i = 1; i < call->getNumOperands(); ++i) { // 跳过函数指针 + oss << "_" << getValueNumber(call->getOperand(i)); + } + } else { + return ""; // 有副作用的函数调用不可优化 + } + } else { + return ""; // 不支持的指令类型 + } + + return oss.str(); +} + +Value* GVNContext::findExistingValue(const std::string& exprKey, Instruction* inst) { + auto it = expressionToNumber.find(exprKey); + if (it != expressionToNumber.end()) { + unsigned number = it->second; + auto valueIt = numberToValue.find(number); + if (valueIt != numberToValue.end()) { + Value* existing = valueIt->second; + + // 对于load指令,需要额外检查内存安全性 + if (auto loadInst = dynamic_cast(inst)) { + if (auto existingLoad = dynamic_cast(existing)) { + if (!isMemorySafe(existingLoad, loadInst)) { + return nullptr; + } + } + } + + return existing; + } + } + return nullptr; +} + +bool GVNContext::dominates(Instruction* a, Instruction* b) { + auto aBB = a->getParent(); + auto bBB = b->getParent(); + + // 同一基本块内的情况 + if (aBB == bBB) { + auto &insts = aBB->getInstructions(); + auto aIt = std::find_if(insts.begin(), insts.end(), + [a](const auto &ptr) { return ptr.get() == a; }); + auto bIt = std::find_if(insts.begin(), insts.end(), + [b](const auto &ptr) { return ptr.get() == b; }); + + if (aIt == insts.end() || bIt == insts.end()) { + return false; + } + + return std::distance(insts.begin(), aIt) < std::distance(insts.begin(), bIt); + } + + // 不同基本块的情况,使用支配树 + if (domTree) { + auto dominators = domTree->getDominators(bBB); + return dominators && dominators->count(aBB); + } + + return false; // 保守做法 +} + +bool GVNContext::isMemorySafe(LoadInst* earlierLoad, LoadInst* laterLoad) { + // 检查两个load是否访问相同的内存位置 + unsigned earlierPtr = getValueNumber(earlierLoad->getPointer()); + unsigned laterPtr = getValueNumber(laterLoad->getPointer()); + + if (earlierPtr != laterPtr) { + return false; // 不同的内存位置 + } + + // 检查类型是否匹配 + if (earlierLoad->getType() != laterLoad->getType()) { + return false; + } + + // 简单情况:如果在同一个基本块且没有中间的store,则安全 auto earlierBB = earlierLoad->getParent(); auto laterBB = laterLoad->getParent(); if (earlierBB != laterBB) { - // 跨基本块的情况:为了安全起见,暂时认为有intervening store - // 这是保守的做法,可能会错过一些优化机会,但确保正确性 - if (DEBUG) { - std::cout << " Cross-block load optimization: conservatively assuming intervening store" << std::endl; - } - return true; + // 跨基本块的情况需要更复杂的分析,暂时保守处理 + return false; } - // 同一基本块内的情况:检查指令序列 + // 同一基本块内检查是否有中间的store auto &insts = earlierBB->getInstructions(); - - // 找到两个load指令的位置 - auto earlierIt = std::find_if(insts.begin(), insts.end(), + auto earlierIt = std::find_if(insts.begin(), insts.end(), [earlierLoad](const auto &ptr) { return ptr.get() == earlierLoad; }); auto laterIt = std::find_if(insts.begin(), insts.end(), [laterLoad](const auto &ptr) { return ptr.get() == laterLoad; }); if (earlierIt == insts.end() || laterIt == insts.end()) { - if (DEBUG) { - std::cout << " Could not find load instructions in basic block" << std::endl; - } - return true; // 找不到指令,保守返回true + return false; } - // 确定实际的执行顺序(哪个load在前,哪个在后) - auto firstIt = earlierIt; - auto secondIt = laterIt; - - if (std::distance(insts.begin(), earlierIt) > std::distance(insts.begin(), laterIt)) { - // 如果"earlier"实际上在"later"之后,交换它们 - firstIt = laterIt; - secondIt = earlierIt; - if (DEBUG) { - std::cout << " Swapped load order: " << laterLoad->getName() - << " actually comes before " << earlierLoad->getName() << std::endl; - } + // 确保earlierLoad真的在laterLoad之前 + if (std::distance(insts.begin(), earlierIt) >= std::distance(insts.begin(), laterIt)) { + return false; } - // 检查两个load之间的所有指令 - for (auto it = std::next(firstIt); it != secondIt; ++it) { - auto inst = it->get(); - - // 检查是否是store指令 - if (auto storeInst = dynamic_cast(inst)) { - auto storePtr = checkHashtable(storeInst->getPointer()); - - // 如果store的目标地址与load的地址相同,说明内存被修改了 - if (storePtr == ptr) { - if (DEBUG) { - std::cout << " Found intervening store to same address: " << storeInst->getName() << std::endl; - } - return true; + // 检查中间是否有store指令修改了相同的内存位置 + for (auto it = std::next(earlierIt); it != laterIt; ++it) { + if (auto store = dynamic_cast(it->get())) { + unsigned storePtr = getValueNumber(store->getPointer()); + if (storePtr == earlierPtr) { + return false; // 找到中间的store } - - // TODO: 这里还应该检查别名分析,看store是否可能影响load的地址 - // 为了简化,现在只检查精确匹配 } // 检查函数调用是否可能修改内存 - if (auto callInst = dynamic_cast(inst)) { - if (sideEffectAnalysis && !sideEffectAnalysis->isPureFunction(callInst->getCallee())) { - // 如果是有副作用的函数调用,且load的是全局变量,则可能被修改 - if (auto globalPtr = dynamic_cast(ptr)) { - if (DEBUG) { - std::cout << " Found function call that may modify global variable: " << callInst->getName() << std::endl; - } - return true; - } - // TODO: 这里还应该检查函数是否可能修改通过指针参数传递的内存 + if (auto call = dynamic_cast(it->get())) { + if (sideEffectAnalysis && !sideEffectAnalysis->isPureFunction(call->getCallee())) { + // 保守处理:有副作用的函数可能修改内存 + return false; } } } - if (DEBUG) { - std::cout << " No intervening store found between loads" << std::endl; - } - - return false; // 没有找到会修改内存的指令 + return true; // 安全 } -void GVNContext::invalidateLoadsAffectedByStore(StoreInst* storeInst) { - auto storePtr = checkHashtable(storeInst->getPointer()); +void GVNContext::invalidateMemoryValues(StoreInst* store) { + unsigned storePtr = getValueNumber(store->getPointer()); if (DEBUG) { - std::cout << " Invalidating loads affected by store to address" << std::endl; + std::cout << " Invalidating memory values affected by store" << std::endl; } - // 查找hashtable中所有可能被这个store影响的指令 - std::vector toRemove; - std::set invalidatedLoads; + // 找到所有可能被这个store影响的load表达式 + std::vector toRemove; - // 第一步:找到所有被直接影响的load指令 - for (auto& [key, value] : hashtable) { - if (auto loadInst = dynamic_cast(key)) { - auto loadPtr = checkHashtable(loadInst->getPointer()); - - // 如果load的地址与store的地址相同,则需要从hashtable中移除 - if (loadPtr == storePtr) { - toRemove.push_back(key); - invalidatedLoads.insert(loadInst); - if (DEBUG) { - std::cout << " Invalidating load from same address: " << loadInst->getName() << std::endl; - } + for (auto& [exprKey, number] : expressionToNumber) { + if (exprKey.find("load_" + std::to_string(storePtr)) == 0) { + toRemove.push_back(exprKey); + if (DEBUG) { + std::cout << " Invalidating expression: " << exprKey << std::endl; } } } - // 第二步:找到所有依赖被失效load的指令(如binary指令) - bool foundMore = true; - while (foundMore) { - foundMore = false; - std::vector additionalToRemove; - - for (auto& [key, value] : hashtable) { - // 跳过已经标记要删除的指令 - if (std::find(toRemove.begin(), toRemove.end(), key) != toRemove.end()) { - continue; - } - - bool shouldInvalidate = false; - - // 检查binary指令的操作数 - if (auto binaryInst = dynamic_cast(key)) { - auto lhs = checkHashtable(binaryInst->getLhs()); - auto rhs = checkHashtable(binaryInst->getRhs()); - - if (invalidatedLoads.count(lhs) || invalidatedLoads.count(rhs)) { - shouldInvalidate = true; - if (DEBUG) { - std::cout << " Invalidating binary instruction due to invalidated operand: " - << binaryInst->getName() << std::endl; - } - } - } - // 检查unary指令的操作数 - else if (auto unaryInst = dynamic_cast(key)) { - auto operand = checkHashtable(unaryInst->getOperand()); - if (invalidatedLoads.count(operand)) { - shouldInvalidate = true; - if (DEBUG) { - std::cout << " Invalidating unary instruction due to invalidated operand: " - << unaryInst->getName() << std::endl; - } - } - } - // 检查GEP指令的操作数 - else if (auto gepInst = dynamic_cast(key)) { - auto basePtr = checkHashtable(gepInst->getBasePointer()); - if (invalidatedLoads.count(basePtr)) { - shouldInvalidate = true; - } else { - // 检查索引操作数 - for (unsigned i = 0; i < gepInst->getNumIndices(); ++i) { - if (invalidatedLoads.count(checkHashtable(gepInst->getIndex(i)))) { - shouldInvalidate = true; - break; - } - } - } - if (shouldInvalidate && DEBUG) { - std::cout << " Invalidating GEP instruction due to invalidated operand: " - << gepInst->getName() << std::endl; - } - } - - if (shouldInvalidate) { - additionalToRemove.push_back(key); - if (auto inst = dynamic_cast(key)) { - invalidatedLoads.insert(inst); - } - foundMore = true; - } - } - - // 将新找到的失效指令加入移除列表 - toRemove.insert(toRemove.end(), additionalToRemove.begin(), additionalToRemove.end()); - } - - // 从hashtable中移除所有被影响的指令 - for (auto key : toRemove) { - hashtable.erase(key); - } - - if (DEBUG && toRemove.size() > invalidatedLoads.size()) { - std::cout << " Total invalidated instructions: " << toRemove.size() - << " (including " << (toRemove.size() - invalidatedLoads.size()) << " dependent instructions)" << std::endl; + // 移除失效的表达式 + for (const auto& key : toRemove) { + expressionToNumber.erase(key); } } -std::string GVNContext::getCanonicalExpression(Instruction *inst) { - std::ostringstream oss; - - if (auto binary = dynamic_cast(inst)) { - oss << "binary_" << static_cast(binary->getKind()) << "_"; - oss << checkHashtable(binary->getLhs()) << "_"; - oss << checkHashtable(binary->getRhs()); - } else if (auto unary = dynamic_cast(inst)) { - oss << "unary_" << static_cast(unary->getKind()) << "_"; - oss << checkHashtable(unary->getOperand()); - } else if (auto gep = dynamic_cast(inst)) { - oss << "gep_" << checkHashtable(gep->getBasePointer()); - for (unsigned i = 0; i < gep->getNumIndices(); ++i) { - oss << "_" << checkHashtable(gep->getIndex(i)); +void GVNContext::eliminateRedundantInstructions(bool& changed) { + int removeCount = 0; + for (auto inst : needRemove) { + if (DEBUG) { + std::cout << " Removing redundant instruction " << ++removeCount + << "/" << needRemove.size() << ": " << inst->getName() << std::endl; } + + // 删除指令前先断开所有使用关系 + // inst->replaceAllUsesWith 已在 processInstruction 中调用 + SysYIROptUtils::usedelete(inst); + changed = true; } - - return oss.str(); } } // namespace sysy From f317010d76894b95ebd93035742f6d8d80a2fec9 Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Sun, 17 Aug 2025 17:42:19 +0800 Subject: [PATCH 09/12] =?UTF-8?q?[midend-Loop-LICM][fix]=E6=A3=80=E6=9F=A5?= =?UTF-8?q?load=E8=83=BD=E5=90=A6=E5=A4=96=E6=8F=90=E6=97=B6=E5=85=B6?= =?UTF-8?q?=E5=86=85=E5=AD=98=E5=9C=B0=E5=9D=80=E5=9C=A8=E5=BE=AA=E7=8E=AF?= =?UTF-8?q?=E4=B8=AD=E6=98=AF=E5=90=A6=E4=BC=9A=E8=A2=AB=E4=BF=AE=E6=94=B9?= =?UTF-8?q?=EF=BC=8C=E9=9C=80=E8=A6=81=E5=88=A4=E6=96=AD=E5=87=BD=E6=95=B0?= =?UTF-8?q?=E8=B0=83=E7=94=A8=E5=AF=B9load=E5=86=85=E5=AD=98=E5=9C=B0?= =?UTF-8?q?=E5=9D=80=E7=9A=84=E5=BD=B1=E5=93=8D=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../Pass/Analysis/LoopCharacteristics.cpp | 44 ++++++++++++++++++- src/midend/Pass/Pass.cpp | 14 +++--- 2 files changed, 50 insertions(+), 8 deletions(-) diff --git a/src/midend/Pass/Analysis/LoopCharacteristics.cpp b/src/midend/Pass/Analysis/LoopCharacteristics.cpp index eb58af1..5d2f2aa 100644 --- a/src/midend/Pass/Analysis/LoopCharacteristics.cpp +++ b/src/midend/Pass/Analysis/LoopCharacteristics.cpp @@ -789,9 +789,10 @@ bool LoopCharacteristicsPass::isInvariantOperands(Instruction* inst, Loop* loop, // 检查内存位置是否在循环中被修改 bool LoopCharacteristicsPass::isMemoryLocationModifiedInLoop(Value* ptr, Loop* loop) { - // 遍历循环中的所有Store指令,检查是否有对该内存位置的写入 + // 遍历循环中的所有指令,检查是否有对该内存位置的写入 for (BasicBlock* bb : loop->getBlocks()) { for (auto& inst : bb->getInstructions()) { + // 1. 检查直接的Store指令 if (auto* storeInst = dynamic_cast(inst.get())) { Value* storeTar = storeInst->getPointer(); @@ -812,6 +813,47 @@ bool LoopCharacteristicsPass::isMemoryLocationModifiedInLoop(Value* ptr, Loop* l } } } + + // 2. 检查函数调用是否可能修改该内存位置 + else if (auto* callInst = dynamic_cast(inst.get())) { + Function* calledFunc = callInst->getCallee(); + + // 如果是纯函数,不会修改内存 + if (isPureFunction(calledFunc)) { + continue; + } + + // 检查函数参数中是否有该内存位置的指针 + for (size_t i = 1; i < callInst->getNumOperands(); ++i) { // 跳过函数指针 + Value* arg = callInst->getOperand(i); + + // 检查参数是否是指针类型且可能指向该内存位置 + if (auto* ptrType = dynamic_cast(arg->getType())) { + // 使用别名分析检查 + if (aliasAnalysis) { + auto aliasType = aliasAnalysis->queryAlias(ptr, arg); + if (aliasType != AliasType::NO_ALIAS) { + if (DEBUG) { + std::cout << " Memory location " << ptr->getName() + << " may be modified by function call " << calledFunc->getName() + << " through parameter " << arg->getName() << std::endl; + } + return true; + } + } else { + // 没有别名分析,检查精确匹配 + if (ptr == arg) { + if (DEBUG) { + std::cout << " Memory location " << ptr->getName() + << " may be modified by function call " << calledFunc->getName() + << " (exact match)" << std::endl; + } + return true; + } + } + } + } + } } } return false; diff --git a/src/midend/Pass/Pass.cpp b/src/midend/Pass/Pass.cpp index 69f9790..09de26e 100644 --- a/src/midend/Pass/Pass.cpp +++ b/src/midend/Pass/Pass.cpp @@ -161,14 +161,14 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR } - // this->clearPasses(); - // this->addPass(&LICM::ID); - // this->run(); + this->clearPasses(); + this->addPass(&LICM::ID); + this->run(); - // if(DEBUG) { - // std::cout << "=== IR After LICM ===\n"; - // printPasses(); - // } + if(DEBUG) { + std::cout << "=== IR After LICM ===\n"; + printPasses(); + } this->clearPasses(); this->addPass(&LoopStrengthReduction::ID); From c9a0c700e1a9a5fc81d3ba5215bae541935c14ca Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Mon, 18 Aug 2025 11:30:40 +0800 Subject: [PATCH 10/12] =?UTF-8?q?[midend]=E5=A2=9E=E5=8A=A0=E5=85=A8?= =?UTF-8?q?=E5=B1=80=E5=BC=BA=E5=BA=A6=E5=89=8A=E5=BC=B1=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E9=81=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../Pass/Optimize/GlobalStrengthReduction.h | 107 ++ src/midend/CMakeLists.txt | 1 + .../Pass/Optimize/GlobalStrengthReduction.cpp | 1031 +++++++++++++++++ src/midend/Pass/Pass.cpp | 11 + 4 files changed, 1150 insertions(+) create mode 100644 src/include/midend/Pass/Optimize/GlobalStrengthReduction.h create mode 100644 src/midend/Pass/Optimize/GlobalStrengthReduction.cpp diff --git a/src/include/midend/Pass/Optimize/GlobalStrengthReduction.h b/src/include/midend/Pass/Optimize/GlobalStrengthReduction.h new file mode 100644 index 0000000..574494c --- /dev/null +++ b/src/include/midend/Pass/Optimize/GlobalStrengthReduction.h @@ -0,0 +1,107 @@ +#pragma once + +#include "Pass.h" +#include "IR.h" +#include "SideEffectAnalysis.h" +#include +#include +#include +#include + +namespace sysy { + +// 魔数乘法结构,用于除法优化 +struct MagicNumber { + uint32_t multiplier; + int shift; + bool needAdd; + + MagicNumber(uint32_t m, int s, bool add = false) + : multiplier(m), shift(s), needAdd(add) {} +}; + +// 全局强度削弱优化遍的核心逻辑封装类 +class GlobalStrengthReductionContext { +public: + // 构造函数,接受IRBuilder参数 + explicit GlobalStrengthReductionContext(IRBuilder* builder) : builder(builder) {} + + // 运行优化的主要方法 + void run(Function* func, AnalysisManager* AM, bool& changed); + +private: + IRBuilder* builder; // IR构建器 + + // 分析结果 + SideEffectAnalysisResult* sideEffectAnalysis = nullptr; + + // 优化计数 + int algebraicOptCount = 0; + int strengthReductionCount = 0; + int divisionOptCount = 0; + + // 主要优化方法 + bool processBasicBlock(BasicBlock* bb); + bool processInstruction(Instruction* inst); + + // 代数优化方法 + bool tryAlgebraicOptimization(Instruction* inst); + bool optimizeAddition(BinaryInst* inst); + bool optimizeSubtraction(BinaryInst* inst); + bool optimizeMultiplication(BinaryInst* inst); + bool optimizeDivision(BinaryInst* inst); + bool optimizeComparison(BinaryInst* inst); + bool optimizeLogical(BinaryInst* inst); + + // 强度削弱方法 + bool tryStrengthReduction(Instruction* inst); + bool reduceMultiplication(BinaryInst* inst); + bool reduceDivision(BinaryInst* inst); + bool reducePower(CallInst* inst); + + // 复杂乘法强度削弱方法 + bool tryComplexMultiplication(BinaryInst* inst, Value* variable, int constant); + bool findOptimalShiftDecomposition(int constant, std::vector& shifts); + Value* createShiftDecomposition(BinaryInst* inst, Value* variable, const std::vector& shifts); + + // 魔数乘法相关方法 + MagicNumber computeMagicNumber(uint32_t divisor); + std::pair computeMulhMagicNumbers(int divisor); + Value* createMagicDivision(BinaryInst* divInst, uint32_t divisor, const MagicNumber& magic); + Value* createMagicDivisionLibdivide(BinaryInst* divInst, int divisor, const std::pair& magicPair); + bool isPowerOfTwo(uint32_t n); + int log2OfPowerOfTwo(uint32_t n); + + // 辅助方法 + bool isConstantInt(Value* val, int& constVal); + bool isConstantInt(Value* val, uint32_t& constVal); + ConstantInteger* getConstantInt(int val); + bool hasOnlyLocalUses(Instruction* inst); + void replaceWithOptimized(Instruction* original, Value* replacement); +}; + +// 全局强度削弱优化遍类 +class GlobalStrengthReduction : public OptimizationPass { +private: + IRBuilder* builder; // IR构建器,用于创建新指令 + +public: + // 静态成员,作为该遍的唯一ID + static void* ID; + + // 构造函数,接受IRBuilder参数 + explicit GlobalStrengthReduction(IRBuilder* builder) + : OptimizationPass("GlobalStrengthReduction", Granularity::Function), builder(builder) {} + + // 在函数上运行优化 + bool runOnFunction(Function* func, AnalysisManager& AM) override; + + // 返回该遍的唯一ID + void* getPassID() const override { return ID; } + + // 声明分析依赖 + void getAnalysisUsage(std::set& analysisDependencies, + std::set& analysisInvalidations) const override; +}; + +} // namespace sysy diff --git a/src/midend/CMakeLists.txt b/src/midend/CMakeLists.txt index 66fc461..8243a63 100644 --- a/src/midend/CMakeLists.txt +++ b/src/midend/CMakeLists.txt @@ -22,6 +22,7 @@ add_library(midend_lib STATIC Pass/Optimize/LICM.cpp Pass/Optimize/LoopStrengthReduction.cpp Pass/Optimize/InductionVariableElimination.cpp + Pass/Optimize/GlobalStrengthReduction.cpp Pass/Optimize/BuildCFG.cpp Pass/Optimize/LargeArrayToGlobal.cpp ) diff --git a/src/midend/Pass/Optimize/GlobalStrengthReduction.cpp b/src/midend/Pass/Optimize/GlobalStrengthReduction.cpp new file mode 100644 index 0000000..118b793 --- /dev/null +++ b/src/midend/Pass/Optimize/GlobalStrengthReduction.cpp @@ -0,0 +1,1031 @@ +#include "GlobalStrengthReduction.h" +#include "SysYIROptUtils.h" +#include "IRBuilder.h" +#include +#include +#include +#include + +extern int DEBUG; + +namespace sysy { + +// 全局强度削弱优化遍的静态 ID +void *GlobalStrengthReduction::ID = (void *)&GlobalStrengthReduction::ID; + +// ====================================================================== +// GlobalStrengthReduction 类的实现 +// ====================================================================== + +bool GlobalStrengthReduction::runOnFunction(Function *func, AnalysisManager &AM) { + if (func->getBasicBlocks().empty()) { + return false; + } + + if (DEBUG) { + std::cout << "\n=== Running GlobalStrengthReduction on function: " << func->getName() << " ===" << std::endl; + } + + bool changed = false; + GlobalStrengthReductionContext context(builder); + context.run(func, &AM, changed); + + if (DEBUG) { + if (changed) { + std::cout << "GlobalStrengthReduction: Function " << func->getName() << " was modified" << std::endl; + } else { + std::cout << "GlobalStrengthReduction: Function " << func->getName() << " was not modified" << std::endl; + } + std::cout << "=== GlobalStrengthReduction completed for function: " << func->getName() << " ===" << std::endl; + } + + return changed; +} + +void GlobalStrengthReduction::getAnalysisUsage(std::set &analysisDependencies, std::set &analysisInvalidations) const { + // 强度削弱依赖副作用分析来判断指令是否可以安全优化 + analysisDependencies.insert(&SysYSideEffectAnalysisPass::ID); + + // 强度削弱不会使分析失效,因为: + // - 只替换计算指令,不改变控制流 + // - 不修改内存,不影响别名分析 + // - 保持程序语义不变 + // analysisInvalidations 保持为空 + + if (DEBUG) { + std::cout << "GlobalStrengthReduction: Declared analysis dependencies (SideEffectAnalysis)" << std::endl; + } +} + +// ====================================================================== +// GlobalStrengthReductionContext 类的实现 +// ====================================================================== + +void GlobalStrengthReductionContext::run(Function *func, AnalysisManager *AM, bool &changed) { + if (DEBUG) { + std::cout << " Starting GlobalStrengthReduction analysis for function: " << func->getName() << std::endl; + } + + // 获取分析结果 + if (AM) { + sideEffectAnalysis = AM->getAnalysisResult(); + + if (DEBUG) { + if (sideEffectAnalysis) { + std::cout << " GlobalStrengthReduction: Using side effect analysis" << std::endl; + } else { + std::cout << " GlobalStrengthReduction: Warning - side effect analysis not available" << std::endl; + } + } + } + + // 重置计数器 + algebraicOptCount = 0; + strengthReductionCount = 0; + divisionOptCount = 0; + + // 遍历所有基本块进行优化 + for (auto &bb_ptr : func->getBasicBlocks()) { + if (processBasicBlock(bb_ptr.get())) { + changed = true; + } + } + + if (DEBUG) { + std::cout << " GlobalStrengthReduction completed for function: " << func->getName() << std::endl; + std::cout << " Algebraic optimizations: " << algebraicOptCount << std::endl; + std::cout << " Strength reductions: " << strengthReductionCount << std::endl; + std::cout << " Division optimizations: " << divisionOptCount << std::endl; + } +} + +bool GlobalStrengthReductionContext::processBasicBlock(BasicBlock *bb) { + bool changed = false; + + if (DEBUG) { + std::cout << " Processing block: " << bb->getName() << std::endl; + } + + // 收集需要处理的指令(避免迭代器失效) + std::vector instructions; + for (auto &inst_ptr : bb->getInstructions()) { + instructions.push_back(inst_ptr.get()); + } + + // 处理每条指令 + for (auto inst : instructions) { + if (processInstruction(inst)) { + changed = true; + } + } + + return changed; +} + +bool GlobalStrengthReductionContext::processInstruction(Instruction *inst) { + bool changed = false; + + if (DEBUG >= 2) { + std::cout << " Processing instruction: " << inst->getName() << std::endl; + } + + // 先尝试代数优化 + if (tryAlgebraicOptimization(inst)) { + changed = true; + algebraicOptCount++; + } + + // 再尝试强度削弱 + if (tryStrengthReduction(inst)) { + changed = true; + strengthReductionCount++; + } + + return changed; +} + +// ====================================================================== +// 代数优化方法 +// ====================================================================== + +bool GlobalStrengthReductionContext::tryAlgebraicOptimization(Instruction *inst) { + auto binary = dynamic_cast(inst); + if (!binary) { + return false; + } + + switch (binary->getKind()) { + case Instruction::kAdd: + return optimizeAddition(binary); + case Instruction::kSub: + return optimizeSubtraction(binary); + case Instruction::kMul: + return optimizeMultiplication(binary); + case Instruction::kDiv: + return optimizeDivision(binary); + case Instruction::kICmpEQ: + case Instruction::kICmpNE: + case Instruction::kICmpLT: + case Instruction::kICmpGT: + case Instruction::kICmpLE: + case Instruction::kICmpGE: + return optimizeComparison(binary); + case Instruction::kAnd: + case Instruction::kOr: + return optimizeLogical(binary); + default: + return false; + } +} + +bool GlobalStrengthReductionContext::optimizeAddition(BinaryInst *inst) { + Value *lhs = inst->getLhs(); + Value *rhs = inst->getRhs(); + int constVal; + + // x + 0 = x + if (isConstantInt(rhs, constVal) && constVal == 0) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x + 0 -> x" << std::endl; + } + replaceWithOptimized(inst, lhs); + return true; + } + + // 0 + x = x + if (isConstantInt(lhs, constVal) && constVal == 0) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = 0 + x -> x" << std::endl; + } + replaceWithOptimized(inst, rhs); + return true; + } + + // x + (-y) = x - y + if (auto rhsInst = dynamic_cast(rhs)) { + if (rhsInst->getKind() == Instruction::kNeg) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x + (-y) -> x - y" << std::endl; + } + // 创建减法指令 + builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst)); + auto subInst = builder->createSubInst(lhs, rhsInst->getOperand()); + replaceWithOptimized(inst, subInst); + return true; + } + } + + return false; +} + +bool GlobalStrengthReductionContext::optimizeSubtraction(BinaryInst *inst) { + Value *lhs = inst->getLhs(); + Value *rhs = inst->getRhs(); + int constVal; + + // x - 0 = x + if (isConstantInt(rhs, constVal) && constVal == 0) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x - 0 -> x" << std::endl; + } + replaceWithOptimized(inst, lhs); + return true; + } + + // x - x = 0 (如果x没有副作用) + if (lhs == rhs && hasOnlyLocalUses(dynamic_cast(lhs))) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x - x -> 0" << std::endl; + } + replaceWithOptimized(inst, getConstantInt(0)); + return true; + } + + // x - (-y) = x + y + if (auto rhsInst = dynamic_cast(rhs)) { + if (rhsInst->getKind() == Instruction::kNeg) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x - (-y) -> x + y" << std::endl; + } + builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst)); + auto addInst = builder->createAddInst(lhs, rhsInst->getOperand()); + replaceWithOptimized(inst, addInst); + return true; + } + } + + return false; +} + +bool GlobalStrengthReductionContext::optimizeMultiplication(BinaryInst *inst) { + Value *lhs = inst->getLhs(); + Value *rhs = inst->getRhs(); + int constVal; + + // x * 0 = 0 + if (isConstantInt(rhs, constVal) && constVal == 0) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x * 0 -> 0" << std::endl; + } + replaceWithOptimized(inst, getConstantInt(0)); + return true; + } + + // 0 * x = 0 + if (isConstantInt(lhs, constVal) && constVal == 0) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = 0 * x -> 0" << std::endl; + } + replaceWithOptimized(inst, getConstantInt(0)); + return true; + } + + // x * 1 = x + if (isConstantInt(rhs, constVal) && constVal == 1) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x * 1 -> x" << std::endl; + } + replaceWithOptimized(inst, lhs); + return true; + } + + // 1 * x = x + if (isConstantInt(lhs, constVal) && constVal == 1) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = 1 * x -> x" << std::endl; + } + replaceWithOptimized(inst, rhs); + return true; + } + + // x * (-1) = -x + if (isConstantInt(rhs, constVal) && constVal == -1) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x * (-1) -> -x" << std::endl; + } + builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst)); + auto negInst = builder->createNegInst(lhs); + replaceWithOptimized(inst, negInst); + return true; + } + + return false; +} + +bool GlobalStrengthReductionContext::optimizeDivision(BinaryInst *inst) { + Value *lhs = inst->getLhs(); + Value *rhs = inst->getRhs(); + int constVal; + + // x / 1 = x + if (isConstantInt(rhs, constVal) && constVal == 1) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x / 1 -> x" << std::endl; + } + replaceWithOptimized(inst, lhs); + return true; + } + + // x / (-1) = -x + if (isConstantInt(rhs, constVal) && constVal == -1) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x / (-1) -> -x" << std::endl; + } + builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst)); + auto negInst = builder->createNegInst(lhs); + replaceWithOptimized(inst, negInst); + return true; + } + + // x / x = 1 (如果x != 0且没有副作用) + if (lhs == rhs && hasOnlyLocalUses(dynamic_cast(lhs))) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x / x -> 1" << std::endl; + } + replaceWithOptimized(inst, getConstantInt(1)); + return true; + } + + return false; +} + +bool GlobalStrengthReductionContext::optimizeComparison(BinaryInst *inst) { + Value *lhs = inst->getLhs(); + Value *rhs = inst->getRhs(); + + // x == x = true (如果x没有副作用) + if (inst->getKind() == Instruction::kICmpEQ && lhs == rhs && + hasOnlyLocalUses(dynamic_cast(lhs))) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x == x -> true" << std::endl; + } + replaceWithOptimized(inst, getConstantInt(1)); + return true; + } + + // x != x = false (如果x没有副作用) + if (inst->getKind() == Instruction::kICmpNE && lhs == rhs && + hasOnlyLocalUses(dynamic_cast(lhs))) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x != x -> false" << std::endl; + } + replaceWithOptimized(inst, getConstantInt(0)); + return true; + } + + return false; +} + +bool GlobalStrengthReductionContext::optimizeLogical(BinaryInst *inst) { + Value *lhs = inst->getLhs(); + Value *rhs = inst->getRhs(); + int constVal; + + if (inst->getKind() == Instruction::kAnd) { + // x && 0 = 0 + if (isConstantInt(rhs, constVal) && constVal == 0) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x && 0 -> 0" << std::endl; + } + replaceWithOptimized(inst, getConstantInt(0)); + return true; + } + + // x && 1 = x + if (isConstantInt(rhs, constVal) && constVal == 1) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x && 1 -> x" << std::endl; + } + replaceWithOptimized(inst, lhs); + return true; + } + + // x && x = x + if (lhs == rhs) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x && x -> x" << std::endl; + } + replaceWithOptimized(inst, lhs); + return true; + } + } else if (inst->getKind() == Instruction::kOr) { + // x || 0 = x + if (isConstantInt(rhs, constVal) && constVal == 0) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x || 0 -> x" << std::endl; + } + replaceWithOptimized(inst, lhs); + return true; + } + + // x || 1 = 1 + if (isConstantInt(rhs, constVal) && constVal == 1) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x || 1 -> 1" << std::endl; + } + replaceWithOptimized(inst, getConstantInt(1)); + return true; + } + + // x || x = x + if (lhs == rhs) { + if (DEBUG) { + std::cout << " Algebraic: " << inst->getName() << " = x || x -> x" << std::endl; + } + replaceWithOptimized(inst, lhs); + return true; + } + } + + return false; +} + +// ====================================================================== +// 强度削弱方法 +// ====================================================================== + +bool GlobalStrengthReductionContext::tryStrengthReduction(Instruction *inst) { + if (auto binary = dynamic_cast(inst)) { + switch (binary->getKind()) { + case Instruction::kMul: + return reduceMultiplication(binary); + case Instruction::kDiv: + return reduceDivision(binary); + default: + return false; + } + } else if (auto call = dynamic_cast(inst)) { + return reducePower(call); + } + + return false; +} + +bool GlobalStrengthReductionContext::reduceMultiplication(BinaryInst *inst) { + Value *lhs = inst->getLhs(); + Value *rhs = inst->getRhs(); + int constVal; + + // 尝试右操作数为常数 + Value* variable = lhs; + if (isConstantInt(rhs, constVal) && constVal > 0) { + return tryComplexMultiplication(inst, variable, constVal); + } + + // 尝试左操作数为常数 + if (isConstantInt(lhs, constVal) && constVal > 0) { + variable = rhs; + return tryComplexMultiplication(inst, variable, constVal); + } + + return false; +} + +bool GlobalStrengthReductionContext::tryComplexMultiplication(BinaryInst* inst, Value* variable, int constant) { + // 首先检查是否为2的幂,使用简单位移 + if (isPowerOfTwo(constant)) { + int shiftAmount = log2OfPowerOfTwo(constant); + if (DEBUG) { + std::cout << " StrengthReduction: " << inst->getName() + << " = x * " << constant << " -> x << " << shiftAmount << std::endl; + } + + builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst)); + auto shiftInst = builder->createBinaryInst(Instruction::kSll, Type::getIntType(), variable, getConstantInt(shiftAmount)); + replaceWithOptimized(inst, shiftInst); + return true; + } + + // 尝试分解为位移和加法的组合 + std::vector shifts; + if (findOptimalShiftDecomposition(constant, shifts)) { + if (DEBUG) { + std::cout << " StrengthReduction: " << inst->getName() + << " = x * " << constant << " -> shift decomposition with " << shifts.size() << " terms" << std::endl; + } + + Value* result = createShiftDecomposition(inst, variable, shifts); + if (result) { + replaceWithOptimized(inst, result); + return true; + } + } + + return false; +} + +bool GlobalStrengthReductionContext::findOptimalShiftDecomposition(int constant, std::vector& shifts) { + shifts.clear(); + + // 常见的有效分解模式 + switch (constant) { + case 3: // 3 = 2^1 + 2^0 -> (x << 1) + x + shifts = {1, 0}; + return true; + case 5: // 5 = 2^2 + 2^0 -> (x << 2) + x + shifts = {2, 0}; + return true; + case 6: // 6 = 2^2 + 2^1 -> (x << 2) + (x << 1) + shifts = {2, 1}; + return true; + case 7: // 7 = 2^2 + 2^1 + 2^0 -> (x << 2) + (x << 1) + x + shifts = {2, 1, 0}; + return true; + case 9: // 9 = 2^3 + 2^0 -> (x << 3) + x + shifts = {3, 0}; + return true; + case 10: // 10 = 2^3 + 2^1 -> (x << 3) + (x << 1) + shifts = {3, 1}; + return true; + case 11: // 11 = 2^3 + 2^1 + 2^0 -> (x << 3) + (x << 1) + x + shifts = {3, 1, 0}; + return true; + case 12: // 12 = 2^3 + 2^2 -> (x << 3) + (x << 2) + shifts = {3, 2}; + return true; + case 13: // 13 = 2^3 + 2^2 + 2^0 -> (x << 3) + (x << 2) + x + shifts = {3, 2, 0}; + return true; + case 14: // 14 = 2^3 + 2^2 + 2^1 -> (x << 3) + (x << 2) + (x << 1) + shifts = {3, 2, 1}; + return true; + case 15: // 15 = 2^3 + 2^2 + 2^1 + 2^0 -> (x << 3) + (x << 2) + (x << 1) + x + shifts = {3, 2, 1, 0}; + return true; + case 17: // 17 = 2^4 + 2^0 -> (x << 4) + x + shifts = {4, 0}; + return true; + case 18: // 18 = 2^4 + 2^1 -> (x << 4) + (x << 1) + shifts = {4, 1}; + return true; + case 20: // 20 = 2^4 + 2^2 -> (x << 4) + (x << 2) + shifts = {4, 2}; + return true; + case 24: // 24 = 2^4 + 2^3 -> (x << 4) + (x << 3) + shifts = {4, 3}; + return true; + case 25: // 25 = 2^4 + 2^3 + 2^0 -> (x << 4) + (x << 3) + x + shifts = {4, 3, 0}; + return true; + case 100: // 100 = 2^6 + 2^5 + 2^2 -> (x << 6) + (x << 5) + (x << 2) + shifts = {6, 5, 2}; + return true; + } + + // 通用二进制分解(最多4个项,避免过度复杂化) + if (constant > 0 && constant < 256) { + std::vector binaryShifts; + int temp = constant; + int bit = 0; + + while (temp > 0 && binaryShifts.size() < 4) { + if (temp & 1) { + binaryShifts.push_back(bit); + } + temp >>= 1; + bit++; + } + + // 只有当项数不超过3个时才使用二进制分解(比直接乘法更有效) + if (binaryShifts.size() <= 3 && binaryShifts.size() >= 2) { + shifts = binaryShifts; + return true; + } + } + + return false; +} + +Value* GlobalStrengthReductionContext::createShiftDecomposition(BinaryInst* inst, Value* variable, const std::vector& shifts) { + if (shifts.empty()) return nullptr; + + builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst)); + + Value* result = nullptr; + + for (int shift : shifts) { + Value* term; + if (shift == 0) { + // 0位移就是原变量 + term = variable; + } else { + // 创建位移指令 + term = builder->createBinaryInst(Instruction::kSll, Type::getIntType(), variable, getConstantInt(shift)); + } + + if (result == nullptr) { + result = term; + } else { + // 累加到结果中 + result = builder->createAddInst(result, term); + } + } + + return result; +} + +bool GlobalStrengthReductionContext::reduceDivision(BinaryInst *inst) { + Value *lhs = inst->getLhs(); + Value *rhs = inst->getRhs(); + uint32_t constVal; + + // x / 2^n = x >> n (对于无符号除法或已知为正数的情况) + if (isConstantInt(rhs, constVal) && constVal > 0 && isPowerOfTwo(constVal)) { + int shiftAmount = log2OfPowerOfTwo(constVal); + if (DEBUG) { + std::cout << " StrengthReduction: " << inst->getName() + << " = x / " << constVal << " -> x >> " << shiftAmount << std::endl; + } + + builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst)); + auto shiftInst = builder->createBinaryInst(Instruction::kSra, Type::getIntType(), lhs, getConstantInt(shiftAmount)); + replaceWithOptimized(inst, shiftInst); + strengthReductionCount++; + return true; + } + + // x / c = x * magic_number (魔数乘法优化 - 使用libdivide算法) + if (isConstantInt(rhs, constVal) && constVal > 1 && constVal != (uint32_t)(-1)) { + auto magicPair = computeMulhMagicNumbers(static_cast(constVal)); + if (magicPair.first != -1) { // 有效的魔数 + if (DEBUG) { + std::cout << " StrengthReduction: " << inst->getName() + << " = x / " << constVal << " -> libdivide magic multiplication" << std::endl; + } + + Value* magicResult = createMagicDivisionLibdivide(inst, static_cast(constVal), magicPair); + replaceWithOptimized(inst, magicResult); + divisionOptCount++; + return true; + } + } + + return false; +} + +bool GlobalStrengthReductionContext::reducePower(CallInst *inst) { + // 检查是否是pow函数调用 + Function* callee = inst->getCallee(); + if (!callee || callee->getName() != "pow") { + return false; + } + + // pow(x, 2) = x * x + if (inst->getNumOperands() >= 2) { + int exponent; + if (isConstantInt(inst->getOperand(1), exponent)) { + if (exponent == 2) { + if (DEBUG) { + std::cout << " StrengthReduction: pow(x, 2) -> x * x" << std::endl; + } + + Value* base = inst->getOperand(0); + builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst)); + auto mulInst = builder->createMulInst(base, base); + replaceWithOptimized(inst, mulInst); + strengthReductionCount++; + return true; + } else if (exponent >= 3 && exponent <= 8) { + // 对于小的指数,展开为连续乘法 + if (DEBUG) { + std::cout << " StrengthReduction: pow(x, " << exponent << ") -> repeated multiplication" << std::endl; + } + + Value* base = inst->getOperand(0); + Value* result = base; + builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst)); + + for (int i = 1; i < exponent; i++) { + result = builder->createMulInst(result, base); + } + + replaceWithOptimized(inst, result); + strengthReductionCount++; + return true; + } + } + } + + return false; +} + +// ====================================================================== +// 魔数乘法相关方法 +// ====================================================================== + +// 该实现参考了libdivide的算法 +std::pair GlobalStrengthReductionContext::computeMulhMagicNumbers(int divisor) { + + if (DEBUG) { + std::cout << "\n[SR] ===== Computing magic numbers for divisor " << divisor << " (libdivide algorithm) =====" << std::endl; + } + + if (divisor == 0) { + if (DEBUG) std::cout << "[SR] Error: divisor must be != 0" << std::endl; + return {-1, -1}; + } + + // libdivide 常数 + const uint8_t LIBDIVIDE_ADD_MARKER = 0x40; + const uint8_t LIBDIVIDE_NEGATIVE_DIVISOR = 0x80; + + // 辅助函数:计算前导零个数 + auto count_leading_zeros32 = [](uint32_t val) -> uint32_t { + if (val == 0) return 32; + return __builtin_clz(val); + }; + + // 辅助函数:64位除法返回32位商和余数 + auto div_64_32 = [](uint32_t high, uint32_t low, uint32_t divisor, uint32_t* rem) -> uint32_t { + uint64_t dividend = ((uint64_t)high << 32) | low; + uint32_t quotient = dividend / divisor; + *rem = dividend % divisor; + return quotient; + }; + + if (DEBUG) { + std::cout << "[SR] Input divisor: " << divisor << std::endl; + } + + // libdivide_internal_s32_gen 算法实现 + int32_t d = divisor; + uint32_t ud = (uint32_t)d; + uint32_t absD = (d < 0) ? -ud : ud; + + if (DEBUG) { + std::cout << "[SR] absD = " << absD << std::endl; + } + + uint32_t floor_log_2_d = 31 - count_leading_zeros32(absD); + + if (DEBUG) { + std::cout << "[SR] floor_log_2_d = " << floor_log_2_d << std::endl; + } + + // 检查 absD 是否为2的幂 + if ((absD & (absD - 1)) == 0) { + if (DEBUG) { + std::cout << "[SR] " << absD << " 是2的幂,使用移位方法" << std::endl; + } + + // 对于2的幂,我们只使用移位,不需要魔数 + int shift = floor_log_2_d; + if (d < 0) shift |= 0x80; // 标记负数 + + if (DEBUG) { + std::cout << "[SR] Power of 2 result: magic=0, shift=" << shift << std::endl; + std::cout << "[SR] ===== End magic computation =====" << std::endl; + } + + // 对于我们的目的,我们将在IR生成中以不同方式处理2的幂 + // 返回特殊标记 + return {0, shift}; + } + + if (DEBUG) { + std::cout << "[SR] " << absD << " is not a power of 2, computing magic number" << std::endl; + } + + // 非2的幂除数的魔数计算 + uint8_t more; + uint32_t rem, proposed_m; + + // 计算 proposed_m = floor(2^(floor_log_2_d + 31) / absD) + proposed_m = div_64_32((uint32_t)1 << (floor_log_2_d - 1), 0, absD, &rem); + const uint32_t e = absD - rem; + + if (DEBUG) { + std::cout << "[SR] proposed_m = " << proposed_m << ", rem = " << rem << ", e = " << e << std::endl; + } + + // 确定是否需要"加法"版本 + const bool branchfree = false; // 使用分支版本 + + if (!branchfree && e < ((uint32_t)1 << floor_log_2_d)) { + // 这个幂次有效 + more = (uint8_t)(floor_log_2_d - 1); + if (DEBUG) { + std::cout << "[SR] Using basic algorithm, shift = " << (int)more << std::endl; + } + } else { + // 我们需要上升一个等级 + proposed_m += proposed_m; + const uint32_t twice_rem = rem + rem; + if (twice_rem >= absD || twice_rem < rem) { + proposed_m += 1; + } + more = (uint8_t)(floor_log_2_d | LIBDIVIDE_ADD_MARKER); + if (DEBUG) { + std::cout << "[SR] Using add algorithm, proposed_m = " << proposed_m << ", more = " << (int)more << std::endl; + } + } + + proposed_m += 1; + int32_t magic = (int32_t)proposed_m; + + // 处理负除数 + if (d < 0) { + more |= LIBDIVIDE_NEGATIVE_DIVISOR; + if (!branchfree) { + magic = -magic; + } + if (DEBUG) { + std::cout << "[SR] Negative divisor, magic = " << magic << ", more = " << (int)more << std::endl; + } + } + + // 为我们的IR生成提取移位量和标志 + int shift = more & 0x3F; // 移除标志,保留移位量(位0-5) + bool need_add = (more & LIBDIVIDE_ADD_MARKER) != 0; + bool is_negative = (more & LIBDIVIDE_NEGATIVE_DIVISOR) != 0; + // 返回魔数、移位量,并在移位中编码ADD_MARKER标志 + // 我们将使用移位的第6位表示ADD_MARKER,第7位表示负数(如果需要) + int encoded_shift = shift; + if (need_add) { + encoded_shift |= 0x40; // 设置第6位表示ADD_MARKER + if (DEBUG) { + std::cout << "[SR] Encoding ADD_MARKER in shift: " << encoded_shift << std::endl; + } + } + + return {magic, encoded_shift}; + } + +Value* GlobalStrengthReductionContext::createMagicDivision(BinaryInst* divInst, uint32_t divisor, const MagicNumber& magic) { + builder->setPosition(divInst->getParent(), divInst->getParent()->findInstIterator(divInst)); + + Value* dividend = divInst->getLhs(); + + // 创建魔数常量 + Value* magicConst = getConstantInt(static_cast(magic.multiplier)); + + // 执行乘法: tmp = dividend * magic + Value* tmp = builder->createMulInst(dividend, magicConst); + + if (magic.needAdd) { + // 需要额外加法的情况 + Value* sum = builder->createAddInst(tmp, dividend); + if (magic.shift > 0) { + Value* shiftConst = getConstantInt(magic.shift); + tmp = builder->createBinaryInst(Instruction::kSra, Type::getIntType(), sum, shiftConst); + } else { + tmp = sum; + } + } else { + // 直接右移 + if (magic.shift > 0) { + Value* shiftConst = getConstantInt(magic.shift); + tmp = builder->createBinaryInst(Instruction::kSra, Type::getIntType(), tmp, shiftConst); + } + } + + // 处理符号:如果被除数为负,结果需要调整 + // 这里简化处理,假设都是正数除法 + + return tmp; +} + +Value* GlobalStrengthReductionContext::createMagicDivisionLibdivide(BinaryInst* divInst, int divisor, const std::pair& magicPair) { + builder->setPosition(divInst->getParent(), divInst->getParent()->findInstIterator(divInst)); + + Value* dividend = divInst->getLhs(); + int magic = magicPair.first; + int encoded_shift = magicPair.second; + + if (DEBUG) { + std::cout << "[SR] Creating libdivide magic division: magic=" << magic + << ", encoded_shift=" << encoded_shift << std::endl; + } + + // 检查是否为2的幂(magic=0表示是2的幂) + if (magic == 0) { + // 2的幂除法,直接使用算术右移 + int shift = encoded_shift & 0x3F; // 获取实际移位量 + bool is_negative = (encoded_shift & 0x80) != 0; + + if (DEBUG) { + std::cout << "[SR] Power of 2 division: shift=" << shift + << ", negative=" << is_negative << std::endl; + } + + Value* result = dividend; + if (shift > 0) { + Value* shiftConst = getConstantInt(shift); + result = builder->createBinaryInst(Instruction::kSra, Type::getIntType(), dividend, shiftConst); + } + + // 如果原除数为负,需要取反结果 + if (is_negative) { + result = builder->createNegInst(result); + } + + return result; + } + + // 非2的幂除法,使用魔数乘法 + int shift = encoded_shift & 0x3F; // 获取移位量(位0-5) + bool need_add = (encoded_shift & 0x40) != 0; // 检查ADD_MARKER标志(位6) + + if (DEBUG) { + std::cout << "[SR] Magic multiplication: shift=" << shift + << ", need_add=" << need_add << std::endl; + } + + // 创建魔数常量 + Value* magicConst = getConstantInt(magic); + + // 执行高位乘法:mulh(dividend, magic) + // 由于我们的IR可能没有直接的mulh指令,我们使用64位乘法然后取高32位 + // 这里需要根据实际的IR指令集进行调整 + Value* tmp = builder->createMulInst(dividend, magicConst); + + if (need_add) { + // ADD算法:(mulh(dividend, magic) + dividend) >> shift + tmp = builder->createAddInst(tmp, dividend); + } + + if (shift > 0) { + Value* shiftConst = getConstantInt(shift); + tmp = builder->createBinaryInst(Instruction::kSra, Type::getIntType(), tmp, shiftConst); + } + + // 处理符号位调整 + // 如果被除数为负数,可能需要额外的符号处理 + // 这里简化处理,实际实现可能需要更复杂的符号位处理 + + return tmp; +} + +// ====================================================================== +// 辅助方法 +// ====================================================================== + +bool GlobalStrengthReductionContext::isPowerOfTwo(uint32_t n) { + return n > 0 && (n & (n - 1)) == 0; +} + +int GlobalStrengthReductionContext::log2OfPowerOfTwo(uint32_t n) { + int result = 0; + while (n > 1) { + n >>= 1; + result++; + } + return result; +} + +bool GlobalStrengthReductionContext::isConstantInt(Value* val, int& constVal) { + if (auto constInt = dynamic_cast(val)) { + constVal = std::get(constInt->getVal()); + return true; + } + return false; +} + +bool GlobalStrengthReductionContext::isConstantInt(Value* val, uint32_t& constVal) { + if (auto constInt = dynamic_cast(val)) { + int signedVal = std::get(constInt->getVal()); + if (signedVal >= 0) { + constVal = static_cast(signedVal); + return true; + } + } + return false; +} + +ConstantInteger* GlobalStrengthReductionContext::getConstantInt(int val) { + return ConstantInteger::get(val); +} + +bool GlobalStrengthReductionContext::hasOnlyLocalUses(Instruction* inst) { + if (!inst) return true; + + // 简单检查:如果指令没有副作用,则认为是本地的 + if (sideEffectAnalysis) { + auto sideEffect = sideEffectAnalysis->getInstructionSideEffect(inst); + return sideEffect.type == SideEffectType::NO_SIDE_EFFECT; + } + + // 没有副作用分析时,保守处理 + return !inst->isCall() && !inst->isStore() && !inst->isLoad(); +} + +void GlobalStrengthReductionContext::replaceWithOptimized(Instruction* original, Value* replacement) { + if (DEBUG >= 2) { + std::cout << " Replacing " << original->getName() + << " with " << replacement->getName() << std::endl; + } + + original->replaceAllUsesWith(replacement); + + // 如果替换值是新创建的指令,确保它有合适的名字 +// if (auto replInst = dynamic_cast(replacement)) { +// if (replInst->getName().empty()) { +// replInst->setName(original->getName() + "_opt"); +// } +// } + + // 删除原指令,让调用者处理 + SysYIROptUtils::usedelete(original); +} + +} // namespace sysy diff --git a/src/midend/Pass/Pass.cpp b/src/midend/Pass/Pass.cpp index 09de26e..c834e30 100644 --- a/src/midend/Pass/Pass.cpp +++ b/src/midend/Pass/Pass.cpp @@ -18,6 +18,7 @@ #include "LICM.h" #include "LoopStrengthReduction.h" #include "InductionVariableElimination.h" +#include "GlobalStrengthReduction.h" #include "Pass.h" #include #include @@ -179,6 +180,16 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR printPasses(); } + // 全局强度削弱优化,包括代数优化和魔数除法 + this->clearPasses(); + this->addPass(&GlobalStrengthReduction::ID); + this->run(); + + if(DEBUG) { + std::cout << "=== IR After Global Strength Reduction Optimizations ===\n"; + printPasses(); + } + // this->clearPasses(); // this->addPass(&Reg2Mem::ID); // this->run(); From 5c34cbc7b871b30eb1f2a960286c4c098f8933b9 Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Mon, 18 Aug 2025 20:37:20 +0800 Subject: [PATCH 11/12] =?UTF-8?q?[midend-GSR]=E5=B0=86=E9=AD=94=E6=95=B0?= =?UTF-8?q?=E6=B1=82=E8=A7=A3=E7=A7=BB=E5=8A=A8=E5=88=B0utils=E7=9A=84?= =?UTF-8?q?=E9=9D=99=E6=80=81=E6=96=B9=E6=B3=95=E4=B8=AD=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../Pass/Optimize/GlobalStrengthReduction.h | 2 +- .../Pass/Optimize/LoopStrengthReduction.h | 7 - .../midend/Pass/Optimize/SysYIROptUtils.h | 184 ++++++++++ .../Pass/Optimize/GlobalStrengthReduction.cpp | 341 +++++------------- .../Pass/Optimize/LoopStrengthReduction.cpp | 183 +--------- src/midend/Pass/Pass.cpp | 2 + 6 files changed, 279 insertions(+), 440 deletions(-) diff --git a/src/include/midend/Pass/Optimize/GlobalStrengthReduction.h b/src/include/midend/Pass/Optimize/GlobalStrengthReduction.h index 574494c..43c75fb 100644 --- a/src/include/midend/Pass/Optimize/GlobalStrengthReduction.h +++ b/src/include/midend/Pass/Optimize/GlobalStrengthReduction.h @@ -68,7 +68,7 @@ private: MagicNumber computeMagicNumber(uint32_t divisor); std::pair computeMulhMagicNumbers(int divisor); Value* createMagicDivision(BinaryInst* divInst, uint32_t divisor, const MagicNumber& magic); - Value* createMagicDivisionLibdivide(BinaryInst* divInst, int divisor, const std::pair& magicPair); + Value* createMagicDivisionLibdivide(BinaryInst* divInst, int divisor); bool isPowerOfTwo(uint32_t n); int log2OfPowerOfTwo(uint32_t n); diff --git a/src/include/midend/Pass/Optimize/LoopStrengthReduction.h b/src/include/midend/Pass/Optimize/LoopStrengthReduction.h index ecf96dd..a6bc8cd 100644 --- a/src/include/midend/Pass/Optimize/LoopStrengthReduction.h +++ b/src/include/midend/Pass/Optimize/LoopStrengthReduction.h @@ -127,13 +127,6 @@ private: */ bool analyzeInductionVariableRange(const InductionVarInfo* ivInfo, Loop* loop) const; - /** - * 计算用于除法优化的魔数和移位量 - * @param divisor 除数 - * @return {魔数, 移位量} - */ - std::pair computeMulhMagicNumbers(int divisor) const; - /** * 生成除法替换代码 * @param candidate 优化候选项 diff --git a/src/include/midend/Pass/Optimize/SysYIROptUtils.h b/src/include/midend/Pass/Optimize/SysYIROptUtils.h index 00a4db5..48d2f26 100644 --- a/src/include/midend/Pass/Optimize/SysYIROptUtils.h +++ b/src/include/midend/Pass/Optimize/SysYIROptUtils.h @@ -107,6 +107,190 @@ public: // 所以当AllocaInst的basetype是PointerType时(一维数组)或者是指向ArrayType的PointerType(多位数组)时,返回true return aval && (baseType->isPointer() || baseType->as()->getBaseType()->isArray()); } + + + //该实现参考了libdivide的算法 + static std::pair computeMulhMagicNumbers(int divisor) { + + if (DEBUG) { + std::cout << "\n[SR] ===== Computing magic numbers for divisor " << divisor << " (libdivide algorithm) =====" << std::endl; + } + + if (divisor == 0) { + if (DEBUG) std::cout << "[SR] Error: divisor must be != 0" << std::endl; + return {-1, -1}; + } + + // libdivide 常数 + const uint8_t LIBDIVIDE_ADD_MARKER = 0x40; + const uint8_t LIBDIVIDE_NEGATIVE_DIVISOR = 0x80; + + // 辅助函数:计算前导零个数 + auto count_leading_zeros32 = [](uint32_t val) -> uint32_t { + if (val == 0) return 32; + return __builtin_clz(val); + }; + + // 辅助函数:64位除法返回32位商和余数 + auto div_64_32 = [](uint32_t high, uint32_t low, uint32_t divisor, uint32_t* rem) -> uint32_t { + uint64_t dividend = ((uint64_t)high << 32) | low; + uint32_t quotient = dividend / divisor; + *rem = dividend % divisor; + return quotient; + }; + + if (DEBUG) { + std::cout << "[SR] Input divisor: " << divisor << std::endl; + } + + // libdivide_internal_s32_gen 算法实现 + int32_t d = divisor; + uint32_t ud = (uint32_t)d; + uint32_t absD = (d < 0) ? -ud : ud; + + if (DEBUG) { + std::cout << "[SR] absD = " << absD << std::endl; + } + + uint32_t floor_log_2_d = 31 - count_leading_zeros32(absD); + + if (DEBUG) { + std::cout << "[SR] floor_log_2_d = " << floor_log_2_d << std::endl; + } + + // 检查 absD 是否为2的幂 + if ((absD & (absD - 1)) == 0) { + if (DEBUG) { + std::cout << "[SR] " << absD << " 是2的幂,使用移位方法" << std::endl; + } + + // 对于2的幂,我们只使用移位,不需要魔数 + int shift = floor_log_2_d; + if (d < 0) shift |= 0x80; // 标记负数 + + if (DEBUG) { + std::cout << "[SR] Power of 2 result: magic=0, shift=" << shift << std::endl; + std::cout << "[SR] ===== End magic computation =====" << std::endl; + } + + // 对于我们的目的,我们将在IR生成中以不同方式处理2的幂 + // 返回特殊标记 + return {0, shift}; + } + + if (DEBUG) { + std::cout << "[SR] " << absD << " is not a power of 2, computing magic number" << std::endl; + } + + // 非2的幂除数的魔数计算 + uint8_t more; + uint32_t rem, proposed_m; + + // 计算 proposed_m = floor(2^(floor_log_2_d + 31) / absD) + proposed_m = div_64_32((uint32_t)1 << (floor_log_2_d - 1), 0, absD, &rem); + const uint32_t e = absD - rem; + + if (DEBUG) { + std::cout << "[SR] proposed_m = " << proposed_m << ", rem = " << rem << ", e = " << e << std::endl; + } + + // 确定是否需要"加法"版本 + const bool branchfree = false; // 使用分支版本 + + if (!branchfree && e < ((uint32_t)1 << floor_log_2_d)) { + // 这个幂次有效 + more = (uint8_t)(floor_log_2_d - 1); + if (DEBUG) { + std::cout << "[SR] Using basic algorithm, shift = " << (int)more << std::endl; + } + } else { + // 我们需要上升一个等级 + proposed_m += proposed_m; + const uint32_t twice_rem = rem + rem; + if (twice_rem >= absD || twice_rem < rem) { + proposed_m += 1; + } + more = (uint8_t)(floor_log_2_d | LIBDIVIDE_ADD_MARKER); + if (DEBUG) { + std::cout << "[SR] Using add algorithm, proposed_m = " << proposed_m << ", more = " << (int)more << std::endl; + } + } + + proposed_m += 1; + int32_t magic = (int32_t)proposed_m; + + // 处理负除数 + if (d < 0) { + more |= LIBDIVIDE_NEGATIVE_DIVISOR; + if (!branchfree) { + magic = -magic; + } + if (DEBUG) { + std::cout << "[SR] Negative divisor, magic = " << magic << ", more = " << (int)more << std::endl; + } + } + + // 为我们的IR生成提取移位量和标志 + int shift = more & 0x3F; // 移除标志,保留移位量(位0-5) + bool need_add = (more & LIBDIVIDE_ADD_MARKER) != 0; + bool is_negative = (more & LIBDIVIDE_NEGATIVE_DIVISOR) != 0; + + if (DEBUG) { + std::cout << "[SR] Final result: magic = " << magic << ", more = " << (int)more + << " (0x" << std::hex << (int)more << std::dec << ")" << std::endl; + std::cout << "[SR] Shift = " << shift << ", need_add = " << need_add + << ", is_negative = " << is_negative << std::endl; + + // Test the magic number using the correct libdivide algorithm + std::cout << "[SR] Testing magic number (libdivide algorithm):" << std::endl; + int test_values[] = {1, 7, 37, 100, 999, -1, -7, -37, -100}; + + for (int test_val : test_values) { + int64_t quotient; + + // 实现正确的libdivide算法 + int64_t product = (int64_t)test_val * magic; + int64_t high_bits = product >> 32; + + if (need_add) { + // ADD_MARKER情况:移位前加上被除数 + // 这是libdivide的关键洞察! + high_bits += test_val; + quotient = high_bits >> shift; + } else { + // 正常情况:只是移位 + quotient = high_bits >> shift; + } + + // 符号修正:这是libdivide有符号除法的关键部分! + // 如果被除数为负,商需要加1来匹配C语言的截断除法语义 + if (test_val < 0) { + quotient += 1; + } + + int expected = test_val / divisor; + + bool correct = (quotient == expected); + std::cout << "[SR] " << test_val << " / " << divisor << " = " << quotient + << " (expected " << expected << ") " << (correct ? "✓" : "✗") << std::endl; + } + + std::cout << "[SR] ===== End magic computation =====" << std::endl; + } + + // 返回魔数、移位量,并在移位中编码ADD_MARKER标志 + // 我们将使用移位的第6位表示ADD_MARKER,第7位表示负数(如果需要) + int encoded_shift = shift; + if (need_add) { + encoded_shift |= 0x40; // 设置第6位表示ADD_MARKER + if (DEBUG) { + std::cout << "[SR] Encoding ADD_MARKER in shift: " << encoded_shift << std::endl; + } + } + + return {magic, encoded_shift}; + } + }; }// namespace sysy \ No newline at end of file diff --git a/src/midend/Pass/Optimize/GlobalStrengthReduction.cpp b/src/midend/Pass/Optimize/GlobalStrengthReduction.cpp index 118b793..e15a9c5 100644 --- a/src/midend/Pass/Optimize/GlobalStrengthReduction.cpp +++ b/src/midend/Pass/Optimize/GlobalStrengthReduction.cpp @@ -123,25 +123,24 @@ bool GlobalStrengthReductionContext::processBasicBlock(BasicBlock *bb) { } bool GlobalStrengthReductionContext::processInstruction(Instruction *inst) { - bool changed = false; - if (DEBUG >= 2) { + if (DEBUG) { std::cout << " Processing instruction: " << inst->getName() << std::endl; } // 先尝试代数优化 if (tryAlgebraicOptimization(inst)) { - changed = true; algebraicOptCount++; + return true; } // 再尝试强度削弱 if (tryStrengthReduction(inst)) { - changed = true; strengthReductionCount++; + return true; } - return changed; + return false; } // ====================================================================== @@ -638,7 +637,9 @@ bool GlobalStrengthReductionContext::reduceDivision(BinaryInst *inst) { } builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst)); - auto shiftInst = builder->createBinaryInst(Instruction::kSra, Type::getIntType(), lhs, getConstantInt(shiftAmount)); + Value* divisor_minus_1 = ConstantInteger::get(constVal - 1); + Value* adjusted = builder->createAddInst(lhs, divisor_minus_1); + Value* shiftInst = builder->createBinaryInst(Instruction::kSra, Type::getIntType(), adjusted, getConstantInt(shiftAmount)); replaceWithOptimized(inst, shiftInst); strengthReductionCount++; return true; @@ -646,18 +647,11 @@ bool GlobalStrengthReductionContext::reduceDivision(BinaryInst *inst) { // x / c = x * magic_number (魔数乘法优化 - 使用libdivide算法) if (isConstantInt(rhs, constVal) && constVal > 1 && constVal != (uint32_t)(-1)) { - auto magicPair = computeMulhMagicNumbers(static_cast(constVal)); - if (magicPair.first != -1) { // 有效的魔数 - if (DEBUG) { - std::cout << " StrengthReduction: " << inst->getName() - << " = x / " << constVal << " -> libdivide magic multiplication" << std::endl; - } - - Value* magicResult = createMagicDivisionLibdivide(inst, static_cast(constVal), magicPair); - replaceWithOptimized(inst, magicResult); - divisionOptCount++; - return true; - } + // auto magicPair = computeMulhMagicNumbers(static_cast(constVal)); + Value* magicResult = createMagicDivisionLibdivide(inst, static_cast(constVal)); + replaceWithOptimized(inst, magicResult); + divisionOptCount++; + return true; } return false; @@ -709,251 +703,98 @@ bool GlobalStrengthReductionContext::reducePower(CallInst *inst) { return false; } -// ====================================================================== -// 魔数乘法相关方法 -// ====================================================================== - -// 该实现参考了libdivide的算法 -std::pair GlobalStrengthReductionContext::computeMulhMagicNumbers(int divisor) { - - if (DEBUG) { - std::cout << "\n[SR] ===== Computing magic numbers for divisor " << divisor << " (libdivide algorithm) =====" << std::endl; - } - - if (divisor == 0) { - if (DEBUG) std::cout << "[SR] Error: divisor must be != 0" << std::endl; - return {-1, -1}; - } - - // libdivide 常数 - const uint8_t LIBDIVIDE_ADD_MARKER = 0x40; - const uint8_t LIBDIVIDE_NEGATIVE_DIVISOR = 0x80; - - // 辅助函数:计算前导零个数 - auto count_leading_zeros32 = [](uint32_t val) -> uint32_t { - if (val == 0) return 32; - return __builtin_clz(val); - }; - - // 辅助函数:64位除法返回32位商和余数 - auto div_64_32 = [](uint32_t high, uint32_t low, uint32_t divisor, uint32_t* rem) -> uint32_t { - uint64_t dividend = ((uint64_t)high << 32) | low; - uint32_t quotient = dividend / divisor; - *rem = dividend % divisor; - return quotient; - }; - - if (DEBUG) { - std::cout << "[SR] Input divisor: " << divisor << std::endl; - } - - // libdivide_internal_s32_gen 算法实现 - int32_t d = divisor; - uint32_t ud = (uint32_t)d; - uint32_t absD = (d < 0) ? -ud : ud; - - if (DEBUG) { - std::cout << "[SR] absD = " << absD << std::endl; - } - - uint32_t floor_log_2_d = 31 - count_leading_zeros32(absD); - - if (DEBUG) { - std::cout << "[SR] floor_log_2_d = " << floor_log_2_d << std::endl; - } - - // 检查 absD 是否为2的幂 - if ((absD & (absD - 1)) == 0) { - if (DEBUG) { - std::cout << "[SR] " << absD << " 是2的幂,使用移位方法" << std::endl; - } - - // 对于2的幂,我们只使用移位,不需要魔数 - int shift = floor_log_2_d; - if (d < 0) shift |= 0x80; // 标记负数 - - if (DEBUG) { - std::cout << "[SR] Power of 2 result: magic=0, shift=" << shift << std::endl; - std::cout << "[SR] ===== End magic computation =====" << std::endl; - } - - // 对于我们的目的,我们将在IR生成中以不同方式处理2的幂 - // 返回特殊标记 - return {0, shift}; - } - - if (DEBUG) { - std::cout << "[SR] " << absD << " is not a power of 2, computing magic number" << std::endl; - } - - // 非2的幂除数的魔数计算 - uint8_t more; - uint32_t rem, proposed_m; - - // 计算 proposed_m = floor(2^(floor_log_2_d + 31) / absD) - proposed_m = div_64_32((uint32_t)1 << (floor_log_2_d - 1), 0, absD, &rem); - const uint32_t e = absD - rem; - - if (DEBUG) { - std::cout << "[SR] proposed_m = " << proposed_m << ", rem = " << rem << ", e = " << e << std::endl; - } - - // 确定是否需要"加法"版本 - const bool branchfree = false; // 使用分支版本 - - if (!branchfree && e < ((uint32_t)1 << floor_log_2_d)) { - // 这个幂次有效 - more = (uint8_t)(floor_log_2_d - 1); - if (DEBUG) { - std::cout << "[SR] Using basic algorithm, shift = " << (int)more << std::endl; - } - } else { - // 我们需要上升一个等级 - proposed_m += proposed_m; - const uint32_t twice_rem = rem + rem; - if (twice_rem >= absD || twice_rem < rem) { - proposed_m += 1; - } - more = (uint8_t)(floor_log_2_d | LIBDIVIDE_ADD_MARKER); - if (DEBUG) { - std::cout << "[SR] Using add algorithm, proposed_m = " << proposed_m << ", more = " << (int)more << std::endl; - } - } - - proposed_m += 1; - int32_t magic = (int32_t)proposed_m; - - // 处理负除数 - if (d < 0) { - more |= LIBDIVIDE_NEGATIVE_DIVISOR; - if (!branchfree) { - magic = -magic; - } - if (DEBUG) { - std::cout << "[SR] Negative divisor, magic = " << magic << ", more = " << (int)more << std::endl; - } - } - - // 为我们的IR生成提取移位量和标志 - int shift = more & 0x3F; // 移除标志,保留移位量(位0-5) - bool need_add = (more & LIBDIVIDE_ADD_MARKER) != 0; - bool is_negative = (more & LIBDIVIDE_NEGATIVE_DIVISOR) != 0; - // 返回魔数、移位量,并在移位中编码ADD_MARKER标志 - // 我们将使用移位的第6位表示ADD_MARKER,第7位表示负数(如果需要) - int encoded_shift = shift; - if (need_add) { - encoded_shift |= 0x40; // 设置第6位表示ADD_MARKER - if (DEBUG) { - std::cout << "[SR] Encoding ADD_MARKER in shift: " << encoded_shift << std::endl; - } - } - - return {magic, encoded_shift}; - } - -Value* GlobalStrengthReductionContext::createMagicDivision(BinaryInst* divInst, uint32_t divisor, const MagicNumber& magic) { +Value* GlobalStrengthReductionContext::createMagicDivisionLibdivide(BinaryInst* divInst, int divisor) { builder->setPosition(divInst->getParent(), divInst->getParent()->findInstIterator(divInst)); + // 使用mulh指令优化任意常数除法 + auto [magic, shift] = SysYIROptUtils::computeMulhMagicNumbers(divisor); - Value* dividend = divInst->getLhs(); - - // 创建魔数常量 - Value* magicConst = getConstantInt(static_cast(magic.multiplier)); - - // 执行乘法: tmp = dividend * magic - Value* tmp = builder->createMulInst(dividend, magicConst); - - if (magic.needAdd) { - // 需要额外加法的情况 - Value* sum = builder->createAddInst(tmp, dividend); - if (magic.shift > 0) { - Value* shiftConst = getConstantInt(magic.shift); - tmp = builder->createBinaryInst(Instruction::kSra, Type::getIntType(), sum, shiftConst); - } else { - tmp = sum; - } - } else { - // 直接右移 - if (magic.shift > 0) { - Value* shiftConst = getConstantInt(magic.shift); - tmp = builder->createBinaryInst(Instruction::kSra, Type::getIntType(), tmp, shiftConst); - } - } - - // 处理符号:如果被除数为负,结果需要调整 - // 这里简化处理,假设都是正数除法 - - return tmp; -} - -Value* GlobalStrengthReductionContext::createMagicDivisionLibdivide(BinaryInst* divInst, int divisor, const std::pair& magicPair) { - builder->setPosition(divInst->getParent(), divInst->getParent()->findInstIterator(divInst)); - - Value* dividend = divInst->getLhs(); - int magic = magicPair.first; - int encoded_shift = magicPair.second; - - if (DEBUG) { - std::cout << "[SR] Creating libdivide magic division: magic=" << magic - << ", encoded_shift=" << encoded_shift << std::endl; - } - - // 检查是否为2的幂(magic=0表示是2的幂) - if (magic == 0) { - // 2的幂除法,直接使用算术右移 - int shift = encoded_shift & 0x3F; // 获取实际移位量 - bool is_negative = (encoded_shift & 0x80) != 0; - + // 检查是否无法优化(magic == -1, shift == -1 表示失败) + if (magic == -1 && shift == -1) { if (DEBUG) { - std::cout << "[SR] Power of 2 division: shift=" << shift - << ", negative=" << is_negative << std::endl; + std::cout << "[SR] Cannot optimize division by " << divisor + << ", keeping original division" << std::endl; } - - Value* result = dividend; - if (shift > 0) { - Value* shiftConst = getConstantInt(shift); - result = builder->createBinaryInst(Instruction::kSra, Type::getIntType(), dividend, shiftConst); - } - - // 如果原除数为负,需要取反结果 - if (is_negative) { - result = builder->createNegInst(result); - } - - return result; + // 返回 nullptr 表示无法优化,调用方应该保持原始除法 + return nullptr; } - // 非2的幂除法,使用魔数乘法 - int shift = encoded_shift & 0x3F; // 获取移位量(位0-5) - bool need_add = (encoded_shift & 0x40) != 0; // 检查ADD_MARKER标志(位6) - - if (DEBUG) { - std::cout << "[SR] Magic multiplication: shift=" << shift - << ", need_add=" << need_add << std::endl; + // 2的幂次方除法可以用移位优化(但这不是魔数法的情况)这种情况应该不会被分类到这里但是还是做一个保护措施 + if ((divisor & (divisor - 1)) == 0 && divisor > 0) { + // 是2的幂次方,可以用移位 + int shift_amount = 0; + int temp = divisor; + while (temp > 1) { + temp >>= 1; + shift_amount++; + } + + Value* shiftConstant = ConstantInteger::get(shift_amount); + // 对于有符号除法,需要先加上除数-1然后再移位(为了正确处理负数舍入) + Value* divisor_minus_1 = ConstantInteger::get(divisor - 1); + Value* adjusted = builder->createAddInst(divInst->getOperand(0), divisor_minus_1); + return builder->createBinaryInst( + Instruction::Kind::kSra, // 算术右移 + divInst->getOperand(0)->getType(), + adjusted, + shiftConstant + ); } // 创建魔数常量 - Value* magicConst = getConstantInt(magic); - - // 执行高位乘法:mulh(dividend, magic) - // 由于我们的IR可能没有直接的mulh指令,我们使用64位乘法然后取高32位 - // 这里需要根据实际的IR指令集进行调整 - Value* tmp = builder->createMulInst(dividend, magicConst); - - if (need_add) { - // ADD算法:(mulh(dividend, magic) + dividend) >> shift - tmp = builder->createAddInst(tmp, dividend); + // 检查魔数是否能放入32位,如果不能,则不进行优化 + if (magic > INT32_MAX || magic < INT32_MIN) { + if (DEBUG) { + std::cout << "[SR] Magic number " << magic << " exceeds 32-bit range, skipping optimization" << std::endl; + } + return nullptr; // 无法优化,保持原始除法 } - if (shift > 0) { - Value* shiftConst = getConstantInt(shift); - tmp = builder->createBinaryInst(Instruction::kSra, Type::getIntType(), tmp, shiftConst); + Value* magicConstant = ConstantInteger::get((int32_t)magic); + + // 检查是否需要ADD_MARKER处理(加法调整) + bool needAdd = (shift & 0x40) != 0; + int actualShift = shift & 0x3F; // 提取真实的移位量 + + if (DEBUG) { + std::cout << "[SR] IR Generation: magic=" << magic << ", needAdd=" << needAdd + << ", actualShift=" << actualShift << std::endl; } - // 处理符号位调整 - // 如果被除数为负数,可能需要额外的符号处理 - // 这里简化处理,实际实现可能需要更复杂的符号位处理 + // 执行高位乘法:mulh(x, magic) + Value* mulhResult = builder->createBinaryInst( + Instruction::Kind::kMulh, // 高位乘法 + divInst->getOperand(0)->getType(), + divInst->getOperand(0), + magicConstant + ); - return tmp; + if (needAdd) { + // ADD_MARKER 情况:需要在移位前加上被除数 + // 这对应于 libdivide 的加法调整算法 + if (DEBUG) { + std::cout << "[SR] Applying ADD_MARKER: adding dividend before shift" << std::endl; + } + mulhResult = builder->createAddInst(mulhResult, divInst->getOperand(0)); + } + + if (actualShift > 0) { + // 如果需要额外移位 + Value* shiftConstant = ConstantInteger::get(actualShift); + mulhResult = builder->createBinaryInst( + Instruction::Kind::kSra, // 算术右移 + divInst->getOperand(0)->getType(), + mulhResult, + shiftConstant + ); + } + + // 标准的有符号除法符号修正:如果被除数为负,商需要加1 + // 这对所有有符号除法都需要,不管是否可能有负数 + Value* isNegative = builder->createICmpLTInst(divInst->getOperand(0), ConstantInteger::get(0)); + // 将i1转换为i32:负数时为1,非负数时为0 ICmpLTInst的结果会默认转化为32位 + mulhResult = builder->createAddInst(mulhResult, isNegative); + + return mulhResult; } // ====================================================================== @@ -1010,7 +851,7 @@ bool GlobalStrengthReductionContext::hasOnlyLocalUses(Instruction* inst) { } void GlobalStrengthReductionContext::replaceWithOptimized(Instruction* original, Value* replacement) { - if (DEBUG >= 2) { + if (DEBUG) { std::cout << " Replacing " << original->getName() << " with " << replacement->getName() << std::endl; } diff --git a/src/midend/Pass/Optimize/LoopStrengthReduction.cpp b/src/midend/Pass/Optimize/LoopStrengthReduction.cpp index 973a053..0edbed4 100644 --- a/src/midend/Pass/Optimize/LoopStrengthReduction.cpp +++ b/src/midend/Pass/Optimize/LoopStrengthReduction.cpp @@ -106,187 +106,6 @@ bool StrengthReductionContext::analyzeInductionVariableRange( return hasNegativePotential; } -//该实现参考了libdivide的算法 -std::pair StrengthReductionContext::computeMulhMagicNumbers(int divisor) const { - - if (DEBUG) { - std::cout << "\n[SR] ===== Computing magic numbers for divisor " << divisor << " (libdivide algorithm) =====" << std::endl; - } - - if (divisor == 0) { - if (DEBUG) std::cout << "[SR] Error: divisor must be != 0" << std::endl; - return {-1, -1}; - } - - // libdivide 常数 - const uint8_t LIBDIVIDE_ADD_MARKER = 0x40; - const uint8_t LIBDIVIDE_NEGATIVE_DIVISOR = 0x80; - - // 辅助函数:计算前导零个数 - auto count_leading_zeros32 = [](uint32_t val) -> uint32_t { - if (val == 0) return 32; - return __builtin_clz(val); - }; - - // 辅助函数:64位除法返回32位商和余数 - auto div_64_32 = [](uint32_t high, uint32_t low, uint32_t divisor, uint32_t* rem) -> uint32_t { - uint64_t dividend = ((uint64_t)high << 32) | low; - uint32_t quotient = dividend / divisor; - *rem = dividend % divisor; - return quotient; - }; - - if (DEBUG) { - std::cout << "[SR] Input divisor: " << divisor << std::endl; - } - - // libdivide_internal_s32_gen 算法实现 - int32_t d = divisor; - uint32_t ud = (uint32_t)d; - uint32_t absD = (d < 0) ? -ud : ud; - - if (DEBUG) { - std::cout << "[SR] absD = " << absD << std::endl; - } - - uint32_t floor_log_2_d = 31 - count_leading_zeros32(absD); - - if (DEBUG) { - std::cout << "[SR] floor_log_2_d = " << floor_log_2_d << std::endl; - } - - // 检查 absD 是否为2的幂 - if ((absD & (absD - 1)) == 0) { - if (DEBUG) { - std::cout << "[SR] " << absD << " 是2的幂,使用移位方法" << std::endl; - } - - // 对于2的幂,我们只使用移位,不需要魔数 - int shift = floor_log_2_d; - if (d < 0) shift |= 0x80; // 标记负数 - - if (DEBUG) { - std::cout << "[SR] Power of 2 result: magic=0, shift=" << shift << std::endl; - std::cout << "[SR] ===== End magic computation =====" << std::endl; - } - - // 对于我们的目的,我们将在IR生成中以不同方式处理2的幂 - // 返回特殊标记 - return {0, shift}; - } - - if (DEBUG) { - std::cout << "[SR] " << absD << " is not a power of 2, computing magic number" << std::endl; - } - - // 非2的幂除数的魔数计算 - uint8_t more; - uint32_t rem, proposed_m; - - // 计算 proposed_m = floor(2^(floor_log_2_d + 31) / absD) - proposed_m = div_64_32((uint32_t)1 << (floor_log_2_d - 1), 0, absD, &rem); - const uint32_t e = absD - rem; - - if (DEBUG) { - std::cout << "[SR] proposed_m = " << proposed_m << ", rem = " << rem << ", e = " << e << std::endl; - } - - // 确定是否需要"加法"版本 - const bool branchfree = false; // 使用分支版本 - - if (!branchfree && e < ((uint32_t)1 << floor_log_2_d)) { - // 这个幂次有效 - more = (uint8_t)(floor_log_2_d - 1); - if (DEBUG) { - std::cout << "[SR] Using basic algorithm, shift = " << (int)more << std::endl; - } - } else { - // 我们需要上升一个等级 - proposed_m += proposed_m; - const uint32_t twice_rem = rem + rem; - if (twice_rem >= absD || twice_rem < rem) { - proposed_m += 1; - } - more = (uint8_t)(floor_log_2_d | LIBDIVIDE_ADD_MARKER); - if (DEBUG) { - std::cout << "[SR] Using add algorithm, proposed_m = " << proposed_m << ", more = " << (int)more << std::endl; - } - } - - proposed_m += 1; - int32_t magic = (int32_t)proposed_m; - - // 处理负除数 - if (d < 0) { - more |= LIBDIVIDE_NEGATIVE_DIVISOR; - if (!branchfree) { - magic = -magic; - } - if (DEBUG) { - std::cout << "[SR] Negative divisor, magic = " << magic << ", more = " << (int)more << std::endl; - } - } - - // 为我们的IR生成提取移位量和标志 - int shift = more & 0x3F; // 移除标志,保留移位量(位0-5) - bool need_add = (more & LIBDIVIDE_ADD_MARKER) != 0; - bool is_negative = (more & LIBDIVIDE_NEGATIVE_DIVISOR) != 0; - - if (DEBUG) { - std::cout << "[SR] Final result: magic = " << magic << ", more = " << (int)more - << " (0x" << std::hex << (int)more << std::dec << ")" << std::endl; - std::cout << "[SR] Shift = " << shift << ", need_add = " << need_add - << ", is_negative = " << is_negative << std::endl; - - // Test the magic number using the correct libdivide algorithm - std::cout << "[SR] Testing magic number (libdivide algorithm):" << std::endl; - int test_values[] = {1, 7, 37, 100, 999, -1, -7, -37, -100}; - - for (int test_val : test_values) { - int64_t quotient; - - // 实现正确的libdivide算法 - int64_t product = (int64_t)test_val * magic; - int64_t high_bits = product >> 32; - - if (need_add) { - // ADD_MARKER情况:移位前加上被除数 - // 这是libdivide的关键洞察! - high_bits += test_val; - quotient = high_bits >> shift; - } else { - // 正常情况:只是移位 - quotient = high_bits >> shift; - } - - // 符号修正:这是libdivide有符号除法的关键部分! - // 如果被除数为负,商需要加1来匹配C语言的截断除法语义 - if (test_val < 0) { - quotient += 1; - } - - int expected = test_val / divisor; - - bool correct = (quotient == expected); - std::cout << "[SR] " << test_val << " / " << divisor << " = " << quotient - << " (expected " << expected << ") " << (correct ? "✓" : "✗") << std::endl; - } - - std::cout << "[SR] ===== End magic computation =====" << std::endl; - } - - // 返回魔数、移位量,并在移位中编码ADD_MARKER标志 - // 我们将使用移位的第6位表示ADD_MARKER,第7位表示负数(如果需要) - int encoded_shift = shift; - if (need_add) { - encoded_shift |= 0x40; // 设置第6位表示ADD_MARKER - if (DEBUG) { - std::cout << "[SR] Encoding ADD_MARKER in shift: " << encoded_shift << std::endl; - } - } - - return {magic, encoded_shift}; -} bool LoopStrengthReduction::runOnFunction(Function* F, AnalysisManager& AM) { if (F->getBasicBlocks().empty()) { @@ -1018,7 +837,7 @@ Value* StrengthReductionContext::generateConstantDivisionReplacement( IRBuilder* builder ) const { // 使用mulh指令优化任意常数除法 - auto [magic, shift] = computeMulhMagicNumbers(candidate->multiplier); + auto [magic, shift] = SysYIROptUtils::computeMulhMagicNumbers(candidate->multiplier); // 检查是否无法优化(magic == -1, shift == -1 表示失败) if (magic == -1 && shift == -1) { diff --git a/src/midend/Pass/Pass.cpp b/src/midend/Pass/Pass.cpp index c834e30..aee29b0 100644 --- a/src/midend/Pass/Pass.cpp +++ b/src/midend/Pass/Pass.cpp @@ -78,6 +78,8 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR registerOptimizationPass(builderIR); registerOptimizationPass(builderIR); registerOptimizationPass(); + + registerOptimizationPass(builderIR); registerOptimizationPass(builderIR); registerOptimizationPass(builderIR); From ad74e435bad01cc633d5b05375959c54067c3dd7 Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Mon, 18 Aug 2025 21:55:57 +0800 Subject: [PATCH 12/12] =?UTF-8?q?[midend-GSR]=E4=BF=AE=E5=A4=8D=E9=94=99?= =?UTF-8?q?=E8=AF=AF=E7=9A=84=E4=BB=A3=E6=95=B0=E7=AE=80=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../Pass/Optimize/GlobalStrengthReduction.cpp | 57 +++++++++++++------ 1 file changed, 41 insertions(+), 16 deletions(-) diff --git a/src/midend/Pass/Optimize/GlobalStrengthReduction.cpp b/src/midend/Pass/Optimize/GlobalStrengthReduction.cpp index e15a9c5..e8254a2 100644 --- a/src/midend/Pass/Optimize/GlobalStrengthReduction.cpp +++ b/src/midend/Pass/Optimize/GlobalStrengthReduction.cpp @@ -390,8 +390,8 @@ bool GlobalStrengthReductionContext::optimizeLogical(BinaryInst *inst) { return true; } - // x && 1 = x - if (isConstantInt(rhs, constVal) && constVal == 1) { + // x && -1 = x + if (isConstantInt(rhs, constVal) && constVal == -1) { if (DEBUG) { std::cout << " Algebraic: " << inst->getName() << " = x && 1 -> x" << std::endl; } @@ -416,15 +416,6 @@ bool GlobalStrengthReductionContext::optimizeLogical(BinaryInst *inst) { replaceWithOptimized(inst, lhs); return true; } - - // x || 1 = 1 - if (isConstantInt(rhs, constVal) && constVal == 1) { - if (DEBUG) { - std::cout << " Algebraic: " << inst->getName() << " = x || 1 -> 1" << std::endl; - } - replaceWithOptimized(inst, getConstantInt(1)); - return true; - } // x || x = x if (lhs == rhs) { @@ -630,16 +621,50 @@ bool GlobalStrengthReductionContext::reduceDivision(BinaryInst *inst) { // x / 2^n = x >> n (对于无符号除法或已知为正数的情况) if (isConstantInt(rhs, constVal) && constVal > 0 && isPowerOfTwo(constVal)) { + builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst)); int shiftAmount = log2OfPowerOfTwo(constVal); + // 有符号除法校正:(x + (x >> 31) & mask) >> k + int maskValue = constVal - 1; + + // x >> 31 (算术右移获取符号位) + Value* signShift = ConstantInteger::get(31); + Value* signBits = builder->createBinaryInst( + Instruction::Kind::kSra, // 算术右移 + lhs->getType(), + lhs, + signShift + ); + + // (x >> 31) & mask + Value* mask = ConstantInteger::get(maskValue); + Value* correction = builder->createBinaryInst( + Instruction::Kind::kAnd, + lhs->getType(), + signBits, + mask + ); + + // x + correction + Value* corrected = builder->createAddInst(lhs, correction); + + // (x + correction) >> k + Value* divShift = ConstantInteger::get(shiftAmount); + Value* shiftInst = builder->createBinaryInst( + Instruction::Kind::kSra, // 算术右移 + lhs->getType(), + corrected, + divShift + ); + if (DEBUG) { std::cout << " StrengthReduction: " << inst->getName() - << " = x / " << constVal << " -> x >> " << shiftAmount << std::endl; + << " = x / " << constVal << " -> (x + (x >> 31) & mask) >> " << shiftAmount << std::endl; } - builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst)); - Value* divisor_minus_1 = ConstantInteger::get(constVal - 1); - Value* adjusted = builder->createAddInst(lhs, divisor_minus_1); - Value* shiftInst = builder->createBinaryInst(Instruction::kSra, Type::getIntType(), adjusted, getConstantInt(shiftAmount)); + // builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst)); + // Value* divisor_minus_1 = ConstantInteger::get(constVal - 1); + // Value* adjusted = builder->createAddInst(lhs, divisor_minus_1); + // Value* shiftInst = builder->createBinaryInst(Instruction::kSra, Type::getIntType(), adjusted, getConstantInt(shiftAmount)); replaceWithOptimized(inst, shiftInst); strengthReductionCount++; return true;