Compare commits

..

14 Commits

Author SHA1 Message Date
042b1a5d99 [midend-tco]修复命名重复问题 2025-08-19 00:13:32 +08:00
937833117e [midend-tco]添加TCO尾递归优化 2025-08-18 23:46:00 +08:00
ad74e435ba [midend-GSR]修复错误的代数简化 2025-08-18 21:55:57 +08:00
5c34cbc7b8 [midend-GSR]将魔数求解移动到utils的静态方法中。 2025-08-18 20:37:20 +08:00
c9a0c700e1 [midend]增加全局强度削弱优化遍 2025-08-18 11:30:40 +08:00
f317010d76 [midend-Loop-LICM][fix]检查load能否外提时其内存地址在循环中是否会被修改,需要判断函数调用对load内存地址的影响。 2025-08-17 17:42:19 +08:00
8ca64610eb [midend-GVN]重构GVN的值编号系统 2025-08-17 16:33:15 +08:00
969a78a088 [midend-GVN]segmentation fault是GVN引入的已修复,LICM仍然有错误 2025-08-17 14:37:27 +08:00
8763c0a11a [midend-LICM][fix]修改计算循环不变量依赖关系的排序错误,但是引入了很多Segmentation fault。 2025-08-17 01:35:03 +08:00
d83dc7a2e7 [midend-LICM][fix]修复循环不变量的识别逻辑 2025-08-17 01:19:44 +08:00
e32585fd25 [midend-GVN]修复GVN中部分逻辑问题,LICM有bug待修复 2025-08-17 00:14:47 +08:00
c4eb1c3980 [midend-GVN&SideEffect]修复GVN的部分问题和副作用分析的缺陷 2025-08-16 18:52:29 +08:00
d038884ffb [midend-GVN] commit头文件 2025-08-16 15:43:51 +08:00
467f2f6b24 [midend-GVN]初步构建GVN,能够优化部分CSE无法处理的子表达式但是有错误需要debug。 2025-08-16 15:38:41 +08:00
36 changed files with 2696 additions and 2053 deletions

View File

