From 550f4017bed3c68e8ef485d73f085d8c1823c9e7 Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Mon, 21 Jul 2025 15:19:38 +0800 Subject: [PATCH] =?UTF-8?q?[midend]=E9=87=8D=E6=9E=84=E4=B8=AD=E7=AB=AF?= =?UTF-8?q?=EF=BC=8C=E5=BB=BA=E7=AB=8B=E9=81=8D=E7=AE=A1=E7=90=86=E5=99=A8?= =?UTF-8?q?=EF=BC=8C=E6=B3=A8=E5=86=8C=E5=99=A8=E7=AD=89=EF=BC=8C=E5=88=9D?= =?UTF-8?q?=E6=AD=A5=E6=9E=84=E5=BB=BA=E6=94=AF=E9=85=8D=E6=A0=91=E5=88=86?= =?UTF-8?q?=E6=9E=90=E9=81=8D=EF=BC=8C=E5=A2=9E=E5=8A=A0=E5=9F=BA=E6=9C=AC?= =?UTF-8?q?=E5=9D=97=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/Dom.cpp | 181 ++++++++++++++++++++++ src/include/Dom.h | 52 +++++++ src/include/IR.h | 19 ++- src/include/Pass.h | 284 +++++++++++++++++++++++++++++++++++ src/include/SysYIROptUtils.h | 3 + 5 files changed, 538 insertions(+), 1 deletion(-) create mode 100644 src/Dom.cpp create mode 100644 src/include/Dom.h create mode 100644 src/include/Pass.h diff --git a/src/Dom.cpp b/src/Dom.cpp new file mode 100644 index 0000000..beb1d65 --- /dev/null +++ b/src/Dom.cpp @@ -0,0 +1,181 @@ +#include "Dom.h" +#include // for std::numeric_limits +#include + +namespace sysy { + +// 初始化 支配树静态 ID +char DominatorTreeAnalysisPass::ID = 0; + +// ============================================================== +// DominatorTree 结果类的实现 +// ============================================================== + +DominatorTree::DominatorTree(Function *F) : AssociatedFunction(F) { + // 构造时可以不计算,在分析遍运行里计算并填充 +} + +const std::set *DominatorTree::getDominators(BasicBlock *BB) const { + auto it = Dominators.find(BB); + if (it != Dominators.end()) { + return &(it->second); + } + return nullptr; +} + +BasicBlock *DominatorTree::getImmediateDominator(BasicBlock *BB) const { + auto it = IDoms.find(BB); + if (it != IDoms.end()) { + return it->second; + } + return nullptr; +} + +const std::set *DominatorTree::getDominanceFrontier(BasicBlock *BB) const { + auto it = DominanceFrontiers.find(BB); + if (it != DominanceFrontiers.end()) { + return &(it->second); + } + return nullptr; +} + +void DominatorTree::computeDominators(Function *F) { + // 经典的迭代算法计算支配者集合 + // TODO: 可以替换为更高效的算法,如 Lengauer-Tarjan 算法 + BasicBlock *entryBlock = F->getEntryBlock(); + + for (const auto &bb_ptr : F->getBasicBlocks()) { + BasicBlock *bb = bb_ptr.get(); + if (bb == entryBlock) { + Dominators[bb].insert(bb); + } else { + for (const auto &all_bb_ptr : F->getBasicBlocks()) { + Dominators[bb].insert(all_bb_ptr.get()); + } + } + } + + bool changed = true; + while (changed) { + changed = false; + for (const auto &bb_ptr : F->getBasicBlocks()) { + BasicBlock *bb = bb_ptr.get(); + if (bb == entryBlock) + continue; + + std::set newDom; + bool firstPred = true; + for (BasicBlock *pred : bb->getPredecessors()) { + if (Dominators.count(pred)) { + if (firstPred) { + newDom = Dominators[pred]; + firstPred = false; + } else { + std::set intersection; + std::set_intersection(newDom.begin(), newDom.end(), Dominators[pred].begin(), Dominators[pred].end(), + std::inserter(intersection, intersection.begin())); + newDom = intersection; + } + } + } + newDom.insert(bb); + + if (newDom != Dominators[bb]) { + Dominators[bb] = newDom; + changed = true; + } + } + } +} + +void DominatorTree::computeIDoms(Function *F) { + // 采用与之前类似的简化实现。TODO:Lengauer-Tarjan等算法。 + BasicBlock *entryBlock = F->getEntryBlock(); + IDoms[entryBlock] = nullptr; + + for (const auto &bb_ptr : F->getBasicBlocks()) { + BasicBlock *bb = bb_ptr.get(); + if (bb == entryBlock) + continue; + + BasicBlock *currentIDom = nullptr; + const std::set *domsOfBB = getDominators(bb); + if (!domsOfBB) + continue; + + for (BasicBlock *D : *domsOfBB) { + if (D == bb) + continue; + + bool isCandidateIDom = true; + for (BasicBlock *candidate : *domsOfBB) { + if (candidate == bb || candidate == D) + continue; + const std::set *domsOfCandidate = getDominators(candidate); + if (domsOfCandidate && domsOfCandidate->count(D) == 0 && domsOfBB->count(candidate)) { + isCandidateIDom = false; + break; + } + } + if (isCandidateIDom) { + currentIDom = D; + break; + } + } + IDoms[bb] = currentIDom; + } +} + +void DominatorTree::computeDominanceFrontiers(Function *F) { + // 经典的支配边界计算算法 + for (const auto &bb_ptr_X : F->getBasicBlocks()) { + BasicBlock *X = bb_ptr_X.get(); + DominanceFrontiers[X].clear(); + + for (BasicBlock *Y : X->getSuccessors()) { + const std::set *domsOfY = getDominators(Y); + if (domsOfY && domsOfY->find(X) == domsOfY->end()) { + DominanceFrontiers[X].insert(Y); + } + } + + const std::set *domsOfX = getDominators(X); + if (!domsOfX) + continue; + for (const auto &bb_ptr_Z : F->getBasicBlocks()) { + BasicBlock *Z = bb_ptr_Z.get(); + if (Z == X) + continue; + const std::set *domsOfZ = getDominators(Z); + if (domsOfZ && domsOfZ->count(X) && Z != X) { + + for (BasicBlock *Y : Z->getSuccessors()) { + const std::set *domsOfY = getDominators(Y); + if (domsOfY && domsOfY->find(X) == domsOfY->end()) { + DominanceFrontiers[X].insert(Y); + } + } + } + } + } +} + +// ============================================================== +// DominatorTreeAnalysisPass 的实现 +// ============================================================== + + +bool DominatorTreeAnalysisPass::runOnFunction(Function* F) { + CurrentDominatorTree = std::make_unique(F); + CurrentDominatorTree->computeDominators(F); + CurrentDominatorTree->computeIDoms(F); + CurrentDominatorTree->computeDominanceFrontiers(F); + return false; +} + +std::unique_ptr DominatorTreeAnalysisPass::getResult() { + // 返回计算好的 DominatorTree 实例,所有权转移给 AnalysisManager + return std::move(CurrentDominatorTree); +} + +} // namespace sysy \ No newline at end of file diff --git a/src/include/Dom.h b/src/include/Dom.h new file mode 100644 index 0000000..6b38f83 --- /dev/null +++ b/src/include/Dom.h @@ -0,0 +1,52 @@ +#pragma once + +#include "Pass.h" // 包含 Pass 框架 +#include "IR.h" // 包含 IR 定义 +#include +#include +#include +#include + +namespace sysy { + +// 支配树分析结果类 (保持不变) +class DominatorTree : public AnalysisResultBase { +public: + DominatorTree(Function* F); + const std::set* getDominators(BasicBlock* BB) const; + BasicBlock* getImmediateDominator(BasicBlock* BB) const; + const std::set* getDominanceFrontier(BasicBlock* BB) const; + const std::map>& getDominatorsMap() const { return Dominators; } + const std::map& getIDomsMap() const { return IDoms; } + const std::map>& getDominanceFrontiersMap() const { return DominanceFrontiers; } + void computeDominators(Function* F); + void computeIDoms(Function* F); + void computeDominanceFrontiers(Function* F); +private: + Function* AssociatedFunction; + std::map> Dominators; + std::map IDoms; + std::map> DominanceFrontiers; +}; + + +// 支配树分析遍 +class DominatorTreeAnalysisPass : public AnalysisPass { +public: + // 唯一的 Pass ID + static char ID; // LLVM 风格的唯一 ID + + DominatorTreeAnalysisPass() : AnalysisPass("DominatorTreeAnalysis", Pass::Granularity::Function) {} + + // 实现 getPassID + void* getPassID() const override { return &ID; } + + bool runOnFunction(Function* F) override; + + std::unique_ptr getResult() override; + +private: + std::unique_ptr CurrentDominatorTree; +}; + +} // namespace sysy \ No newline at end of file diff --git a/src/include/IR.h b/src/include/IR.h index 6e35715..2abe3e1 100644 --- a/src/include/IR.h +++ b/src/include/IR.h @@ -522,12 +522,20 @@ public: void setParent(Function *func) { parent = func; } inst_list& getInstructions() { return instructions; } arg_list& getArguments() { return arguments; } - const block_list& getPredecessors() const { return predecessors; } + block_list& getPredecessors() { return predecessors; } + void clearPredecessors() { predecessors.clear(); } block_list& getSuccessors() { return successors; } + void clearSuccessors() { successors.clear(); } iterator begin() { return instructions.begin(); } iterator end() { return instructions.end(); } iterator terminator() { return std::prev(end()); } void insertArgument(AllocaInst *inst) { arguments.push_back(inst); } + bool hasSuccessor(BasicBlock *block) const { + return std::find(successors.begin(), successors.end(), block) != successors.end(); + } ///< 判断是否有后继块 + bool hasPredecessor(BasicBlock *block) const { + return std::find(predecessors.begin(), predecessors.end(), block) != predecessors.end(); + } ///< 判断是否有前驱块 void addPredecessor(BasicBlock *block) { if (std::find(predecessors.begin(), predecessors.end(), block) == predecessors.end()) { predecessors.push_back(block); @@ -580,6 +588,15 @@ public: next->addPredecessor(prev); } void removeInst(iterator pos) { instructions.erase(pos); } + void removeInst(Instruction *inst) { + auto pos = std::find_if(instructions.begin(), instructions.end(), + [inst](const std::unique_ptr &i) { return i.get() == inst; }); + if (pos != instructions.end()) { + instructions.erase(pos); + } else { + assert(false && "Instruction not found in BasicBlock"); + } + } ///< 移除指定位置的指令 iterator moveInst(iterator sourcePos, iterator targetPos, BasicBlock *block); }; diff --git a/src/include/Pass.h b/src/include/Pass.h new file mode 100644 index 0000000..063462f --- /dev/null +++ b/src/include/Pass.h @@ -0,0 +1,284 @@ +#pragma once + +#include // For std::function +#include +#include +#include +#include +#include // For std::type_index (although void* ID is more common in LLVM) +#include + +namespace sysy { + +// 抽象基类:分析结果 +class AnalysisResultBase { +public: + virtual ~AnalysisResultBase() = default; +}; + +// 抽象基类:Pass +class Pass { +public: + enum class Granularity { Module, Function, BasicBlock }; + + enum class PassKind { Analysis, Optimization }; + + Pass(const std::string &name, Granularity g, PassKind k) : Name(name), G(g), K(k) {} + virtual ~Pass() = default; + + const std::string &getName() const { return Name; } + Granularity getGranularity() const { return G; } + PassKind getPassKind() const { return K; } + + virtual bool runOnModule(Module *M, AnalysisManager& AM) { return false; } + virtual bool runOnFunction(Function *F, AnalysisManager& AM) { return false; } + virtual bool runOnBasicBlock(BasicBlock *BB, AnalysisManager& AM) { return false; } + + // 所有 Pass 都必须提供一个唯一的 ID + // 这通常是一个静态成员,并在 Pass 类外部定义 + virtual void *getPassID() const = 0; + +protected: + std::string Name; + Granularity G; + PassKind K; +}; + +// 抽象基类:分析遍 +class AnalysisPass : public Pass { +public: + AnalysisPass(const std::string &name, Granularity g) : Pass(name, g, PassKind::Analysis) {} + + virtual std::unique_ptr getResult() = 0; +}; + +// 抽象基类:优化遍 +class OptimizationPass : public Pass { +public: + OptimizationPass(const std::string &name, Granularity g) : Pass(name, g, PassKind::Optimization) {} + + virtual void getAnalysisUsage(std::set &analysisDependencies, std::set &analysisInvalidations) const { + // 默认不依赖也不修改任何分析 + } +}; + +// ====================================================================== +// PassRegistry: 全局 Pass 注册表 (单例) +// ====================================================================== +class PassRegistry { +public: + // Pass 工厂函数类型:返回 Pass 的唯一指针 + using PassFactory = std::function()>; + + // 获取 PassRegistry 实例 (单例模式) + static PassRegistry &getPassRegistry() { + static PassRegistry instance; + return instance; + } + + // 注册一个 Pass + // passID 是 Pass 类的唯一静态 ID (例如 MyPass::ID 的地址) + // factory 是一个 lambda 或函数指针,用于创建该 Pass 的实例 + void registerPass(void *passID, PassFactory factory) { + if (factories.count(passID)) { + // Error: Pass with this ID already registered + // You might want to throw an exception or log an error + return; + } + factories[passID] = std::move(factory); + } + + // 通过 Pass ID 创建一个 Pass 实例 + std::unique_ptr createPass(void *passID) { + auto it = factories.find(passID); + if (it == factories.end()) { + // Error: Pass with this ID not registered + return nullptr; + } + return it->second(); // 调用工厂函数创建实例 + } + +private: + PassRegistry() = default; // 私有构造函数,实现单例 + ~PassRegistry() = default; + PassRegistry(const PassRegistry &) = delete; // 禁用拷贝构造 + PassRegistry &operator=(const PassRegistry &) = delete; // 禁用赋值操作 + + std::map factories; +}; + +// ====================================================================== +// AnalysisManager: 负责管理和提供分析结果 +// ====================================================================== +class AnalysisManager { +public: + AnalysisManager() = default; + ~AnalysisManager() = default; + + // 获取分析结果 + // T 是 AnalysisResult 的具体类型,E 是 AnalysisPass 的具体类型 + // PassManager 应该在运行 Pass 之前调用 registerAnalysisPass + template T *getAnalysisResult(Function *F) { // 针对函数级别的分析,需要传入 Function* + void *analysisID = E::ID; // 获取分析遍的唯一 ID + + // 检查是否已存在有效结果 + auto it = cachedResults.find({F, analysisID}); + if (it != cachedResults.end()) { + return static_cast(it->second.get()); // 返回缓存结果 + } + + // 如果没有缓存结果,通过 PassRegistry 创建分析遍并运行它 + // 注意:这里需要 PassRegistry 实例。如果 AnalysisManager 独立于 PassManager, + // 则需要传入 PassRegistry 引用或指针。 + // 为了简化,假设 AnalysisManager 能够访问到 PassRegistry + std::unique_ptr basePass = PassRegistry::getPassRegistry().createPass(analysisID); + if (!basePass) { + // Error: Analysis pass not registered + return nullptr; + } + + AnalysisPass *analysisPass = static_cast(basePass.get()); + + // 确保分析遍的粒度与请求的上下文匹配 + if (analysisPass->getGranularity() == Pass::Granularity::Function) { + analysisPass->runOnFunction(F); // 运行分析遍 + // 获取结果并缓存 + std::unique_ptr result = analysisPass->getResult(); + T *specificResult = static_cast(result.get()); + cachedResults[{F, analysisID}] = std::move(result); // 缓存结果 + return specificResult; + } + // TODO: 处理 Module 或 BasicBlock 粒度的分析 + + return nullptr; + } + + // 使所有或特定分析结果失效 (当 IR 被修改时调用) + void invalidateAllAnalyses() { cachedResults.clear(); } + + // 使特定分析结果失效 + void invalidateAnalysis(void *analysisID, Function *F = nullptr) { + if (F) { + // 使特定函数的特定分析结果失效 + cachedResults.erase({F, analysisID}); + } else { + // 使所有函数的特定分析结果失效 + std::map, std::unique_ptr> newCachedResults; + for (auto &pair : cachedResults) { + if (pair.first.second != analysisID) { + newCachedResults.insert(std::move(pair)); + } + } + cachedResults = std::move(newCachedResults); + } + } + +private: + std::map, std::unique_ptr> cachedResults; +}; + +// ====================================================================== +// PassManager:遍管理器 +// ====================================================================== +class PassManager { + + Module *pmodule; + AnalysisManager &AM; // 引用 AnalysisManager,用于获取分析结果 + +public: + PassManager() = default; + ~PassManager() = default; + + // 添加遍:现在接受 Pass 的 ID,而不是直接的 unique_ptr + void addPass(void *passID) { + PassRegistry ®istry = PassRegistry::getPassRegistry(); + std::unique_ptr P = registry.createPass(passID); + if (!P) { + // Error: Pass not found or failed to create + return; + } + + passes.push_back(std::move(P)); + } + + // 运行所有注册的遍 + bool run(Module *M) { + bool changed = false; + for (const auto &p : passes) { + bool passChanged = false; // 记录当前遍是否修改了 IR + + // 处理优化遍的分析依赖和失效 + if (p->getPassKind() == Pass::PassKind::Optimization) { + OptimizationPass *optPass = static_cast(p.get()); + std::set analysisDependencies; + std::set analysisInvalidations; + optPass->getAnalysisUsage(analysisDependencies, analysisInvalidations); + + // PassManager 不显式运行分析依赖。 + // 而是优化遍在 runOnFunction 内部通过 AnalysisManager.getAnalysisResult 按需请求。 + } + + if (p->getGranularity() == Pass::Granularity::Module) { + passChanged = p->runOnModule(M, AM); + } else if (p->getGranularity() == Pass::Granularity::Function) { + for (auto &funcPair : M->getFunctions()) { + Function *F = funcPair.second.get(); + passChanged = p->runOnFunction(F, AM) || passChanged; + + if (passChanged && p->getPassKind() == Pass::PassKind::Optimization) { + OptimizationPass *optPass = static_cast(p.get()); + std::set analysisDependencies; + std::set analysisInvalidations; + optPass->getAnalysisUsage(analysisDependencies, analysisInvalidations); + for (void *invalidationID : analysisInvalidations) { + analysisManager.invalidateAnalysis(invalidationID, F); + } + } + } + } else if (p->getGranularity() == Pass::Granularity::BasicBlock) { + for (auto &funcPair : M->getFunctions()) { + Function *F = funcPair.second.get(); + for (auto &bbPtr : funcPair.second->getBasicBlocks()) { + passChanged = p->runOnBasicBlock(bbPtr.get(), AM) || passChanged; + + if (passChanged && p->getPassKind() == Pass::PassKind::Optimization) { + OptimizationPass *optPass = static_cast(p.get()); + std::set analysisDependencies; + std::set analysisInvalidations; + optPass->getAnalysisUsage(analysisDependencies, analysisInvalidations); + for (void *invalidationID : analysisInvalidations) { + analysisManager.invalidateAnalysis(invalidationID, F); + } + } + } + } + } + changed = changed || passChanged; + } + return changed; + } + + AnalysisManager &getAnalysisManager() { return analysisManager; } + +private: + std::vector> passes; + AnalysisManager analysisManager; +}; + +// ====================================================================== +// 辅助宏或函数,用于简化 Pass 的注册 +// ====================================================================== + +// 用于分析遍的注册 +template void registerAnalysisPass() { + PassRegistry::getPassRegistry().registerPass(&AnalysisPassType::ID, + []() { return std::make_unique(); }); +} + +// 用于优化遍的注册 +template void registerOptimizationPass() { + PassRegistry::getPassRegistry().registerPass(&OptimizationPassType::ID, + []() { return std::make_unique(); }); +} + +} // namespace sysy \ No newline at end of file diff --git a/src/include/SysYIROptUtils.h b/src/include/SysYIROptUtils.h index d2d2e55..66929d1 100644 --- a/src/include/SysYIROptUtils.h +++ b/src/include/SysYIROptUtils.h @@ -11,11 +11,14 @@ class SysYIROptUtils{ public: // 删除use关系 + // 根据指令的使用情况删除其所有的use关系 + // 找到指令的所有使用者,并从它们的使用列表中删除该指令 static void usedelete(Instruction *instr) { for (auto &use : instr->getOperands()) { Value* val = use->getValue(); val->removeUse(use); } + instr->getParent()->removeInst(instr); // 从基本块中删除指令 } // 判断是否是全局变量