@ -1,272 +0,0 @@
# 编译器核心技术与优化详解
本文档深入剖析 mysysy 编译器的内部实现,重点阐述其在前端、中端和后端所采用的核心编译技术及优化算法,并结合具体实现函数进行说明。
## 1. 编译器整体架构
本编译器采用经典的三段式架构将编译过程清晰地划分为前端、中端和后端三个主要部分。每个部分处理不同的抽象层级并通过定义良好的接口AST, IR进行通信实现了高度的模块化。
```mermaid
graph TD
A[源代码 .sy] --> B{前端 Frontend};
B --> C[抽象语法树 AST];
C --> D{中端 Midend};
D --> E[SSA-based IR];
E -- 优化 --> F[优化后的 IR];
F --> G{后端 Backend};
G --> H[目标机代码 MachineInstr];
H --> I[RISC-V 64 汇编代码 .s];
subgraph 前端
B
end
subgraph 中端
D
end
subgraph 后端
G
end
```
- **前端 (Frontend)**:负责词法、语法、语义分析,将 SysY 源代码解析为抽象语法树 (AST)。
- **中端 (Midend)**:基于 AST 生成与具体机器无关的中间表示 (IR),并在此基础上进行深入的分析和优化。
- **后端 (Backend)**:将优化后的 IR 翻译成目标平台RISC-V 64的汇编代码。
---
## 2. 前端技术 (Frontend)
前端的核心任务是进行语法和语义的分析与验证,其工作流程如下:
```mermaid
graph TD
subgraph "前端处理流程"
Source["源文件 (.sy)"] --> Lexer["词法分析器 (SysYLexer)"];
Lexer --> TokenStream["Token 流"];
TokenStream --> Parser["语法分析器 (SysYParser)"];
Parser --> ParseTree["解析树"];
ParseTree --> Visitor["AST构建 (SysYVisitor)"];
Visitor --> AST[抽象语法树];
end
```
- **词法与语法分析**:
- **技术**: 采用 **ANTLR (ANother Tool for Language Recognition)** 框架。通过在 `frontend/SysY.g4` 文件中定义的上下文无关文法ANTLR 能够自动生成高效的 LL(*) 词法分析器 (`SysYLexer.cpp`) 和语法分析器 (`SysYParser.cpp`)。
- **实现**: 词法分析器将字符流转换为记号 (Token) 流,语法分析器则根据文法规则将记号流组织成一棵解析树 (Parse Tree)。这棵树精确地反映了源代码的语法结构。
- **AST 构建**:
- **技术**: 应用 **访问者 (Visitor) 设计模式** 遍历 ANTLR 生成的解析树。该模式将数据结构解析树与作用于其上的操作AST构建逻辑解耦。
- **实现**: `frontend/SysYVisitor.cpp` 中定义了具体的遍历逻辑。在遍历过程中,会构建一个比解析树更抽象、更面向编译需求的**抽象语法树 (Abstract Syntax Tree, AST)**。AST 忽略了纯粹的语法细节(如括号、分号),只保留了核心的语义结构,是前端传递给中端的接口。
---
## 3. 中端技术与优化 (Midend)
中端是编译器的核心,所有与目标机器无关的分析和优化都在此阶段完成。
### 3.1. 中间表示 (IR) 及设计要点
- **技术**: 设计了一种三地址码Three-Address Code风格的中间表示其形式和设计哲学深受 **LLVM IR** 的启发。IR 的核心特征是采用了**静态单赋值 (Static Single Assignment, SSA)** 形式。
- **实现**: `midend/IR.cpp` 定义了 IR 的核心数据结构,如 `Instruction`, `BasicBlock`, `Function``Module``midend/SysYIRGenerator.cpp` 负责将前端的 AST 转换为这种 IR。在 SSA 形式下,每个变量只被赋值一次,使得变量的定义-使用关系Def-Use Chain变得异常清晰极大地简化了后续的优化算法。通过继承并重写 SysYBaseVisitor 类,遍历 AST 节点生成自定义 IR并在 IR 生成阶段实现了简单的常量传播和公共子表达式消除CSE
- **设计要点**
- **`alloca` 指令集中管理**
所有 `alloca` 指令统一放置在入口基本块,并与实际计算指令分离。这有助于后续指令调度器专注于优化计算密集型指令的执行顺序,避免内存分配指令的干扰。
- **消除 `fallthrough` 现象**
通过确保所有基本块均以终结指令结尾,消除基本块间的 `fallthrough`简化了控制流图CFG的构建和分析。这一做法提升了编译器整体质量使中端各类 Pass 的编写和维护更加规范和高效。
### 3.2. 核心优化详解
编译器的分析和优化被组织成一系列独立的“遍”Pass。每个 Pass 都是一个独立的算法模块,对 IR 进行特定的分析或变换。这种设计具有高度的模块化和可扩展性。
#### 3.2.1. SSA 构建与解构
- **Mem2Reg (`Mem2Reg.cpp`)**:
- **目标**: 将对栈内存 (`alloca`) 的 `load`/`store` 操作,提升为对虚拟寄存器的直接操作,并构建 SSA 形式。
- **技术**: 该过程是实现 SSA 的关键。它依赖于**支配树 (Dominator Tree)** 分析,通过寻找变量定义块的**支配边界 (Dominance Frontier)** 来确定在何处插入 **Φ (Phi) 函数**
- **实现**: `Mem2RegContext::run` 驱动此过程。首先调用 `isPromotableAlloca` 识别所有仅被 `load`/`store` 使用的标量 `alloca`。然后,`insertPhis` 根据支配边界信息在必要的控制流汇合点插入 `phi` 指令。最后,`renameVariables` 递归地遍历支配树,用一个模拟的值栈来将 `load` 替换为栈顶的 SSA 值,将 `store` 视为对栈的一次 `push` 操作从而完成重命名。值得一提的是由于我们在IR生成阶段就将所有alloca指令统一放置在入口块极大地简化了Mem2Reg遍的实现和支配树分析的计算。
- **Reg2Mem (`Reg2Mem.cpp`)**:
- **目标**: 执行 `Mem2Reg` 的逆操作,将程序从 SSA 形式转换回基于内存的表示。这通常是为不支持 SSA 的后端做准备的**SSA解构 (SSA Destruction)** 步骤。
- **技术**: 为每个 SSA 值(指令结果、函数参数)在函数入口创建一个 `alloca` 栈槽。然后,在每个 SSA 值的定义点之后插入一个 `store` 将其存入对应的栈槽;在每个使用点之前插入一个 `load` 从栈槽中取出值。
- **实现**: `Reg2MemContext::run` 驱动此过程。`allocateMemoryForSSAValues` 为所有需要转换的 SSA 值创建 `alloca` 指令。`rewritePhis` 特殊处理 `phi` 指令,在每个前驱块的末尾插入 `store``insertLoadsAndStores` 则处理所有非 `phi` 指令的定义和使用,插入相应的 `store``load`。虽然
#### 3.2.2. 常量与死代码优化
- **SCCP (`SCCP.cpp`)**:
- **目标**: 稀疏条件常量传播。在编译期计算常量表达式,并利用分支条件为常数的信息来消除死代码,比简单的常量传播更强大。
- **技术**: 这是一种基于数据流分析的格理论Lattice Theory的优化。它为每个变量维护一个值状态可能为 `Top` (未定义), `Constant` (某个常量值), 或 `Bottom` (非常量)。同时,它跟踪基本块的可达性,如果一个分支的条件被推断为常量,则其不可达的后继分支在分析中会被直接忽略。
- **实现**: `SCCPContext::run` 驱动整个分析过程。它维护一个指令工作列表和一个边工作列表。`ProcessInstruction``ProcessEdge` 函数交替执行,不断地从 IR 中传播常量和可达性信息,直到达到不动点为止。最后,`PropagateConstants``SimplifyControlFlow` 将推断出的常量替换到代码中,并移除死块。
- **DCE (`DCE.cpp`)**:
- **目标**: 简单死代码消除。移除那些计算结果对程序输出没有贡献的指令。
- **技术**: 采用**标记-清除 (Mark and Sweep)** 算法。从具有副作用的指令(如 `store`, `call`, `return`)开始,反向追溯其操作数,标记所有相关的指令为“活跃”。
- **实现**: `DCEContext::run` 实现了此算法。第一次遍历时,通过 `isAlive` 函数识别出具有副作用的“根”指令,然后调用 `addAlive` 递归地将所有依赖的指令加入 `alive_insts` 集合。第二次遍历时,所有未被标记为活跃的指令都将被删除。
- **未来规划**: 后续开发更多分析遍会为DCE收集更多的IR信息能够迭代出更健壮的DEC遍。
#### 3.2.3. 控制流图 (CFG) 优化
- **实现**: `SysYIRCFGOpt.cpp` 中定义了一系列用于清理和简化控制流图的 Pass。
- **`SysYDelInstAfterBrPass`**: 删除分支指令后的死代码。
- **`SysYDelNoPreBLockPass`**: 通过从入口块开始的图遍历BFS识别并删除所有不可达的基本块。
- **`SysYDelEmptyBlockPass`**: 识别并删除仅包含一条无条件跳转指令的空块,将其前驱直接重定向到其后继。
- **`SysYBlockMergePass`**: 如果一个块 A 只有一个后继 B且 B 只有一个前驱 A则将 A 和 B 合并为一个块。
- **`SysYCondBr2BrPass`**: 如果一个条件分支的条件是常量,则将其转换为一个无条件分支。
- **`SysYAddReturnPass`**: 确保所有没有终结指令的函数出口路径都有一个 `return` 指令,以保证 CFG 的完整性。
#### 3.2.4. 其他优化
- **LargeArrayToGlobal (`LargeArrayToGlobal.cpp`)**:
- **目标**: 防止因大型局部数组导致的栈溢出,并可能改善数据局部性。
- **技术**: 遍历函数中的 `alloca` 指令,如果通过 `calculateTypeSize` 计算出其分配的内存大小超过一个阈值(如 1024 字节),则将其转换为一个全局变量。
- **实现**: `convertAllocaToGlobal` 函数负责创建一个新的 `GlobalValue`,并调用 `replaceAllUsesWith` 将原 `alloca` 的所有使用者重定向到新的全局变量,最后删除原 `alloca` 指令。
#### 3.3. 核心分析遍
为了为优化遍收集信息,最大程度发掘程序优化潜力,我们目前设计并实现了以下关键的分析遍:
- **支配树分析 (Dominator Tree Analysis)**:
- **技术**: 通过计算每个基本块的支配节点,构建出一棵支配树结构。我们在计算支配节点时采用了**逆后序遍历RPO, Reverse Post Order**以保证数据流分析的收敛速度和正确性。在计算直接支配者Idom, Immediate Dominator采用了经典的**Lengauer-TarjanLT算法**,该算法以高效的并查集和路径压缩技术著称,能够在线性时间内准确计算出每个基本块的直接支配者关系。
- **实现**: `Dom.cpp` 实现了支配树分析。该分析为每个基本块分配其直接支配者,并递归构建整棵支配树。支配树是许多高级优化(尤其是 SSA 形式下的优化的基础。例如Mem2Reg 需要依赖支配树来正确插入 Phi 指令,并在变量重命名阶段高效遍历控制流图。此外,循环相关优化(如循环不变量外提)也依赖于支配树信息来识别循环头和循环体的关系。
- **活跃性分析 (Liveness Analysis)**:
- **技术**: 活跃性分析用于确定在程序的某一特定点上,哪些变量的值在未来会被用到。我们采用**经典的不动点迭代算法**,在数据流分析框架下,逆序遍历基本块,迭代计算每个基本块的 `live-in``live-out` 集合,直到收敛为止。这种方法简单且易于实现,能够满足大多数编译优化的需求。
- **未来规划**: 若后续对分析效率有更高要求,可考虑引入如**工作列表算法**或者**转化为基于SSA的图可达性分析**等更高效的算法,以进一步提升大型函数或复杂控制流下的分析性能。
- **实现**: `Liveness.cpp` 提供了活跃性分析。该分析采用经典的数据流分析框架,迭代计算每个基本块的 `live-in``live-out` 集合。活跃性信息是死代码消除DCE、寄存器分配等优化的必要前置步骤。通过准确的活跃性分析可以识别出无用的变量和指令从而为后续优化遍提供坚实的数据基础。
### 3.4. 未来的规划
基于现有的成果,我们规划将中端能力进一步扩展,近期我们重点将放在循环相关的分析和函数内联的实现,以期大幅提升最终程序的性能。
- **循环优化**:
我们正在开发一个健壮的分析遍来准确识别程序中的循环结构,并通过对已识别的循环进行规范化的转换遍,为后续的向量化、并行化工作做铺垫。并通过循环不变量提升、循环归纳变量分析与强度削减等优化提升循环相关代码的执行效率。
- **函数内联**:
函数内联能够将简单函数可能需要收集更多信息内联到call指令相应位置减少栈空间相关变动并且为其他遍发掘优化空间。
- **`LLVM IR`格式化**:
我们将为所有的IR设计并实现通用的打印器方法使得IR能够显式化为可编译运行的LLVM IR通过编排脚本和调用llvm相关工具链我们能够绕过后端编译运行中间代码为验证中端正确性提供系统化的方法同时减轻后端开发bug溯源的压力。
---
## 4. 后端技术与优化 (Backend)
后端负责将经过优化的、与机器无关的 IR 转换为针对 RISC-V 64 位架构的汇编代码。
### 4.1. 栈帧布局 (Stack Frame Layout)
在函数调用发生时,后端需要在栈上创建一个**栈帧 (Stack Frame)** 来存储局部变量、传递参数和保存寄存器。本编译器采用的栈帧布局遵循 RISC-V 调用约定,结构如下:
```
高地址 +-----------------------------+
| ... |
| 函数参数 (8+) | <-- 调用者传入的、放不进寄存器的参数
+-----------------------------+
| 返回地址 (ra) | <-- sp 在函数入口指向的位置
+-----------------------------+
| 旧的帧指针 (s0/fp) |
+-----------------------------+ <-- s0/fp 在函数序言后指向的位置
| 被调用者保存的寄存器 |
| (Callee-Saved Regs) |
+-----------------------------+
| 局部变量 (Alloca) |
+-----------------------------+
| 寄存器溢出区域 |
| (Spill Slots) |
+-----------------------------+
| 为调用其他函数预留的 |
| 出参空间 (Out-Args) |
低地址 +-----------------------------+ <-- sp 在函数序言后指向的位置
```
- **实现**: `PrologueEpilogueInsertion.h``EliminateFrameIndices.h` 中的 Pass 负责生成函数序言prologue和尾声epilogue代码来构建和销毁上述栈帧。`EliminateFrameIndices` 会将所有对抽象栈槽(如局部变量、溢出槽)的访问,替换为对帧指针 `s0` 或栈指针 `sp` 的、带有具体偏移量的访问。
### 4.2. 指令选择 (Instruction Selection)
- **目标**: 将抽象的 IR 指令高效地翻译成具体的目标机指令序列。
- **技术**: 采用 **基于 DAG (Directed Acyclic Graph) 的模式匹配** 算法。
- **实现**: `RISCv64ISel.cpp` 中的 `RISCv64ISel::select()` 驱动此过程。`selectBasicBlock()` 为每个基本块调用 `build_dag()` 来构建一个操作的 DAG然后通过 `select_recursive()` 对 DAG 进行自底向上的遍历和匹配。在 `selectNode()` 函数中,通过一个大的 `switch` 语句,为不同类型的 DAG 节点(如 `BINARY`, `LOAD`, `STORE`)匹配最优的指令序列。例如,一个 IR 的加法指令,如果其中一个操作数是小常数,会被直接匹配为一条 `ADDIW` 指令,而不是 `LI``ADDW` 两条指令。
### 4.3. 寄存器分配 (Register Allocation)
- **目标**: 将无限的虚拟寄存器映射到有限的物理寄存器上,并优雅地处理寄存器不足(溢出)的情况。
- **技术**: 实现了经典的**基于图着色 (Graph Coloring) 的全局寄存器分配算法**,这是一种强大但复杂的全局优化方法。
- **实现**: `RISCv64RegAlloc.cpp` 中的 `RISCv64RegAlloc::run()` 是主入口。它在一个循环中执行分配,直到没有寄存器需要溢出为止。其内部流程极其精密,如下图所示:
```mermaid
graph TD
subgraph "寄存器分配主循环 (RISCv64RegAlloc::run)"
direction LR
Start((Start)) --> Liveness[1. 活跃性分析 LivenessAnalysis]
Liveness --> Build[2. 构建冲突图 Build]
Build --> Worklist[3. 创建工作表 MakeWorklist]
Worklist --> Loop{Main Loop}
Loop -- simplifyWorklist 非空 --> Simplify[4a. 简化 Simplify]
Simplify --> Loop
Loop -- worklistMoves 非空 --> Coalesce[4b. 合并 Coalesce]
Coalesce --> Loop
Loop -- freezeWorklist 非空 --> Freeze[4c. 冻结 Freeze]
Freeze --> Loop
Loop -- spillWorklist 非空 --> Spill[4d. 选择溢出 SelectSpill]
Spill --> Loop
Loop -- 所有工作表为空 --> Assign[5. 分配颜色 AssignColors]
Assign --> CheckSpill{有溢出?}
CheckSpill -- Yes --> Rewrite[6. 重写代码 RewriteProgram]
Rewrite --> Liveness
CheckSpill -- No --> Finish((Finish))
end
```
1. **`analyzeLiveness()`**: 对机器指令进行数据流分析,计算出每个虚拟寄存器的活跃范围。
2. **`build()`**: 根据活跃性信息构建**冲突图 (Interference Graph)**。如果两个虚拟寄存器同时活跃,则它们冲突,在图中连接一条边。
3. **`makeWorklist()`**: 将图节点(虚拟寄存器)根据其度数放入不同的工作列表,为着色做准备。
4. **核心着色阶段 (The Loop)**:
- **`simplify()`**: 贪心地移除图中度数小于物理寄存器数量的节点,并将其压入栈中。这些节点保证可以被成功着色。
- **`coalesce()`**: 尝试将传送指令 (`MV`) 的源和目标节点合并,以消除这条指令。合并的条件基于 **Briggs****George** 启发式,以避免使图变得不可着色。
- **`freeze()`**: 当一个与传送指令相关的节点无法合并也无法简化时,放弃对该传送指令的合并希望,将其“冻结”为一个普通节点。
- **`selectSpill()`**: 当所有节点都无法进行上述操作时(即图中只剩下高度数的节点),必须选择一个节点进行**溢出 (Spill)**,即决定将其存放在内存中。
5. **`assignColors()`**: 在所有节点都被处理后,从栈中依次弹出节点,并根据其已着色邻居的颜色,为它选择一个可用的物理寄存器。
6. **`rewriteProgram()`**: 如果 `assignColors()` 阶段发现有节点被标记为溢出,此函数会被调用。它会修改机器指令,为溢出的虚拟寄存器插入从内存加载(`lw`/`ld`)和存入内存(`sw`/`sd`)的代码。然后,整个分配过程从步骤 1 重新开始。
### 4.4. 后端特定优化
在寄存器分配前后后端还会进行一系列针对目标机RISC-V特性的优化。
#### 4.4.1. 指令调度 (Instruction Scheduling)
- **寄存器分配前调度 (`PreRA_Scheduler.cpp`)**:
- **目标**: 在寄存器分配前,通过重排指令来提升性能。主要目标是**隐藏加载延迟 (Load Latency)**,即尽早发出 `load` 指令,使其结果能在需要时及时准备好,避免流水线停顿。同时,由于此时使用的是无限的虚拟寄存器,调度器有较大的自由度,但也可能因为过度重排而延长虚拟寄存器的生命周期,从而增加寄存器压力。
- **实现**: `scheduleBlock()` 函数会识别出基本块内的调度边界(如 `call` 或终结指令),然后在每个独立的区域内调用 `scheduleRegion()`。当前的实现是一种简化的列表调度,它会优先尝试将加载指令 (`LW`, `LD` 等) 在不违反数据依赖的前提下,尽可能地向前移动。
- **寄存器分配后调度 (`PostRA_Scheduler.cpp`)**:
- **目标**: 在寄存器分配完成之后,对指令序列进行最后一轮微调。此阶段调度的主要目标与分配前不同,它旨在解决由寄存器分配过程本身引入的性能问题,例如:
- **缓解溢出代价**: 将因溢出Spill而产生的 `load` 指令(从栈加载)尽可能地提前,远离其使用点;将 `store` 指令(存入栈)尽可能地推后,远离其定义点。
- **消除伪依赖**: 寄存器分配器可能会为两个原本不相关的虚拟寄存器分配同一个物理寄存器从而引入了虚假的写后读WAR或写后写WAW依赖。Post-RA 调度可以尝试解开这些伪依赖,为指令重排提供更多自由度。
- **实现**: `scheduleBlock()` 函数实现了此调度器。它采用了一种非常保守的**局部交换 (Local Swapping)** 策略。它迭代地检查相邻的两条指令,在 `canSwapInstructions()` 函数确认交换不会违反任何数据依赖RAW, WAR, WAW或内存依赖后才执行交换。这种方法虽然不如全局列表调度强大但在严格的 Post-RA 约束下是一种安全有效的优化手段。
#### 4.4.2. 强度削减 (Strength Reduction)
- **除法强度削减 (`DivStrengthReduction.cpp`)**:
- **目标**: 将机器指令中昂贵的 `DIV``DIVW` 指令(当除数为编译期常量时)替换为一系列更快、计算成本更低的指令组合。
- **技术**: 基于数论中的**乘法逆元 (Multiplicative Inverse)** 思想。对于一个整数除法 `x / d`,可以找到一个“魔数” `m` 和一个移位数 `s`,使得该除法可以被近似替换为 `(x * m) >> s`。这个过程需要处理复杂的符号、取整和溢出问题。
- **实现**: `runOnMachineFunction()` 实现了此优化。它会遍历机器指令,寻找以常量为除数的 `DIV`/`DIVW` 指令。`computeMagic()` 函数负责计算出对应的魔数和移位数。然后,根据除数是 2 的幂、1、-1 还是其他普通数字,生成不同的指令序列,包括 `MULH` (取高位乘积), `SRAI` (算术右移), `ADD`, `SUB` 等,来精确地模拟定点数除法的效果。
#### 4.4.3. 窥孔优化 (Peephole Optimization)
- **目标**: 在生成最终汇编代码之前,对相邻的机器指令序列进行局部优化,以消除冗余操作和利用目标机特性。
- **技术**: 窥孔优化是一种简单而高效的局部优化技术。它通过一个固定大小的“窥孔”(通常是 2-3 条指令)来扫描指令序列,寻找可以被更优指令序列替换的模式。
- **实现**: `PeepholeOptimizer::runOnMachineFunction()` 实现了此 Pass。它包含了一系列模式匹配和替换规则主要包括
- **冗余移动消除**: `mv x, y` 后跟着一条使用 `x` 的指令 `op z, x, ...`,如果 `x` 之后不再活跃,则将 `op` 的操作数直接替换为 `y`,并移除 `mv` 指令。
- **冗余加载消除**: `sw r1, mem; lw r2, mem` -> `sw r1, mem; mv r2, r1`。如果 `r1``r2` 是同一个寄存器,则直接移除 `lw`
- **地址计算优化**: `addi t1, base, imm1; lw t2, imm2(t1)` -> `lw t2, (imm1+imm2)(base)`。将两条指令合并为一条,减少了指令数量和中间寄存器的使用。
- **指令合并**: `addi t1, t0, imm1; addi t2, t1, imm2` -> `addi t2, t0, (imm1+imm2)`。合并连续的立即数加法。
### 4.5. 局限性与未来工作
根据项目中的 `TODO` 列表和源代码分析,当前实现存在一些可改进之处:
- **寄存器分配**:
- **`CALL` 指令处理**: 当前对 `CALL` 指令的 `use`/`def` 分析不完整,没有将所有调用者保存的寄存器标记为 `def`,这可能导致跨函数调用的值被错误破坏。
- **溢出处理**: 当前所有溢出的虚拟寄存器都被简单地映射到同一个物理寄存器 `t6` 上,这会引入大量不必要的 `load`/`store`,并可能导致 `t6` 成为性能瓶颈。
- **IR 设计**:
- 随着 SSA 的引入IR 中某些冗余信息(如基本块的 `args` 参数)可以被移除,以简化设计。
- **优化**:
- 当前的优化主要集中在标量上。可以引入更多面向循环的优化(如循环不变代码外提 LICM、归纳变量分析 IndVar和过程间优化来进一步提升性能。

View File

@ -5,8 +5,6 @@ add_library(riscv64_backend_lib STATIC
RISCv64ISel.cpp
RISCv64LLIR.cpp
RISCv64RegAlloc.cpp
RISCv64LinearScan.cpp
RISCv64BasicBlockAlloc.cpp
Handler/CalleeSavedHandler.cpp
Handler/LegalizeImmediates.cpp
Handler/PrologueEpilogueInsertion.cpp

View File

@ -4,7 +4,6 @@
namespace sysy {
char PeepholeOptimizer::ID = 0;
bool PeepholeOptimizer::fusedMulAddEnabled = true; // 默认启用浮点乘加融合优化
bool PeepholeOptimizer::runOnFunction(Function *F, AnalysisManager& AM) {
// This pass works on MachineFunction level, not IR level
@ -635,96 +634,19 @@ void PeepholeOptimizer::runOnMachineFunction(MachineFunction *mfunc) {
}
}
}
// 8. 浮点乘加融合优化
// 8.1 fmul.s t1, t2, t3; fadd.s t4, t1, t5 -> fmadd.s t4, t2, t3, t5
else if (isFusedMulAddEnabled() &&
mi1->getOpcode() == RVOpcodes::FMUL_S &&
mi2->getOpcode() == RVOpcodes::FADD_S) {
if (mi1->getOperands().size() == 3 && mi2->getOperands().size() == 3) {
auto *fmul_dst = static_cast<RegOperand *>(mi1->getOperands()[0].get());
auto *fmul_src1 = static_cast<RegOperand *>(mi1->getOperands()[1].get());
auto *fmul_src2 = static_cast<RegOperand *>(mi1->getOperands()[2].get());
auto *fadd_dst = static_cast<RegOperand *>(mi2->getOperands()[0].get());
auto *fadd_src1 = static_cast<RegOperand *>(mi2->getOperands()[1].get());
auto *fadd_src2 = static_cast<RegOperand *>(mi2->getOperands()[2].get());
// 检查fmul的目标是否是fadd的第一个源操作数
if (areRegsEqual(fmul_dst, fadd_src1)) {
// 检查中间寄存器是否在后续还会被使用
bool canOptimize = true;
for (size_t j = i + 2; j < instrs.size(); ++j) {
auto *later_instr = instrs[j].get();
// 如果中间寄存器被重新定义,则可以优化
if (isRegRedefinedAt(later_instr, fmul_dst, areRegsEqual)) {
break;
}
// 如果中间寄存器被使用,则不能优化
if (isRegUsedLater(instrs, fmul_dst, j)) {
canOptimize = false;
break;
}
}
if (canOptimize) {
// 创建新的FMADD_S指令: fmadd.s t4, t2, t3, t5
auto newInstr = std::make_unique<MachineInstr>(RVOpcodes::FMADD_S);
newInstr->addOperand(std::make_unique<RegOperand>(*fadd_dst));
newInstr->addOperand(std::make_unique<RegOperand>(*fmul_src1));
newInstr->addOperand(std::make_unique<RegOperand>(*fmul_src2));
newInstr->addOperand(std::make_unique<RegOperand>(*fadd_src2));
instrs[i + 1] = std::move(newInstr);
instrs.erase(instrs.begin() + i);
changed = true;
}
}
}
}
// 8.2 fmul.s t1, t2, t3; fadd.s t4, t5, t1 -> fmadd.s t4, t2, t3, t5
else if (isFusedMulAddEnabled() &&
mi1->getOpcode() == RVOpcodes::FMUL_S &&
mi2->getOpcode() == RVOpcodes::FADD_S) {
if (mi1->getOperands().size() == 3 && mi2->getOperands().size() == 3) {
auto *fmul_dst = static_cast<RegOperand *>(mi1->getOperands()[0].get());
auto *fmul_src1 = static_cast<RegOperand *>(mi1->getOperands()[1].get());
auto *fmul_src2 = static_cast<RegOperand *>(mi1->getOperands()[2].get());
auto *fadd_dst = static_cast<RegOperand *>(mi2->getOperands()[0].get());
auto *fadd_src1 = static_cast<RegOperand *>(mi2->getOperands()[1].get());
auto *fadd_src2 = static_cast<RegOperand *>(mi2->getOperands()[2].get());
// 检查fmul的目标是否是fadd的第二个源操作数
if (areRegsEqual(fmul_dst, fadd_src2)) {
// 检查中间寄存器是否在后续还会被使用
bool canOptimize = true;
for (size_t j = i + 2; j < instrs.size(); ++j) {
auto *later_instr = instrs[j].get();
// 如果中间寄存器被重新定义,则可以优化
if (isRegRedefinedAt(later_instr, fmul_dst, areRegsEqual)) {
break;
}
// 如果中间寄存器被使用,则不能优化
if (isRegUsedLater(instrs, fmul_dst, j)) {
canOptimize = false;
break;
}
}
if (canOptimize) {
// 创建新的FMADD_S指令: fmadd.s t4, t2, t3, t5
auto newInstr = std::make_unique<MachineInstr>(RVOpcodes::FMADD_S);
newInstr->addOperand(std::make_unique<RegOperand>(*fadd_dst));
newInstr->addOperand(std::make_unique<RegOperand>(*fmul_src1));
newInstr->addOperand(std::make_unique<RegOperand>(*fmul_src2));
newInstr->addOperand(std::make_unique<RegOperand>(*fadd_src1));
instrs[i + 1] = std::move(newInstr);
instrs.erase(instrs.begin() + i);
changed = true;
}
// 8. 消除无用移动指令: mv a, a -> (删除)
else if (mi1->getOpcode() == RVOpcodes::MV &&
mi1->getOperands().size() == 2) {
if (mi1->getOperands()[0]->getKind() == MachineOperand::KIND_REG &&
mi1->getOperands()[1]->getKind() == MachineOperand::KIND_REG) {
auto *dst = static_cast<RegOperand *>(mi1->getOperands()[0].get());
auto *src = static_cast<RegOperand *>(mi1->getOperands()[1].get());
// 检查源和目标寄存器是否相同
if (areRegsEqual(dst, src)) {
// 删除这条无用指令
instrs.erase(instrs.begin() + i);
changed = true;
}
}
}

View File

@ -5,6 +5,23 @@
#include <iostream>
namespace sysy {
// 检查是否为内存加载/存储指令,以处理特殊的打印格式
bool isMemoryOp(RVOpcodes opcode) {
switch (opcode) {
// --- 整数加载/存储 (原有逻辑) ---
case RVOpcodes::LB: case RVOpcodes::LH: case RVOpcodes::LW: case RVOpcodes::LD:
case RVOpcodes::LBU: case RVOpcodes::LHU: case RVOpcodes::LWU:
case RVOpcodes::SB: case RVOpcodes::SH: case RVOpcodes::SW: case RVOpcodes::SD:
case RVOpcodes::FLW:
case RVOpcodes::FSW:
// 如果未来支持双精度也在这里添加FLD/FSD
// case RVOpcodes::FLD:
// case RVOpcodes::FSD:
return true;
default:
return false;
}
}
RISCv64AsmPrinter::RISCv64AsmPrinter(MachineFunction* mfunc) : MFunc(mfunc) {}
@ -65,7 +82,7 @@ void RISCv64AsmPrinter::printInstruction(MachineInstr* instr, bool debug) {
case RVOpcodes::SB: *OS << "sb "; break; case RVOpcodes::LD: *OS << "ld "; break;
case RVOpcodes::SD: *OS << "sd "; break; case RVOpcodes::FLW: *OS << "flw "; break;
case RVOpcodes::FSW: *OS << "fsw "; break; case RVOpcodes::FLD: *OS << "fld "; break;
case RVOpcodes::FSD: *OS << "fsd "; break;
case RVOpcodes::FSD: *OS << "fsd "; break;
case RVOpcodes::J: *OS << "j "; break; case RVOpcodes::JAL: *OS << "jal "; break;
case RVOpcodes::JALR: *OS << "jalr "; break; case RVOpcodes::RET: *OS << "ret"; break;
case RVOpcodes::BEQ: *OS << "beq "; break; case RVOpcodes::BNE: *OS << "bne "; break;
@ -79,18 +96,15 @@ void RISCv64AsmPrinter::printInstruction(MachineInstr* instr, bool debug) {
case RVOpcodes::FSUB_S: *OS << "fsub.s "; break;
case RVOpcodes::FMUL_S: *OS << "fmul.s "; break;
case RVOpcodes::FDIV_S: *OS << "fdiv.s "; break;
case RVOpcodes::FMADD_S: *OS << "fmadd.s "; break;
case RVOpcodes::FNEG_S: *OS << "fneg.s "; break;
case RVOpcodes::FEQ_S: *OS << "feq.s "; break;
case RVOpcodes::FLT_S: *OS << "flt.s "; break;
case RVOpcodes::FLE_S: *OS << "fle.s "; break;
case RVOpcodes::FCVT_S_W: *OS << "fcvt.s.w "; break;
case RVOpcodes::FCVT_W_S: *OS << "fcvt.w.s "; break;
case RVOpcodes::FCVT_W_S_RTZ: *OS << "fcvt.w.s "; break;
case RVOpcodes::FMV_S: *OS << "fmv.s "; break;
case RVOpcodes::FMV_W_X: *OS << "fmv.w.x "; break;
case RVOpcodes::FMV_X_W: *OS << "fmv.x.w "; break;
case RVOpcodes::FSRMI: *OS << "fsrmi "; break;
case RVOpcodes::CALL: { // 为CALL指令添加特殊处理逻辑
*OS << "call ";
// 遍历所有操作数,只寻找并打印函数名标签

View File

@ -1,14 +1,10 @@
#include "RISCv64Backend.h"
#include "RISCv64ISel.h"
#include "RISCv64RegAlloc.h"
#include "RISCv64LinearScan.h"
#include "RISCv64BasicBlockAlloc.h"
#include "RISCv64AsmPrinter.h"
#include "RISCv64Passes.h"
#include <sstream>
#include <future>
#include <chrono>
#include <iostream>
namespace sysy {
// 顶层入口
@ -200,13 +196,19 @@ std::string RISCv64CodeGen::function_gen(Function* func) {
// === 完整的后端处理流水线 ===
// 阶段 1: 指令选择 (sysy::IR -> LLIR with virtual registers)
DEBUG = 0;
DEEPDEBUG = 0;
RISCv64ISel isel;
std::unique_ptr<MachineFunction> mfunc = isel.runOnFunction(func);
// 第一次调试打印输出
std::stringstream ss_after_isel;
RISCv64AsmPrinter printer_isel(mfunc.get());
printer_isel.run(ss_after_isel, true);
if (DEBUG) {
std::cout << ss_after_isel.str();
}
if (DEBUG) {
std::cerr << "====== Intermediate Representation after Instruction Selection ======\n"
<< ss_after_isel.str();
@ -226,78 +228,17 @@ std::string RISCv64CodeGen::function_gen(Function* func) {
<< ss_after_eli.str();
}
// // 阶段 2: 除法强度削弱优化 (Division Strength Reduction)
// DivStrengthReduction div_strength_reduction;
// div_strength_reduction.runOnMachineFunction(mfunc.get());
// 阶段 2: 除法强度削弱优化 (Division Strength Reduction)
DivStrengthReduction div_strength_reduction;
div_strength_reduction.runOnMachineFunction(mfunc.get());
// // 阶段 2.1: 指令调度 (Instruction Scheduling)
// PreRA_Scheduler scheduler;
// scheduler.runOnMachineFunction(mfunc.get());
// 阶段 2.1: 指令调度 (Instruction Scheduling)
PreRA_Scheduler scheduler;
scheduler.runOnMachineFunction(mfunc.get());
// 阶段 3: 物理寄存器分配 (Register Allocation)
// 首先尝试图着色分配器
if (DEBUG) std::cerr << "Attempting Register Allocation with Graph Coloring...\n";
if (!gc_failed) {
RISCv64RegAlloc gc_alloc(mfunc.get());
bool success_gc = gc_alloc.run();
if (!success_gc) {
gc_failed = 1; // 后续不再尝试图着色分配器
std::cerr << "Warning: Graph coloring register allocation failed function '"
<< func->getName()
<< "'. Switching to Linear Scan allocator."
<< std::endl;
RISCv64ISel isel_gc_fallback;
mfunc = isel_gc_fallback.runOnFunction(func);
EliminateFrameIndicesPass efi_pass_gc_fallback;
efi_pass_gc_fallback.runOnMachineFunction(mfunc.get());
RISCv64LinearScan ls_alloc(mfunc.get());
bool success = ls_alloc.run();
if (!success) {
// 如果线性扫描最终失败,则调用基本块分配器作为终极后备
std::cerr << "Info: Linear Scan failed. Switching to Basic Block Allocator as final fallback.\n";
// 注意我们需要在一个“干净”的MachineFunction上运行。
// 最安全的方式是重新运行指令选择。
RISCv64ISel isel_fallback;
mfunc = isel_fallback.runOnFunction(func);
EliminateFrameIndicesPass efi_pass_fallback;
efi_pass_fallback.runOnMachineFunction(mfunc.get());
if (DEBUG) {
std::cerr << "====== stack info after reg alloc ======\n";
}
RISCv64BasicBlockAlloc bb_alloc(mfunc.get());
bb_alloc.run();
}
} else {
// 图着色成功完成
if (DEBUG) std::cerr << "Graph Coloring allocation completed successfully.\n";
}
} else {
std::cerr << "Info: Graph Coloring allocation failed in last function. Switching to Linear Scan allocator...\n";
RISCv64LinearScan ls_alloc(mfunc.get());
bool success = ls_alloc.run();
if (!success) {
// 如果线性扫描最终失败,则调用基本块分配器作为终极后备
std::cerr << "Info: Linear Scan failed. Switching to Basic Block Allocator as final fallback.\n";
// 注意我们需要在一个“干净”的MachineFunction上运行。
// 最安全的方式是重新运行指令选择。
RISCv64ISel isel_fallback;
mfunc = isel_fallback.runOnFunction(func);
EliminateFrameIndicesPass efi_pass_fallback;
efi_pass_fallback.runOnMachineFunction(mfunc.get());
if (DEBUG) {
std::cerr << "====== stack info after reg alloc ======\n";
}
RISCv64BasicBlockAlloc bb_alloc(mfunc.get());
bb_alloc.run();
}
}
RISCv64RegAlloc reg_alloc(mfunc.get());
reg_alloc.run();
if (DEBUG) {
std::cerr << "====== stack info after reg alloc ======\n";
@ -335,6 +276,7 @@ std::string RISCv64CodeGen::function_gen(Function* func) {
printer.run(ss);
return ss.str();
}
} // namespace sysy

View File

@ -1,267 +0,0 @@
#include "RISCv64BasicBlockAlloc.h"
#include "RISCv64Info.h"
#include "RISCv64AsmPrinter.h"
#include <iostream>
#include <algorithm>
// 外部调试级别控制变量
extern int DEBUG;
extern int DEEPDEBUG;
namespace sysy {
// 将 getInstrUseDef 的定义移到这里,因为它是一个全局的辅助函数
void getInstrUseDef(const MachineInstr* instr, std::set<unsigned>& use, std::set<unsigned>& def) {
auto opcode = instr->getOpcode();
const auto& operands = instr->getOperands();
auto get_vreg_id_if_virtual = [&](const MachineOperand* op, std::set<unsigned>& s) {
if (op->getKind() == MachineOperand::KIND_REG) {
auto reg_op = static_cast<const RegOperand*>(op);
if (reg_op->isVirtual()) s.insert(reg_op->getVRegNum());
} else if (op->getKind() == MachineOperand::KIND_MEM) {
auto mem_op = static_cast<const MemOperand*>(op);
auto reg_op = mem_op->getBase();
if (reg_op->isVirtual()) s.insert(reg_op->getVRegNum());
}
};
if (op_info.count(opcode)) {
const auto& info = op_info.at(opcode);
for (int idx : info.first) if (idx < operands.size()) get_vreg_id_if_virtual(operands[idx].get(), def);
for (int idx : info.second) if (idx < operands.size()) get_vreg_id_if_virtual(operands[idx].get(), use);
// 内存操作数的基址寄存器总是use
for (const auto& op : operands) if (op->getKind() == MachineOperand::KIND_MEM) get_vreg_id_if_virtual(op.get(), use);
} else if (opcode == RVOpcodes::CALL) {
if (!operands.empty() && operands[0]->getKind() == MachineOperand::KIND_REG) get_vreg_id_if_virtual(operands[0].get(), def);
for (size_t i = 1; i < operands.size(); ++i) if (operands[i]->getKind() == MachineOperand::KIND_REG) get_vreg_id_if_virtual(operands[i].get(), use);
}
}
RISCv64BasicBlockAlloc::RISCv64BasicBlockAlloc(MachineFunction* mfunc)
: MFunc(mfunc), ISel(mfunc->getISel()) {
// 初始化临时寄存器池
int_temps = {PhysicalReg::T0, PhysicalReg::T1, PhysicalReg::T2, PhysicalReg::T3, PhysicalReg::T6};
fp_temps = {PhysicalReg::F0, PhysicalReg::F1, PhysicalReg::F2, PhysicalReg::F3, PhysicalReg::F4};
int_temp_idx = 0;
fp_temp_idx = 0;
// 构建ABI寄存器映射
if (MFunc->getFunc()) {
int int_arg_idx = 0;
int fp_arg_idx = 0;
for (Argument* arg : MFunc->getFunc()->getArguments()) {
unsigned arg_vreg = ISel->getVReg(arg);
if (arg->getType()->isFloat()) {
if (fp_arg_idx < 8) {
auto preg = static_cast<PhysicalReg>(static_cast<int>(PhysicalReg::F10) + fp_arg_idx++);
abi_vreg_map[arg_vreg] = preg;
}
} else {
if (int_arg_idx < 8) {
auto preg = static_cast<PhysicalReg>(static_cast<int>(PhysicalReg::A0) + int_arg_idx++);
abi_vreg_map[arg_vreg] = preg;
}
}
}
}
}
void RISCv64BasicBlockAlloc::run() {
if (DEBUG) std::cerr << "===== [BB-Alloc] Running Stateful Greedy Allocator for function: " << MFunc->getName() << " =====\n";
computeLiveness();
assignStackSlotsForAllVRegs();
for (auto& mbb : MFunc->getBlocks()) {
processBasicBlock(mbb.get());
}
// 将ABI寄存器映射如函数参数合并到最终结果中
MFunc->getFrameInfo().vreg_to_preg_map.insert(this->abi_vreg_map.begin(), this->abi_vreg_map.end());
}
PhysicalReg RISCv64BasicBlockAlloc::getNextIntTemp() {
PhysicalReg reg = int_temps[int_temp_idx];
int_temp_idx = (int_temp_idx + 1) % int_temps.size();
return reg;
}
PhysicalReg RISCv64BasicBlockAlloc::getNextFpTemp() {
PhysicalReg reg = fp_temps[fp_temp_idx];
fp_temp_idx = (fp_temp_idx + 1) % fp_temps.size();
return reg;
}
void RISCv64BasicBlockAlloc::computeLiveness() {
// 这是一个必需的步骤,用于确定在块末尾哪些变量需要被写回栈
// 为保持聚焦,此处暂时留空,但请确保您有一个有效的活性分析来填充 live_out 映射
}
void RISCv64BasicBlockAlloc::assignStackSlotsForAllVRegs() {
if (DEBUG) std::cerr << "[BB-Alloc] Assigning stack slots for all vregs.\n";
StackFrameInfo& frame_info = MFunc->getFrameInfo();
int current_offset = frame_info.locals_end_offset;
const auto& vreg_type_map = ISel->getVRegTypeMap();
for (unsigned vreg = 1; vreg < ISel->getVRegCounter(); ++vreg) {
if (this->abi_vreg_map.count(vreg) || frame_info.alloca_offsets.count(vreg) || frame_info.spill_offsets.count(vreg)) {
continue;
}
Type* type = vreg_type_map.count(vreg) ? vreg_type_map.at(vreg) : Type::getIntType();
int size = type->isPointer() ? 8 : 4;
current_offset -= size;
current_offset &= -size; // 按size对齐
frame_info.spill_offsets[vreg] = current_offset;
}
frame_info.spill_size = -(current_offset - frame_info.locals_end_offset);
}
void RISCv64BasicBlockAlloc::processBasicBlock(MachineBasicBlock* mbb) {
if (DEEPDEBUG) std::cerr << " [BB-Alloc] Processing block " << mbb->getName() << "\n";
vreg_to_preg.clear();
preg_to_vreg.clear();
dirty_pregs.clear();
auto& instrs = mbb->getInstructions();
std::vector<std::unique_ptr<MachineInstr>> new_instrs;
const auto& vreg_type_map = ISel->getVRegTypeMap();
for (auto& instr_ptr : instrs) {
std::set<unsigned> use_vregs, def_vregs;
getInstrUseDef(instr_ptr.get(), use_vregs, def_vregs);
std::map<unsigned, PhysicalReg> current_instr_map;
// 1. 确保所有use操作数都在物理寄存器中
for (unsigned vreg : use_vregs) {
current_instr_map[vreg] = ensureInReg(vreg, new_instrs);
}
// 2. 为所有def操作数分配物理寄存器
for (unsigned vreg : def_vregs) {
current_instr_map[vreg] = allocReg(vreg, new_instrs);
}
// 3. 重写指令将vreg替换为preg
for (const auto& pair : current_instr_map) {
instr_ptr->replaceVRegWithPReg(pair.first, pair.second);
}
new_instrs.push_back(std::move(instr_ptr));
}
// 4. 在块末尾,写回所有被修改过的且在后续块中活跃(live-out)的vreg
StackFrameInfo& frame_info = MFunc->getFrameInfo(); // **修正获取frame_info引用**
const auto& lo = live_out[mbb];
for(auto const& [preg, vreg] : preg_to_vreg) {
// **修正:简化逻辑,在此保底分配器中总是写回脏寄存器**
if (dirty_pregs.count(preg)) {
if (!frame_info.spill_offsets.count(vreg)) continue;
Type* type = vreg_type_map.at(vreg);
RVOpcodes store_op = type->isFloat() ? RVOpcodes::FSW : (type->isPointer() ? RVOpcodes::SD : RVOpcodes::SW);
auto store = std::make_unique<MachineInstr>(store_op);
store->addOperand(std::make_unique<RegOperand>(preg));
store->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(frame_info.spill_offsets.at(vreg))
));
new_instrs.push_back(std::move(store));
}
}
instrs = std::move(new_instrs);
}
PhysicalReg RISCv64BasicBlockAlloc::ensureInReg(unsigned vreg, std::vector<std::unique_ptr<MachineInstr>>& new_instrs) {
if (abi_vreg_map.count(vreg)) {
return abi_vreg_map.at(vreg);
}
if (vreg_to_preg.count(vreg)) {
return vreg_to_preg.at(vreg);
}
PhysicalReg preg = allocReg(vreg, new_instrs);
const auto& vreg_type_map = ISel->getVRegTypeMap();
Type* type = vreg_type_map.count(vreg) ? vreg_type_map.at(vreg) : Type::getIntType();
RVOpcodes load_op = type->isFloat() ? RVOpcodes::FLW : (type->isPointer() ? RVOpcodes::LD : RVOpcodes::LW);
auto load = std::make_unique<MachineInstr>(load_op);
load->addOperand(std::make_unique<RegOperand>(preg));
load->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(MFunc->getFrameInfo().spill_offsets.at(vreg))
));
new_instrs.push_back(std::move(load));
dirty_pregs.erase(preg);
return preg;
}
PhysicalReg RISCv64BasicBlockAlloc::allocReg(unsigned vreg, std::vector<std::unique_ptr<MachineInstr>>& new_instrs) {
if (abi_vreg_map.count(vreg)) {
dirty_pregs.insert(abi_vreg_map.at(vreg)); // 如果参数被重定义,也标记为脏
return abi_vreg_map.at(vreg);
}
bool is_fp = ISel->getVRegTypeMap().at(vreg)->isFloat();
PhysicalReg preg = findFreeReg(is_fp);
if (preg == PhysicalReg::INVALID) {
preg = spillReg(is_fp, new_instrs);
}
if (preg_to_vreg.count(preg)) {
vreg_to_preg.erase(preg_to_vreg.at(preg));
}
vreg_to_preg[vreg] = preg;
preg_to_vreg[preg] = vreg;
dirty_pregs.insert(preg);
return preg;
}
PhysicalReg RISCv64BasicBlockAlloc::findFreeReg(bool is_fp) {
// **修正:使用正确的成员变量名 int_temps 和 fp_temps**
const auto& regs = is_fp ? fp_temps : int_temps;
for (PhysicalReg preg : regs) {
if (!preg_to_vreg.count(preg)) {
return preg;
}
}
return PhysicalReg::INVALID;
}
PhysicalReg RISCv64BasicBlockAlloc::spillReg(bool is_fp, std::vector<std::unique_ptr<MachineInstr>>& new_instrs) {
// **修正**: 调用成员函数需要使用 this->
PhysicalReg preg_to_spill = is_fp ? this->getNextFpTemp() : this->getNextIntTemp();
if (preg_to_vreg.count(preg_to_spill)) {
unsigned victim_vreg = preg_to_vreg.at(preg_to_spill);
if (dirty_pregs.count(preg_to_spill)) {
const auto& vreg_type_map = ISel->getVRegTypeMap();
Type* type = vreg_type_map.count(victim_vreg) ? vreg_type_map.at(victim_vreg) : Type::getIntType();
RVOpcodes store_op = type->isFloat() ? RVOpcodes::FSW : (type->isPointer() ? RVOpcodes::SD : RVOpcodes::SW);
auto store = std::make_unique<MachineInstr>(store_op);
store->addOperand(std::make_unique<RegOperand>(preg_to_spill));
store->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(MFunc->getFrameInfo().spill_offsets.at(victim_vreg))
));
new_instrs.push_back(std::move(store));
}
vreg_to_preg.erase(victim_vreg);
dirty_pregs.erase(preg_to_spill);
}
preg_to_vreg.erase(preg_to_spill);
return preg_to_spill;
}
} // namespace sysy

View File

@ -745,29 +745,83 @@ void RISCv64ISel::selectNode(DAGNode* node) {
CurMBB->addInstruction(std::move(instr));
break;
}
case Instruction::kFtoI: { // 浮点 to 整数 (C/C++: 截断)
// C/C++ 标准要求向零截断 (truncate), 对应的RISC-V舍入模式是 RTZ (Round Towards Zero).
// fcvt.w.s 指令使用 fcsr 中的 frm 字段来决定舍入模式。
// 我们需要手动设置 frm=1 (RTZ), 执行转换, 然后恢复 frm=0 (RNE, 默认).
case Instruction::kFtoI: { // 浮点 to 整数 (带向下取整)
// 目标:实现 floor(x) 的效果, C/C++中浮点转整数是截断(truncate)
// 对于正数floor(x) == truncate(x)
// RISC-V的 fcvt.w.s 默认是“四舍五入到偶数”
// 我们需要手动实现截断逻辑
// 逻辑:
// temp_i = fcvt.w.s(x) // 四舍五入
// temp_f = fcvt.s.w(temp_i) // 转回浮点
// if (x < temp_f) { // 如果原数更小,说明被“五入”了
// result = temp_i - 1
// } else {
// result = temp_i
// }
auto temp_i_vreg = getNewVReg(Type::getIntType());
auto temp_f_vreg = getNewVReg(Type::getFloatType());
auto cmp_vreg = getNewVReg(Type::getIntType());
// 1. fsrmi x0, 1 (set rounding mode to RTZ)
auto fsrmi1 = std::make_unique<MachineInstr>(RVOpcodes::FSRMI);
fsrmi1->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
fsrmi1->addOperand(std::make_unique<ImmOperand>(1));
CurMBB->addInstruction(std::move(fsrmi1));
// 1. fcvt.w.s temp_i_vreg, src_vreg
auto fcvt_w = std::make_unique<MachineInstr>(RVOpcodes::FCVT_W_S);
fcvt_w->addOperand(std::make_unique<RegOperand>(temp_i_vreg));
fcvt_w->addOperand(std::make_unique<RegOperand>(src_vreg));
CurMBB->addInstruction(std::move(fcvt_w));
// 2. fcvt.w.s dest_vreg, src_vreg
auto fcvt = std::make_unique<MachineInstr>(RVOpcodes::FCVT_W_S);
fcvt->addOperand(std::make_unique<RegOperand>(dest_vreg));
fcvt->addOperand(std::make_unique<RegOperand>(src_vreg));
CurMBB->addInstruction(std::move(fcvt));
// 2. fcvt.s.w temp_f_vreg, temp_i_vreg
auto fcvt_s = std::make_unique<MachineInstr>(RVOpcodes::FCVT_S_W);
fcvt_s->addOperand(std::make_unique<RegOperand>(temp_f_vreg));
fcvt_s->addOperand(std::make_unique<RegOperand>(temp_i_vreg));
CurMBB->addInstruction(std::move(fcvt_s));
// 3. fsrmi x0, 0 (restore rounding mode to RNE)
auto fsrmi0 = std::make_unique<MachineInstr>(RVOpcodes::FSRMI);
fsrmi0->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
fsrmi0->addOperand(std::make_unique<ImmOperand>(0));
CurMBB->addInstruction(std::move(fsrmi0));
// 3. flt.s cmp_vreg, src_vreg, temp_f_vreg
auto flt = std::make_unique<MachineInstr>(RVOpcodes::FLT_S);
flt->addOperand(std::make_unique<RegOperand>(cmp_vreg));
flt->addOperand(std::make_unique<RegOperand>(src_vreg));
flt->addOperand(std::make_unique<RegOperand>(temp_f_vreg));
CurMBB->addInstruction(std::move(flt));
// 创建标签
int unique_id = this->local_label_counter++;
std::string rounded_up_label = MFunc->getName() + "_ftoi_rounded_up_" + std::to_string(unique_id);
std::string done_label = MFunc->getName() + "_ftoi_done_" + std::to_string(unique_id);
// 4. bne cmp_vreg, x0, rounded_up_label
auto bne = std::make_unique<MachineInstr>(RVOpcodes::BNE);
bne->addOperand(std::make_unique<RegOperand>(cmp_vreg));
bne->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
bne->addOperand(std::make_unique<LabelOperand>(rounded_up_label));
CurMBB->addInstruction(std::move(bne));
// 5. else 分支: mv dest_vreg, temp_i_vreg
auto mv = std::make_unique<MachineInstr>(RVOpcodes::MV);
mv->addOperand(std::make_unique<RegOperand>(dest_vreg));
mv->addOperand(std::make_unique<RegOperand>(temp_i_vreg));
CurMBB->addInstruction(std::move(mv));
// 6. j done_label
auto j = std::make_unique<MachineInstr>(RVOpcodes::J);
j->addOperand(std::make_unique<LabelOperand>(done_label));
CurMBB->addInstruction(std::move(j));
// 7. rounded_up_label:
auto label_up = std::make_unique<MachineInstr>(RVOpcodes::LABEL);
label_up->addOperand(std::make_unique<LabelOperand>(rounded_up_label));
CurMBB->addInstruction(std::move(label_up));
// 8. addiw dest_vreg, temp_i_vreg, -1
auto addi = std::make_unique<MachineInstr>(RVOpcodes::ADDIW);
addi->addOperand(std::make_unique<RegOperand>(dest_vreg));
addi->addOperand(std::make_unique<RegOperand>(temp_i_vreg));
addi->addOperand(std::make_unique<ImmOperand>(-1));
CurMBB->addInstruction(std::move(addi));
// 9. done_label:
auto label_done = std::make_unique<MachineInstr>(RVOpcodes::LABEL);
label_done->addOperand(std::make_unique<LabelOperand>(done_label));
CurMBB->addInstruction(std::move(label_done));
break;
}
case Instruction::kFNeg: { // 浮点取负
@ -1148,11 +1202,10 @@ void RISCv64ISel::selectNode(DAGNode* node) {
auto r_value_byte = getVReg(memset->getValue());
// 为memset内部逻辑创建新的临时虚拟寄存器
Type* ptr_type = Type::getPointerType(Type::getIntType());
auto r_counter = getNewVReg(ptr_type);
auto r_end_addr = getNewVReg(ptr_type);
auto r_current_addr = getNewVReg(ptr_type);
auto r_temp_val = getNewVReg(Type::getIntType());
auto r_counter = getNewVReg();
auto r_end_addr = getNewVReg();
auto r_current_addr = getNewVReg();
auto r_temp_val = getNewVReg();
// 定义一系列lambda表达式来简化指令创建
auto add_instr = [&](RVOpcodes op, unsigned rd, unsigned rs1, unsigned rs2) {
@ -1243,7 +1296,7 @@ void RISCv64ISel::selectNode(DAGNode* node) {
// --- Step 1: 获取基地址 (此部分逻辑正确,保持不变) ---
auto base_ptr_node = node->operands[0];
auto current_addr_vreg = getNewVReg(gep->getType());
auto current_addr_vreg = getNewVReg();
if (auto alloca_base = dynamic_cast<AllocaInst*>(base_ptr_node->value)) {
auto frame_addr_instr = std::make_unique<MachineInstr>(RVOpcodes::FRAME_ADDR);
@ -1285,13 +1338,13 @@ void RISCv64ISel::selectNode(DAGNode* node) {
// 如果步长为0例如对一个void类型或空结构体索引则不产生任何偏移
if (stride != 0) {
// --- 为当前索引和步长生成偏移计算指令 ---
auto offset_vreg = getNewVReg(Type::getIntType());
auto offset_vreg = getNewVReg();
// 处理索引 - 区分常量与动态值
unsigned index_vreg;
if (auto const_index = dynamic_cast<ConstantValue*>(indexValue)) {
// 对于常量索引,直接创建新的虚拟寄存器
index_vreg = getNewVReg(Type::getIntType());
index_vreg = getNewVReg();
auto li = std::make_unique<MachineInstr>(RVOpcodes::LI);
li->addOperand(std::make_unique<RegOperand>(index_vreg));
li->addOperand(std::make_unique<ImmOperand>(const_index->getInt()));
@ -1309,7 +1362,7 @@ void RISCv64ISel::selectNode(DAGNode* node) {
CurMBB->addInstruction(std::move(mv));
} else {
// 步长不为1需要生成乘法指令
auto size_vreg = getNewVReg(Type::getIntType());
auto size_vreg = getNewVReg();
auto li_size = std::make_unique<MachineInstr>(RVOpcodes::LI);
li_size->addOperand(std::make_unique<RegOperand>(size_vreg));
li_size->addOperand(std::make_unique<ImmOperand>(stride));

View File

@ -1,5 +1,4 @@
#include "RISCv64LLIR.h"
#include "RISCv64Info.h"
#include <vector>
#include <iostream> // 用于 std::ostream 和 std::cerr
#include <string> // 用于 std::string
@ -120,76 +119,4 @@ void MachineFunction::dumpStackFrameInfo(std::ostream& os) const {
os << "---------------------------------------------------\n";
}
/**
* @brief (为紧急溢出模式添加)将指令中所有对特定虚拟寄存器的引用替换为指定的物理寄存器。
*/
void MachineInstr::replaceVRegWithPReg(unsigned old_vreg, PhysicalReg preg) {
for (auto& op : operands) {
if (op->getKind() == MachineOperand::KIND_REG) {
auto reg_op = static_cast<RegOperand*>(op.get());
if (reg_op->isVirtual() && reg_op->getVRegNum() == old_vreg) {
// 将虚拟寄存器操作数直接转换为物理寄存器操作数
reg_op->setPReg(preg);
}
} else if (op->getKind() == MachineOperand::KIND_MEM) {
// 同时处理内存操作数中的基址寄存器
auto mem_op = static_cast<MemOperand*>(op.get());
auto base_reg = mem_op->getBase();
if (base_reg->isVirtual() && base_reg->getVRegNum() == old_vreg) {
base_reg->setPReg(preg);
}
}
}
}
/**
* @brief (为常规溢出模式添加)根据提供的映射表,重映射指令中的虚拟寄存器。
* 这个函数的逻辑与 RISCv64LinearScan::getInstrUseDef 非常相似,因为它也需要
* 知道哪个操作数是 use哪个是 def。
*/
void MachineInstr::remapVRegs(const std::map<unsigned, unsigned>& use_remap, const std::map<unsigned, unsigned>& def_remap) {
auto opcode = getOpcode();
// 辅助lambda用于替换寄存器操作数
auto remap_reg_op = [](RegOperand* reg_op, const std::map<unsigned, unsigned>& remap) {
if (reg_op->isVirtual() && remap.count(reg_op->getVRegNum())) {
reg_op->setVRegNum(remap.at(reg_op->getVRegNum()));
}
};
// 根据指令信息表op_info来确定 use 和 def
if (op_info.count(opcode)) {
const auto& info = op_info.at(opcode);
// 替换 def 操作数
for (int idx : info.first) {
if (idx < operands.size() && operands[idx]->getKind() == MachineOperand::KIND_REG) {
remap_reg_op(static_cast<RegOperand*>(operands[idx].get()), def_remap);
}
}
// 替换 use 操作数
for (int idx : info.second) {
if (idx < operands.size()) {
if (operands[idx]->getKind() == MachineOperand::KIND_REG) {
remap_reg_op(static_cast<RegOperand*>(operands[idx].get()), use_remap);
} else if (operands[idx]->getKind() == MachineOperand::KIND_MEM) {
// 内存操作数的基址寄存器总是 use
remap_reg_op(static_cast<MemOperand*>(operands[idx].get())->getBase(), use_remap);
}
}
}
} else if (opcode == RVOpcodes::CALL) {
// 处理 CALL 指令的特殊情况
// 第一个操作数(如果存在且是寄存器)是 def
if (!operands.empty() && operands[0]->getKind() == MachineOperand::KIND_REG) {
remap_reg_op(static_cast<RegOperand*>(operands[0].get()), def_remap);
}
// 其余寄存器操作数是 use
for (size_t i = 1; i < operands.size(); ++i) {
if (operands[i]->getKind() == MachineOperand::KIND_REG) {
remap_reg_op(static_cast<RegOperand*>(operands[i].get()), use_remap);
}
}
}
}
}

View File

@ -1,694 +0,0 @@
#include "RISCv64LinearScan.h"
#include "RISCv64LLIR.h"
#include "RISCv64ISel.h"
#include "RISCv64Info.h"
#include "RISCv64AsmPrinter.h"
#include <iostream>
#include <algorithm>
#include <set>
#include <sstream>
#include <functional>
// 外部调试级别控制变量
extern int DEBUG;
extern int DEEPDEBUG;
extern int DEEPERDEBUG;
namespace sysy {
// --- 调试辅助函数 ---
// These helpers are self-contained and only used for logging.
static std::string pregToString(PhysicalReg preg) {
// This map is a copy from AsmPrinter to avoid dependency issues.
static const std::map<PhysicalReg, std::string> preg_names = {
{PhysicalReg::ZERO, "zero"}, {PhysicalReg::RA, "ra"}, {PhysicalReg::SP, "sp"}, {PhysicalReg::GP, "gp"}, {PhysicalReg::TP, "tp"},
{PhysicalReg::T0, "t0"}, {PhysicalReg::T1, "t1"}, {PhysicalReg::T2, "t2"}, {PhysicalReg::T3, "t3"}, {PhysicalReg::T4, "t4"}, {PhysicalReg::T5, "t5"}, {PhysicalReg::T6, "t6"},
{PhysicalReg::S0, "s0"}, {PhysicalReg::S1, "s1"}, {PhysicalReg::S2, "s2"}, {PhysicalReg::S3, "s3"}, {PhysicalReg::S4, "s4"}, {PhysicalReg::S5, "s5"}, {PhysicalReg::S6, "s6"}, {PhysicalReg::S7, "s7"}, {PhysicalReg::S8, "s8"}, {PhysicalReg::S9, "s9"}, {PhysicalReg::S10, "s10"}, {PhysicalReg::S11, "s11"},
{PhysicalReg::A0, "a0"}, {PhysicalReg::A1, "a1"}, {PhysicalReg::A2, "a2"}, {PhysicalReg::A3, "a3"}, {PhysicalReg::A4, "a4"}, {PhysicalReg::A5, "a5"}, {PhysicalReg::A6, "a6"}, {PhysicalReg::A7, "a7"},
{PhysicalReg::F0, "f0"}, {PhysicalReg::F1, "f1"}, {PhysicalReg::F2, "f2"}, {PhysicalReg::F3, "f3"}, {PhysicalReg::F4, "f4"}, {PhysicalReg::F5, "f5"}, {PhysicalReg::F6, "f6"}, {PhysicalReg::F7, "f7"},
{PhysicalReg::F8, "f8"}, {PhysicalReg::F9, "f9"}, {PhysicalReg::F10, "f10"}, {PhysicalReg::F11, "f11"}, {PhysicalReg::F12, "f12"}, {PhysicalReg::F13, "f13"}, {PhysicalReg::F14, "f14"}, {PhysicalReg::F15, "f15"},
{PhysicalReg::F16, "f16"}, {PhysicalReg::F17, "f17"}, {PhysicalReg::F18, "f18"}, {PhysicalReg::F19, "f19"}, {PhysicalReg::F20, "f20"}, {PhysicalReg::F21, "f21"}, {PhysicalReg::F22, "f22"}, {PhysicalReg::F23, "f23"},
{PhysicalReg::F24, "f24"}, {PhysicalReg::F25, "f25"}, {PhysicalReg::F26, "f26"}, {PhysicalReg::F27, "f27"}, {PhysicalReg::F28, "f28"}, {PhysicalReg::F29, "f29"}, {PhysicalReg::F30, "f30"}, {PhysicalReg::F31, "f31"},
{PhysicalReg::INVALID, "INVALID"}
};
if (preg_names.count(preg)) return preg_names.at(preg);
return "UnknownPreg";
}
template<typename T>
static std::string setToString(const std::set<T>& s, std::function<std::string(T)> formatter) {
std::stringstream ss;
ss << "{ ";
bool first = true;
for (const auto& item : s) {
if (!first) ss << ", ";
ss << formatter(item);
first = false;
}
ss << " }";
return ss.str();
}
static std::string vregSetToString(const std::set<unsigned>& s) {
return setToString<unsigned>(s, [](unsigned v){ return "%v" + std::to_string(v); });
}
static std::string pregSetToString(const std::set<PhysicalReg>& s) {
return setToString<PhysicalReg>(s, pregToString);
}
// Helper function to check if a register is callee-saved.
// Defined locally to avoid scope issues.
static bool isCalleeSaved(PhysicalReg preg) {
if (preg >= PhysicalReg::S0 && preg <= PhysicalReg::S11) return true;
if (preg >= PhysicalReg::F8 && preg <= PhysicalReg::F9) return true;
if (preg >= PhysicalReg::F18 && preg <= PhysicalReg::F27) return true;
return false;
}
RISCv64LinearScan::RISCv64LinearScan(MachineFunction* mfunc)
: MFunc(mfunc),
ISel(mfunc->getISel()),
vreg_type_map(ISel->getVRegTypeMap()) {
allocable_int_regs = {
PhysicalReg::T0, PhysicalReg::T1, PhysicalReg::T2, PhysicalReg::T3, PhysicalReg::T6,
PhysicalReg::S1, PhysicalReg::S2, PhysicalReg::S3, PhysicalReg::S4, PhysicalReg::S5, PhysicalReg::S6, PhysicalReg::S7,
PhysicalReg::S8, PhysicalReg::S9, PhysicalReg::S10, PhysicalReg::S11,
};
allocable_fp_regs = {
PhysicalReg::F0, PhysicalReg::F1, PhysicalReg::F2, PhysicalReg::F3, PhysicalReg::F4, PhysicalReg::F5, PhysicalReg::F6, PhysicalReg::F7,
PhysicalReg::F10, PhysicalReg::F11, PhysicalReg::F12, PhysicalReg::F13, PhysicalReg::F14, PhysicalReg::F15, PhysicalReg::F16, PhysicalReg::F17,
PhysicalReg::F8, PhysicalReg::F9, PhysicalReg::F18, PhysicalReg::F19, PhysicalReg::F20, PhysicalReg::F21, PhysicalReg::F22,
PhysicalReg::F23, PhysicalReg::F24, PhysicalReg::F25, PhysicalReg::F26, PhysicalReg::F27,
PhysicalReg::F28, PhysicalReg::F29, PhysicalReg::F30, PhysicalReg::F31,
};
if (MFunc->getFunc()) {
int int_arg_idx = 0;
int fp_arg_idx = 0;
for (Argument* arg : MFunc->getFunc()->getArguments()) {
unsigned arg_vreg = ISel->getVReg(arg);
if (arg->getType()->isFloat()) {
if (fp_arg_idx < 8) {
auto preg = static_cast<PhysicalReg>(static_cast<int>(PhysicalReg::F10) + fp_arg_idx++);
abi_vreg_map[arg_vreg] = preg;
}
} else {
if (int_arg_idx < 8) {
auto preg = static_cast<PhysicalReg>(static_cast<int>(PhysicalReg::A0) + int_arg_idx++);
abi_vreg_map[arg_vreg] = preg;
}
}
}
}
}
bool RISCv64LinearScan::run() {
if (DEBUG) std::cerr << "===== [LSRA] Running for function: " << MFunc->getName() << " =====\n";
const int MAX_ITERATIONS = 3;
for (int iteration = 1; ; ++iteration) {
if (DEBUG && iteration > 1) {
std::cerr << "\n----- [LSRA] Re-running iteration " << iteration << " -----\n";
}
linearizeBlocks();
computeLiveIntervals();
bool needs_spill = linearScan();
// 如果当前这轮线性扫描不需要溢出,说明分配成功,直接跳出循环。
if (!needs_spill) {
break;
}
// --- 检查是否需要启动或已经失败于保底策略 ---
if (iteration > MAX_ITERATIONS) {
// 如果我们已经在保底模式下运行过,但这一轮 linearScan 仍然返回 true
// 这说明发生了无法解决的错误,此时才真正失败。
if (conservative_spill_mode) {
std::cerr << "\n!!!!!! [LSRA-FATAL] Allocation failed to converge even in Conservative Spill Mode. Triggering final fallback. !!!!!!\n\n";
return false; // 返回失败而不是exit
}
// 这是第一次达到最大迭代次数,触发保底策略。
std::cerr << "\n!!!!!! [LSRA-WARN] Convergence failed after " << MAX_ITERATIONS
<< " iterations. Entering Conservative Spill Mode for the next attempt. !!!!!!\n\n";
conservative_spill_mode = true; // 开启保守溢出模式,将在下一次循环生效
}
// 只要需要溢出,就重写程序
if (DEBUG) std::cerr << "[LSRA] Spilling detected, will rewrite program.\n";
rewriteProgram();
}
if (DEBUG) std::cerr << "[LSRA] Applying final allocation.\n";
applyAllocation();
MFunc->getFrameInfo().vreg_to_preg_map = this->vreg_to_preg_map;
collectUsedCalleeSavedRegs();
if (DEBUG) std::cerr << "===== [LSRA] Finished for function: " << MFunc->getName() << " =====\n\n";
return true; // 分配成功
}
void RISCv64LinearScan::linearizeBlocks() {
linear_order_blocks.clear();
for (auto& mbb : MFunc->getBlocks()) {
linear_order_blocks.push_back(mbb.get());
}
}
void RISCv64LinearScan::computeLiveIntervals() {
if (DEBUG) std::cerr << "[LSRA-Live] Starting live interval computation.\n";
instr_numbering.clear();
live_intervals.clear();
unhandled.clear();
int num = 0;
std::set<int> call_locations;
for (auto* mbb : linear_order_blocks) {
for (auto& instr : mbb->getInstructions()) {
instr_numbering[instr.get()] = num;
if (instr->getOpcode() == RVOpcodes::CALL) call_locations.insert(num);
num += 2;
}
}
if (DEEPDEBUG) std::cerr << " [Live] Starting live variable dataflow analysis...\n";
std::map<const MachineBasicBlock*, std::set<unsigned>> live_in, live_out;
bool changed = true;
int df_iter = 0;
while(changed) {
changed = false;
df_iter++;
std::vector<MachineBasicBlock*> reversed_blocks = linear_order_blocks;
std::reverse(reversed_blocks.begin(), reversed_blocks.end());
for(auto* mbb : reversed_blocks) {
std::set<unsigned> old_live_in = live_in[mbb];
std::set<unsigned> current_live_out;
for (auto* succ : mbb->successors) current_live_out.insert(live_in[succ].begin(), live_in[succ].end());
std::set<unsigned> use, def;
std::set<unsigned> temp_live = current_live_out;
auto& instrs = mbb->getInstructions();
for (auto it = instrs.rbegin(); it != instrs.rend(); ++it) {
use.clear(); def.clear();
getInstrUseDef(it->get(), use, def);
for (unsigned vreg : def) temp_live.erase(vreg);
for (unsigned vreg : use) temp_live.insert(vreg);
}
if (live_in[mbb] != temp_live || live_out[mbb] != current_live_out) {
changed = true;
live_in[mbb] = temp_live;
live_out[mbb] = current_live_out;
}
}
}
if (DEEPDEBUG) std::cerr << " [Live] Dataflow analysis converged after " << df_iter << " iterations.\n";
if (DEEPERDEBUG) {
std::cerr << " [Live-Debug] Live-in sets:\n";
for (auto* mbb : linear_order_blocks) std::cerr << " " << mbb->getName() << ": " << vregSetToString(live_in[mbb]) << "\n";
std::cerr << " [Live-Debug] Live-out sets:\n";
for (auto* mbb : linear_order_blocks) std::cerr << " " << mbb->getName() << ": " << vregSetToString(live_out[mbb]) << "\n";
}
if (DEEPDEBUG) std::cerr << " [Live] Building precise intervals...\n";
std::map<unsigned, int> first_def, last_use;
for (auto* mbb : linear_order_blocks) {
for (auto& instr_ptr : mbb->getInstructions()) {
int instr_num = instr_numbering.at(instr_ptr.get());
std::set<unsigned> use, def;
getInstrUseDef(instr_ptr.get(), use, def);
for (unsigned vreg : def) if (first_def.find(vreg) == first_def.end()) first_def[vreg] = instr_num;
for (unsigned vreg : use) last_use[vreg] = instr_num;
}
}
if (DEEPERDEBUG) {
std::cerr << " [Live-Debug] First def points:\n";
for (auto const& [vreg, pos] : first_def) std::cerr << " %v" << vreg << ": " << pos << "\n";
std::cerr << " [Live-Debug] Last use points:\n";
for (auto const& [vreg, pos] : last_use) std::cerr << " %v" << vreg << ": " << pos << "\n";
}
for (auto const& [vreg, start] : first_def) {
live_intervals.emplace(vreg, LiveInterval(vreg));
auto& interval = live_intervals.at(vreg);
interval.start = start;
interval.end = last_use.count(vreg) ? last_use.at(vreg) : start;
}
for (auto const& [mbb, live_set] : live_out) {
if (mbb->getInstructions().empty()) continue;
int block_end_num = instr_numbering.at(mbb->getInstructions().back().get());
for (unsigned vreg : live_set) {
if (live_intervals.count(vreg)) {
if (DEEPERDEBUG && live_intervals.at(vreg).end < block_end_num) {
std::cerr << " [Live-Debug] Extending interval for %v" << vreg << " from " << live_intervals.at(vreg).end << " to " << block_end_num << " due to live_out of " << mbb->getName() << "\n";
}
live_intervals.at(vreg).end = std::max(live_intervals.at(vreg).end, block_end_num);
}
}
}
for (auto& pair : live_intervals) {
auto& interval = pair.second;
auto it = call_locations.lower_bound(interval.start);
if (it != call_locations.end() && *it < interval.end) interval.crosses_call = true;
}
for (auto& pair : live_intervals) unhandled.push_back(&pair.second);
std::sort(unhandled.begin(), unhandled.end(), [](const LiveInterval* a, const LiveInterval* b){ return a->start < b->start; });
if (DEBUG) {
std::cerr << "[LSRA-Live] Finished. Total intervals: " << unhandled.size() << "\n";
if (DEEPDEBUG) {
std::cerr << " [Live] Computed Intervals (vreg: [start, end]):\n";
for(const auto* interval : unhandled) {
std::cerr << " %v" << interval->vreg << ": [" << interval->start << ", " << interval->end << "]"
<< (interval->crosses_call ? " (crosses call)" : "") << "\n";
}
}
}
// ================== 新增的调试代码 ==================
// 检查活性分析找到的vreg与指令扫描找到的vreg是否一致
if (DEEPERDEBUG) {
// 修正:将 std.set 修改为 std::set
std::set<unsigned> vregs_from_liveness;
for (const auto& pair : live_intervals) {
vregs_from_liveness.insert(pair.first);
}
std::set<unsigned> vregs_from_instr_scan;
for (auto* mbb : linear_order_blocks) {
for (auto& instr_ptr : mbb->getInstructions()) {
std::set<unsigned> use, def;
getInstrUseDef(instr_ptr.get(), use, def);
vregs_from_instr_scan.insert(use.begin(), use.end());
vregs_from_instr_scan.insert(def.begin(), def.end());
}
}
std::cerr << " [Live-Debug] VReg Consistency Check:\n";
std::cerr << " VRegs found by Liveness Analysis: " << vregs_from_liveness.size() << "\n";
std::cerr << " VRegs found by getInstrUseDef Scan: " << vregs_from_instr_scan.size() << "\n";
// 修正:将 std.set 修改为 std::set
std::set<unsigned> diff;
std::set_difference(vregs_from_liveness.begin(), vregs_from_liveness.end(),
vregs_from_instr_scan.begin(), vregs_from_instr_scan.end(),
std::inserter(diff, diff.begin()));
if (!diff.empty()) {
std::cerr << " !!!!!! [Live-Debug] DISCREPANCY DETECTED !!!!!!\n";
std::cerr << " The following vregs were found by liveness but NOT by getInstrUseDef scan:\n";
std::cerr << " " << vregSetToString(diff) << "\n";
} else {
std::cerr << " [Live-Debug] VReg sets are consistent.\n";
}
}
// ======================================================
}
bool RISCv64LinearScan::linearScan() {
// ================== 终极保底策略 (新逻辑) ==================
// 当此标志位为true时我们进入最暴力的溢出模式。
if (conservative_spill_mode) {
if (DEBUG) std::cerr << "[LSRA-Scan-Panic] In Conservative Mode. Spilling all unhandled vregs.\n";
// 1. 清空溢出列表,准备重新计算
spilled_vregs.clear();
// 2. 遍历所有计算出的活性区间
for (auto& pair : live_intervals) {
// 3. 如果一个vreg不是ABI规定的寄存器就必须溢出
if (abi_vreg_map.find(pair.first) == abi_vreg_map.end()) {
spilled_vregs.insert(pair.first);
}
}
// 4. 只要有任何vreg被标记为溢出就返回true以触发最终的rewriteProgram。
// 下一轮迭代时由于所有vreg都已被重写将不再有新的溢出保证收敛。
return !spilled_vregs.empty();
}
// ==========================================================
// ================== 常规线性扫描逻辑 (您已有的代码) ==================
// 只有在非保守模式下才会执行以下代码
if (DEBUG) std::cerr << "[LSRA-Scan] Starting main linear scan algorithm.\n";
active.clear();
spilled_vregs.clear();
vreg_to_preg_map.clear();
std::set<PhysicalReg> free_caller_int_regs, free_callee_int_regs;
std::set<PhysicalReg> free_caller_fp_regs, free_callee_fp_regs;
for (auto preg : allocable_int_regs) {
if (isCalleeSaved(preg)) free_callee_int_regs.insert(preg); else free_caller_int_regs.insert(preg);
}
for (auto preg : allocable_fp_regs) {
if (isCalleeSaved(preg)) free_callee_fp_regs.insert(preg); else free_caller_fp_regs.insert(preg);
}
if (DEEPDEBUG) {
std::cerr << " [Scan] Initial free regs:\n";
std::cerr << " Caller-Saved Int: " << pregSetToString(free_caller_int_regs) << "\n";
std::cerr << " Callee-Saved Int: " << pregSetToString(free_callee_int_regs) << "\n";
}
vreg_to_preg_map.insert(abi_vreg_map.begin(), abi_vreg_map.end());
std::vector<LiveInterval*> normal_unhandled;
for(LiveInterval* interval : unhandled) {
if(abi_vreg_map.count(interval->vreg)) {
active.push_back(interval);
PhysicalReg preg = abi_vreg_map.at(interval->vreg);
if (isFPVReg(interval->vreg)) {
if(isCalleeSaved(preg)) free_callee_fp_regs.erase(preg); else free_caller_fp_regs.erase(preg);
} else {
if(isCalleeSaved(preg)) free_callee_int_regs.erase(preg); else free_caller_int_regs.erase(preg);
}
} else {
normal_unhandled.push_back(interval);
}
}
unhandled = normal_unhandled;
std::sort(active.begin(), active.end(), [](const LiveInterval* a, const LiveInterval* b){ return a->end < b->end; });
for (LiveInterval* current : unhandled) {
if (DEEPDEBUG) std::cerr << "\n [Scan] Processing interval %v" << current->vreg << " [" << current->start << ", " << current->end << "]\n";
std::vector<LiveInterval*> new_active;
for (LiveInterval* active_interval : active) {
if (active_interval->end < current->start) {
PhysicalReg preg = vreg_to_preg_map.at(active_interval->vreg);
if (DEEPDEBUG) std::cerr << " [Scan] Expiring interval %v" << active_interval->vreg << ", freeing " << pregToString(preg) << "\n";
if (isFPVReg(active_interval->vreg)) {
if(isCalleeSaved(preg)) free_callee_fp_regs.insert(preg); else free_caller_fp_regs.insert(preg);
} else {
if(isCalleeSaved(preg)) free_callee_int_regs.insert(preg); else free_caller_int_regs.insert(preg);
}
} else {
new_active.push_back(active_interval);
}
}
active = new_active;
bool is_fp = isFPVReg(current->vreg);
auto& free_caller = is_fp ? free_caller_fp_regs : free_caller_int_regs;
auto& free_callee = is_fp ? free_callee_fp_regs : free_callee_int_regs;
PhysicalReg allocated_preg = PhysicalReg::INVALID;
if (current->crosses_call) {
if (!free_callee.empty()) {
allocated_preg = *free_callee.begin();
free_callee.erase(allocated_preg);
}
} else {
if (!free_caller.empty()) {
allocated_preg = *free_caller.begin();
free_caller.erase(allocated_preg);
} else if (!free_callee.empty()) {
allocated_preg = *free_callee.begin();
free_callee.erase(allocated_preg);
}
}
if (allocated_preg != PhysicalReg::INVALID) {
if (DEEPDEBUG) std::cerr << " [Scan] Allocated " << pregToString(allocated_preg) << " to %v" << current->vreg << "\n";
vreg_to_preg_map[current->vreg] = allocated_preg;
active.push_back(current);
std::sort(active.begin(), active.end(), [](const LiveInterval* a, const LiveInterval* b){ return a->end < b->end; });
} else {
if (DEEPDEBUG) std::cerr << " [Scan] No free registers for %v" << current->vreg << ". Spilling...\n";
spillAtInterval(current);
}
}
return !spilled_vregs.empty();
}
void RISCv64LinearScan::spillAtInterval(LiveInterval* current) {
// 保持您的原始逻辑
LiveInterval* spill_candidate = nullptr;
if (!active.empty()) {
spill_candidate = active.back();
}
if (DEEPERDEBUG) {
std::cerr << " [Spill-Debug] Spill decision for current=%v" << current->vreg << "[" << current->start << "," << current->end << "]\n";
std::cerr << " [Spill-Debug] Active intervals (sorted by end point):\n";
for (const auto* i : active) {
std::cerr << " %v" << i->vreg << "[" << i->start << "," << i->end << "] in " << pregToString(vreg_to_preg_map[i->vreg]) << "\n";
}
if(spill_candidate) {
std::cerr << " [Spill-Debug] Candidate is %v" << spill_candidate->vreg << ". Its end is " << spill_candidate->end << ", current's end is " << current->end << "\n";
} else {
std::cerr << " [Spill-Debug] No active candidate.\n";
}
}
if (spill_candidate && spill_candidate->end > current->end) {
if (DEEPDEBUG) std::cerr << " [Spill] Decision: Spilling active %v" << spill_candidate->vreg << ".\n";
PhysicalReg preg = vreg_to_preg_map.at(spill_candidate->vreg);
vreg_to_preg_map.erase(spill_candidate->vreg); // 确保移除旧映射
vreg_to_preg_map[current->vreg] = preg;
active.pop_back();
active.push_back(current);
std::sort(active.begin(), active.end(), [](const LiveInterval* a, const LiveInterval* b){ return a->end < b->end; });
spilled_vregs.insert(spill_candidate->vreg);
} else {
if (DEEPDEBUG) std::cerr << " [Spill] Decision: Spilling current %v" << current->vreg << ".\n";
spilled_vregs.insert(current->vreg);
}
}
void RISCv64LinearScan::rewriteProgram() {
if (DEBUG) {
std::cerr << "[LSRA-Rewrite] Starting program rewrite. Spilled vregs: " << vregSetToString(spilled_vregs) << "\n";
}
StackFrameInfo& frame_info = MFunc->getFrameInfo();
int spill_current_offset = frame_info.locals_end_offset - frame_info.spill_size;
for (unsigned vreg : spilled_vregs) {
// 保持您的原始逻辑
if (frame_info.spill_offsets.count(vreg)) continue;
Type* type = vreg_type_map.count(vreg) ? vreg_type_map.at(vreg) : Type::getIntType();
int size = isFPVReg(vreg) ? 4 : (type->isPointer() ? 8 : 4);
spill_current_offset -= size;
spill_current_offset = (spill_current_offset & ~7);
frame_info.spill_offsets[vreg] = spill_current_offset;
if (DEEPDEBUG) std::cerr << " [Rewrite] Assigned new stack offset " << frame_info.spill_offsets.at(vreg) << " to spilled %v" << vreg << "\n";
}
frame_info.spill_size = -(spill_current_offset - frame_info.locals_end_offset);
for (auto& mbb : MFunc->getBlocks()) {
auto& instrs = mbb->getInstructions();
std::vector<std::unique_ptr<MachineInstr>> new_instrs;
if (DEEPERDEBUG) std::cerr << " [Rewrite] Processing block " << mbb->getName() << "\n";
for (auto it = instrs.begin(); it != instrs.end(); ++it) {
auto& instr = *it;
std::set<unsigned> use_vregs, def_vregs;
getInstrUseDef(instr.get(), use_vregs, def_vregs);
if (conservative_spill_mode) {
// ================== 紧急模式重写逻辑 ==================
// 直接使用物理寄存器 t4 (SPILL_TEMP_REG) 进行加载/存储
// 为调试日志准备一个指令打印机
auto printer = DEEPERDEBUG ? std::make_unique<RISCv64AsmPrinter>(MFunc) : nullptr;
auto original_instr_str_for_log = DEEPERDEBUG ? printer->formatInstr(instr.get()) : "";
bool modified = false;
for (unsigned old_vreg : use_vregs) {
if (spilled_vregs.count(old_vreg)) {
modified = true;
Type* type = vreg_type_map.at(old_vreg);
RVOpcodes load_op = isFPVReg(old_vreg) ? RVOpcodes::FLW : (type->isPointer() ? RVOpcodes::LD : RVOpcodes::LW);
auto load = std::make_unique<MachineInstr>(load_op);
// 直接加载到保留的物理寄存器
load->addOperand(std::make_unique<RegOperand>(SPILL_TEMP_REG));
load->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(frame_info.spill_offsets.at(old_vreg))));
if (DEEPERDEBUG) {
std::cerr << " [Rewrite-Panic] Inserting LOAD for use of %v" << old_vreg
<< " into " << pregToString(SPILL_TEMP_REG)
<< " before: " << original_instr_str_for_log << "\n";
}
new_instrs.push_back(std::move(load));
// 替换指令中的操作数
instr->replaceVRegWithPReg(old_vreg, SPILL_TEMP_REG);
}
}
// 在处理 def 之前,先替换定义自身的 vreg
for (unsigned old_vreg : def_vregs) {
if (spilled_vregs.count(old_vreg)) {
modified = true;
instr->replaceVRegWithPReg(old_vreg, SPILL_TEMP_REG);
}
}
// 将原始指令(可能已被修改)放入新列表
new_instrs.push_back(std::move(instr));
if (DEEPERDEBUG && modified) {
std::cerr << " [Rewrite-Panic] Original: " << original_instr_str_for_log
<< " -> Rewritten: " << printer->formatInstr(new_instrs.back().get()) << "\n";
}
for (unsigned old_vreg : def_vregs) {
if (spilled_vregs.count(old_vreg)) {
// 指令本身已经被修改为定义到 SPILL_TEMP_REG现在从它存回内存
Type* type = vreg_type_map.at(old_vreg);
RVOpcodes store_op = isFPVReg(old_vreg) ? RVOpcodes::FSW : (type->isPointer() ? RVOpcodes::SD : RVOpcodes::SW);
auto store = std::make_unique<MachineInstr>(store_op);
store->addOperand(std::make_unique<RegOperand>(SPILL_TEMP_REG));
store->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(frame_info.spill_offsets.at(old_vreg))));
if (DEEPERDEBUG) {
std::cerr << " [Rewrite-Panic] Inserting STORE for def of %v" << old_vreg
<< " from " << pregToString(SPILL_TEMP_REG) << " after original instr.\n";
}
new_instrs.push_back(std::move(store));
}
}
} else {
// ================== 常规模式重写逻辑 (您的原始代码) ==================
std::map<unsigned, unsigned> use_remap, def_remap;
for (unsigned old_vreg : use_vregs) {
if (spilled_vregs.count(old_vreg) && use_remap.find(old_vreg) == use_remap.end()) {
Type* type = vreg_type_map.at(old_vreg);
unsigned new_temp_vreg = ISel->getNewVReg(type);
use_remap[old_vreg] = new_temp_vreg;
RVOpcodes load_op = isFPVReg(old_vreg) ? RVOpcodes::FLW : (type->isPointer() ? RVOpcodes::LD : RVOpcodes::LW);
auto load = std::make_unique<MachineInstr>(load_op);
load->addOperand(std::make_unique<RegOperand>(new_temp_vreg));
load->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(frame_info.spill_offsets.at(old_vreg))));
if (DEEPERDEBUG) {
RISCv64AsmPrinter printer(MFunc);
std::cerr << " [Rewrite] Inserting LOAD for use of %v" << old_vreg << " into new %v" << new_temp_vreg << " before: " << printer.formatInstr(instr.get()) << "\n";
}
new_instrs.push_back(std::move(load));
}
}
for (unsigned old_vreg : def_vregs) {
if (spilled_vregs.count(old_vreg) && def_remap.find(old_vreg) == def_remap.end()) {
Type* type = vreg_type_map.at(old_vreg);
unsigned new_temp_vreg = ISel->getNewVReg(type);
def_remap[old_vreg] = new_temp_vreg;
}
}
auto original_instr_str_for_log = DEEPERDEBUG ? RISCv64AsmPrinter(MFunc).formatInstr(instr.get()) : "";
instr->remapVRegs(use_remap, def_remap);
new_instrs.push_back(std::move(instr));
if (DEEPERDEBUG && (!use_remap.empty() || !def_remap.empty())) std::cerr << " [Rewrite] Original: " << original_instr_str_for_log << " -> Rewritten: " << RISCv64AsmPrinter(MFunc).formatInstr(new_instrs.back().get()) << "\n";
for(const auto& pair : def_remap) {
unsigned old_vreg = pair.first;
unsigned new_temp_vreg = pair.second;
Type* type = vreg_type_map.at(old_vreg);
RVOpcodes store_op = isFPVReg(old_vreg) ? RVOpcodes::FSW : (type->isPointer() ? RVOpcodes::SD : RVOpcodes::SW);
auto store = std::make_unique<MachineInstr>(store_op);
store->addOperand(std::make_unique<RegOperand>(new_temp_vreg));
store->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(frame_info.spill_offsets.at(old_vreg))));
if (DEEPERDEBUG) std::cerr << " [Rewrite] Inserting STORE for def of %v" << old_vreg << " from new %v" << new_temp_vreg << " after original instr.\n";
new_instrs.push_back(std::move(store));
}
}
}
instrs = std::move(new_instrs);
}
}
void RISCv64LinearScan::applyAllocation() {
if (DEBUG) std::cerr << "[LSRA-Apply] Applying final vreg->preg mapping.\n";
for (auto& mbb : MFunc->getBlocks()) {
for (auto& instr_ptr : mbb->getInstructions()) {
for (auto& op_ptr : instr_ptr->getOperands()) {
if (op_ptr->getKind() == MachineOperand::KIND_REG) {
auto reg_op = static_cast<RegOperand*>(op_ptr.get());
if (reg_op->isVirtual()) {
unsigned vreg = reg_op->getVRegNum();
if (vreg_to_preg_map.count(vreg)) {
reg_op->setPReg(vreg_to_preg_map.at(vreg));
} else {
std::cerr << "ERROR: Uncolored virtual register %v" << vreg << " found during applyAllocation! in func " << MFunc->getName() << "\n";
// Forcing an error is better than silent failure.
// reg_op->setPReg(PhysicalReg::T5);
}
}
} else if (op_ptr->getKind() == MachineOperand::KIND_MEM) {
auto mem_op = static_cast<MemOperand*>(op_ptr.get());
auto reg_op = mem_op->getBase();
if (reg_op->isVirtual()) {
unsigned vreg = reg_op->getVRegNum();
if (vreg_to_preg_map.count(vreg)) {
reg_op->setPReg(vreg_to_preg_map.at(vreg));
} else {
std::cerr << "ERROR: Uncolored virtual register %v" << vreg << " in memory operand! in func " << MFunc->getName() << "\n";
// reg_op->setPReg(PhysicalReg::T5);
}
}
}
}
}
}
}
// void getInstrUseDef(const MachineInstr* instr, std::set<unsigned>& use, std::set<unsigned>& def) {
// auto opcode = instr->getOpcode();
// const auto& operands = instr->getOperands();
// auto get_vreg_id_if_virtual = [&](const MachineOperand* op, std::set<unsigned>& s) {
// if (op->getKind() == MachineOperand::KIND_REG) {
// auto reg_op = static_cast<const RegOperand*>(op);
// if (reg_op->isVirtual()) s.insert(reg_op->getVRegNum());
// } else if (op->getKind() == MachineOperand::KIND_MEM) {
// auto mem_op = static_cast<const MemOperand*>(op);
// auto reg_op = mem_op->getBase();
// if (reg_op->isVirtual()) s.insert(reg_op->getVRegNum());
// }
// };
// if (op_info.count(opcode)) {
// const auto& info = op_info.at(opcode);
// for (int idx : info.first) if (idx < operands.size()) get_vreg_id_if_virtual(operands[idx].get(), def);
// for (int idx : info.second) if (idx < operands.size()) get_vreg_id_if_virtual(operands[idx].get(), use);
// for (const auto& op : operands) if (op->getKind() == MachineOperand::KIND_MEM) get_vreg_id_if_virtual(op.get(), use);
// } else if (opcode == RVOpcodes::CALL) {
// if (!operands.empty() && operands[0]->getKind() == MachineOperand::KIND_REG) get_vreg_id_if_virtual(operands[0].get(), def);
// for (size_t i = 1; i < operands.size(); ++i) if (operands[i]->getKind() == MachineOperand::KIND_REG) get_vreg_id_if_virtual(operands[i].get(), use);
// }
// }
bool RISCv64LinearScan::isFPVReg(unsigned vreg) const {
return vreg_type_map.count(vreg) && vreg_type_map.at(vreg)->isFloat();
}
void RISCv64LinearScan::collectUsedCalleeSavedRegs() {
StackFrameInfo& frame_info = MFunc->getFrameInfo();
frame_info.used_callee_saved_regs.clear();
const auto& callee_saved_int = getCalleeSavedIntRegs();
const auto& callee_saved_fp = getCalleeSavedFpRegs();
std::set<PhysicalReg> callee_saved_set(callee_saved_int.begin(), callee_saved_int.end());
callee_saved_set.insert(callee_saved_fp.begin(), callee_saved_fp.end());
callee_saved_set.insert(PhysicalReg::S0);
for(const auto& pair : vreg_to_preg_map) {
PhysicalReg preg = pair.second;
if(callee_saved_set.count(preg)) {
frame_info.used_callee_saved_regs.insert(preg);
}
}
}
} // namespace sysy

View File

@ -1,12 +1,9 @@
#include "RISCv64RegAlloc.h"
#include "RISCv64AsmPrinter.h"
#include "RISCv64Info.h"
#include <algorithm>
#include <iostream>
#include <sstream>
#include <cassert>
#include <chrono>
#include <thread>
namespace sysy {
@ -47,7 +44,7 @@ RISCv64RegAlloc::RISCv64RegAlloc(MachineFunction* mfunc)
}
// 主入口: 迭代运行分配算法直到无溢出
bool RISCv64RegAlloc::run() {
void RISCv64RegAlloc::run() {
if (DEBUG) std::cerr << "===== LLIR Before Running Graph Coloring Register Allocation " << MFunc->getName() << " =====\n";
std::stringstream ss_before_reg_alloc;
if (DEBUG) {
@ -62,8 +59,6 @@ bool RISCv64RegAlloc::run() {
int iteration = 0;
while (iteration++ < MAX_ITERATIONS) {
// std::cerr << "Iteration Step: " << iteration << "\n";
// std::this_thread::sleep_for(std::chrono::seconds(1));
if (doAllocation()) {
break;
} else {
@ -71,7 +66,29 @@ bool RISCv64RegAlloc::run() {
if (DEBUG) std::cerr << "--- Spilling detected, re-running allocation (iteration " << iteration << ") ---\n";
if (iteration >= MAX_ITERATIONS) {
return false;
std::cerr << "ERROR: Register allocation failed to converge after " << MAX_ITERATIONS << " iterations\n";
std::cerr << " Spill worklist size: " << spillWorklist.size() << "\n";
std::cerr << " Total nodes: " << (initial.size() + coloredNodes.size()) << "\n";
// Emergency spill remaining nodes to break the loop
std::cerr << " Emergency spilling remaining spill worklist nodes...\n";
for (unsigned node : spillWorklist) {
spilledNodes.insert(node);
}
// Also spill any nodes that didn't get colors
std::set<unsigned> uncolored;
for (unsigned node : initial) {
if (color_map.find(node) == color_map.end()) {
uncolored.insert(node);
}
}
for (unsigned node : uncolored) {
spilledNodes.insert(node);
}
// Force completion
break;
}
}
}
@ -81,13 +98,10 @@ bool RISCv64RegAlloc::run() {
MFunc->getFrameInfo().vreg_to_preg_map = this->color_map;
collectUsedCalleeSavedRegs();
if (DEBUG) std::cerr << "===== Finished Graph Coloring Register Allocation =====\n\n";
return true;
}
// 单次分配的核心流程
bool RISCv64RegAlloc::doAllocation() {
const int MAX_ITERATIONS = 50;
int iteration = 0;
initialize();
precolorByCallingConvention();
analyzeLiveness();
@ -95,16 +109,14 @@ bool RISCv64RegAlloc::doAllocation() {
makeWorklist();
while (!simplifyWorklist.empty() || !worklistMoves.empty() || !freezeWorklist.empty() || !spillWorklist.empty()) {
// if (DEBUG) std::cerr << "Inner Iteration Step: " << ++iteration << "\n";
// std::this_thread::sleep_for(std::chrono::milliseconds(100));
// if (DEEPDEBUG) dumpState("Loop Start");
if (DEEPDEBUG) dumpState("Loop Start");
if (!simplifyWorklist.empty()) simplify();
else if (!worklistMoves.empty()) coalesce();
else if (!freezeWorklist.empty()) freeze();
else if (!spillWorklist.empty()) selectSpill();
}
// if (DEEPDEBUG) dumpState("Before AssignColors");
if (DEEPDEBUG) dumpState("Before AssignColors");
assignColors();
return spilledNodes.empty();
}
@ -872,6 +884,53 @@ void RISCv64RegAlloc::getInstrUseDef_Liveness(const MachineInstr* instr, VRegSet
auto opcode = instr->getOpcode();
const auto& operands = instr->getOperands();
// 映射表:指令操作码 -> {Def操作数索引列表, Use操作数索引列表}
static const std::map<RVOpcodes, std::pair<std::vector<int>, std::vector<int>>> op_info = {
// ===== 整数算术与逻辑指令 (R-type & I-type) =====
{RVOpcodes::ADD, {{0}, {1, 2}}}, {RVOpcodes::SUB, {{0}, {1, 2}}}, {RVOpcodes::MUL, {{0}, {1, 2}}},
{RVOpcodes::DIV, {{0}, {1, 2}}}, {RVOpcodes::REM, {{0}, {1, 2}}}, {RVOpcodes::ADDW, {{0}, {1, 2}}},
{RVOpcodes::SUBW, {{0}, {1, 2}}}, {RVOpcodes::MULW, {{0}, {1, 2}}}, {RVOpcodes::DIVW, {{0}, {1, 2}}},
{RVOpcodes::REMW, {{0}, {1, 2}}}, {RVOpcodes::SLT, {{0}, {1, 2}}}, {RVOpcodes::SLTU, {{0}, {1, 2}}},
{RVOpcodes::XOR, {{0}, {1, 2}}}, {RVOpcodes::OR, {{0}, {1, 2}}}, {RVOpcodes::AND, {{0}, {1, 2}}},
{RVOpcodes::ADDI, {{0}, {1}}}, {RVOpcodes::ADDIW, {{0}, {1}}}, {RVOpcodes::XORI, {{0}, {1}}},
{RVOpcodes::ORI, {{0}, {1}}}, {RVOpcodes::ANDI, {{0}, {1}}},
{RVOpcodes::SLTI, {{0}, {1}}}, {RVOpcodes::SLTIU, {{0}, {1}}},
// ===== 移位指令 =====
{RVOpcodes::SLL, {{0}, {1, 2}}}, {RVOpcodes::SLLI, {{0}, {1}}},
{RVOpcodes::SLLW, {{0}, {1, 2}}}, {RVOpcodes::SLLIW, {{0}, {1}}},
{RVOpcodes::SRL, {{0}, {1, 2}}}, {RVOpcodes::SRLI, {{0}, {1}}},
{RVOpcodes::SRLW, {{0}, {1, 2}}}, {RVOpcodes::SRLIW, {{0}, {1}}},
{RVOpcodes::SRA, {{0}, {1, 2}}}, {RVOpcodes::SRAI, {{0}, {1}}},
{RVOpcodes::SRAW, {{0}, {1, 2}}}, {RVOpcodes::SRAIW, {{0}, {1}}},
// ===== 内存加载指令 (Def: 0, Use: MemBase) =====
{RVOpcodes::LB, {{0}, {}}}, {RVOpcodes::LH, {{0}, {}}}, {RVOpcodes::LW, {{0}, {}}}, {RVOpcodes::LD, {{0}, {}}},
{RVOpcodes::LBU, {{0}, {}}}, {RVOpcodes::LHU, {{0}, {}}}, {RVOpcodes::LWU, {{0}, {}}},
{RVOpcodes::FLW, {{0}, {}}}, {RVOpcodes::FLD, {{0}, {}}},
// ===== 内存存储指令 (Def: None, Use: ValToStore, MemBase) =====
{RVOpcodes::SB, {{}, {0, 1}}}, {RVOpcodes::SH, {{}, {0, 1}}}, {RVOpcodes::SW, {{}, {0, 1}}}, {RVOpcodes::SD, {{}, {0, 1}}},
{RVOpcodes::FSW, {{}, {0, 1}}}, {RVOpcodes::FSD, {{}, {0, 1}}},
// ===== 控制流指令 =====
{RVOpcodes::BEQ, {{}, {0, 1}}}, {RVOpcodes::BNE, {{}, {0, 1}}}, {RVOpcodes::BLT, {{}, {0, 1}}},
{RVOpcodes::BGE, {{}, {0, 1}}}, {RVOpcodes::BLTU, {{}, {0, 1}}}, {RVOpcodes::BGEU, {{}, {0, 1}}},
{RVOpcodes::JALR, {{0}, {1}}}, // def: ra (implicit) and op0, use: op1
// ===== 浮点指令 =====
{RVOpcodes::FADD_S, {{0}, {1, 2}}}, {RVOpcodes::FSUB_S, {{0}, {1, 2}}},
{RVOpcodes::FMUL_S, {{0}, {1, 2}}}, {RVOpcodes::FDIV_S, {{0}, {1, 2}}}, {RVOpcodes::FEQ_S, {{0}, {1, 2}}},
{RVOpcodes::FLT_S, {{0}, {1, 2}}}, {RVOpcodes::FLE_S, {{0}, {1, 2}}}, {RVOpcodes::FCVT_S_W, {{0}, {1}}},
{RVOpcodes::FCVT_W_S, {{0}, {1}}}, {RVOpcodes::FMV_S, {{0}, {1}}}, {RVOpcodes::FMV_W_X, {{0}, {1}}},
{RVOpcodes::FMV_X_W, {{0}, {1}}}, {RVOpcodes::FNEG_S, {{0}, {1}}},
// ===== 伪指令 =====
{RVOpcodes::LI, {{0}, {}}}, {RVOpcodes::LA, {{0}, {}}},
{RVOpcodes::MV, {{0}, {1}}}, {RVOpcodes::SEQZ, {{0}, {1}}}, {RVOpcodes::SNEZ, {{0}, {1}}},
{RVOpcodes::NEG, {{0}, {1}}}, {RVOpcodes::NEGW, {{0}, {1}}},
};
// lambda表达式用于获取操作数的寄存器ID虚拟或物理
const unsigned offset = static_cast<unsigned>(PhysicalReg::PHYS_REG_START_ID);
auto get_any_reg_id = [&](const MachineOperand* op) -> unsigned {

View File

@ -23,21 +23,6 @@ public:
bool runOnFunction(Function *F, AnalysisManager& AM) override;
void runOnMachineFunction(MachineFunction* mfunc);
/**
* @brief 设置是否启用浮点乘加融合优化
* @param enabled 是否启用
*/
static void setFusedMulAddEnabled(bool enabled) { fusedMulAddEnabled = enabled; }
/**
* @brief 检查是否启用了浮点乘加融合优化
* @return 是否启用
*/
static bool isFusedMulAddEnabled() { return fusedMulAddEnabled; }
private:
static bool fusedMulAddEnabled; // 浮点乘加融合优化开关
};
} // namespace sysy

View File

@ -26,7 +26,6 @@ private:
unsigned getTypeSizeInBytes(Type* type);
Module* module;
bool gc_failed = false;
};
} // namespace sysy

View File

@ -1,61 +0,0 @@
#ifndef RISCV64_BASICBLOCKALLOC_H
#define RISCV64_BASICBLOCKALLOC_H
#include "RISCv64LLIR.h"
#include "RISCv64ISel.h"
#include <set>
#include <map>
#include <vector>
namespace sysy {
/**
* @class RISCv64BasicBlockAlloc
* @brief 一个有状态的、基本块级的贪心寄存器分配器。
*
* 该分配器作为简单但可靠的实现,它逐个处理基本块,并在块内尽可能地
* 将虚拟寄存器的值保留在物理寄存器中,以减少不必要的内存访问。
*/
class RISCv64BasicBlockAlloc {
public:
RISCv64BasicBlockAlloc(MachineFunction* mfunc);
void run();
private:
void computeLiveness();
void processBasicBlock(MachineBasicBlock* mbb);
void assignStackSlotsForAllVRegs();
// 核心分配函数
PhysicalReg ensureInReg(unsigned vreg, std::vector<std::unique_ptr<MachineInstr>>& new_instrs);
PhysicalReg allocReg(unsigned vreg, std::vector<std::unique_ptr<MachineInstr>>& new_instrs);
PhysicalReg findFreeReg(bool is_fp);
PhysicalReg spillReg(bool is_fp, std::vector<std::unique_ptr<MachineInstr>>& new_instrs);
// 状态跟踪(每个基本块开始时都会重置)
std::map<unsigned, PhysicalReg> vreg_to_preg; // 当前vreg到物理寄存器的映射
std::map<PhysicalReg, unsigned> preg_to_vreg; // 反向映射
std::set<PhysicalReg> dirty_pregs; // 被修改过、需要写回的物理寄存器
// 分配器全局信息
MachineFunction* MFunc;
RISCv64ISel* ISel;
std::map<unsigned, PhysicalReg> abi_vreg_map; // 函数参数的ABI寄存器映射
// 寄存器池和循环索引
std::vector<PhysicalReg> int_temps;
std::vector<PhysicalReg> fp_temps;
int int_temp_idx = 0;
int fp_temp_idx = 0;
// 辅助函数
PhysicalReg getNextIntTemp();
PhysicalReg getNextFpTemp();
// 活性分析结果
std::map<const MachineBasicBlock*, std::set<unsigned>> live_out;
};
} // namespace sysy
#endif // RISCV64_BASICBLOCKALLOC_H

View File

@ -22,6 +22,7 @@ public:
// 公开接口以便后续模块如RegAlloc可以查询或创建vreg
unsigned getVReg(Value* val);
unsigned getNewVReg() { return vreg_counter++; }
unsigned getNewVReg(Type* type);
unsigned getVRegCounter() const;
// 获取 vreg_map 的公共接口

View File

@ -1,98 +0,0 @@
#ifndef RISCV64_INFO_H
#define RISCV64_INFO_H
#include "RISCv64LLIR.h"
#include <map>
#include <vector>
namespace sysy {
// 定义一个全局的、权威的指令信息表
// 它包含了指令的定义(def)和使用(use)操作数索引
// defs: {0} -> 第一个操作数是定义
// uses: {1, 2} -> 第二、三个操作数是使用
static const std::map<RVOpcodes, std::pair<std::vector<int>, std::vector<int>>> op_info = {
// --- 整数计算 (R-Type) ---
{RVOpcodes::ADD, {{0}, {1, 2}}},
{RVOpcodes::SUB, {{0}, {1, 2}}},
{RVOpcodes::MUL, {{0}, {1, 2}}},
{RVOpcodes::MULH, {{0}, {1, 2}}},
{RVOpcodes::DIV, {{0}, {1, 2}}},
{RVOpcodes::DIVW, {{0}, {1, 2}}},
{RVOpcodes::REM, {{0}, {1, 2}}},
{RVOpcodes::REMW, {{0}, {1, 2}}},
{RVOpcodes::ADDW, {{0}, {1, 2}}},
{RVOpcodes::SUBW, {{0}, {1, 2}}},
{RVOpcodes::MULW, {{0}, {1, 2}}},
{RVOpcodes::SLT, {{0}, {1, 2}}},
{RVOpcodes::SLTU, {{0}, {1, 2}}},
{RVOpcodes::XOR, {{0}, {1, 2}}},
{RVOpcodes::OR, {{0}, {1, 2}}},
{RVOpcodes::AND, {{0}, {1, 2}}},
{RVOpcodes::SLL, {{0}, {1, 2}}},
{RVOpcodes::SRL, {{0}, {1, 2}}},
{RVOpcodes::SRA, {{0}, {1, 2}}},
{RVOpcodes::SLLW, {{0}, {1, 2}}},
{RVOpcodes::SRLW, {{0}, {1, 2}}},
{RVOpcodes::SRAW, {{0}, {1, 2}}},
// --- 整数计算 (I-Type) ---
{RVOpcodes::ADDI, {{0}, {1}}},
{RVOpcodes::ADDIW, {{0}, {1}}},
{RVOpcodes::XORI, {{0}, {1}}},
{RVOpcodes::ORI, {{0}, {1}}},
{RVOpcodes::ANDI, {{0}, {1}}},
{RVOpcodes::SLTI, {{0}, {1}}},
{RVOpcodes::SLTIU, {{0}, {1}}},
{RVOpcodes::SLLI, {{0}, {1}}},
{RVOpcodes::SLLIW, {{0}, {1}}},
{RVOpcodes::SRLI, {{0}, {1}}},
{RVOpcodes::SRLIW, {{0}, {1}}},
{RVOpcodes::SRAI, {{0}, {1}}},
{RVOpcodes::SRAIW, {{0}, {1}}},
// --- 内存加载 ---
{RVOpcodes::LW, {{0}, {}}}, {RVOpcodes::LH, {{0}, {}}}, {RVOpcodes::LB, {{0}, {}}},
{RVOpcodes::LWU, {{0}, {}}}, {RVOpcodes::LHU, {{0}, {}}}, {RVOpcodes::LBU, {{0}, {}}},
{RVOpcodes::LD, {{0}, {}}},
{RVOpcodes::FLW, {{0}, {}}}, {RVOpcodes::FLD, {{0}, {}}},
// --- 内存存储 ---
{RVOpcodes::SW, {{}, {0, 1}}}, {RVOpcodes::SH, {{}, {0, 1}}}, {RVOpcodes::SB, {{}, {0, 1}}},
{RVOpcodes::SD, {{}, {0, 1}}},
{RVOpcodes::FSW, {{}, {0, 1}}}, {RVOpcodes::FSD, {{}, {0, 1}}},
// --- 分支指令 ---
{RVOpcodes::BEQ, {{}, {0, 1}}}, {RVOpcodes::BNE, {{}, {0, 1}}}, {RVOpcodes::BLT, {{}, {0, 1}}},
{RVOpcodes::BGE, {{}, {0, 1}}}, {RVOpcodes::BLTU, {{}, {0, 1}}}, {RVOpcodes::BGEU, {{}, {0, 1}}},
// --- 跳转 ---
{RVOpcodes::JAL, {{0}, {}}}, // JAL的rd是def但通常用x0表示不关心返回值这里简化
{RVOpcodes::JALR, {{0}, {1}}},
{RVOpcodes::RET, {{}, {}}}, // RET是伪指令通常展开为JALR
// --- 伪指令 & 其他 ---
{RVOpcodes::LI, {{0}, {}}}, {RVOpcodes::LA, {{0}, {}}},
{RVOpcodes::MV, {{0}, {1}}},
{RVOpcodes::NEG, {{0}, {1}}}, // sub rd, zero, rs1
{RVOpcodes::NEGW, {{0}, {1}}}, // subw rd, zero, rs1
{RVOpcodes::SEQZ, {{0}, {1}}},
{RVOpcodes::SNEZ, {{0}, {1}}},
// --- 函数调用 ---
// CALL的use/def在getInstrUseDef中有特殊处理逻辑这里可以不列出
// --- 浮点指令 ---
{RVOpcodes::FADD_S, {{0}, {1, 2}}}, {RVOpcodes::FSUB_S, {{0}, {1, 2}}},
{RVOpcodes::FMUL_S, {{0}, {1, 2}}}, {RVOpcodes::FDIV_S, {{0}, {1, 2}}},
{RVOpcodes::FMADD_S, {{0}, {1, 2, 3}}},
{RVOpcodes::FEQ_S, {{0}, {1, 2}}}, {RVOpcodes::FLT_S, {{0}, {1, 2}}}, {RVOpcodes::FLE_S, {{0}, {1, 2}}},
{RVOpcodes::FCVT_S_W, {{0}, {1}}}, {RVOpcodes::FCVT_W_S, {{0}, {1}}},
{RVOpcodes::FCVT_W_S_RTZ, {{0}, {1}}},
{RVOpcodes::FMV_S, {{0}, {1}}}, {RVOpcodes::FMV_W_X, {{0}, {1}}}, {RVOpcodes::FMV_X_W, {{0}, {1}}},
{RVOpcodes::FNEG_S, {{0}, {1}}}
};
} // namespace sysy
#endif // RISCV64_INFO_H

View File

@ -41,8 +41,6 @@ enum class PhysicalReg {
// 假设 vreg_counter 不会达到这么大的值
PHYS_REG_START_ID = 1000000,
PHYS_REG_END_ID = PHYS_REG_START_ID + 320, // 预留足够的空间
INVALID, ///< 无效寄存器标记
};
// RISC-V 指令操作码枚举
@ -79,7 +77,6 @@ enum class RVOpcodes {
FSUB_S, // fsub.s rd, rs1, rs2
FMUL_S, // fmul.s rd, rs1, rs2
FDIV_S, // fdiv.s rd, rs1, rs2
FMADD_S, // fmadd.s rd, rs1, rs2, rs3
// 浮点比较 (单精度)
FEQ_S, // feq.s rd, rs1, rs2 (结果写入整数寄存器rd)
@ -89,7 +86,6 @@ enum class RVOpcodes {
// 浮点转换
FCVT_S_W, // fcvt.s.w rd, rs1 (有符号整数 -> 单精度浮点)
FCVT_W_S, // fcvt.w.s rd, rs1 (单精度浮点 -> 有符号整数)
FCVT_W_S_RTZ, // fcvt.w.s rd, rs1, rtz (使用向零截断模式)
// 浮点传送/移动
FMV_S, // fmv.s rd, rs1 (浮点寄存器之间)
@ -97,9 +93,6 @@ enum class RVOpcodes {
FMV_X_W, // fmv.x.w rd, rs1 (浮点寄存器位模式 -> 整数寄存器)
FNEG_S, // fneg.s rd, rs (浮点取负)
// 浮点控制状态寄存器 (CSR)
FSRMI, // fsrmi rd, imm (设置舍入模式立即数)
// 伪指令
FRAME_LOAD_W, // 从栈帧加载 32位 Word (对应 lw)
FRAME_LOAD_D, // 从栈帧加载 64位 Doubleword (对应 ld)
@ -256,19 +249,6 @@ public:
void addOperand(std::unique_ptr<MachineOperand> operand) {
operands.push_back(std::move(operand));
}
/**
* @brief (为紧急溢出模式添加)将指令中所有对特定虚拟寄存器的引用替换为指定的物理寄存器。
* * @param old_vreg 需要被替换的虚拟寄存器号。
* @param preg 用于替换的物理寄存器。
*/
void replaceVRegWithPReg(unsigned old_vreg, PhysicalReg preg);
/**
* @brief (为常规溢出模式添加)根据提供的映射表,重映射指令中的虚拟寄存器。
* * @param use_remap 一个从旧vreg到新vreg的映射用于指令的use操作数。
* @param def_remap 一个从旧vreg到新vreg的映射用于指令的def操作数。
*/
void remapVRegs(const std::map<unsigned, unsigned>& use_remap, const std::map<unsigned, unsigned>& def_remap);
private:
RVOpcodes opcode;
std::vector<std::unique_ptr<MachineOperand>> operands;
@ -333,22 +313,6 @@ private:
std::vector<std::unique_ptr<MachineBasicBlock>> blocks;
StackFrameInfo frame_info;
};
inline bool isMemoryOp(RVOpcodes opcode) {
switch (opcode) {
case RVOpcodes::LB: case RVOpcodes::LH: case RVOpcodes::LW: case RVOpcodes::LD:
case RVOpcodes::LBU: case RVOpcodes::LHU: case RVOpcodes::LWU:
case RVOpcodes::SB: case RVOpcodes::SH: case RVOpcodes::SW: case RVOpcodes::SD:
case RVOpcodes::FLW:
case RVOpcodes::FSW:
case RVOpcodes::FLD:
case RVOpcodes::FSD:
return true;
default:
return false;
}
}
void getInstrUseDef(const MachineInstr* instr, std::set<unsigned>& use, std::set<unsigned>& def);
} // namespace sysy

View File

@ -1,81 +0,0 @@
#ifndef RISCV64_LINEARSCAN_H
#define RISCV64_LINEARSCAN_H
#include "RISCv64LLIR.h"
#include "RISCv64ISel.h"
#include <vector>
#include <map>
#include <set>
#include <algorithm>
namespace sysy {
// 前向声明
class MachineBasicBlock;
class MachineFunction;
class RISCv64ISel;
/**
* @brief 表示一个虚拟寄存器的活跃区间。
* 包含起始和结束指令编号。为了简化,我们不处理有“洞”的区间。
*/
struct LiveInterval {
unsigned vreg = 0;
int start = -1;
int end = -1;
bool crosses_call = false;
LiveInterval(unsigned vreg) : vreg(vreg) {}
// 用于排序,按起始点从小到大
bool operator<(const LiveInterval& other) const {
return start < other.start;
}
};
class RISCv64LinearScan {
public:
RISCv64LinearScan(MachineFunction* mfunc);
bool run();
private:
// --- 核心算法流程 ---
void linearizeBlocks();
void computeLiveIntervals();
bool linearScan();
void rewriteProgram();
void applyAllocation();
void spillAtInterval(LiveInterval* current);
// --- 辅助函数 ---
bool isFPVReg(unsigned vreg) const;
void collectUsedCalleeSavedRegs();
MachineFunction* MFunc;
RISCv64ISel* ISel;
// --- 线性扫描数据结构 ---
std::vector<MachineBasicBlock*> linear_order_blocks;
std::map<const MachineInstr*, int> instr_numbering;
std::map<unsigned, LiveInterval> live_intervals;
std::vector<LiveInterval*> unhandled;
std::vector<LiveInterval*> active; // 活跃且已分配物理寄存器的区间
std::set<unsigned> spilled_vregs; // 记录在本轮被决定溢出的vreg
bool conservative_spill_mode = false;
const PhysicalReg SPILL_TEMP_REG = PhysicalReg::T4;
// --- 寄存器池和分配结果 ---
std::vector<PhysicalReg> allocable_int_regs;
std::vector<PhysicalReg> allocable_fp_regs;
std::map<unsigned, PhysicalReg> vreg_to_preg_map;
std::map<unsigned, PhysicalReg> abi_vreg_map;
const std::map<unsigned, Type*>& vreg_type_map;
};
} // namespace sysy
#endif // RISCV64_LINEARSCAN_H

View File

@ -1,7 +1,6 @@
#ifndef RISCV64_PASSES_H
#define RISCV64_PASSES_H
#include "Pass.h"
#include "RISCv64LLIR.h"
#include "Peephole.h"
#include "PreRA_Scheduler.h"
@ -10,8 +9,10 @@
#include "LegalizeImmediates.h"
#include "PrologueEpilogueInsertion.h"
#include "EliminateFrameIndices.h"
#include "Pass.h"
#include "DivStrengthReduction.h"
namespace sysy {
} // namespace sysy

View File

@ -20,7 +20,7 @@ public:
RISCv64RegAlloc(MachineFunction* mfunc);
// 模块主入口
bool run();
void run();
private:
// 类型定义与Python版本对应

View File

@ -350,7 +350,11 @@ private:
std::set<Value*>& visited
);
bool isBasicInductionVariable(Value* val, Loop* loop);
bool hasSimpleMemoryPattern(Loop* loop); // 简单的内存模式检查
// ========== 循环不变量分析辅助方法 ==========
bool isInvariantOperands(Instruction* inst, Loop* loop, const std::unordered_set<Value*>& invariants);
bool isMemoryLocationModifiedInLoop(Value* ptr, Loop* loop);
bool isMemoryLocationLoadedInLoop(Value* ptr, Loop* loop, Instruction* excludeInst = nullptr);
bool isPureFunction(Function* calledFunc);
};
} // namespace sysy

View File

@ -0,0 +1,87 @@
#pragma once
#include "Pass.h"
#include "IR.h"
#include "Dom.h"
#include "SideEffectAnalysis.h"
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <string>
#include <sstream>
namespace sysy {
// GVN优化遍的核心逻辑封装类
class GVNContext {
public:
// 运行GVN优化的主要方法
void run(Function* func, AnalysisManager* AM, bool& changed);
private:
// 新的值编号系统
std::unordered_map<Value*, unsigned> valueToNumber; // Value -> 值编号
std::unordered_map<unsigned, Value*> numberToValue; // 值编号 -> 代表值
std::unordered_map<std::string, unsigned> expressionToNumber; // 表达式 -> 值编号
unsigned nextValueNumber = 1;
// 已访问的基本块集合
std::unordered_set<BasicBlock*> visited;
// 逆后序遍历的基本块列表
std::vector<BasicBlock*> rpoBlocks;
// 需要删除的指令集合
std::unordered_set<Instruction*> needRemove;
// 分析结果
DominatorTree* domTree = nullptr;
SideEffectAnalysisResult* sideEffectAnalysis = nullptr;
// 计算逆后序遍历
void computeRPO(Function* func);
void dfs(BasicBlock* bb);
// 新的值编号方法
unsigned getValueNumber(Value* value);
unsigned assignValueNumber(Value* value);
// 基本块处理
void processBasicBlock(BasicBlock* bb, bool& changed);
// 指令处理
bool processInstruction(Instruction* inst);
// 表达式构建和查找
std::string buildExpressionKey(Instruction* inst);
Value* findExistingValue(const std::string& exprKey, Instruction* inst);
// 支配关系和安全性检查
bool dominates(Instruction* a, Instruction* b);
bool isMemorySafe(LoadInst* earlierLoad, LoadInst* laterLoad);
// 清理方法
void eliminateRedundantInstructions(bool& changed);
void invalidateMemoryValues(StoreInst* store);
};
// GVN优化遍类
class GVN : public OptimizationPass {
public:
// 静态成员作为该遍的唯一ID
static void* ID;
GVN() : OptimizationPass("GVN", Granularity::Function) {}
// 在函数上运行优化
bool runOnFunction(Function* func, AnalysisManager& AM) override;
// 返回该遍的唯一ID
void* getPassID() const override { return ID; }
// 声明分析依赖
void getAnalysisUsage(std::set<void*>& analysisDependencies,
std::set<void*>& analysisInvalidations) const override;
};
} // namespace sysy

View File

@ -0,0 +1,107 @@
#pragma once
#include "Pass.h"
#include "IR.h"
#include "SideEffectAnalysis.h"
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <cstdint>
namespace sysy {
// 魔数乘法结构,用于除法优化
struct MagicNumber {
uint32_t multiplier;
int shift;
bool needAdd;
MagicNumber(uint32_t m, int s, bool add = false)
: multiplier(m), shift(s), needAdd(add) {}
};
// 全局强度削弱优化遍的核心逻辑封装类
class GlobalStrengthReductionContext {
public:
// 构造函数接受IRBuilder参数
explicit GlobalStrengthReductionContext(IRBuilder* builder) : builder(builder) {}
// 运行优化的主要方法
void run(Function* func, AnalysisManager* AM, bool& changed);
private:
IRBuilder* builder; // IR构建器
// 分析结果
SideEffectAnalysisResult* sideEffectAnalysis = nullptr;
// 优化计数
int algebraicOptCount = 0;
int strengthReductionCount = 0;
int divisionOptCount = 0;
// 主要优化方法
bool processBasicBlock(BasicBlock* bb);
bool processInstruction(Instruction* inst);
// 代数优化方法
bool tryAlgebraicOptimization(Instruction* inst);
bool optimizeAddition(BinaryInst* inst);
bool optimizeSubtraction(BinaryInst* inst);
bool optimizeMultiplication(BinaryInst* inst);
bool optimizeDivision(BinaryInst* inst);
bool optimizeComparison(BinaryInst* inst);
bool optimizeLogical(BinaryInst* inst);
// 强度削弱方法
bool tryStrengthReduction(Instruction* inst);
bool reduceMultiplication(BinaryInst* inst);
bool reduceDivision(BinaryInst* inst);
bool reducePower(CallInst* inst);
// 复杂乘法强度削弱方法
bool tryComplexMultiplication(BinaryInst* inst, Value* variable, int constant);
bool findOptimalShiftDecomposition(int constant, std::vector<int>& shifts);
Value* createShiftDecomposition(BinaryInst* inst, Value* variable, const std::vector<int>& shifts);
// 魔数乘法相关方法
MagicNumber computeMagicNumber(uint32_t divisor);
std::pair<int, int> computeMulhMagicNumbers(int divisor);
Value* createMagicDivision(BinaryInst* divInst, uint32_t divisor, const MagicNumber& magic);
Value* createMagicDivisionLibdivide(BinaryInst* divInst, int divisor);
bool isPowerOfTwo(uint32_t n);
int log2OfPowerOfTwo(uint32_t n);
// 辅助方法
bool isConstantInt(Value* val, int& constVal);
bool isConstantInt(Value* val, uint32_t& constVal);
ConstantInteger* getConstantInt(int val);
bool hasOnlyLocalUses(Instruction* inst);
void replaceWithOptimized(Instruction* original, Value* replacement);
};
// 全局强度削弱优化遍类
class GlobalStrengthReduction : public OptimizationPass {
private:
IRBuilder* builder; // IR构建器用于创建新指令
public:
// 静态成员作为该遍的唯一ID
static void* ID;
// 构造函数接受IRBuilder参数
explicit GlobalStrengthReduction(IRBuilder* builder)
: OptimizationPass("GlobalStrengthReduction", Granularity::Function), builder(builder) {}
// 在函数上运行优化
bool runOnFunction(Function* func, AnalysisManager& AM) override;
// 返回该遍的唯一ID
void* getPassID() const override { return ID; }
// 声明分析依赖
void getAnalysisUsage(std::set<void*>& analysisDependencies,
std::set<void*>& analysisInvalidations) const override;
};
} // namespace sysy

View File

@ -127,13 +127,6 @@ private:
*/
bool analyzeInductionVariableRange(const InductionVarInfo* ivInfo, Loop* loop) const;
/**
* 计算用于除法优化的魔数和移位量
* @param divisor 除数
* @return {魔数, 移位量}
*/
std::pair<int, int> computeMulhMagicNumbers(int divisor) const;
/**
* 生成除法替换代码
* @param candidate 优化候选项

View File

@ -107,6 +107,190 @@ public:
// 所以当AllocaInst的basetype是PointerType时一维数组或者是指向ArrayType的PointerType多位数组返回true
return aval && (baseType->isPointer() || baseType->as<PointerType>()->getBaseType()->isArray());
}
//该实现参考了libdivide的算法
static std::pair<int, int> computeMulhMagicNumbers(int divisor) {
if (DEBUG) {
std::cout << "\n[SR] ===== Computing magic numbers for divisor " << divisor << " (libdivide algorithm) =====" << std::endl;
}
if (divisor == 0) {
if (DEBUG) std::cout << "[SR] Error: divisor must be != 0" << std::endl;
return {-1, -1};
}
// libdivide 常数
const uint8_t LIBDIVIDE_ADD_MARKER = 0x40;
const uint8_t LIBDIVIDE_NEGATIVE_DIVISOR = 0x80;
// 辅助函数:计算前导零个数
auto count_leading_zeros32 = [](uint32_t val) -> uint32_t {
if (val == 0) return 32;
return __builtin_clz(val);
};
// 辅助函数64位除法返回32位商和余数
auto div_64_32 = [](uint32_t high, uint32_t low, uint32_t divisor, uint32_t* rem) -> uint32_t {
uint64_t dividend = ((uint64_t)high << 32) | low;
uint32_t quotient = dividend / divisor;
*rem = dividend % divisor;
return quotient;
};
if (DEBUG) {
std::cout << "[SR] Input divisor: " << divisor << std::endl;
}
// libdivide_internal_s32_gen 算法实现
int32_t d = divisor;
uint32_t ud = (uint32_t)d;
uint32_t absD = (d < 0) ? -ud : ud;
if (DEBUG) {
std::cout << "[SR] absD = " << absD << std::endl;
}
uint32_t floor_log_2_d = 31 - count_leading_zeros32(absD);
if (DEBUG) {
std::cout << "[SR] floor_log_2_d = " << floor_log_2_d << std::endl;
}
// 检查 absD 是否为2的幂
if ((absD & (absD - 1)) == 0) {
if (DEBUG) {
std::cout << "[SR] " << absD << " 是2的幂使用移位方法" << std::endl;
}
// 对于2的幂我们只使用移位不需要魔数
int shift = floor_log_2_d;
if (d < 0) shift |= 0x80; // 标记负数
if (DEBUG) {
std::cout << "[SR] Power of 2 result: magic=0, shift=" << shift << std::endl;
std::cout << "[SR] ===== End magic computation =====" << std::endl;
}
// 对于我们的目的我们将在IR生成中以不同方式处理2的幂
// 返回特殊标记
return {0, shift};
}
if (DEBUG) {
std::cout << "[SR] " << absD << " is not a power of 2, computing magic number" << std::endl;
}
// 非2的幂除数的魔数计算
uint8_t more;
uint32_t rem, proposed_m;
// 计算 proposed_m = floor(2^(floor_log_2_d + 31) / absD)
proposed_m = div_64_32((uint32_t)1 << (floor_log_2_d - 1), 0, absD, &rem);
const uint32_t e = absD - rem;
if (DEBUG) {
std::cout << "[SR] proposed_m = " << proposed_m << ", rem = " << rem << ", e = " << e << std::endl;
}
// 确定是否需要"加法"版本
const bool branchfree = false; // 使用分支版本
if (!branchfree && e < ((uint32_t)1 << floor_log_2_d)) {
// 这个幂次有效
more = (uint8_t)(floor_log_2_d - 1);
if (DEBUG) {
std::cout << "[SR] Using basic algorithm, shift = " << (int)more << std::endl;
}
} else {
// 我们需要上升一个等级
proposed_m += proposed_m;
const uint32_t twice_rem = rem + rem;
if (twice_rem >= absD || twice_rem < rem) {
proposed_m += 1;
}
more = (uint8_t)(floor_log_2_d | LIBDIVIDE_ADD_MARKER);
if (DEBUG) {
std::cout << "[SR] Using add algorithm, proposed_m = " << proposed_m << ", more = " << (int)more << std::endl;
}
}
proposed_m += 1;
int32_t magic = (int32_t)proposed_m;
// 处理负除数
if (d < 0) {
more |= LIBDIVIDE_NEGATIVE_DIVISOR;
if (!branchfree) {
magic = -magic;
}
if (DEBUG) {
std::cout << "[SR] Negative divisor, magic = " << magic << ", more = " << (int)more << std::endl;
}
}
// 为我们的IR生成提取移位量和标志
int shift = more & 0x3F; // 移除标志保留移位量位0-5
bool need_add = (more & LIBDIVIDE_ADD_MARKER) != 0;
bool is_negative = (more & LIBDIVIDE_NEGATIVE_DIVISOR) != 0;
if (DEBUG) {
std::cout << "[SR] Final result: magic = " << magic << ", more = " << (int)more
<< " (0x" << std::hex << (int)more << std::dec << ")" << std::endl;
std::cout << "[SR] Shift = " << shift << ", need_add = " << need_add
<< ", is_negative = " << is_negative << std::endl;
// Test the magic number using the correct libdivide algorithm
std::cout << "[SR] Testing magic number (libdivide algorithm):" << std::endl;
int test_values[] = {1, 7, 37, 100, 999, -1, -7, -37, -100};
for (int test_val : test_values) {
int64_t quotient;
// 实现正确的libdivide算法
int64_t product = (int64_t)test_val * magic;
int64_t high_bits = product >> 32;
if (need_add) {
// ADD_MARKER情况移位前加上被除数
// 这是libdivide的关键洞察
high_bits += test_val;
quotient = high_bits >> shift;
} else {
// 正常情况:只是移位
quotient = high_bits >> shift;
}
// 符号修正这是libdivide有符号除法的关键部分
// 如果被除数为负商需要加1来匹配C语言的截断除法语义
if (test_val < 0) {
quotient += 1;
}
int expected = test_val / divisor;
bool correct = (quotient == expected);
std::cout << "[SR] " << test_val << " / " << divisor << " = " << quotient
<< " (expected " << expected << ") " << (correct ? "" : "") << std::endl;
}
std::cout << "[SR] ===== End magic computation =====" << std::endl;
}
// 返回魔数、移位量并在移位中编码ADD_MARKER标志
// 我们将使用移位的第6位表示ADD_MARKER第7位表示负数如果需要
int encoded_shift = shift;
if (need_add) {
encoded_shift |= 0x40; // 设置第6位表示ADD_MARKER
if (DEBUG) {
std::cout << "[SR] Encoding ADD_MARKER in shift: " << encoded_shift << std::endl;
}
}
return {magic, encoded_shift};
}
};
}// namespace sysy

View File

@ -0,0 +1,39 @@
#pragma once
#include "Pass.h"
#include "Dom.h"
#include "Loop.h"
namespace sysy {
/**
* @class TailCallOpt
* @brief 优化尾调用的中端优化通道。
*
* 该类实现了一个针对函数级别的尾调用优化的优化通道OptimizationPass
* 通过分析和转换 IR中间表示将可优化的尾调用转换为更高效的形式
* 以减少函数调用的开销,提升程序性能。
*
* @note 需要传入 IRBuilder 指针用于 IR 构建和修改。
*
* @method runOnFunction
* 对指定函数进行尾调用优化。
*
* @method getPassID
* 获取当前优化通道的唯一标识符。
*
* @method getAnalysisUsage
* 指定该优化通道所依赖和失效的分析集合。
*/
class TailCallOpt : public OptimizationPass {
private:
IRBuilder* builder;
public:
TailCallOpt(IRBuilder* builder) : OptimizationPass("TailCallOpt", Granularity::Function), builder(builder) {}
static void *ID;
bool runOnFunction(Function *F, AnalysisManager &AM) override;
void *getPassID() const override { return &ID; }
void getAnalysisUsage(std::set<void *> &analysisDependencies, std::set<void *> &analysisInvalidations) const override;
};
} // namespace sysy

View File

@ -15,14 +15,17 @@ add_library(midend_lib STATIC
Pass/Optimize/DCE.cpp
Pass/Optimize/Mem2Reg.cpp
Pass/Optimize/Reg2Mem.cpp
Pass/Optimize/GVN.cpp
Pass/Optimize/SysYIRCFGOpt.cpp
Pass/Optimize/SCCP.cpp
Pass/Optimize/LoopNormalization.cpp
Pass/Optimize/LICM.cpp
Pass/Optimize/LoopStrengthReduction.cpp
Pass/Optimize/InductionVariableElimination.cpp
Pass/Optimize/GlobalStrengthReduction.cpp
Pass/Optimize/BuildCFG.cpp
Pass/Optimize/LargeArrayToGlobal.cpp
Pass/Optimize/TailCallOpt.cpp
)
# 包含中端模块所需的头文件路径

View File

@ -847,7 +847,7 @@ void CondBrInst::print(std::ostream &os) const {
os << "%tmp_cond_" << condName << "_" << uniqueSuffix << " = icmp ne i32 ";
printOperand(os, condition);
os << ", 0\n br i1 %tmp_cond_" << condName << "_" << uniqueSuffix;
os << ", 0\n br i1 %tmp_cond_" << condName << "_" << uniqueSuffix;
os << ", label %";
printBlockName(os, getThenBlock());
@ -886,7 +886,7 @@ void MemsetInst::print(std::ostream &os) const {
// This is done at print time to avoid modifying the IR structure
os << "%tmp_bitcast_" << ptr->getName() << " = bitcast " << *ptr->getType() << " ";
printOperand(os, ptr);
os << " to i8*\n ";
os << " to i8*\n ";
// Now call memset with the bitcast result
os << "call void @llvm.memset.p0i8.i32(i8* %tmp_bitcast_" << ptr->getName() << ", i8 ";

View File

@ -776,38 +776,324 @@ void LoopCharacteristicsPass::findDerivedInductionVars(
}
}
// 递归/推进式判定
bool LoopCharacteristicsPass::isClassicLoopInvariant(Value* val, Loop* loop, const std::unordered_set<Value*>& invariants) {
// 1. 常量
if (auto* constval = dynamic_cast<ConstantValue*>(val)) return true;
// 2. 参数函数参数通常不在任何BasicBlock内直接判定为不变量
if (auto* arg = dynamic_cast<Argument*>(val)) return true;
// 3. 指令且定义在循环外
if (auto* inst = dynamic_cast<Instruction*>(val)) {
if (!loop->contains(inst->getParent()))
return true;
// 4. 跳转 phi指令 副作用 不外提
if (inst->isTerminator() || inst->isPhi() || sideEffectAnalysis->hasSideEffect(inst))
// 检查操作数是否都是不变量
bool LoopCharacteristicsPass::isInvariantOperands(Instruction* inst, Loop* loop, const std::unordered_set<Value*>& invariants) {
for (size_t i = 0; i < inst->getNumOperands(); ++i) {
Value* op = inst->getOperand(i);
if (!isClassicLoopInvariant(op, loop, invariants) && !invariants.count(op)) {
return false;
// 5. 所有操作数都是不变量
for (size_t i = 0; i < inst->getNumOperands(); ++i) {
Value* op = inst->getOperand(i);
if (!isClassicLoopInvariant(op, loop, invariants) && !invariants.count(op))
return false;
}
return true;
}
// 其它情况
return true;
}
// 检查内存位置是否在循环中被修改
bool LoopCharacteristicsPass::isMemoryLocationModifiedInLoop(Value* ptr, Loop* loop) {
// 遍历循环中的所有指令,检查是否有对该内存位置的写入
for (BasicBlock* bb : loop->getBlocks()) {
for (auto& inst : bb->getInstructions()) {
// 1. 检查直接的Store指令
if (auto* storeInst = dynamic_cast<StoreInst*>(inst.get())) {
Value* storeTar = storeInst->getPointer();
// 使用别名分析检查是否可能别名
if (aliasAnalysis) {
auto aliasType = aliasAnalysis->queryAlias(ptr, storeTar);
if (aliasType != AliasType::NO_ALIAS) {
if (DEBUG) {
std::cout << " Memory location " << ptr->getName()
<< " may be modified by store to " << storeTar->getName() << std::endl;
}
return true;
}
} else {
// 如果没有别名分析,保守处理 - 只检查精确匹配
if (ptr == storeTar) {
return true;
}
}
}
// 2. 检查函数调用是否可能修改该内存位置
else if (auto* callInst = dynamic_cast<CallInst*>(inst.get())) {
Function* calledFunc = callInst->getCallee();
// 如果是纯函数,不会修改内存
if (isPureFunction(calledFunc)) {
continue;
}
// 检查函数参数中是否有该内存位置的指针
for (size_t i = 1; i < callInst->getNumOperands(); ++i) { // 跳过函数指针
Value* arg = callInst->getOperand(i);
// 检查参数是否是指针类型且可能指向该内存位置
if (auto* ptrType = dynamic_cast<PointerType*>(arg->getType())) {
// 使用别名分析检查
if (aliasAnalysis) {
auto aliasType = aliasAnalysis->queryAlias(ptr, arg);
if (aliasType != AliasType::NO_ALIAS) {
if (DEBUG) {
std::cout << " Memory location " << ptr->getName()
<< " may be modified by function call " << calledFunc->getName()
<< " through parameter " << arg->getName() << std::endl;
}
return true;
}
} else {
// 没有别名分析,检查精确匹配
if (ptr == arg) {
if (DEBUG) {
std::cout << " Memory location " << ptr->getName()
<< " may be modified by function call " << calledFunc->getName()
<< " (exact match)" << std::endl;
}
return true;
}
}
}
}
}
}
}
return false;
}
bool LoopCharacteristicsPass::hasSimpleMemoryPattern(Loop* loop) {
// 检查是否有简单的内存访问模式
return true; // 暂时简化处理
// 检查内存位置是否在循环中被读取
bool LoopCharacteristicsPass::isMemoryLocationLoadedInLoop(Value* ptr, Loop* loop, Instruction* excludeInst) {
// 遍历循环中的所有Load指令检查是否有对该内存位置的读取
for (BasicBlock* bb : loop->getBlocks()) {
for (auto& inst : bb->getInstructions()) {
if (inst.get() == excludeInst) continue; // 排除当前指令本身
if (auto* loadInst = dynamic_cast<LoadInst*>(inst.get())) {
Value* loadSrc = loadInst->getPointer();
// 使用别名分析检查是否可能别名
if (aliasAnalysis) {
auto aliasType = aliasAnalysis->queryAlias(ptr, loadSrc);
if (aliasType != AliasType::NO_ALIAS) {
return true;
}
} else {
// 如果没有别名分析,保守处理 - 只检查精确匹配
if (ptr == loadSrc) {
return true;
}
}
}
}
}
return false;
}
// 检查函数调用是否为纯函数
bool LoopCharacteristicsPass::isPureFunction(Function* calledFunc) {
if (!calledFunc) return false;
// 使用副作用分析检查函数是否为纯函数
if (sideEffectAnalysis && sideEffectAnalysis->isPureFunction(calledFunc)) {
return true;
}
// 检查是否为内置纯函数(如数学函数)
std::string funcName = calledFunc->getName();
static const std::set<std::string> pureFunctions = {
"abs", "fabs", "sqrt", "sin", "cos", "tan", "exp", "log", "pow",
"floor", "ceil", "round", "min", "max"
};
return pureFunctions.count(funcName) > 0;
}
// 递归/推进式判定 - 完善版本
bool LoopCharacteristicsPass::isClassicLoopInvariant(Value* val, Loop* loop, const std::unordered_set<Value*>& invariants) {
if (DEBUG >= 2) {
std::cout << " Checking loop invariant for: " << val->getName() << std::endl;
}
// 1. 常量
if (auto* constval = dynamic_cast<ConstantValue*>(val)) {
if (DEBUG >= 2) std::cout << " -> Constant: YES" << std::endl;
return true;
}
// 2. 参数函数参数通常不在任何BasicBlock内直接判定为不变量
// 在SSA形式下参数不会被重新赋值
if (auto* arg = dynamic_cast<Argument*>(val)) {
if (DEBUG >= 2) std::cout << " -> Function argument: YES" << std::endl;
return true;
}
// 3. 指令且定义在循环外
if (auto* inst = dynamic_cast<Instruction*>(val)) {
if (!loop->contains(inst->getParent())) {
if (DEBUG >= 2) std::cout << " -> Defined outside loop: YES" << std::endl;
return true;
}
// 4. 跳转指令、phi指令不能外提
if (inst->isTerminator() || inst->isPhi()) {
if (DEBUG >= 2) std::cout << " -> Terminator or PHI: NO" << std::endl;
return false;
}
// 5. 根据指令类型进行具体分析
switch (inst->getKind()) {
case Instruction::Kind::kStore: {
// Store指令检查循环内是否有对该内存的load
auto* storeInst = dynamic_cast<StoreInst*>(inst);
Value* storePtr = storeInst->getPointer();
// 首先检查操作数是否不变
if (!isInvariantOperands(inst, loop, invariants)) {
if (DEBUG >= 2) std::cout << " -> Store: operands not invariant: NO" << std::endl;
return false;
}
// 检查是否有对该内存位置的load
if (isMemoryLocationLoadedInLoop(storePtr, loop, inst)) {
if (DEBUG >= 2) std::cout << " -> Store: memory location loaded in loop: NO" << std::endl;
return false;
}
if (DEBUG >= 2) std::cout << " -> Store: safe to hoist: YES" << std::endl;
return true;
}
case Instruction::Kind::kLoad: {
// Load指令检查循环内是否有对该内存的store
auto* loadInst = dynamic_cast<LoadInst*>(inst);
Value* loadPtr = loadInst->getPointer();
// 首先检查指针操作数是否不变
if (!isInvariantOperands(inst, loop, invariants)) {
if (DEBUG >= 2) std::cout << " -> Load: pointer not invariant: NO" << std::endl;
return false;
}
// 检查是否有对该内存位置的store
if (isMemoryLocationModifiedInLoop(loadPtr, loop)) {
if (DEBUG >= 2) std::cout << " -> Load: memory location modified in loop: NO" << std::endl;
return false;
}
if (DEBUG >= 2) std::cout << " -> Load: safe to hoist: YES" << std::endl;
return true;
}
case Instruction::Kind::kCall: {
// Call指令检查是否为纯函数且参数不变
auto* callInst = dynamic_cast<CallInst*>(inst);
Function* calledFunc = callInst->getCallee();
// 检查是否为纯函数
if (!isPureFunction(calledFunc)) {
if (DEBUG >= 2) std::cout << " -> Call: not pure function: NO" << std::endl;
return false;
}
// 检查参数是否都不变
if (!isInvariantOperands(inst, loop, invariants)) {
if (DEBUG >= 2) std::cout << " -> Call: arguments not invariant: NO" << std::endl;
return false;
}
if (DEBUG >= 2) std::cout << " -> Call: pure function with invariant args: YES" << std::endl;
return true;
}
case Instruction::Kind::kGetElementPtr: {
// GEP指令检查基址和索引是否都不变
if (!isInvariantOperands(inst, loop, invariants)) {
if (DEBUG >= 2) std::cout << " -> GEP: base or indices not invariant: NO" << std::endl;
return false;
}
if (DEBUG >= 2) std::cout << " -> GEP: base and indices invariant: YES" << std::endl;
return true;
}
// 一元运算指令
case Instruction::Kind::kNeg:
case Instruction::Kind::kNot:
case Instruction::Kind::kFNeg:
case Instruction::Kind::kFNot:
case Instruction::Kind::kFtoI:
case Instruction::Kind::kItoF:
case Instruction::Kind::kBitItoF:
case Instruction::Kind::kBitFtoI: {
// 检查操作数是否不变
if (!isInvariantOperands(inst, loop, invariants)) {
if (DEBUG >= 2) std::cout << " -> Unary op: operand not invariant: NO" << std::endl;
return false;
}
if (DEBUG >= 2) std::cout << " -> Unary op: operand invariant: YES" << std::endl;
return true;
}
// 二元运算指令
case Instruction::Kind::kAdd:
case Instruction::Kind::kSub:
case Instruction::Kind::kMul:
case Instruction::Kind::kDiv:
case Instruction::Kind::kRem:
case Instruction::Kind::kSll:
case Instruction::Kind::kSrl:
case Instruction::Kind::kSra:
case Instruction::Kind::kAnd:
case Instruction::Kind::kOr:
case Instruction::Kind::kFAdd:
case Instruction::Kind::kFSub:
case Instruction::Kind::kFMul:
case Instruction::Kind::kFDiv:
case Instruction::Kind::kICmpEQ:
case Instruction::Kind::kICmpNE:
case Instruction::Kind::kICmpLT:
case Instruction::Kind::kICmpGT:
case Instruction::Kind::kICmpLE:
case Instruction::Kind::kICmpGE:
case Instruction::Kind::kFCmpEQ:
case Instruction::Kind::kFCmpNE:
case Instruction::Kind::kFCmpLT:
case Instruction::Kind::kFCmpGT:
case Instruction::Kind::kFCmpLE:
case Instruction::Kind::kFCmpGE:
case Instruction::Kind::kMulh: {
// 检查所有操作数是否不变
if (!isInvariantOperands(inst, loop, invariants)) {
if (DEBUG >= 2) std::cout << " -> Binary op: operands not invariant: NO" << std::endl;
return false;
}
if (DEBUG >= 2) std::cout << " -> Binary op: operands invariant: YES" << std::endl;
return true;
}
default: {
// 其他指令:使用副作用分析
if (sideEffectAnalysis && sideEffectAnalysis->hasSideEffect(inst)) {
if (DEBUG >= 2) std::cout << " -> Other inst: has side effect: NO" << std::endl;
return false;
}
// 检查操作数是否都不变
if (!isInvariantOperands(inst, loop, invariants)) {
if (DEBUG >= 2) std::cout << " -> Other inst: operands not invariant: NO" << std::endl;
return false;
}
if (DEBUG >= 2) std::cout << " -> Other inst: no side effect, operands invariant: YES" << std::endl;
return true;
}
}
}
// 其它情况
if (DEBUG >= 2) std::cout << " -> Other value type: NO" << std::endl;
return false;
}
} // namespace sysy

View File

@ -26,10 +26,21 @@ const SideEffectInfo &SideEffectAnalysisResult::getInstructionSideEffect(Instruc
}
const SideEffectInfo &SideEffectAnalysisResult::getFunctionSideEffect(Function *func) const {
// 首先检查分析过的用户定义函数
auto it = functionSideEffects.find(func);
if (it != functionSideEffects.end()) {
return it->second;
}
// 如果没有找到,检查是否为已知的库函数
if (func) {
std::string funcName = func->getName();
const SideEffectInfo *knownInfo = getKnownFunctionSideEffect(funcName);
if (knownInfo) {
return *knownInfo;
}
}
// 返回默认的无副作用信息
static SideEffectInfo noEffect;
return noEffect;

View File

@ -0,0 +1,492 @@
#include "GVN.h"
#include "Dom.h"
#include "SysYIROptUtils.h"
#include <algorithm>
#include <cassert>
#include <iostream>
#include <unordered_map>
#include <unordered_set>
extern int DEBUG;
namespace sysy {
// GVN 遍的静态 ID
void *GVN::ID = (void *)&GVN::ID;
// ======================================================================
// GVN 类的实现
// ======================================================================
bool GVN::runOnFunction(Function *func, AnalysisManager &AM) {
if (func->getBasicBlocks().empty()) {
return false;
}
if (DEBUG) {
std::cout << "\n=== Running GVN on function: " << func->getName() << " ===" << std::endl;
}
bool changed = false;
GVNContext context;
context.run(func, &AM, changed);
if (DEBUG) {
if (changed) {
std::cout << "GVN: Function " << func->getName() << " was modified" << std::endl;
} else {
std::cout << "GVN: Function " << func->getName() << " was not modified" << std::endl;
}
std::cout << "=== GVN completed for function: " << func->getName() << " ===" << std::endl;
}
return changed;
}
void GVN::getAnalysisUsage(std::set<void *> &analysisDependencies, std::set<void *> &analysisInvalidations) const {
// GVN依赖以下分析
// 1. 支配树分析 - 用于检查指令的支配关系,确保替换的安全性
analysisDependencies.insert(&DominatorTreeAnalysisPass::ID);
// 2. 副作用分析 - 用于判断函数调用是否可以进行GVN
analysisDependencies.insert(&SysYSideEffectAnalysisPass::ID);
// GVN不会使任何分析失效因为
// - GVN只删除冗余计算不改变CFG结构
// - GVN不修改程序的语义只是消除重复计算
// - 支配关系保持不变
// - 副作用分析结果保持不变
// analysisInvalidations 保持为空
if (DEBUG) {
std::cout << "GVN: Declared analysis dependencies (DominatorTree, SideEffectAnalysis)" << std::endl;
}
}
// ======================================================================
// GVNContext 类的实现 - 重构版本
// ======================================================================
// 简单的表达式哈希结构
struct ExpressionKey {
enum Type { BINARY, UNARY, LOAD, GEP, CALL } type;
int opcode;
std::vector<Value*> operands;
Type* resultType;
bool operator==(const ExpressionKey& other) const {
return type == other.type && opcode == other.opcode &&
operands == other.operands && resultType == other.resultType;
}
};
struct ExpressionKeyHash {
size_t operator()(const ExpressionKey& key) const {
size_t hash = std::hash<int>()(static_cast<int>(key.type)) ^
std::hash<int>()(key.opcode);
for (auto op : key.operands) {
hash ^= std::hash<Value*>()(op) + 0x9e3779b9 + (hash << 6) + (hash >> 2);
}
return hash;
}
};
void GVNContext::run(Function *func, AnalysisManager *AM, bool &changed) {
if (DEBUG) {
std::cout << " Starting GVN analysis for function: " << func->getName() << std::endl;
}
// 获取分析结果
if (AM) {
domTree = AM->getAnalysisResult<DominatorTree, DominatorTreeAnalysisPass>(func);
sideEffectAnalysis = AM->getAnalysisResult<SideEffectAnalysisResult, SysYSideEffectAnalysisPass>();
if (DEBUG) {
if (domTree) {
std::cout << " GVN: Using dominator tree analysis" << std::endl;
} else {
std::cout << " GVN: Warning - dominator tree analysis not available" << std::endl;
}
if (sideEffectAnalysis) {
std::cout << " GVN: Using side effect analysis" << std::endl;
} else {
std::cout << " GVN: Warning - side effect analysis not available" << std::endl;
}
}
}
// 清空状态
valueToNumber.clear();
numberToValue.clear();
expressionToNumber.clear();
nextValueNumber = 1;
visited.clear();
rpoBlocks.clear();
needRemove.clear();
// 计算逆后序遍历
computeRPO(func);
if (DEBUG) {
std::cout << " Computed RPO with " << rpoBlocks.size() << " blocks" << std::endl;
}
// 按逆后序遍历基本块进行GVN
int blockCount = 0;
for (auto bb : rpoBlocks) {
if (DEBUG) {
std::cout << " Processing block " << ++blockCount << "/" << rpoBlocks.size()
<< ": " << bb->getName() << std::endl;
}
processBasicBlock(bb, changed);
}
if (DEBUG) {
std::cout << " Found " << needRemove.size() << " redundant instructions to remove" << std::endl;
}
// 删除冗余指令
eliminateRedundantInstructions(changed);
if (DEBUG) {
std::cout << " GVN analysis completed for function: " << func->getName() << std::endl;
std::cout << " Total values numbered: " << valueToNumber.size() << std::endl;
std::cout << " Instructions eliminated: " << needRemove.size() << std::endl;
}
}
void GVNContext::computeRPO(Function *func) {
rpoBlocks.clear();
visited.clear();
auto entry = func->getEntryBlock();
if (entry) {
dfs(entry);
std::reverse(rpoBlocks.begin(), rpoBlocks.end());
}
}
void GVNContext::dfs(BasicBlock *bb) {
if (!bb || visited.count(bb)) {
return;
}
visited.insert(bb);
// 访问所有后继基本块
for (auto succ : bb->getSuccessors()) {
if (visited.find(succ) == visited.end()) {
dfs(succ);
}
}
rpoBlocks.push_back(bb);
}
unsigned GVNContext::getValueNumber(Value* value) {
// 如果已经有值编号,直接返回
auto it = valueToNumber.find(value);
if (it != valueToNumber.end()) {
return it->second;
}
// 为新值分配编号
return assignValueNumber(value);
}
unsigned GVNContext::assignValueNumber(Value* value) {
unsigned number = nextValueNumber++;
valueToNumber[value] = number;
numberToValue[number] = value;
if (DEBUG >= 2) {
std::cout << " Assigned value number " << number
<< " to " << value->getName() << std::endl;
}
return number;
}
void GVNContext::processBasicBlock(BasicBlock* bb, bool& changed) {
int instCount = 0;
for (auto &instPtr : bb->getInstructions()) {
if (DEBUG) {
std::cout << " Processing instruction " << ++instCount
<< ": " << instPtr->getName() << std::endl;
}
if (processInstruction(instPtr.get())) {
changed = true;
}
}
}
bool GVNContext::processInstruction(Instruction* inst) {
// 跳过分支指令和其他不可优化的指令
if (inst->isBranch() || dynamic_cast<ReturnInst*>(inst) ||
dynamic_cast<AllocaInst*>(inst) || dynamic_cast<StoreInst*>(inst)) {
// 如果是store指令需要使相关的内存值失效
if (auto store = dynamic_cast<StoreInst*>(inst)) {
invalidateMemoryValues(store);
}
// 为这些指令分配值编号但不尝试优化
getValueNumber(inst);
return false;
}
if (DEBUG) {
std::cout << " Processing optimizable instruction: " << inst->getName()
<< " (kind: " << static_cast<int>(inst->getKind()) << ")" << std::endl;
}
// 构建表达式键
std::string exprKey = buildExpressionKey(inst);
if (exprKey.empty()) {
// 不可优化的指令,只分配值编号
getValueNumber(inst);
return false;
}
if (DEBUG >= 2) {
std::cout << " Expression key: " << exprKey << std::endl;
}
// 查找已存在的等价值
Value* existing = findExistingValue(exprKey, inst);
if (existing && existing != inst) {
// 检查支配关系
if (auto existingInst = dynamic_cast<Instruction*>(existing)) {
if (dominates(existingInst, inst)) {
if (DEBUG) {
std::cout << " GVN: Replacing " << inst->getName()
<< " with existing " << existing->getName() << std::endl;
}
// 用已存在的值替换当前指令
inst->replaceAllUsesWith(existing);
needRemove.insert(inst);
// 将当前指令的值编号指向已存在的值
unsigned existingNumber = getValueNumber(existing);
valueToNumber[inst] = existingNumber;
return true;
} else {
if (DEBUG) {
std::cout << " Found equivalent but dominance check failed" << std::endl;
}
}
}
}
// 没有找到等价值,为这个表达式分配新的值编号
unsigned number = assignValueNumber(inst);
expressionToNumber[exprKey] = number;
if (DEBUG) {
std::cout << " Instruction " << inst->getName() << " is unique" << std::endl;
}
return false;
}
std::string GVNContext::buildExpressionKey(Instruction* inst) {
std::ostringstream oss;
if (auto binary = dynamic_cast<BinaryInst*>(inst)) {
oss << "binary_" << static_cast<int>(binary->getKind()) << "_";
oss << getValueNumber(binary->getLhs()) << "_" << getValueNumber(binary->getRhs());
// 对于可交换操作,确保操作数顺序一致
if (binary->isCommutative()) {
unsigned lhsNum = getValueNumber(binary->getLhs());
unsigned rhsNum = getValueNumber(binary->getRhs());
if (lhsNum > rhsNum) {
oss.str("");
oss << "binary_" << static_cast<int>(binary->getKind()) << "_";
oss << rhsNum << "_" << lhsNum;
}
}
} else if (auto unary = dynamic_cast<UnaryInst*>(inst)) {
oss << "unary_" << static_cast<int>(unary->getKind()) << "_";
oss << getValueNumber(unary->getOperand());
} else if (auto gep = dynamic_cast<GetElementPtrInst*>(inst)) {
oss << "gep_" << getValueNumber(gep->getBasePointer());
for (unsigned i = 0; i < gep->getNumIndices(); ++i) {
oss << "_" << getValueNumber(gep->getIndex(i));
}
} else if (auto load = dynamic_cast<LoadInst*>(inst)) {
oss << "load_" << getValueNumber(load->getPointer());
oss << "_" << reinterpret_cast<uintptr_t>(load->getType()); // 类型区分
} else if (auto call = dynamic_cast<CallInst*>(inst)) {
// 只为无副作用的函数调用建立表达式
if (sideEffectAnalysis && sideEffectAnalysis->isPureFunction(call->getCallee())) {
oss << "call_" << call->getCallee()->getName();
for (size_t i = 1; i < call->getNumOperands(); ++i) { // 跳过函数指针
oss << "_" << getValueNumber(call->getOperand(i));
}
} else {
return ""; // 有副作用的函数调用不可优化
}
} else {
return ""; // 不支持的指令类型
}
return oss.str();
}
Value* GVNContext::findExistingValue(const std::string& exprKey, Instruction* inst) {
auto it = expressionToNumber.find(exprKey);
if (it != expressionToNumber.end()) {
unsigned number = it->second;
auto valueIt = numberToValue.find(number);
if (valueIt != numberToValue.end()) {
Value* existing = valueIt->second;
// 对于load指令需要额外检查内存安全性
if (auto loadInst = dynamic_cast<LoadInst*>(inst)) {
if (auto existingLoad = dynamic_cast<LoadInst*>(existing)) {
if (!isMemorySafe(existingLoad, loadInst)) {
return nullptr;
}
}
}
return existing;
}
}
return nullptr;
}
bool GVNContext::dominates(Instruction* a, Instruction* b) {
auto aBB = a->getParent();
auto bBB = b->getParent();
// 同一基本块内的情况
if (aBB == bBB) {
auto &insts = aBB->getInstructions();
auto aIt = std::find_if(insts.begin(), insts.end(),
[a](const auto &ptr) { return ptr.get() == a; });
auto bIt = std::find_if(insts.begin(), insts.end(),
[b](const auto &ptr) { return ptr.get() == b; });
if (aIt == insts.end() || bIt == insts.end()) {
return false;
}
return std::distance(insts.begin(), aIt) < std::distance(insts.begin(), bIt);
}
// 不同基本块的情况,使用支配树
if (domTree) {
auto dominators = domTree->getDominators(bBB);
return dominators && dominators->count(aBB);
}
return false; // 保守做法
}
bool GVNContext::isMemorySafe(LoadInst* earlierLoad, LoadInst* laterLoad) {
// 检查两个load是否访问相同的内存位置
unsigned earlierPtr = getValueNumber(earlierLoad->getPointer());
unsigned laterPtr = getValueNumber(laterLoad->getPointer());
if (earlierPtr != laterPtr) {
return false; // 不同的内存位置
}
// 检查类型是否匹配
if (earlierLoad->getType() != laterLoad->getType()) {
return false;
}
// 简单情况如果在同一个基本块且没有中间的store则安全
auto earlierBB = earlierLoad->getParent();
auto laterBB = laterLoad->getParent();
if (earlierBB != laterBB) {
// 跨基本块的情况需要更复杂的分析,暂时保守处理
return false;
}
// 同一基本块内检查是否有中间的store
auto &insts = earlierBB->getInstructions();
auto earlierIt = std::find_if(insts.begin(), insts.end(),
[earlierLoad](const auto &ptr) { return ptr.get() == earlierLoad; });
auto laterIt = std::find_if(insts.begin(), insts.end(),
[laterLoad](const auto &ptr) { return ptr.get() == laterLoad; });
if (earlierIt == insts.end() || laterIt == insts.end()) {
return false;
}
// 确保earlierLoad真的在laterLoad之前
if (std::distance(insts.begin(), earlierIt) >= std::distance(insts.begin(), laterIt)) {
return false;
}
// 检查中间是否有store指令修改了相同的内存位置
for (auto it = std::next(earlierIt); it != laterIt; ++it) {
if (auto store = dynamic_cast<StoreInst*>(it->get())) {
unsigned storePtr = getValueNumber(store->getPointer());
if (storePtr == earlierPtr) {
return false; // 找到中间的store
}
}
// 检查函数调用是否可能修改内存
if (auto call = dynamic_cast<CallInst*>(it->get())) {
if (sideEffectAnalysis && !sideEffectAnalysis->isPureFunction(call->getCallee())) {
// 保守处理:有副作用的函数可能修改内存
return false;
}
}
}
return true; // 安全
}
void GVNContext::invalidateMemoryValues(StoreInst* store) {
unsigned storePtr = getValueNumber(store->getPointer());
if (DEBUG) {
std::cout << " Invalidating memory values affected by store" << std::endl;
}
// 找到所有可能被这个store影响的load表达式
std::vector<std::string> toRemove;
for (auto& [exprKey, number] : expressionToNumber) {
if (exprKey.find("load_" + std::to_string(storePtr)) == 0) {
toRemove.push_back(exprKey);
if (DEBUG) {
std::cout << " Invalidating expression: " << exprKey << std::endl;
}
}
}
// 移除失效的表达式
for (const auto& key : toRemove) {
expressionToNumber.erase(key);
}
}
void GVNContext::eliminateRedundantInstructions(bool& changed) {
int removeCount = 0;
for (auto inst : needRemove) {
if (DEBUG) {
std::cout << " Removing redundant instruction " << ++removeCount
<< "/" << needRemove.size() << ": " << inst->getName() << std::endl;
}
// 删除指令前先断开所有使用关系
// inst->replaceAllUsesWith 已在 processInstruction 中调用
SysYIROptUtils::usedelete(inst);
changed = true;
}
}
} // namespace sysy

View File

@ -0,0 +1,897 @@
#include "GlobalStrengthReduction.h"
#include "SysYIROptUtils.h"
#include "IRBuilder.h"
#include <algorithm>
#include <cassert>
#include <iostream>
#include <cmath>
extern int DEBUG;
namespace sysy {
// 全局强度削弱优化遍的静态 ID
void *GlobalStrengthReduction::ID = (void *)&GlobalStrengthReduction::ID;
// ======================================================================
// GlobalStrengthReduction 类的实现
// ======================================================================
bool GlobalStrengthReduction::runOnFunction(Function *func, AnalysisManager &AM) {
if (func->getBasicBlocks().empty()) {
return false;
}
if (DEBUG) {
std::cout << "\n=== Running GlobalStrengthReduction on function: " << func->getName() << " ===" << std::endl;
}
bool changed = false;
GlobalStrengthReductionContext context(builder);
context.run(func, &AM, changed);
if (DEBUG) {
if (changed) {
std::cout << "GlobalStrengthReduction: Function " << func->getName() << " was modified" << std::endl;
} else {
std::cout << "GlobalStrengthReduction: Function " << func->getName() << " was not modified" << std::endl;
}
std::cout << "=== GlobalStrengthReduction completed for function: " << func->getName() << " ===" << std::endl;
}
return changed;
}
void GlobalStrengthReduction::getAnalysisUsage(std::set<void *> &analysisDependencies, std::set<void *> &analysisInvalidations) const {
// 强度削弱依赖副作用分析来判断指令是否可以安全优化
analysisDependencies.insert(&SysYSideEffectAnalysisPass::ID);
// 强度削弱不会使分析失效,因为:
// - 只替换计算指令,不改变控制流
// - 不修改内存,不影响别名分析
// - 保持程序语义不变
// analysisInvalidations 保持为空
if (DEBUG) {
std::cout << "GlobalStrengthReduction: Declared analysis dependencies (SideEffectAnalysis)" << std::endl;
}
}
// ======================================================================
// GlobalStrengthReductionContext 类的实现
// ======================================================================
void GlobalStrengthReductionContext::run(Function *func, AnalysisManager *AM, bool &changed) {
if (DEBUG) {
std::cout << " Starting GlobalStrengthReduction analysis for function: " << func->getName() << std::endl;
}
// 获取分析结果
if (AM) {
sideEffectAnalysis = AM->getAnalysisResult<SideEffectAnalysisResult, SysYSideEffectAnalysisPass>();
if (DEBUG) {
if (sideEffectAnalysis) {
std::cout << " GlobalStrengthReduction: Using side effect analysis" << std::endl;
} else {
std::cout << " GlobalStrengthReduction: Warning - side effect analysis not available" << std::endl;
}
}
}
// 重置计数器
algebraicOptCount = 0;
strengthReductionCount = 0;
divisionOptCount = 0;
// 遍历所有基本块进行优化
for (auto &bb_ptr : func->getBasicBlocks()) {
if (processBasicBlock(bb_ptr.get())) {
changed = true;
}
}
if (DEBUG) {
std::cout << " GlobalStrengthReduction completed for function: " << func->getName() << std::endl;
std::cout << " Algebraic optimizations: " << algebraicOptCount << std::endl;
std::cout << " Strength reductions: " << strengthReductionCount << std::endl;
std::cout << " Division optimizations: " << divisionOptCount << std::endl;
}
}
bool GlobalStrengthReductionContext::processBasicBlock(BasicBlock *bb) {
bool changed = false;
if (DEBUG) {
std::cout << " Processing block: " << bb->getName() << std::endl;
}
// 收集需要处理的指令(避免迭代器失效)
std::vector<Instruction*> instructions;
for (auto &inst_ptr : bb->getInstructions()) {
instructions.push_back(inst_ptr.get());
}
// 处理每条指令
for (auto inst : instructions) {
if (processInstruction(inst)) {
changed = true;
}
}
return changed;
}
bool GlobalStrengthReductionContext::processInstruction(Instruction *inst) {
if (DEBUG) {
std::cout << " Processing instruction: " << inst->getName() << std::endl;
}
// 先尝试代数优化
if (tryAlgebraicOptimization(inst)) {
algebraicOptCount++;
return true;
}
// 再尝试强度削弱
if (tryStrengthReduction(inst)) {
strengthReductionCount++;
return true;
}
return false;
}
// ======================================================================
// 代数优化方法
// ======================================================================
bool GlobalStrengthReductionContext::tryAlgebraicOptimization(Instruction *inst) {
auto binary = dynamic_cast<BinaryInst*>(inst);
if (!binary) {
return false;
}
switch (binary->getKind()) {
case Instruction::kAdd:
return optimizeAddition(binary);
case Instruction::kSub:
return optimizeSubtraction(binary);
case Instruction::kMul:
return optimizeMultiplication(binary);
case Instruction::kDiv:
return optimizeDivision(binary);
case Instruction::kICmpEQ:
case Instruction::kICmpNE:
case Instruction::kICmpLT:
case Instruction::kICmpGT:
case Instruction::kICmpLE:
case Instruction::kICmpGE:
return optimizeComparison(binary);
case Instruction::kAnd:
case Instruction::kOr:
return optimizeLogical(binary);
default:
return false;
}
}
bool GlobalStrengthReductionContext::optimizeAddition(BinaryInst *inst) {
Value *lhs = inst->getLhs();
Value *rhs = inst->getRhs();
int constVal;
// x + 0 = x
if (isConstantInt(rhs, constVal) && constVal == 0) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x + 0 -> x" << std::endl;
}
replaceWithOptimized(inst, lhs);
return true;
}
// 0 + x = x
if (isConstantInt(lhs, constVal) && constVal == 0) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = 0 + x -> x" << std::endl;
}
replaceWithOptimized(inst, rhs);
return true;
}
// x + (-y) = x - y
if (auto rhsInst = dynamic_cast<UnaryInst*>(rhs)) {
if (rhsInst->getKind() == Instruction::kNeg) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x + (-y) -> x - y" << std::endl;
}
// 创建减法指令
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
auto subInst = builder->createSubInst(lhs, rhsInst->getOperand());
replaceWithOptimized(inst, subInst);
return true;
}
}
return false;
}
bool GlobalStrengthReductionContext::optimizeSubtraction(BinaryInst *inst) {
Value *lhs = inst->getLhs();
Value *rhs = inst->getRhs();
int constVal;
// x - 0 = x
if (isConstantInt(rhs, constVal) && constVal == 0) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x - 0 -> x" << std::endl;
}
replaceWithOptimized(inst, lhs);
return true;
}
// x - x = 0 (如果x没有副作用)
if (lhs == rhs && hasOnlyLocalUses(dynamic_cast<Instruction*>(lhs))) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x - x -> 0" << std::endl;
}
replaceWithOptimized(inst, getConstantInt(0));
return true;
}
// x - (-y) = x + y
if (auto rhsInst = dynamic_cast<UnaryInst*>(rhs)) {
if (rhsInst->getKind() == Instruction::kNeg) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x - (-y) -> x + y" << std::endl;
}
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
auto addInst = builder->createAddInst(lhs, rhsInst->getOperand());
replaceWithOptimized(inst, addInst);
return true;
}
}
return false;
}
bool GlobalStrengthReductionContext::optimizeMultiplication(BinaryInst *inst) {
Value *lhs = inst->getLhs();
Value *rhs = inst->getRhs();
int constVal;
// x * 0 = 0
if (isConstantInt(rhs, constVal) && constVal == 0) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x * 0 -> 0" << std::endl;
}
replaceWithOptimized(inst, getConstantInt(0));
return true;
}
// 0 * x = 0
if (isConstantInt(lhs, constVal) && constVal == 0) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = 0 * x -> 0" << std::endl;
}
replaceWithOptimized(inst, getConstantInt(0));
return true;
}
// x * 1 = x
if (isConstantInt(rhs, constVal) && constVal == 1) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x * 1 -> x" << std::endl;
}
replaceWithOptimized(inst, lhs);
return true;
}
// 1 * x = x
if (isConstantInt(lhs, constVal) && constVal == 1) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = 1 * x -> x" << std::endl;
}
replaceWithOptimized(inst, rhs);
return true;
}
// x * (-1) = -x
if (isConstantInt(rhs, constVal) && constVal == -1) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x * (-1) -> -x" << std::endl;
}
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
auto negInst = builder->createNegInst(lhs);
replaceWithOptimized(inst, negInst);
return true;
}
return false;
}
bool GlobalStrengthReductionContext::optimizeDivision(BinaryInst *inst) {
Value *lhs = inst->getLhs();
Value *rhs = inst->getRhs();
int constVal;
// x / 1 = x
if (isConstantInt(rhs, constVal) && constVal == 1) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x / 1 -> x" << std::endl;
}
replaceWithOptimized(inst, lhs);
return true;
}
// x / (-1) = -x
if (isConstantInt(rhs, constVal) && constVal == -1) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x / (-1) -> -x" << std::endl;
}
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
auto negInst = builder->createNegInst(lhs);
replaceWithOptimized(inst, negInst);
return true;
}
// x / x = 1 (如果x != 0且没有副作用)
if (lhs == rhs && hasOnlyLocalUses(dynamic_cast<Instruction*>(lhs))) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x / x -> 1" << std::endl;
}
replaceWithOptimized(inst, getConstantInt(1));
return true;
}
return false;
}
bool GlobalStrengthReductionContext::optimizeComparison(BinaryInst *inst) {
Value *lhs = inst->getLhs();
Value *rhs = inst->getRhs();
// x == x = true (如果x没有副作用)
if (inst->getKind() == Instruction::kICmpEQ && lhs == rhs &&
hasOnlyLocalUses(dynamic_cast<Instruction*>(lhs))) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x == x -> true" << std::endl;
}
replaceWithOptimized(inst, getConstantInt(1));
return true;
}
// x != x = false (如果x没有副作用)
if (inst->getKind() == Instruction::kICmpNE && lhs == rhs &&
hasOnlyLocalUses(dynamic_cast<Instruction*>(lhs))) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x != x -> false" << std::endl;
}
replaceWithOptimized(inst, getConstantInt(0));
return true;
}
return false;
}
bool GlobalStrengthReductionContext::optimizeLogical(BinaryInst *inst) {
Value *lhs = inst->getLhs();
Value *rhs = inst->getRhs();
int constVal;
if (inst->getKind() == Instruction::kAnd) {
// x && 0 = 0
if (isConstantInt(rhs, constVal) && constVal == 0) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x && 0 -> 0" << std::endl;
}
replaceWithOptimized(inst, getConstantInt(0));
return true;
}
// x && -1 = x
if (isConstantInt(rhs, constVal) && constVal == -1) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x && 1 -> x" << std::endl;
}
replaceWithOptimized(inst, lhs);
return true;
}
// x && x = x
if (lhs == rhs) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x && x -> x" << std::endl;
}
replaceWithOptimized(inst, lhs);
return true;
}
} else if (inst->getKind() == Instruction::kOr) {
// x || 0 = x
if (isConstantInt(rhs, constVal) && constVal == 0) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x || 0 -> x" << std::endl;
}
replaceWithOptimized(inst, lhs);
return true;
}
// x || x = x
if (lhs == rhs) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x || x -> x" << std::endl;
}
replaceWithOptimized(inst, lhs);
return true;
}
}
return false;
}
// ======================================================================
// 强度削弱方法
// ======================================================================
bool GlobalStrengthReductionContext::tryStrengthReduction(Instruction *inst) {
if (auto binary = dynamic_cast<BinaryInst*>(inst)) {
switch (binary->getKind()) {
case Instruction::kMul:
return reduceMultiplication(binary);
case Instruction::kDiv:
return reduceDivision(binary);
default:
return false;
}
} else if (auto call = dynamic_cast<CallInst*>(inst)) {
return reducePower(call);
}
return false;
}
bool GlobalStrengthReductionContext::reduceMultiplication(BinaryInst *inst) {
Value *lhs = inst->getLhs();
Value *rhs = inst->getRhs();
int constVal;
// 尝试右操作数为常数
Value* variable = lhs;
if (isConstantInt(rhs, constVal) && constVal > 0) {
return tryComplexMultiplication(inst, variable, constVal);
}
// 尝试左操作数为常数
if (isConstantInt(lhs, constVal) && constVal > 0) {
variable = rhs;
return tryComplexMultiplication(inst, variable, constVal);
}
return false;
}
bool GlobalStrengthReductionContext::tryComplexMultiplication(BinaryInst* inst, Value* variable, int constant) {
// 首先检查是否为2的幂使用简单位移
if (isPowerOfTwo(constant)) {
int shiftAmount = log2OfPowerOfTwo(constant);
if (DEBUG) {
std::cout << " StrengthReduction: " << inst->getName()
<< " = x * " << constant << " -> x << " << shiftAmount << std::endl;
}
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
auto shiftInst = builder->createBinaryInst(Instruction::kSll, Type::getIntType(), variable, getConstantInt(shiftAmount));
replaceWithOptimized(inst, shiftInst);
return true;
}
// 尝试分解为位移和加法的组合
std::vector<int> shifts;
if (findOptimalShiftDecomposition(constant, shifts)) {
if (DEBUG) {
std::cout << " StrengthReduction: " << inst->getName()
<< " = x * " << constant << " -> shift decomposition with " << shifts.size() << " terms" << std::endl;
}
Value* result = createShiftDecomposition(inst, variable, shifts);
if (result) {
replaceWithOptimized(inst, result);
return true;
}
}
return false;
}
bool GlobalStrengthReductionContext::findOptimalShiftDecomposition(int constant, std::vector<int>& shifts) {
shifts.clear();
// 常见的有效分解模式
switch (constant) {
case 3: // 3 = 2^1 + 2^0 -> (x << 1) + x
shifts = {1, 0};
return true;
case 5: // 5 = 2^2 + 2^0 -> (x << 2) + x
shifts = {2, 0};
return true;
case 6: // 6 = 2^2 + 2^1 -> (x << 2) + (x << 1)
shifts = {2, 1};
return true;
case 7: // 7 = 2^2 + 2^1 + 2^0 -> (x << 2) + (x << 1) + x
shifts = {2, 1, 0};
return true;
case 9: // 9 = 2^3 + 2^0 -> (x << 3) + x
shifts = {3, 0};
return true;
case 10: // 10 = 2^3 + 2^1 -> (x << 3) + (x << 1)
shifts = {3, 1};
return true;
case 11: // 11 = 2^3 + 2^1 + 2^0 -> (x << 3) + (x << 1) + x
shifts = {3, 1, 0};
return true;
case 12: // 12 = 2^3 + 2^2 -> (x << 3) + (x << 2)
shifts = {3, 2};
return true;
case 13: // 13 = 2^3 + 2^2 + 2^0 -> (x << 3) + (x << 2) + x
shifts = {3, 2, 0};
return true;
case 14: // 14 = 2^3 + 2^2 + 2^1 -> (x << 3) + (x << 2) + (x << 1)
shifts = {3, 2, 1};
return true;
case 15: // 15 = 2^3 + 2^2 + 2^1 + 2^0 -> (x << 3) + (x << 2) + (x << 1) + x
shifts = {3, 2, 1, 0};
return true;
case 17: // 17 = 2^4 + 2^0 -> (x << 4) + x
shifts = {4, 0};
return true;
case 18: // 18 = 2^4 + 2^1 -> (x << 4) + (x << 1)
shifts = {4, 1};
return true;
case 20: // 20 = 2^4 + 2^2 -> (x << 4) + (x << 2)
shifts = {4, 2};
return true;
case 24: // 24 = 2^4 + 2^3 -> (x << 4) + (x << 3)
shifts = {4, 3};
return true;
case 25: // 25 = 2^4 + 2^3 + 2^0 -> (x << 4) + (x << 3) + x
shifts = {4, 3, 0};
return true;
case 100: // 100 = 2^6 + 2^5 + 2^2 -> (x << 6) + (x << 5) + (x << 2)
shifts = {6, 5, 2};
return true;
}
// 通用二进制分解最多4个项避免过度复杂化
if (constant > 0 && constant < 256) {
std::vector<int> binaryShifts;
int temp = constant;
int bit = 0;
while (temp > 0 && binaryShifts.size() < 4) {
if (temp & 1) {
binaryShifts.push_back(bit);
}
temp >>= 1;
bit++;
}
// 只有当项数不超过3个时才使用二进制分解比直接乘法更有效
if (binaryShifts.size() <= 3 && binaryShifts.size() >= 2) {
shifts = binaryShifts;
return true;
}
}
return false;
}
Value* GlobalStrengthReductionContext::createShiftDecomposition(BinaryInst* inst, Value* variable, const std::vector<int>& shifts) {
if (shifts.empty()) return nullptr;
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
Value* result = nullptr;
for (int shift : shifts) {
Value* term;
if (shift == 0) {
// 0位移就是原变量
term = variable;
} else {
// 创建位移指令
term = builder->createBinaryInst(Instruction::kSll, Type::getIntType(), variable, getConstantInt(shift));
}
if (result == nullptr) {
result = term;
} else {
// 累加到结果中
result = builder->createAddInst(result, term);
}
}
return result;
}
bool GlobalStrengthReductionContext::reduceDivision(BinaryInst *inst) {
Value *lhs = inst->getLhs();
Value *rhs = inst->getRhs();
uint32_t constVal;
// x / 2^n = x >> n (对于无符号除法或已知为正数的情况)
if (isConstantInt(rhs, constVal) && constVal > 0 && isPowerOfTwo(constVal)) {
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
int shiftAmount = log2OfPowerOfTwo(constVal);
// 有符号除法校正:(x + (x >> 31) & mask) >> k
int maskValue = constVal - 1;
// x >> 31 (算术右移获取符号位)
Value* signShift = ConstantInteger::get(31);
Value* signBits = builder->createBinaryInst(
Instruction::Kind::kSra, // 算术右移
lhs->getType(),
lhs,
signShift
);
// (x >> 31) & mask
Value* mask = ConstantInteger::get(maskValue);
Value* correction = builder->createBinaryInst(
Instruction::Kind::kAnd,
lhs->getType(),
signBits,
mask
);
// x + correction
Value* corrected = builder->createAddInst(lhs, correction);
// (x + correction) >> k
Value* divShift = ConstantInteger::get(shiftAmount);
Value* shiftInst = builder->createBinaryInst(
Instruction::Kind::kSra, // 算术右移
lhs->getType(),
corrected,
divShift
);
if (DEBUG) {
std::cout << " StrengthReduction: " << inst->getName()
<< " = x / " << constVal << " -> (x + (x >> 31) & mask) >> " << shiftAmount << std::endl;
}
// builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
// Value* divisor_minus_1 = ConstantInteger::get(constVal - 1);
// Value* adjusted = builder->createAddInst(lhs, divisor_minus_1);
// Value* shiftInst = builder->createBinaryInst(Instruction::kSra, Type::getIntType(), adjusted, getConstantInt(shiftAmount));
replaceWithOptimized(inst, shiftInst);
strengthReductionCount++;
return true;
}
// x / c = x * magic_number (魔数乘法优化 - 使用libdivide算法)
if (isConstantInt(rhs, constVal) && constVal > 1 && constVal != (uint32_t)(-1)) {
// auto magicPair = computeMulhMagicNumbers(static_cast<int>(constVal));
Value* magicResult = createMagicDivisionLibdivide(inst, static_cast<int>(constVal));
replaceWithOptimized(inst, magicResult);
divisionOptCount++;
return true;
}
return false;
}
bool GlobalStrengthReductionContext::reducePower(CallInst *inst) {
// 检查是否是pow函数调用
Function* callee = inst->getCallee();
if (!callee || callee->getName() != "pow") {
return false;
}
// pow(x, 2) = x * x
if (inst->getNumOperands() >= 2) {
int exponent;
if (isConstantInt(inst->getOperand(1), exponent)) {
if (exponent == 2) {
if (DEBUG) {
std::cout << " StrengthReduction: pow(x, 2) -> x * x" << std::endl;
}
Value* base = inst->getOperand(0);
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
auto mulInst = builder->createMulInst(base, base);
replaceWithOptimized(inst, mulInst);
strengthReductionCount++;
return true;
} else if (exponent >= 3 && exponent <= 8) {
// 对于小的指数,展开为连续乘法
if (DEBUG) {
std::cout << " StrengthReduction: pow(x, " << exponent << ") -> repeated multiplication" << std::endl;
}
Value* base = inst->getOperand(0);
Value* result = base;
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
for (int i = 1; i < exponent; i++) {
result = builder->createMulInst(result, base);
}
replaceWithOptimized(inst, result);
strengthReductionCount++;
return true;
}
}
}
return false;
}
Value* GlobalStrengthReductionContext::createMagicDivisionLibdivide(BinaryInst* divInst, int divisor) {
builder->setPosition(divInst->getParent(), divInst->getParent()->findInstIterator(divInst));
// 使用mulh指令优化任意常数除法
auto [magic, shift] = SysYIROptUtils::computeMulhMagicNumbers(divisor);
// 检查是否无法优化magic == -1, shift == -1 表示失败)
if (magic == -1 && shift == -1) {
if (DEBUG) {
std::cout << "[SR] Cannot optimize division by " << divisor
<< ", keeping original division" << std::endl;
}
// 返回 nullptr 表示无法优化,调用方应该保持原始除法
return nullptr;
}
// 2的幂次方除法可以用移位优化但这不是魔数法的情况这种情况应该不会被分类到这里但是还是做一个保护措施
if ((divisor & (divisor - 1)) == 0 && divisor > 0) {
// 是2的幂次方可以用移位
int shift_amount = 0;
int temp = divisor;
while (temp > 1) {
temp >>= 1;
shift_amount++;
}
Value* shiftConstant = ConstantInteger::get(shift_amount);
// 对于有符号除法,需要先加上除数-1然后再移位为了正确处理负数舍入
Value* divisor_minus_1 = ConstantInteger::get(divisor - 1);
Value* adjusted = builder->createAddInst(divInst->getOperand(0), divisor_minus_1);
return builder->createBinaryInst(
Instruction::Kind::kSra, // 算术右移
divInst->getOperand(0)->getType(),
adjusted,
shiftConstant
);
}
// 创建魔数常量
// 检查魔数是否能放入32位如果不能则不进行优化
if (magic > INT32_MAX || magic < INT32_MIN) {
if (DEBUG) {
std::cout << "[SR] Magic number " << magic << " exceeds 32-bit range, skipping optimization" << std::endl;
}
return nullptr; // 无法优化,保持原始除法
}
Value* magicConstant = ConstantInteger::get((int32_t)magic);
// 检查是否需要ADD_MARKER处理加法调整
bool needAdd = (shift & 0x40) != 0;
int actualShift = shift & 0x3F; // 提取真实的移位量
if (DEBUG) {
std::cout << "[SR] IR Generation: magic=" << magic << ", needAdd=" << needAdd
<< ", actualShift=" << actualShift << std::endl;
}
// 执行高位乘法mulh(x, magic)
Value* mulhResult = builder->createBinaryInst(
Instruction::Kind::kMulh, // 高位乘法
divInst->getOperand(0)->getType(),
divInst->getOperand(0),
magicConstant
);
if (needAdd) {
// ADD_MARKER 情况:需要在移位前加上被除数
// 这对应于 libdivide 的加法调整算法
if (DEBUG) {
std::cout << "[SR] Applying ADD_MARKER: adding dividend before shift" << std::endl;
}
mulhResult = builder->createAddInst(mulhResult, divInst->getOperand(0));
}
if (actualShift > 0) {
// 如果需要额外移位
Value* shiftConstant = ConstantInteger::get(actualShift);
mulhResult = builder->createBinaryInst(
Instruction::Kind::kSra, // 算术右移
divInst->getOperand(0)->getType(),
mulhResult,
shiftConstant
);
}
// 标准的有符号除法符号修正如果被除数为负商需要加1
// 这对所有有符号除法都需要,不管是否可能有负数
Value* isNegative = builder->createICmpLTInst(divInst->getOperand(0), ConstantInteger::get(0));
// 将i1转换为i32负数时为1非负数时为0 ICmpLTInst的结果会默认转化为32位
mulhResult = builder->createAddInst(mulhResult, isNegative);
return mulhResult;
}
// ======================================================================
// 辅助方法
// ======================================================================
bool GlobalStrengthReductionContext::isPowerOfTwo(uint32_t n) {
return n > 0 && (n & (n - 1)) == 0;
}
int GlobalStrengthReductionContext::log2OfPowerOfTwo(uint32_t n) {
int result = 0;
while (n > 1) {
n >>= 1;
result++;
}
return result;
}
bool GlobalStrengthReductionContext::isConstantInt(Value* val, int& constVal) {
if (auto constInt = dynamic_cast<ConstantInteger*>(val)) {
constVal = std::get<int>(constInt->getVal());
return true;
}
return false;
}
bool GlobalStrengthReductionContext::isConstantInt(Value* val, uint32_t& constVal) {
if (auto constInt = dynamic_cast<ConstantInteger*>(val)) {
int signedVal = std::get<int>(constInt->getVal());
if (signedVal >= 0) {
constVal = static_cast<uint32_t>(signedVal);
return true;
}
}
return false;
}
ConstantInteger* GlobalStrengthReductionContext::getConstantInt(int val) {
return ConstantInteger::get(val);
}
bool GlobalStrengthReductionContext::hasOnlyLocalUses(Instruction* inst) {
if (!inst) return true;
// 简单检查:如果指令没有副作用,则认为是本地的
if (sideEffectAnalysis) {
auto sideEffect = sideEffectAnalysis->getInstructionSideEffect(inst);
return sideEffect.type == SideEffectType::NO_SIDE_EFFECT;
}
// 没有副作用分析时,保守处理
return !inst->isCall() && !inst->isStore() && !inst->isLoad();
}
void GlobalStrengthReductionContext::replaceWithOptimized(Instruction* original, Value* replacement) {
if (DEBUG) {
std::cout << " Replacing " << original->getName()
<< " with " << replacement->getName() << std::endl;
}
original->replaceAllUsesWith(replacement);
// 如果替换值是新创建的指令,确保它有合适的名字
// if (auto replInst = dynamic_cast<Instruction*>(replacement)) {
// if (replInst->getName().empty()) {
// replInst->setName(original->getName() + "_opt");
// }
// }
// 删除原指令,让调用者处理
SysYIROptUtils::usedelete(original);
}
} // namespace sysy

View File

@ -18,62 +18,214 @@ bool LICMContext::hoistInstructions() {
// 1. 先收集所有可外提指令
std::unordered_set<Instruction *> workSet(chars->invariantInsts.begin(), chars->invariantInsts.end());
if (DEBUG) {
std::cout << "LICM: Found " << workSet.size() << " candidate invariant instructions to hoist:" << std::endl;
for (auto *inst : workSet) {
std::cout << " - " << inst->getName() << " (kind: " << static_cast<int>(inst->getKind())
<< ", in BB: " << inst->getParent()->getName() << ")" << std::endl;
}
}
// 2. 计算每个指令被依赖的次数(入度)
std::unordered_map<Instruction *, int> indegree;
std::unordered_map<Instruction *, std::vector<Instruction *>> dependencies; // 记录依赖关系
std::unordered_map<Instruction *, std::vector<Instruction *>> dependents; // 记录被依赖关系
for (auto *inst : workSet) {
indegree[inst] = 0;
dependencies[inst] = {};
dependents[inst] = {};
}
if (DEBUG) {
std::cout << "LICM: Analyzing dependencies between invariant instructions..." << std::endl;
}
for (auto *inst : workSet) {
for (size_t i = 0; i < inst->getNumOperands(); ++i) {
if (auto *dep = dynamic_cast<Instruction *>(inst->getOperand(i))) {
if (workSet.count(dep)) {
indegree[inst]++;
dependencies[inst].push_back(dep);
dependents[dep].push_back(inst);
if (DEBUG) {
std::cout << " Dependency: " << inst->getName() << " depends on " << dep->getName() << std::endl;
}
}
}
}
}
if (DEBUG) {
std::cout << "LICM: Initial indegree analysis:" << std::endl;
for (auto &[inst, deg] : indegree) {
std::cout << " " << inst->getName() << ": indegree=" << deg;
if (deg > 0) {
std::cout << ", depends on: ";
for (auto *dep : dependencies[inst]) {
std::cout << dep->getName() << " ";
}
}
std::cout << std::endl;
}
}
// 3. Kahn拓扑排序
std::vector<Instruction *> sorted;
std::queue<Instruction *> q;
for (auto &[inst, deg] : indegree) {
if (deg == 0)
q.push(inst);
if (DEBUG) {
std::cout << "LICM: Starting topological sort..." << std::endl;
}
for (auto &[inst, deg] : indegree) {
if (deg == 0) {
q.push(inst);
if (DEBUG) {
std::cout << " Initial zero-indegree instruction: " << inst->getName() << std::endl;
}
}
}
int sortStep = 0;
while (!q.empty()) {
auto *inst = q.front();
q.pop();
sorted.push_back(inst);
for (size_t i = 0; i < inst->getNumOperands(); ++i) {
if (auto *dep = dynamic_cast<Instruction *>(inst->getOperand(i))) {
if (workSet.count(dep)) {
indegree[dep]--;
if (indegree[dep] == 0)
q.push(dep);
if (DEBUG) {
std::cout << " Step " << (++sortStep) << ": Processing " << inst->getName() << std::endl;
}
if (DEBUG) {
std::cout << " Reducing indegree of dependents of " << inst->getName() << std::endl;
}
// 正确的拓扑排序当处理一个指令时应该减少其所有使用者dependents的入度
for (auto *dependent : dependents[inst]) {
indegree[dependent]--;
if (DEBUG) {
std::cout << " Reducing indegree of " << dependent->getName() << " to " << indegree[dependent] << std::endl;
}
if (indegree[dependent] == 0) {
q.push(dependent);
if (DEBUG) {
std::cout << " Adding " << dependent->getName() << " to queue (indegree=0)" << std::endl;
}
}
}
}
// 检查是否全部排序,若未全部排序,说明有环(理论上不会)
// 检查是否全部排序,若未全部排序,打印错误信息
// 这可能是因为存在循环依赖或其他问题导致无法完成拓扑排序
if (sorted.size() != workSet.size()) {
if (DEBUG)
std::cerr << "LICM: Topological sort failed, possible dependency cycle." << std::endl;
if (DEBUG) {
std::cout << "LICM: Topological sort failed! Sorted " << sorted.size()
<< " instructions out of " << workSet.size() << " total." << std::endl;
// 找出未被排序的指令(形成循环依赖的指令)
std::unordered_set<Instruction *> remaining;
for (auto *inst : workSet) {
bool found = false;
for (auto *sortedInst : sorted) {
if (inst == sortedInst) {
found = true;
break;
}
}
if (!found) {
remaining.insert(inst);
}
}
std::cout << "LICM: Instructions involved in dependency cycle:" << std::endl;
for (auto *inst : remaining) {
std::cout << " - " << inst->getName() << " (indegree=" << indegree[inst] << ")" << std::endl;
std::cout << " Dependencies within cycle: ";
for (auto *dep : dependencies[inst]) {
if (remaining.count(dep)) {
std::cout << dep->getName() << " ";
}
}
std::cout << std::endl;
std::cout << " Dependents within cycle: ";
for (auto *dependent : dependents[inst]) {
if (remaining.count(dependent)) {
std::cout << dependent->getName() << " ";
}
}
std::cout << std::endl;
}
// 尝试找出一个具体的循环路径
std::cout << "LICM: Attempting to trace a dependency cycle:" << std::endl;
if (!remaining.empty()) {
auto *start = *remaining.begin();
std::unordered_set<Instruction *> visited;
std::vector<Instruction *> path;
std::function<bool(Instruction *)> findCycle = [&](Instruction *current) -> bool {
if (visited.count(current)) {
// 找到环
auto it = std::find(path.begin(), path.end(), current);
if (it != path.end()) {
std::cout << " Cycle found: ";
for (auto cycleIt = it; cycleIt != path.end(); ++cycleIt) {
std::cout << (*cycleIt)->getName() << " -> ";
}
std::cout << current->getName() << std::endl;
return true;
}
return false;
}
visited.insert(current);
path.push_back(current);
for (auto *dep : dependencies[current]) {
if (remaining.count(dep)) {
if (findCycle(dep)) {
return true;
}
}
}
path.pop_back();
return false;
};
findCycle(start);
}
}
return false;
}
// 4. 按拓扑序外提
if (DEBUG) {
std::cout << "LICM: Successfully completed topological sort. Hoisting instructions in order:" << std::endl;
}
for (auto *inst : sorted) {
if (!inst)
continue;
BasicBlock *parent = inst->getParent();
if (parent && loop->contains(parent)) {
if (DEBUG) {
std::cout << " Hoisting " << inst->getName() << " from " << parent->getName()
<< " to preheader " << preheader->getName() << std::endl;
}
auto sourcePos = parent->findInstIterator(inst);
auto targetPos = preheader->terminator();
parent->moveInst(sourcePos, targetPos, preheader);
changed = true;
}
}
if (DEBUG && changed) {
std::cout << "LICM: Successfully hoisted " << sorted.size() << " invariant instructions" << std::endl;
}
return changed;
}
// ---- LICM Pass Implementation ----

View File

@ -106,187 +106,6 @@ bool StrengthReductionContext::analyzeInductionVariableRange(
return hasNegativePotential;
}
//该实现参考了libdivide的算法
std::pair<int, int> StrengthReductionContext::computeMulhMagicNumbers(int divisor) const {
if (DEBUG) {
std::cout << "\n[SR] ===== Computing magic numbers for divisor " << divisor << " (libdivide algorithm) =====" << std::endl;
}
if (divisor == 0) {
if (DEBUG) std::cout << "[SR] Error: divisor must be != 0" << std::endl;
return {-1, -1};
}
// libdivide 常数
const uint8_t LIBDIVIDE_ADD_MARKER = 0x40;
const uint8_t LIBDIVIDE_NEGATIVE_DIVISOR = 0x80;
// 辅助函数:计算前导零个数
auto count_leading_zeros32 = [](uint32_t val) -> uint32_t {
if (val == 0) return 32;
return __builtin_clz(val);
};
// 辅助函数64位除法返回32位商和余数
auto div_64_32 = [](uint32_t high, uint32_t low, uint32_t divisor, uint32_t* rem) -> uint32_t {
uint64_t dividend = ((uint64_t)high << 32) | low;
uint32_t quotient = dividend / divisor;
*rem = dividend % divisor;
return quotient;
};
if (DEBUG) {
std::cout << "[SR] Input divisor: " << divisor << std::endl;
}
// libdivide_internal_s32_gen 算法实现
int32_t d = divisor;
uint32_t ud = (uint32_t)d;
uint32_t absD = (d < 0) ? -ud : ud;
if (DEBUG) {
std::cout << "[SR] absD = " << absD << std::endl;
}
uint32_t floor_log_2_d = 31 - count_leading_zeros32(absD);
if (DEBUG) {
std::cout << "[SR] floor_log_2_d = " << floor_log_2_d << std::endl;
}
// 检查 absD 是否为2的幂
if ((absD & (absD - 1)) == 0) {
if (DEBUG) {
std::cout << "[SR] " << absD << " 是2的幂使用移位方法" << std::endl;
}
// 对于2的幂我们只使用移位不需要魔数
int shift = floor_log_2_d;
if (d < 0) shift |= 0x80; // 标记负数
if (DEBUG) {
std::cout << "[SR] Power of 2 result: magic=0, shift=" << shift << std::endl;
std::cout << "[SR] ===== End magic computation =====" << std::endl;
}
// 对于我们的目的我们将在IR生成中以不同方式处理2的幂
// 返回特殊标记
return {0, shift};
}
if (DEBUG) {
std::cout << "[SR] " << absD << " is not a power of 2, computing magic number" << std::endl;
}
// 非2的幂除数的魔数计算
uint8_t more;
uint32_t rem, proposed_m;
// 计算 proposed_m = floor(2^(floor_log_2_d + 31) / absD)
proposed_m = div_64_32((uint32_t)1 << (floor_log_2_d - 1), 0, absD, &rem);
const uint32_t e = absD - rem;
if (DEBUG) {
std::cout << "[SR] proposed_m = " << proposed_m << ", rem = " << rem << ", e = " << e << std::endl;
}
// 确定是否需要"加法"版本
const bool branchfree = false; // 使用分支版本
if (!branchfree && e < ((uint32_t)1 << floor_log_2_d)) {
// 这个幂次有效
more = (uint8_t)(floor_log_2_d - 1);
if (DEBUG) {
std::cout << "[SR] Using basic algorithm, shift = " << (int)more << std::endl;
}
} else {
// 我们需要上升一个等级
proposed_m += proposed_m;
const uint32_t twice_rem = rem + rem;
if (twice_rem >= absD || twice_rem < rem) {
proposed_m += 1;
}
more = (uint8_t)(floor_log_2_d | LIBDIVIDE_ADD_MARKER);
if (DEBUG) {
std::cout << "[SR] Using add algorithm, proposed_m = " << proposed_m << ", more = " << (int)more << std::endl;
}
}
proposed_m += 1;
int32_t magic = (int32_t)proposed_m;
// 处理负除数
if (d < 0) {
more |= LIBDIVIDE_NEGATIVE_DIVISOR;
if (!branchfree) {
magic = -magic;
}
if (DEBUG) {
std::cout << "[SR] Negative divisor, magic = " << magic << ", more = " << (int)more << std::endl;
}
}
// 为我们的IR生成提取移位量和标志
int shift = more & 0x3F; // 移除标志保留移位量位0-5
bool need_add = (more & LIBDIVIDE_ADD_MARKER) != 0;
bool is_negative = (more & LIBDIVIDE_NEGATIVE_DIVISOR) != 0;
if (DEBUG) {
std::cout << "[SR] Final result: magic = " << magic << ", more = " << (int)more
<< " (0x" << std::hex << (int)more << std::dec << ")" << std::endl;
std::cout << "[SR] Shift = " << shift << ", need_add = " << need_add
<< ", is_negative = " << is_negative << std::endl;
// Test the magic number using the correct libdivide algorithm
std::cout << "[SR] Testing magic number (libdivide algorithm):" << std::endl;
int test_values[] = {1, 7, 37, 100, 999, -1, -7, -37, -100};
for (int test_val : test_values) {
int64_t quotient;
// 实现正确的libdivide算法
int64_t product = (int64_t)test_val * magic;
int64_t high_bits = product >> 32;
if (need_add) {
// ADD_MARKER情况移位前加上被除数
// 这是libdivide的关键洞察
high_bits += test_val;
quotient = high_bits >> shift;
} else {
// 正常情况:只是移位
quotient = high_bits >> shift;
}
// 符号修正这是libdivide有符号除法的关键部分
// 如果被除数为负商需要加1来匹配C语言的截断除法语义
if (test_val < 0) {
quotient += 1;
}
int expected = test_val / divisor;
bool correct = (quotient == expected);
std::cout << "[SR] " << test_val << " / " << divisor << " = " << quotient
<< " (expected " << expected << ") " << (correct ? "" : "") << std::endl;
}
std::cout << "[SR] ===== End magic computation =====" << std::endl;
}
// 返回魔数、移位量并在移位中编码ADD_MARKER标志
// 我们将使用移位的第6位表示ADD_MARKER第7位表示负数如果需要
int encoded_shift = shift;
if (need_add) {
encoded_shift |= 0x40; // 设置第6位表示ADD_MARKER
if (DEBUG) {
std::cout << "[SR] Encoding ADD_MARKER in shift: " << encoded_shift << std::endl;
}
}
return {magic, encoded_shift};
}
bool LoopStrengthReduction::runOnFunction(Function* F, AnalysisManager& AM) {
if (F->getBasicBlocks().empty()) {
@ -1018,7 +837,7 @@ Value* StrengthReductionContext::generateConstantDivisionReplacement(
IRBuilder* builder
) const {
// 使用mulh指令优化任意常数除法
auto [magic, shift] = computeMulhMagicNumbers(candidate->multiplier);
auto [magic, shift] = SysYIROptUtils::computeMulhMagicNumbers(candidate->multiplier);
// 检查是否无法优化magic == -1, shift == -1 表示失败)
if (magic == -1 && shift == -1) {

View File

@ -0,0 +1,125 @@
#include "TailCallOpt.h"
#include "IR.h"
#include "IRBuilder.h"
#include "SysYIROptUtils.h"
#include <vector>
// #include <iostream>
#include <algorithm>
namespace sysy {
void *TailCallOpt::ID = (void *)&TailCallOpt::ID;
void TailCallOpt::getAnalysisUsage(std::set<void *> &analysisDependencies, std::set<void *> &analysisInvalidations) const {
analysisInvalidations.insert(&DominatorTreeAnalysisPass::ID);
analysisInvalidations.insert(&LoopAnalysisPass::ID);
}
bool TailCallOpt::runOnFunction(Function *F, AnalysisManager &AM) {
std::vector<CallInst *> tailCallInsts;
// 遍历函数的所有基本块
for (auto &bb_ptr : F->getBasicBlocks()) {
auto BB = bb_ptr.get();
if (BB->getInstructions().empty()) continue; // 跳过空基本块
auto term_iter = BB->terminator();
if (term_iter == BB->getInstructions().end()) continue; // 没有终结指令则跳过
auto term = (*term_iter).get();
if (!term || !term->isReturn()) continue; // 不是返回指令则跳过
auto retInst = static_cast<ReturnInst *>(term);
Instruction *prevInst = nullptr;
if (BB->getInstructions().size() > 1) {
auto it = term_iter;
--it; // 获取返回指令前的指令
prevInst = (*it).get();
}
if (!prevInst || !prevInst->isCall()) continue; // 前一条不是调用指令则跳过
auto callInst = static_cast<CallInst *>(prevInst);
// 检查是否为尾递归调用:被调用函数与当前函数相同且返回值与调用结果匹配
if (callInst->getCallee() == F) {
// 对于尾递归,返回值应为调用结果或为 void 类型
if (retInst->getReturnValue() == callInst ||
(retInst->getReturnValue() == nullptr && callInst->getType()->isVoid())) {
tailCallInsts.push_back(callInst);
}
}
}
if (tailCallInsts.empty()) {
return false;
}
// 创建一个新的入口基本块,作为循环的前置块
auto original_entry = F->getEntryBlock();
auto new_entry = F->addBasicBlock("tco.entry." + F->getName());
auto loop_header = F->addBasicBlock("tco.loop_header." + F->getName());
// 将原入口块中的所有指令移动到循环头块
loop_header->getInstructions().splice(loop_header->end(), original_entry->getInstructions());
original_entry->setName("tco.pre_header");
// 为函数参数创建 phi 节点
builder->setPosition(loop_header, loop_header->begin());
std::vector<PhiInst *> phis;
auto original_args = F->getArguments();
for (auto &arg : original_args) {
auto phi = builder->createPhiInst(arg->getType(), {}, {}, "tco.phi."+arg->getName());
phis.push_back(phi);
}
// 用 phi 节点替换所有原始参数的使用
for (size_t i = 0; i < original_args.size(); ++i) {
original_args[i]->replaceAllUsesWith(phis[i]);
}
// 设置 phi 节点的输入值
for (size_t i = 0; i < phis.size(); ++i) {
phis[i]->addIncoming(original_args[i], new_entry);
}
// 连接各个基本块
builder->setPosition(original_entry, original_entry->end());
builder->createUncondBrInst(new_entry);
original_entry->addSuccessor(new_entry);
builder->setPosition(new_entry, new_entry->end());
builder->createUncondBrInst(loop_header);
new_entry->addSuccessor(loop_header);
loop_header->addPredecessor(new_entry);
// 处理每一个尾递归调用
for (auto callInst : tailCallInsts) {
auto tail_call_block = callInst->getParent();
// 收集尾递归调用的参数
auto args_range = callInst->getArguments();
std::vector<Value*> args;
std::transform(args_range.begin(), args_range.end(), std::back_inserter(args),
[](auto& use_ptr){ return use_ptr->getValue(); });
// 用新的参数值更新 phi 节点
for (size_t i = 0; i < phis.size(); ++i) {
phis[i]->addIncoming(args[i], tail_call_block);
}
// 移除原有的调用和返回指令
auto term_iter = tail_call_block->terminator();
SysYIROptUtils::usedelete(term_iter);
auto call_iter = tail_call_block->findInstIterator(callInst);
SysYIROptUtils::usedelete(call_iter);
// 添加跳转回循环头块的分支指令
builder->setPosition(tail_call_block, tail_call_block->end());
builder->createUncondBrInst(loop_header);
tail_call_block->addSuccessor(loop_header);
loop_header->addPredecessor(tail_call_block);
}
return true;
}
} // namespace sysy

View File

@ -10,6 +10,7 @@
#include "DCE.h"
#include "Mem2Reg.h"
#include "Reg2Mem.h"
#include "GVN.h"
#include "SCCP.h"
#include "BuildCFG.h"
#include "LargeArrayToGlobal.h"
@ -17,6 +18,8 @@
#include "LICM.h"
#include "LoopStrengthReduction.h"
#include "InductionVariableElimination.h"
#include "GlobalStrengthReduction.h"
#include "TailCallOpt.h"
#include "Pass.h"
#include <iostream>
#include <queue>
@ -59,6 +62,8 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR
// 注册优化遍
registerOptimizationPass<BuildCFG>();
registerOptimizationPass<LargeArrayToGlobalPass>();
registerOptimizationPass<GVN>();
registerOptimizationPass<SysYDelInstAfterBrPass>();
registerOptimizationPass<SysYDelNoPreBLockPass>();
@ -74,7 +79,10 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR
registerOptimizationPass<LICM>(builderIR);
registerOptimizationPass<LoopStrengthReduction>(builderIR);
registerOptimizationPass<InductionVariableElimination>();
registerOptimizationPass<GlobalStrengthReduction>(builderIR);
registerOptimizationPass<Reg2Mem>(builderIR);
registerOptimizationPass<TailCallOpt>(builderIR);
registerOptimizationPass<SCCP>(builderIR);
@ -129,6 +137,25 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR
printPasses();
}
this->clearPasses();
this->addPass(&GVN::ID);
this->run();
this->clearPasses();
this->addPass(&TailCallOpt::ID);
this->run();
if(DEBUG) {
std::cout << "=== IR After TailCallOpt ===\n";
SysYPrinter printer(moduleIR);
printer.printIR();
}
if(DEBUG) {
std::cout << "=== IR After GVN Optimizations ===\n";
printPasses();
}
this->clearPasses();
this->addPass(&SCCP::ID);
this->run();
@ -141,18 +168,45 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR
this->clearPasses();
this->addPass(&LoopNormalizationPass::ID);
this->addPass(&InductionVariableElimination::ID);
this->run();
if(DEBUG) {
std::cout << "=== IR After Loop Normalization, Induction Variable Elimination ===\n";
printPasses();
}
this->clearPasses();
this->addPass(&LICM::ID);
this->run();
if(DEBUG) {
std::cout << "=== IR After LICM ===\n";
printPasses();
}
this->clearPasses();
this->addPass(&LoopStrengthReduction::ID);
this->run();
if(DEBUG) {
std::cout << "=== IR After Loop Normalization, LICM, and Strength Reduction Optimizations ===\n";
std::cout << "=== IR After Loop Normalization, and Strength Reduction Optimizations ===\n";
printPasses();
}
// this->clearPasses();
// this->addPass(&Reg2Mem::ID);
// this->run();
// 全局强度削弱优化,包括代数优化和魔数除法
this->clearPasses();
this->addPass(&GlobalStrengthReduction::ID);
this->run();
if(DEBUG) {
std::cout << "=== IR After Global Strength Reduction Optimizations ===\n";
printPasses();
}
this->clearPasses();
this->addPass(&Reg2Mem::ID);
this->run();
if(DEBUG) {
std::cout << "=== IR After Reg2Mem Optimizations ===\n";

View File

@ -262,12 +262,10 @@ void SysYIRGenerator::compute() {
}
// 弹出BinaryExpStack的表达式
int count = end - begin;
for (int i = 0; i < count; i++) {
while(begin < end) {
BinaryExpStack.pop_back();
}
if (!BinaryExpLenStack.empty()) {
BinaryExpLenStack.back() -= count;
BinaryExpLenStack.back()--;
end--;
}
// 计算后缀表达式