Compare commits
32 Commits
backend-fm
...
midend-m2r
| Author | SHA1 | Date | |
|---|---|---|---|
| d465fb02a5 | |||
| 3c49183280 | |||
| 7af3827098 | |||
| 1ab937961f | |||
| 06b4df79ee | |||
| d79857feb9 | |||
| 91d4a39c9a | |||
| 042b1a5d99 | |||
| 0fdcd0dd69 | |||
| d7bf4b061f | |||
| 937833117e | |||
| 094b4c7c39 | |||
| f4617b357e | |||
| babb576317 | |||
| 0720a622c1 | |||
| ad74e435ba | |||
| acb0302a29 | |||
| 5c34cbc7b8 | |||
| c9a0c700e1 | |||
| b57a3f1953 | |||
| f317010d76 | |||
| 8ca64610eb | |||
| 969a78a088 | |||
| d77aedaf8b | |||
| 8763c0a11a | |||
| d83dc7a2e7 | |||
| e32585fd25 | |||
| c4eb1c3980 | |||
| 5ef01ada90 | |||
| 072cd3e9b5 | |||
| d038884ffb | |||
| 467f2f6b24 |
@ -74,31 +74,30 @@ graph TD
|
||||
- **消除 `fallthrough` 现象**:
|
||||
通过确保所有基本块均以终结指令结尾,消除基本块间的 `fallthrough`,简化了控制流图(CFG)的构建和分析。这一做法提升了编译器整体质量,使中端各类 Pass 的编写和维护更加规范和高效。
|
||||
|
||||
|
||||
### 3.2. 核心优化详解
|
||||
|
||||
编译器的分析和优化被组织成一系列独立的“遍”(Pass)。每个 Pass 都是一个独立的算法模块,对 IR 进行特定的分析或变换。这种设计具有高度的模块化和可扩展性。
|
||||
|
||||
#### 3.2.1. SSA 构建与解构
|
||||
|
||||
- **Mem2Reg (`Mem2Reg.cpp`)**:
|
||||
- **Mem2Reg (`Mem2Reg.cpp`)**:
|
||||
- **目标**: 将对栈内存 (`alloca`) 的 `load`/`store` 操作,提升为对虚拟寄存器的直接操作,并构建 SSA 形式。
|
||||
- **技术**: 该过程是实现 SSA 的关键。它依赖于**支配树 (Dominator Tree)** 分析,通过寻找变量定义块的**支配边界 (Dominance Frontier)** 来确定在何处插入 **Φ (Phi) 函数**。
|
||||
- **实现**: `Mem2RegContext::run` 驱动此过程。首先调用 `isPromotableAlloca` 识别所有仅被 `load`/`store` 使用的标量 `alloca`。然后,`insertPhis` 根据支配边界信息在必要的控制流汇合点插入 `phi` 指令。最后,`renameVariables` 递归地遍历支配树,用一个模拟的值栈来将 `load` 替换为栈顶的 SSA 值,将 `store` 视为对栈的一次 `push` 操作,从而完成重命名。值得一提的是,由于我们在IR生成阶段就将所有alloca指令统一放置在入口块,极大地简化了Mem2Reg遍的实现和支配树分析的计算。
|
||||
|
||||
- **Reg2Mem (`Reg2Mem.cpp`)**:
|
||||
- **Reg2Mem (`Reg2Mem.cpp`)**:
|
||||
- **目标**: 执行 `Mem2Reg` 的逆操作,将程序从 SSA 形式转换回基于内存的表示。这通常是为不支持 SSA 的后端做准备的**SSA解构 (SSA Destruction)** 步骤。
|
||||
- **技术**: 为每个 SSA 值(指令结果、函数参数)在函数入口创建一个 `alloca` 栈槽。然后,在每个 SSA 值的定义点之后插入一个 `store` 将其存入对应的栈槽;在每个使用点之前插入一个 `load` 从栈槽中取出值。
|
||||
- **实现**: `Reg2MemContext::run` 驱动此过程。`allocateMemoryForSSAValues` 为所有需要转换的 SSA 值创建 `alloca` 指令。`rewritePhis` 特殊处理 `phi` 指令,在每个前驱块的末尾插入 `store`。`insertLoadsAndStores` 则处理所有非 `phi` 指令的定义和使用,插入相应的 `store` 和 `load`。虽然
|
||||
|
||||
#### 3.2.2. 常量与死代码优化
|
||||
|
||||
- **SCCP (`SCCP.cpp`)**:
|
||||
- **SCCP (`SCCP.cpp`)**:
|
||||
- **目标**: 稀疏条件常量传播。在编译期计算常量表达式,并利用分支条件为常数的信息来消除死代码,比简单的常量传播更强大。
|
||||
- **技术**: 这是一种基于数据流分析的格理论(Lattice Theory)的优化。它为每个变量维护一个值状态,可能为 `Top` (未定义), `Constant` (某个常量值), 或 `Bottom` (非常量)。同时,它跟踪基本块的可达性,如果一个分支的条件被推断为常量,则其不可达的后继分支在分析中会被直接忽略。
|
||||
- **实现**: `SCCPContext::run` 驱动整个分析过程。它维护一个指令工作列表和一个边工作列表。`ProcessInstruction` 和 `ProcessEdge` 函数交替执行,不断地从 IR 中传播常量和可达性信息,直到达到不动点为止。最后,`PropagateConstants` 和 `SimplifyControlFlow` 将推断出的常量替换到代码中,并移除死块。
|
||||
|
||||
- **DCE (`DCE.cpp`)**:
|
||||
- **DCE (`DCE.cpp`)**:
|
||||
- **目标**: 简单死代码消除。移除那些计算结果对程序输出没有贡献的指令。
|
||||
- **技术**: 采用**标记-清除 (Mark and Sweep)** 算法。从具有副作用的指令(如 `store`, `call`, `return`)开始,反向追溯其操作数,标记所有相关的指令为“活跃”。
|
||||
- **实现**: `DCEContext::run` 实现了此算法。第一次遍历时,通过 `isAlive` 函数识别出具有副作用的“根”指令,然后调用 `addAlive` 递归地将所有依赖的指令加入 `alive_insts` 集合。第二次遍历时,所有未被标记为活跃的指令都将被删除。
|
||||
@ -116,24 +115,18 @@ graph TD
|
||||
|
||||
#### 3.2.4. 其他优化
|
||||
|
||||
- **LargeArrayToGlobal (`LargeArrayToGlobal.cpp`)**:
|
||||
- **目标**: 防止因大型局部数组导致的栈溢出,并可能改善数据局部性。
|
||||
- **技术**: 遍历函数中的 `alloca` 指令,如果通过 `calculateTypeSize` 计算出其分配的内存大小超过一个阈值(如 1024 字节),则将其转换为一个全局变量。
|
||||
- **实现**: `convertAllocaToGlobal` 函数负责创建一个新的 `GlobalValue`,并调用 `replaceAllUsesWith` 将原 `alloca` 的所有使用者重定向到新的全局变量,最后删除原 `alloca` 指令。
|
||||
|
||||
#### 3.3. 核心分析遍
|
||||
#### 3.3. 核心分析遍
|
||||
|
||||
为了为优化遍收集信息,最大程度发掘程序优化潜力,我们目前设计并实现了以下关键的分析遍:
|
||||
|
||||
- **支配树分析 (Dominator Tree Analysis)**:
|
||||
- **技术**: 通过计算每个基本块的支配节点,构建出一棵支配树结构。我们在计算支配节点时采用了**逆后序遍历(RPO, Reverse Post Order)**,以保证数据流分析的收敛速度和正确性。在计算直接支配者(Idom, Immediate Dominator)时,采用了经典的**Lengauer-Tarjan(LT)算法**,该算法以高效的并查集和路径压缩技术著称,能够在线性时间内准确计算出每个基本块的直接支配者关系。
|
||||
- **实现**: `Dom.cpp` 实现了支配树分析。该分析为每个基本块分配其直接支配者,并递归构建整棵支配树。支配树是许多高级优化(尤其是 SSA 形式下的优化)的基础。例如,Mem2Reg 需要依赖支配树来正确插入 Phi 指令,并在变量重命名阶段高效遍历控制流图。此外,循环相关优化(如循环不变量外提)也依赖于支配树信息来识别循环头和循环体的关系。
|
||||
|
||||
- **活跃性分析 (Liveness Analysis)**:
|
||||
- **技术**: 活跃性分析用于确定在程序的某一特定点上,哪些变量的值在未来会被用到。我们采用**经典的不动点迭代算法**,在数据流分析框架下,逆序遍历基本块,迭代计算每个基本块的 `live-in` 和 `live-out` 集合,直到收敛为止。这种方法简单且易于实现,能够满足大多数编译优化的需求。
|
||||
- **未来规划**: 若后续对分析效率有更高要求,可考虑引入如**工作列表算法**或者**转化为基于SSA的图可达性分析**等更高效的算法,以进一步提升大型函数或复杂控制流下的分析性能。
|
||||
- **实现**: `Liveness.cpp` 提供了活跃性分析。该分析采用经典的数据流分析框架,迭代计算每个基本块的 `live-in` 和 `live-out` 集合。活跃性信息是死代码消除(DCE)、寄存器分配等优化的必要前置步骤。通过准确的活跃性分析,可以识别出无用的变量和指令,从而为后续优化遍提供坚实的数据基础。
|
||||
- **支配树分析 (Dominator Tree Analysis)**:
|
||||
- **技术**: 通过计算每个基本块的支配节点,构建出一棵支配树结构。我们在计算支配节点时采用了**逆后序遍历(RPO, Reverse Post Order)**,以保证数据流分析的收敛速度和正确性。在计算直接支配者(Idom, Immediate Dominator)时,采用了经典的**Lengauer-Tarjan(LT)算法**,该算法以高效的并查集和路径压缩技术著称,能够在线性时间内准确计算出每个基本块的直接支配者关系。
|
||||
- **实现**: `Dom.cpp` 实现了支配树分析。该分析为每个基本块分配其直接支配者,并递归构建整棵支配树。支配树是许多高级优化(尤其是 SSA 形式下的优化)的基础。例如,Mem2Reg 需要依赖支配树来正确插入 Phi 指令,并在变量重命名阶段高效遍历控制流图。此外,循环相关优化(如循环不变量外提)也依赖于支配树信息来识别循环头和循环体的关系。
|
||||
|
||||
- **活跃性分析 (Liveness Analysis)**:
|
||||
- **技术**: 活跃性分析用于确定在程序的某一特定点上,哪些变量的值在未来会被用到。我们采用**经典的不动点迭代算法**,在数据流分析框架下,逆序遍历基本块,迭代计算每个基本块的 `live-in` 和 `live-out` 集合,直到收敛为止。这种方法简单且易于实现,能够满足大多数编译优化的需求。
|
||||
- **未来规划**: 若后续对分析效率有更高要求,可考虑引入如**工作列表算法**或者**转化为基于SSA的图可达性分析**等更高效的算法,以进一步提升大型函数或复杂控制流下的分析性能。
|
||||
- **实现**: `Liveness.cpp` 提供了活跃性分析。该分析采用经典的数据流分析框架,迭代计算每个基本块的 `live-in` 和 `live-out` 集合。活跃性信息是死代码消除(DCE)、寄存器分配等优化的必要前置步骤。通过准确的活跃性分析,可以识别出无用的变量和指令,从而为后续优化遍提供坚实的数据基础。
|
||||
|
||||
### 3.4. 未来的规划
|
||||
|
||||
@ -145,6 +138,7 @@ graph TD
|
||||
函数内联能够将简单函数(可能需要收集更多信息)内联到call指令相应位置,减少栈空间相关变动,并且为其他遍发掘优化空间。
|
||||
- **`LLVM IR`格式化**:
|
||||
我们将为所有的IR设计并实现通用的打印器方法,使得IR能够显式化为可编译运行的LLVM IR,通过编排脚本和调用llvm相关工具链,我们能够绕过后端编译运行中间代码,为验证中端正确性提供系统化的方法,同时减轻后端开发bug溯源的压力。
|
||||
|
||||
---
|
||||
|
||||
## 4. 后端技术与优化 (Backend)
|
||||
@ -215,16 +209,16 @@ graph TD
|
||||
end
|
||||
```
|
||||
|
||||
1. **`analyzeLiveness()`**: 对机器指令进行数据流分析,计算出每个虚拟寄存器的活跃范围。
|
||||
2. **`build()`**: 根据活跃性信息构建**冲突图 (Interference Graph)**。如果两个虚拟寄存器同时活跃,则它们冲突,在图中连接一条边。
|
||||
3. **`makeWorklist()`**: 将图节点(虚拟寄存器)根据其度数放入不同的工作列表,为着色做准备。
|
||||
4. **核心着色阶段 (The Loop)**:
|
||||
1. **`analyzeLiveness()`**: 对机器指令进行数据流分析,计算出每个虚拟寄存器的活跃范围。
|
||||
2. **`build()`**: 根据活跃性信息构建**冲突图 (Interference Graph)**。如果两个虚拟寄存器同时活跃,则它们冲突,在图中连接一条边。
|
||||
3. **`makeWorklist()`**: 将图节点(虚拟寄存器)根据其度数放入不同的工作列表,为着色做准备。
|
||||
4. **核心着色阶段 (The Loop)**:
|
||||
- **`simplify()`**: 贪心地移除图中度数小于物理寄存器数量的节点,并将其压入栈中。这些节点保证可以被成功着色。
|
||||
- **`coalesce()`**: 尝试将传送指令 (`MV`) 的源和目标节点合并,以消除这条指令。合并的条件基于 **Briggs** 或 **George** 启发式,以避免使图变得不可着色。
|
||||
- **`freeze()`**: 当一个与传送指令相关的节点无法合并也无法简化时,放弃对该传送指令的合并希望,将其“冻结”为一个普通节点。
|
||||
- **`selectSpill()`**: 当所有节点都无法进行上述操作时(即图中只剩下高度数的节点),必须选择一个节点进行**溢出 (Spill)**,即决定将其存放在内存中。
|
||||
5. **`assignColors()`**: 在所有节点都被处理后,从栈中依次弹出节点,并根据其已着色邻居的颜色,为它选择一个可用的物理寄存器。
|
||||
6. **`rewriteProgram()`**: 如果 `assignColors()` 阶段发现有节点被标记为溢出,此函数会被调用。它会修改机器指令,为溢出的虚拟寄存器插入从内存加载(`lw`/`ld`)和存入内存(`sw`/`sd`)的代码。然后,整个分配过程从步骤 1 重新开始。
|
||||
5. **`assignColors()`**: 在所有节点都被处理后,从栈中依次弹出节点,并根据其已着色邻居的颜色,为它选择一个可用的物理寄存器。
|
||||
6. **`rewriteProgram()`**: 如果 `assignColors()` 阶段发现有节点被标记为溢出,此函数会被调用。它会修改机器指令,为溢出的虚拟寄存器插入从内存加载(`lw`/`ld`)和存入内存(`sw`/`sd`)的代码。然后,整个分配过程从步骤 1 重新开始。
|
||||
|
||||
### 4.4. 后端特定优化
|
||||
|
||||
@ -232,11 +226,11 @@ graph TD
|
||||
|
||||
#### 4.4.1. 指令调度 (Instruction Scheduling)
|
||||
|
||||
- **寄存器分配前调度 (`PreRA_Scheduler.cpp`)**:
|
||||
- **寄存器分配前调度 (`PreRA_Scheduler.cpp`)**:
|
||||
- **目标**: 在寄存器分配前,通过重排指令来提升性能。主要目标是**隐藏加载延迟 (Load Latency)**,即尽早发出 `load` 指令,使其结果能在需要时及时准备好,避免流水线停顿。同时,由于此时使用的是无限的虚拟寄存器,调度器有较大的自由度,但也可能因为过度重排而延长虚拟寄存器的生命周期,从而增加寄存器压力。
|
||||
- **实现**: `scheduleBlock()` 函数会识别出基本块内的调度边界(如 `call` 或终结指令),然后在每个独立的区域内调用 `scheduleRegion()`。当前的实现是一种简化的列表调度,它会优先尝试将加载指令 (`LW`, `LD` 等) 在不违反数据依赖的前提下,尽可能地向前移动。
|
||||
|
||||
- **寄存器分配后调度 (`PostRA_Scheduler.cpp`)**:
|
||||
- **寄存器分配后调度 (`PostRA_Scheduler.cpp`)**:
|
||||
- **目标**: 在寄存器分配完成之后,对指令序列进行最后一轮微调。此阶段调度的主要目标与分配前不同,它旨在解决由寄存器分配过程本身引入的性能问题,例如:
|
||||
- **缓解溢出代价**: 将因溢出(Spill)而产生的 `load` 指令(从栈加载)尽可能地提前,远离其使用点;将 `store` 指令(存入栈)尽可能地推后,远离其定义点。
|
||||
- **消除伪依赖**: 寄存器分配器可能会为两个原本不相关的虚拟寄存器分配同一个物理寄存器,从而引入了虚假的写后读(WAR)或写后写(WAW)依赖。Post-RA 调度可以尝试解开这些伪依赖,为指令重排提供更多自由度。
|
||||
@ -244,7 +238,7 @@ graph TD
|
||||
|
||||
#### 4.4.2. 强度削减 (Strength Reduction)
|
||||
|
||||
- **除法强度削减 (`DivStrengthReduction.cpp`)**:
|
||||
- **除法强度削减 (`DivStrengthReduction.cpp`)**:
|
||||
- **目标**: 将机器指令中昂贵的 `DIV` 或 `DIVW` 指令(当除数为编译期常量时)替换为一系列更快、计算成本更低的指令组合。
|
||||
- **技术**: 基于数论中的**乘法逆元 (Multiplicative Inverse)** 思想。对于一个整数除法 `x / d`,可以找到一个“魔数” `m` 和一个移位数 `s`,使得该除法可以被近似替换为 `(x * m) >> s`。这个过程需要处理复杂的符号、取整和溢出问题。
|
||||
- **实现**: `runOnMachineFunction()` 实现了此优化。它会遍历机器指令,寻找以常量为除数的 `DIV`/`DIVW` 指令。`computeMagic()` 函数负责计算出对应的魔数和移位数。然后,根据除数是 2 的幂、1、-1 还是其他普通数字,生成不同的指令序列,包括 `MULH` (取高位乘积), `SRAI` (算术右移), `ADD`, `SUB` 等,来精确地模拟定点数除法的效果。
|
||||
@ -263,10 +257,10 @@ graph TD
|
||||
|
||||
根据项目中的 `TODO` 列表和源代码分析,当前实现存在一些可改进之处:
|
||||
|
||||
- **寄存器分配**:
|
||||
- **寄存器分配**:
|
||||
- **`CALL` 指令处理**: 当前对 `CALL` 指令的 `use`/`def` 分析不完整,没有将所有调用者保存的寄存器标记为 `def`,这可能导致跨函数调用的值被错误破坏。
|
||||
- **溢出处理**: 当前所有溢出的虚拟寄存器都被简单地映射到同一个物理寄存器 `t6` 上,这会引入大量不必要的 `load`/`store`,并可能导致 `t6` 成为性能瓶颈。
|
||||
- **IR 设计**:
|
||||
- **IR 设计**:
|
||||
- 随着 SSA 的引入,IR 中某些冗余信息(如基本块的 `args` 参数)可以被移除,以简化设计。
|
||||
- **优化**:
|
||||
- 当前的优化主要集中在标量上。可以引入更多面向循环的优化(如循环不变代码外提 LICM、归纳变量分析 IndVar)和过程间优化来进一步提升性能。
|
||||
- **优化**:
|
||||
- 当前的优化主要集中在标量上。可以引入更多面向循环的优化(如循环不变代码外提 LICM、归纳变量分析 IndVar)和过程间优化来进一步提升性能。
|
||||
|
||||
@ -20,18 +20,19 @@ QEMU_RISCV64="qemu-riscv64"
|
||||
|
||||
# --- 初始化变量 ---
|
||||
EXECUTE_MODE=false
|
||||
IR_EXECUTE_MODE=false # 新增
|
||||
IR_EXECUTE_MODE=false
|
||||
CLEAN_MODE=false
|
||||
OPTIMIZE_FLAG=""
|
||||
SYSYC_TIMEOUT=30
|
||||
LLC_TIMEOUT=10 # 新增
|
||||
LLC_TIMEOUT=10
|
||||
GCC_TIMEOUT=10
|
||||
EXEC_TIMEOUT=30
|
||||
MAX_OUTPUT_LINES=20
|
||||
MAX_OUTPUT_CHARS=1000
|
||||
SY_FILES=()
|
||||
PASSED_CASES=0
|
||||
FAILED_CASES_LIST=""
|
||||
INTERRUPTED=false # 新增
|
||||
INTERRUPTED=false
|
||||
|
||||
# =================================================================
|
||||
# --- 函数定义 ---
|
||||
@ -50,22 +51,31 @@ show_help() {
|
||||
echo " -gct N 设置 gcc 交叉编译超时为 N 秒 (默认: 10)。"
|
||||
echo " -et N 设置 qemu 自动化执行超时为 N 秒 (默认: 30)。"
|
||||
echo " -ml N, --max-lines N 当输出对比失败时,最多显示 N 行内容 (默认: 20)。"
|
||||
echo " -mc N, --max-chars N 当输出对比失败时,最多显示 N 个字符 (默认: 1000)。"
|
||||
echo " -h, --help 显示此帮助信息并退出。"
|
||||
echo ""
|
||||
echo "可在任何时候按 Ctrl+C 来中断测试并显示当前已完成的测例总结。"
|
||||
}
|
||||
|
||||
# 显示文件内容并根据行数和字符数截断的函数
|
||||
display_file_content() {
|
||||
local file_path="$1"
|
||||
local title="$2"
|
||||
local max_lines="$3"
|
||||
local max_chars="$4" # 新增参数
|
||||
if [ ! -f "$file_path" ]; then return; fi
|
||||
echo -e "$title"
|
||||
local line_count
|
||||
local char_count
|
||||
line_count=$(wc -l < "$file_path")
|
||||
char_count=$(wc -c < "$file_path")
|
||||
|
||||
if [ "$line_count" -gt "$max_lines" ]; then
|
||||
head -n "$max_lines" "$file_path"
|
||||
echo -e "\e[33m[... 输出已截断,共 ${line_count} 行 ...]\e[0m"
|
||||
echo -e "\e[33m[... 输出因行数过多 (共 ${line_count} 行) 而截断 ...]\e[0m"
|
||||
elif [ "$char_count" -gt "$max_chars" ]; then
|
||||
head -c "$max_chars" "$file_path"
|
||||
echo -e "\n\e[33m[... 输出因字符数过多 (共 ${char_count} 字符) 而截断 ...]\e[0m"
|
||||
else
|
||||
cat "$file_path"
|
||||
fi
|
||||
@ -131,6 +141,7 @@ while [[ "$#" -gt 0 ]]; do
|
||||
-gct) if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then GCC_TIMEOUT="$2"; shift 2; else echo "错误: -gct 需要一个正整数参数。" >&2; exit 1; fi ;;
|
||||
-et) if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then EXEC_TIMEOUT="$2"; shift 2; else echo "错误: -et 需要一个正整数参数。" >&2; exit 1; fi ;;
|
||||
-ml|--max-lines) if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then MAX_OUTPUT_LINES="$2"; shift 2; else echo "错误: --max-lines 需要一个正整数参数。" >&2; exit 1; fi ;;
|
||||
-mc|--max-chars) if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then MAX_OUTPUT_CHARS="$2"; shift 2; else echo "错误: --max-chars 需要一个正整数参数。" >&2; exit 1; fi ;;
|
||||
-h|--help) show_help; exit 0 ;;
|
||||
-*) echo "未知选项: $1"; show_help; exit 1 ;;
|
||||
*)
|
||||
@ -180,6 +191,8 @@ TOTAL_CASES=${#SY_FILES[@]}
|
||||
echo "SysY 单例测试运行器启动..."
|
||||
if [ -n "$OPTIMIZE_FLAG" ]; then echo "优化等级: ${OPTIMIZE_FLAG}"; fi
|
||||
echo "超时设置: sysyc=${SYSYC_TIMEOUT}s, llc=${LLC_TIMEOUT}s, gcc=${GCC_TIMEOUT}s, qemu=${EXEC_TIMEOUT}s"
|
||||
echo "失败输出最大行数: ${MAX_OUTPUT_LINES}"
|
||||
echo "失败输出最大字符数: ${MAX_OUTPUT_CHARS}"
|
||||
echo ""
|
||||
|
||||
for sy_file in "${SY_FILES[@]}"; do
|
||||
@ -260,8 +273,8 @@ for sy_file in "${SY_FILES[@]}"; do
|
||||
out_ok=1
|
||||
if ! diff -q <(tr -d '[:space:]' < "${output_actual_file}") <(tr -d '[:space:]' < "${EXPECTED_STDOUT_FILE}") >/dev/null 2>&1; then
|
||||
echo -e "\e[31m 标准输出测试失败。\e[0m"; out_ok=0
|
||||
display_file_content "${EXPECTED_STDOUT_FILE}" " \e[36m--- 期望输出 ---\e[0m" "${MAX_OUTPUT_LINES}"
|
||||
display_file_content "${output_actual_file}" " \e[36m--- 实际输出 ---\e[0m" "${MAX_OUTPUT_LINES}"
|
||||
display_file_content "${EXPECTED_STDOUT_FILE}" " \e[36m--- 期望输出 ---\e[0m" "${MAX_OUTPUT_LINES}" "${MAX_OUTPUT_CHARS}"
|
||||
display_file_content "${output_actual_file}" " \e[36m--- 实际输出 ---\e[0m" "${MAX_OUTPUT_LINES}" "${MAX_OUTPUT_CHARS}"
|
||||
fi
|
||||
|
||||
if [ "$ret_ok" -eq 1 ] && [ "$out_ok" -eq 1 ]; then echo -e "\e[32m 返回码与标准输出测试成功。\e[0m"; else is_passed=0; fi
|
||||
@ -271,8 +284,8 @@ for sy_file in "${SY_FILES[@]}"; do
|
||||
echo -e "\e[32m 标准输出测试成功。\e[0m"
|
||||
else
|
||||
echo -e "\e[31m 标准输出测试失败。\e[0m"; is_passed=0
|
||||
display_file_content "${output_reference_file}" " \e[36m--- 期望输出 ---\e[0m" "${MAX_OUTPUT_LINES}"
|
||||
display_file_content "${output_actual_file}" " \e[36m--- 实际输出 ---\e[0m" "${MAX_OUTPUT_LINES}"
|
||||
display_file_content "${output_reference_file}" " \e[36m--- 期望输出 ---\e[0m" "${MAX_OUTPUT_LINES}" "${MAX_OUTPUT_CHARS}"
|
||||
display_file_content "${output_actual_file}" " \e[36m--- 实际输出 ---\e[0m" "${MAX_OUTPUT_LINES}" "${MAX_OUTPUT_CHARS}"
|
||||
fi
|
||||
fi
|
||||
else
|
||||
@ -301,4 +314,4 @@ for sy_file in "${SY_FILES[@]}"; do
|
||||
done
|
||||
|
||||
# --- 打印最终总结 ---
|
||||
print_summary
|
||||
print_summary
|
||||
|
||||
@ -27,11 +27,12 @@ LLC_TIMEOUT=10
|
||||
GCC_TIMEOUT=10
|
||||
EXEC_TIMEOUT=30
|
||||
MAX_OUTPUT_LINES=20
|
||||
MAX_OUTPUT_CHARS=1000
|
||||
TEST_SETS=()
|
||||
TOTAL_CASES=0
|
||||
PASSED_CASES=0
|
||||
FAILED_CASES_LIST=""
|
||||
INTERRUPTED=false # 新增:用于标记是否被中断
|
||||
INTERRUPTED=false
|
||||
|
||||
# =================================================================
|
||||
# --- 函数定义 ---
|
||||
@ -53,6 +54,7 @@ show_help() {
|
||||
echo " -gct N 设置 gcc 交叉编译超时为 N 秒 (默认: 10)。"
|
||||
echo " -et N 设置 qemu 执行超时为 N 秒 (默认: 30)。"
|
||||
echo " -ml N, --max-lines N 当输出对比失败时,最多显示 N 行内容 (默认: 20)。"
|
||||
echo " -mc N, --max-chars N 当输出对比失败时,最多显示 N 个字符 (默认: 1000)。"
|
||||
echo " -h, --help 显示此帮助信息并退出。"
|
||||
echo ""
|
||||
echo "注意: 默认行为 (无 -e 或 -eir) 是将 .sy 文件同时编译为 .s (汇编) 和 .ll (IR),不执行。"
|
||||
@ -60,18 +62,25 @@ show_help() {
|
||||
}
|
||||
|
||||
|
||||
# 显示文件内容并根据行数截断的函数
|
||||
# 显示文件内容并根据行数和字符数截断的函数
|
||||
display_file_content() {
|
||||
local file_path="$1"
|
||||
local title="$2"
|
||||
local max_lines="$3"
|
||||
local max_chars="$4" # 新增参数
|
||||
if [ ! -f "$file_path" ]; then return; fi
|
||||
echo -e "$title"
|
||||
local line_count
|
||||
local char_count
|
||||
line_count=$(wc -l < "$file_path")
|
||||
char_count=$(wc -c < "$file_path")
|
||||
|
||||
if [ "$line_count" -gt "$max_lines" ]; then
|
||||
head -n "$max_lines" "$file_path"
|
||||
echo -e "\e[33m[... 输出已截断,共 ${line_count} 行 ...]\e[0m"
|
||||
echo -e "\e[33m[... 输出因行数过多 (共 ${line_count} 行) 而截断 ...]\e[0m"
|
||||
elif [ "$char_count" -gt "$max_chars" ]; then
|
||||
head -c "$max_chars" "$file_path"
|
||||
echo -e "\n\e[33m[... 输出因字符数过多 (共 ${char_count} 字符) 而截断 ...]\e[0m"
|
||||
else
|
||||
cat "$file_path"
|
||||
fi
|
||||
@ -151,6 +160,7 @@ while [[ "$#" -gt 0 ]]; do
|
||||
-gct) if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then GCC_TIMEOUT="$2"; shift 2; else echo "错误: -gct 需要一个正整数参数。" >&2; exit 1; fi ;;
|
||||
-et) if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then EXEC_TIMEOUT="$2"; shift 2; else echo "错误: -et 需要一个正整数参数。" >&2; exit 1; fi ;;
|
||||
-ml|--max-lines) if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then MAX_OUTPUT_LINES="$2"; shift 2; else echo "错误: --max-lines 需要一个正整数参数。" >&2; exit 1; fi ;;
|
||||
-mc|--max-chars) if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then MAX_OUTPUT_CHARS="$2"; shift 2; else echo "错误: --max-chars 需要一个正整数参数。" >&2; exit 1; fi ;;
|
||||
-h|--help) show_help; exit 0 ;;
|
||||
*) echo "未知选项: $1"; show_help; exit 1 ;;
|
||||
esac
|
||||
@ -204,6 +214,7 @@ echo "运行模式: ${RUN_MODE_INFO}"
|
||||
echo "${TIMEOUT_INFO}"
|
||||
if ${EXECUTE_MODE} || ${IR_EXECUTE_MODE}; then
|
||||
echo "失败输出最大行数: ${MAX_OUTPUT_LINES}"
|
||||
echo "失败输出最大字符数: ${MAX_OUTPUT_CHARS}"
|
||||
fi
|
||||
echo ""
|
||||
|
||||
@ -298,8 +309,8 @@ while IFS= read -r sy_file; do
|
||||
[ "$test_logic_passed" -eq 1 ] && echo -e "\e[32m 标准输出测试成功\e[0m"
|
||||
else
|
||||
echo -e "\e[31m 标准输出测试失败\e[0m"
|
||||
display_file_content "${EXPECTED_STDOUT_FILE}" " \e[36m---------- 期望输出 ----------\e[0m" "${MAX_OUTPUT_LINES}"
|
||||
display_file_content "${output_actual_file_from_ir}" " \e[36m---------- 实际输出 ----------\e[0m" "${MAX_OUTPUT_LINES}"
|
||||
display_file_content "${EXPECTED_STDOUT_FILE}" " \e[36m---------- 期望输出 ----------\e[0m" "${MAX_OUTPUT_LINES}" "${MAX_OUTPUT_CHARS}"
|
||||
display_file_content "${output_actual_file_from_ir}" " \e[36m---------- 实际输出 ----------\e[0m" "${MAX_OUTPUT_LINES}" "${MAX_OUTPUT_CHARS}"
|
||||
test_logic_passed=0
|
||||
fi
|
||||
else
|
||||
@ -308,8 +319,8 @@ while IFS= read -r sy_file; do
|
||||
echo -e "\e[32m 成功: 输出与参考输出匹配\e[0m"
|
||||
else
|
||||
echo -e "\e[31m 失败: 输出不匹配\e[0m"
|
||||
display_file_content "${output_reference_file}" " \e[36m---------- 期望输出 ----------\e[0m" "${MAX_OUTPUT_LINES}"
|
||||
display_file_content "${output_actual_file_from_ir}" " \e[36m---------- 实际输出 ----------\e[0m" "${MAX_OUTPUT_LINES}"
|
||||
display_file_content "${output_reference_file}" " \e[36m---------- 期望输出 ----------\e[0m" "${MAX_OUTPUT_LINES}" "${MAX_OUTPUT_CHARS}"
|
||||
display_file_content "${output_actual_file_from_ir}" " \e[36m---------- 实际输出 ----------\e[0m" "${MAX_OUTPUT_LINES}" "${MAX_OUTPUT_CHARS}"
|
||||
test_logic_passed=0
|
||||
fi
|
||||
fi
|
||||
@ -375,8 +386,8 @@ while IFS= read -r sy_file; do
|
||||
[ "$test_logic_passed" -eq 1 ] && echo -e "\e[32m 标准输出测试成功\e[0m"
|
||||
else
|
||||
echo -e "\e[31m 标准输出测试失败\e[0m"
|
||||
display_file_content "${EXPECTED_STDOUT_FILE}" " \e[36m---------- 期望输出 ----------\e[0m" "${MAX_OUTPUT_LINES}"
|
||||
display_file_content "${output_actual_file_S}" " \e[36m---------- 实际输出 ----------\e[0m" "${MAX_OUTPUT_LINES}"
|
||||
display_file_content "${EXPECTED_STDOUT_FILE}" " \e[36m---------- 期望输出 ----------\e[0m" "${MAX_OUTPUT_LINES}" "${MAX_OUTPUT_CHARS}"
|
||||
display_file_content "${output_actual_file_S}" " \e[36m---------- 实际输出 ----------\e[0m" "${MAX_OUTPUT_LINES}" "${MAX_OUTPUT_CHARS}"
|
||||
test_logic_passed=0
|
||||
fi
|
||||
else
|
||||
@ -385,8 +396,8 @@ while IFS= read -r sy_file; do
|
||||
echo -e "\e[32m 成功: 输出与参考输出匹配\e[0m"
|
||||
else
|
||||
echo -e "\e[31m 失败: 输出不匹配\e[0m"
|
||||
display_file_content "${output_reference_file}" " \e[36m---------- 期望输出 ----------\e[0m" "${MAX_OUTPUT_LINES}"
|
||||
display_file_content "${output_actual_file_S}" " \e[36m---------- 实际输出 ----------\e[0m" "${MAX_OUTPUT_LINES}"
|
||||
display_file_content "${output_reference_file}" " \e[36m---------- 期望输出 ----------\e[0m" "${MAX_OUTPUT_LINES}" "${MAX_OUTPUT_CHARS}"
|
||||
display_file_content "${output_actual_file_S}" " \e[36m---------- 实际输出 ----------\e[0m" "${MAX_OUTPUT_LINES}" "${MAX_OUTPUT_CHARS}"
|
||||
test_logic_passed=0
|
||||
fi
|
||||
fi
|
||||
|
||||
@ -6,6 +6,7 @@ add_library(riscv64_backend_lib STATIC
|
||||
RISCv64LLIR.cpp
|
||||
RISCv64RegAlloc.cpp
|
||||
RISCv64LinearScan.cpp
|
||||
RISCv64SimpleRegAlloc.cpp
|
||||
RISCv64BasicBlockAlloc.cpp
|
||||
Handler/CalleeSavedHandler.cpp
|
||||
Handler/LegalizeImmediates.cpp
|
||||
|
||||
@ -95,7 +95,7 @@ void LegalizeImmediatesPass::runOnMachineFunction(MachineFunction* mfunc) {
|
||||
case RVOpcodes::LB: case RVOpcodes::LH: case RVOpcodes::LW: case RVOpcodes::LD:
|
||||
case RVOpcodes::LBU: case RVOpcodes::LHU: case RVOpcodes::LWU:
|
||||
case RVOpcodes::SB: case RVOpcodes::SH: case RVOpcodes::SW: case RVOpcodes::SD:
|
||||
case RVOpcodes::FLW: case RVOpcodes::FSW: {
|
||||
case RVOpcodes::FLW: case RVOpcodes::FSW: case RVOpcodes::FLD: case RVOpcodes::FSD: {
|
||||
auto& operands = instr_ptr->getOperands();
|
||||
auto mem_op = static_cast<MemOperand*>(operands.back().get());
|
||||
auto offset_op = mem_op->getOffset();
|
||||
|
||||
@ -185,7 +185,7 @@ void RISCv64AsmPrinter::printOperand(MachineOperand* op) {
|
||||
}
|
||||
}
|
||||
|
||||
std::string RISCv64AsmPrinter::regToString(PhysicalReg reg) {
|
||||
std::string RISCv64AsmPrinter::regToString(PhysicalReg reg) const {
|
||||
switch (reg) {
|
||||
case PhysicalReg::ZERO: return "x0"; case PhysicalReg::RA: return "ra";
|
||||
case PhysicalReg::SP: return "sp"; case PhysicalReg::GP: return "gp";
|
||||
|
||||
@ -2,12 +2,16 @@
|
||||
#include "RISCv64ISel.h"
|
||||
#include "RISCv64RegAlloc.h"
|
||||
#include "RISCv64LinearScan.h"
|
||||
#include "RISCv64SimpleRegAlloc.h"
|
||||
#include "RISCv64BasicBlockAlloc.h"
|
||||
#include "RISCv64AsmPrinter.h"
|
||||
#include "RISCv64Passes.h"
|
||||
#include <sstream>
|
||||
#include <future>
|
||||
#include <chrono>
|
||||
#include <atomic>
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
#include <iostream>
|
||||
namespace sysy {
|
||||
|
||||
@ -197,8 +201,6 @@ std::string RISCv64CodeGen::module_gen() {
|
||||
}
|
||||
|
||||
std::string RISCv64CodeGen::function_gen(Function* func) {
|
||||
// === 完整的后端处理流水线 ===
|
||||
|
||||
// 阶段 1: 指令选择 (sysy::IR -> LLIR with virtual registers)
|
||||
RISCv64ISel isel;
|
||||
std::unique_ptr<MachineFunction> mfunc = isel.runOnFunction(func);
|
||||
@ -206,14 +208,13 @@ std::string RISCv64CodeGen::function_gen(Function* func) {
|
||||
std::stringstream ss_after_isel;
|
||||
RISCv64AsmPrinter printer_isel(mfunc.get());
|
||||
printer_isel.run(ss_after_isel, true);
|
||||
|
||||
// DEBUG = 1;
|
||||
if (DEBUG) {
|
||||
std::cerr << "====== Intermediate Representation after Instruction Selection ======\n"
|
||||
<< ss_after_isel.str();
|
||||
}
|
||||
|
||||
// DEBUG = 0;
|
||||
// 阶段 2: 消除帧索引 (展开伪指令,计算局部变量偏移)
|
||||
// 这个Pass必须在寄存器分配之前运行
|
||||
EliminateFrameIndicesPass efi_pass;
|
||||
efi_pass.runOnMachineFunction(mfunc.get());
|
||||
|
||||
@ -226,79 +227,106 @@ std::string RISCv64CodeGen::function_gen(Function* func) {
|
||||
<< ss_after_eli.str();
|
||||
}
|
||||
|
||||
// // 阶段 2: 除法强度削弱优化 (Division Strength Reduction)
|
||||
// DivStrengthReduction div_strength_reduction;
|
||||
// div_strength_reduction.runOnMachineFunction(mfunc.get());
|
||||
// 阶段 2.1: 除法强度削弱优化 (Division Strength Reduction)
|
||||
DivStrengthReduction div_strength_reduction;
|
||||
div_strength_reduction.runOnMachineFunction(mfunc.get());
|
||||
|
||||
// // 阶段 2.1: 指令调度 (Instruction Scheduling)
|
||||
// // 阶段 2.2: 指令调度 (Instruction Scheduling)
|
||||
// PreRA_Scheduler scheduler;
|
||||
// scheduler.runOnMachineFunction(mfunc.get());
|
||||
|
||||
// 阶段 3: 物理寄存器分配 (Register Allocation)
|
||||
bool allocation_succeeded = false;
|
||||
|
||||
// 首先尝试图着色分配器
|
||||
if (DEBUG) std::cerr << "Attempting Register Allocation with Graph Coloring...\n";
|
||||
if (!gc_failed) {
|
||||
RISCv64RegAlloc gc_alloc(mfunc.get());
|
||||
|
||||
bool success_gc = gc_alloc.run();
|
||||
|
||||
if (!success_gc) {
|
||||
gc_failed = 1; // 后续不再尝试图着色分配器
|
||||
std::cerr << "Warning: Graph coloring register allocation failed function '"
|
||||
<< func->getName()
|
||||
<< "'. Switching to Linear Scan allocator."
|
||||
<< std::endl;
|
||||
|
||||
RISCv64ISel isel_gc_fallback;
|
||||
mfunc = isel_gc_fallback.runOnFunction(func);
|
||||
EliminateFrameIndicesPass efi_pass_gc_fallback;
|
||||
efi_pass_gc_fallback.runOnMachineFunction(mfunc.get());
|
||||
RISCv64LinearScan ls_alloc(mfunc.get());
|
||||
bool success = ls_alloc.run();
|
||||
if (!success) {
|
||||
// 如果线性扫描最终失败,则调用基本块分配器作为终极后备
|
||||
std::cerr << "Info: Linear Scan failed. Switching to Basic Block Allocator as final fallback.\n";
|
||||
|
||||
// 注意:我们需要在一个“干净”的MachineFunction上运行。
|
||||
// 最安全的方式是重新运行指令选择。
|
||||
RISCv64ISel isel_fallback;
|
||||
mfunc = isel_fallback.runOnFunction(func);
|
||||
EliminateFrameIndicesPass efi_pass_fallback;
|
||||
efi_pass_fallback.runOnMachineFunction(mfunc.get());
|
||||
if (DEBUG) {
|
||||
std::cerr << "====== stack info after reg alloc ======\n";
|
||||
// 尝试迭代图着色 (IRC)
|
||||
if (!irc_failed) {
|
||||
if (DEBUG) std::cerr << "Attempting Register Allocation with Iterated Register Coloring (IRC)...\n";
|
||||
RISCv64RegAlloc irc_alloc(mfunc.get());
|
||||
auto stop_flag = std::make_shared<std::atomic<bool>>(false);
|
||||
auto future = std::async(std::launch::async, &RISCv64RegAlloc::run, &irc_alloc, stop_flag);
|
||||
std::future_status status = future.wait_for(std::chrono::seconds(25));
|
||||
bool success_irc = false;
|
||||
if (status == std::future_status::ready) {
|
||||
try {
|
||||
if (future.get()) {
|
||||
success_irc = true;
|
||||
} else {
|
||||
std::cerr << "Warning: IRC explicitly returned failure for function '" << func->getName() << "'.\n";
|
||||
}
|
||||
RISCv64BasicBlockAlloc bb_alloc(mfunc.get());
|
||||
bb_alloc.run();
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "Error: IRC allocation threw an exception: " << e.what() << std::endl;
|
||||
}
|
||||
} else if (status == std::future_status::timeout) {
|
||||
std::cerr << "Warning: IRC allocation timed out after 25 seconds. Requesting cancellation...\n";
|
||||
stop_flag->store(true);
|
||||
try {
|
||||
future.get();
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "Exception occurred during IRC thread shutdown after timeout: " << e.what() << std::endl;
|
||||
}
|
||||
} else {
|
||||
// 图着色成功完成
|
||||
if (DEBUG) std::cerr << "Graph Coloring allocation completed successfully.\n";
|
||||
}
|
||||
} else {
|
||||
std::cerr << "Info: Graph Coloring allocation failed in last function. Switching to Linear Scan allocator...\n";
|
||||
RISCv64LinearScan ls_alloc(mfunc.get());
|
||||
bool success = ls_alloc.run();
|
||||
if (!success) {
|
||||
// 如果线性扫描最终失败,则调用基本块分配器作为终极后备
|
||||
std::cerr << "Info: Linear Scan failed. Switching to Basic Block Allocator as final fallback.\n";
|
||||
|
||||
// 注意:我们需要在一个“干净”的MachineFunction上运行。
|
||||
// 最安全的方式是重新运行指令选择。
|
||||
RISCv64ISel isel_fallback;
|
||||
mfunc = isel_fallback.runOnFunction(func);
|
||||
EliminateFrameIndicesPass efi_pass_fallback;
|
||||
efi_pass_fallback.runOnMachineFunction(mfunc.get());
|
||||
if (DEBUG) {
|
||||
std::cerr << "====== stack info after reg alloc ======\n";
|
||||
}
|
||||
RISCv64BasicBlockAlloc bb_alloc(mfunc.get());
|
||||
bb_alloc.run();
|
||||
|
||||
if (success_irc) {
|
||||
allocation_succeeded = true;
|
||||
if (DEBUG) std::cerr << "IRC allocation succeeded.\n";
|
||||
} else {
|
||||
std::cerr << "Info: Blacklisting IRC for subsequent functions and falling back.\n";
|
||||
irc_failed = true;
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试简单图着色 (SGC)
|
||||
if (!allocation_succeeded) {
|
||||
// 如果是从IRC失败回退过来的,需要重新创建干净的mfunc和ISel
|
||||
RISCv64ISel isel_for_sgc;
|
||||
if (irc_failed) {
|
||||
if (DEBUG) std::cerr << "Info: Resetting MachineFunction for SGC attempt.\n";
|
||||
mfunc = isel_for_sgc.runOnFunction(func);
|
||||
EliminateFrameIndicesPass efi_pass_for_sgc;
|
||||
efi_pass_for_sgc.runOnMachineFunction(mfunc.get());
|
||||
}
|
||||
|
||||
if (DEBUG) std::cerr << "Attempting Register Allocation with Simple Graph Coloring (SGC)...\n";
|
||||
|
||||
bool sgc_completed_in_time = false;
|
||||
{
|
||||
RISCv64SimpleRegAlloc sgc_alloc(mfunc.get());
|
||||
auto future = std::async(std::launch::async, &RISCv64SimpleRegAlloc::run, &sgc_alloc);
|
||||
std::future_status status = future.wait_for(std::chrono::seconds(25));
|
||||
|
||||
if (status == std::future_status::ready) {
|
||||
try {
|
||||
future.get(); // 检查是否有异常
|
||||
sgc_completed_in_time = true;
|
||||
if (DEBUG) std::cerr << "SGC allocation completed successfully within the time limit.\n";
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "Error: SGC allocation threw an exception: " << e.what() << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (sgc_completed_in_time) {
|
||||
allocation_succeeded = true;
|
||||
} else {
|
||||
std::cerr << "Warning: SGC allocation timed out or failed for function '" << func->getName()
|
||||
<< "'. Falling back.\n";
|
||||
}
|
||||
}
|
||||
|
||||
// 如果都失败了,则使用基本块分配器 (BBA)
|
||||
if (!allocation_succeeded) {
|
||||
// 为BBA准备干净的mfunc和ISel
|
||||
std::cerr << "Info: Resetting MachineFunction for BBA fallback.\n";
|
||||
RISCv64ISel isel_for_bba;
|
||||
mfunc = isel_for_bba.runOnFunction(func);
|
||||
EliminateFrameIndicesPass efi_pass_for_bba;
|
||||
efi_pass_for_bba.runOnMachineFunction(mfunc.get());
|
||||
|
||||
std::cerr << "Info: Using Basic Block Allocator as final fallback.\n";
|
||||
RISCv64BasicBlockAlloc bb_alloc(mfunc.get());
|
||||
bb_alloc.run();
|
||||
}
|
||||
|
||||
if (DEBUG) {
|
||||
std::cerr << "====== stack info after reg alloc ======\n";
|
||||
mfunc->dumpStackFrameInfo(std::cerr);
|
||||
@ -317,9 +345,9 @@ std::string RISCv64CodeGen::function_gen(Function* func) {
|
||||
PeepholeOptimizer peephole;
|
||||
peephole.runOnMachineFunction(mfunc.get());
|
||||
|
||||
// 阶段 5: 局部指令调度 (Local Scheduling)
|
||||
PostRA_Scheduler local_scheduler;
|
||||
local_scheduler.runOnMachineFunction(mfunc.get());
|
||||
// // 阶段 5: 局部指令调度 (Local Scheduling)
|
||||
// PostRA_Scheduler local_scheduler;
|
||||
// local_scheduler.runOnMachineFunction(mfunc.get());
|
||||
|
||||
// 阶段 3.2: 插入序言和尾声
|
||||
PrologueEpilogueInsertionPass pei_pass;
|
||||
|
||||
@ -103,6 +103,60 @@ void RISCv64ISel::select() {
|
||||
}
|
||||
}
|
||||
|
||||
if (optLevel > 0) {
|
||||
if (F && !F->getBasicBlocks().empty()) {
|
||||
// 定位到第一个MachineBasicBlock,也就是函数入口
|
||||
BasicBlock* first_ir_block = F->getBasicBlocks_NoRange().front().get();
|
||||
CurMBB = bb_map.at(first_ir_block);
|
||||
|
||||
int int_arg_idx = 0;
|
||||
int fp_arg_idx = 0;
|
||||
|
||||
for (Argument* arg : F->getArguments()) {
|
||||
Type* arg_type = arg->getType();
|
||||
|
||||
// --- 处理整数/指针参数 ---
|
||||
if (!arg_type->isFloat() && int_arg_idx < 8) {
|
||||
// 1. 获取参数原始的、将被预着色为 a0-a7 的 vreg
|
||||
unsigned original_vreg = getVReg(arg);
|
||||
|
||||
// 2. 创建一个新的、安全的 vreg 来持有参数的值
|
||||
unsigned saved_vreg = getNewVReg(arg_type);
|
||||
|
||||
// 3. 生成 mv saved_vreg, original_vreg 指令
|
||||
auto mv = std::make_unique<MachineInstr>(RVOpcodes::MV);
|
||||
mv->addOperand(std::make_unique<RegOperand>(saved_vreg));
|
||||
mv->addOperand(std::make_unique<RegOperand>(original_vreg));
|
||||
CurMBB->addInstruction(std::move(mv));
|
||||
|
||||
// 4.【关键】更新vreg映射表,将arg的vreg指向新的、安全的vreg
|
||||
// 这样,后续所有对该参数的 getVReg(arg) 调用都会自动获得 saved_vreg,
|
||||
// 使得函数体内的代码都使用这个被保存过的值。
|
||||
vreg_map[arg] = saved_vreg;
|
||||
|
||||
int_arg_idx++;
|
||||
}
|
||||
// --- 处理浮点参数 ---
|
||||
else if (arg_type->isFloat() && fp_arg_idx < 8) {
|
||||
unsigned original_vreg = getVReg(arg);
|
||||
unsigned saved_vreg = getNewVReg(arg_type);
|
||||
|
||||
// 对于浮点数,使用 fmv.s 指令
|
||||
auto fmv = std::make_unique<MachineInstr>(RVOpcodes::FMV_S);
|
||||
fmv->addOperand(std::make_unique<RegOperand>(saved_vreg));
|
||||
fmv->addOperand(std::make_unique<RegOperand>(original_vreg));
|
||||
CurMBB->addInstruction(std::move(fmv));
|
||||
|
||||
// 同样更新映射
|
||||
vreg_map[arg] = saved_vreg;
|
||||
|
||||
fp_arg_idx++;
|
||||
}
|
||||
// 对于栈传递的参数,则无需处理
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 遍历基本块,进行指令选择
|
||||
for (const auto& bb_ptr : F->getBasicBlocks()) {
|
||||
selectBasicBlock(bb_ptr.get());
|
||||
@ -526,6 +580,22 @@ void RISCv64ISel::selectNode(DAGNode* node) {
|
||||
CurMBB->addInstruction(std::move(instr));
|
||||
break;
|
||||
}
|
||||
case BinaryInst::kSll: { // 逻辑左移
|
||||
auto instr = std::make_unique<MachineInstr>(RVOpcodes::SLLW);
|
||||
instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
|
||||
instr->addOperand(std::make_unique<RegOperand>(lhs_vreg));
|
||||
instr->addOperand(std::make_unique<RegOperand>(rhs_vreg));
|
||||
CurMBB->addInstruction(std::move(instr));
|
||||
break;
|
||||
}
|
||||
case BinaryInst::kSrl: { // 逻辑右移
|
||||
auto instr = std::make_unique<MachineInstr>(RVOpcodes::SRLW);
|
||||
instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
|
||||
instr->addOperand(std::make_unique<RegOperand>(lhs_vreg));
|
||||
instr->addOperand(std::make_unique<RegOperand>(rhs_vreg));
|
||||
CurMBB->addInstruction(std::move(instr));
|
||||
break;
|
||||
}
|
||||
case BinaryInst::kICmpEQ: { // 等于 (a == b) -> (subw; seqz)
|
||||
auto sub = std::make_unique<MachineInstr>(RVOpcodes::SUBW);
|
||||
sub->addOperand(std::make_unique<RegOperand>(dest_vreg));
|
||||
@ -582,7 +652,7 @@ void RISCv64ISel::selectNode(DAGNode* node) {
|
||||
CurMBB->addInstruction(std::move(xori));
|
||||
break;
|
||||
}
|
||||
case BinaryInst::kICmpGE: { // 大于等于 (a >= b) -> !(a < b) -> (slt; xori)
|
||||
case BinaryInst::kICmpGE: { // 大于等于 (a >= b) -> !(a < b) -> (slt; xori)
|
||||
auto slt = std::make_unique<MachineInstr>(RVOpcodes::SLT);
|
||||
slt->addOperand(std::make_unique<RegOperand>(dest_vreg));
|
||||
slt->addOperand(std::make_unique<RegOperand>(lhs_vreg));
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
#include <cassert>
|
||||
#include <chrono>
|
||||
#include <thread>
|
||||
#include <atomic>
|
||||
|
||||
namespace sysy {
|
||||
|
||||
@ -47,7 +48,7 @@ RISCv64RegAlloc::RISCv64RegAlloc(MachineFunction* mfunc)
|
||||
}
|
||||
|
||||
// 主入口: 迭代运行分配算法直到无溢出
|
||||
bool RISCv64RegAlloc::run() {
|
||||
bool RISCv64RegAlloc::run(std::shared_ptr<std::atomic<bool>> stop_flag) {
|
||||
if (DEBUG) std::cerr << "===== LLIR Before Running Graph Coloring Register Allocation " << MFunc->getName() << " =====\n";
|
||||
std::stringstream ss_before_reg_alloc;
|
||||
if (DEBUG) {
|
||||
@ -68,6 +69,11 @@ bool RISCv64RegAlloc::run() {
|
||||
break;
|
||||
} else {
|
||||
rewriteProgram();
|
||||
if (stop_flag && stop_flag->load()) {
|
||||
// 如果从外部接收到停止信号
|
||||
std::cerr << "Info: IRC allocation cancelled due to timeout.\n";
|
||||
return false; // 提前退出,并返回失败
|
||||
}
|
||||
if (DEBUG) std::cerr << "--- Spilling detected, re-running allocation (iteration " << iteration << ") ---\n";
|
||||
|
||||
if (iteration >= MAX_ITERATIONS) {
|
||||
@ -121,20 +127,46 @@ void RISCv64RegAlloc::precolorByCallingConvention() {
|
||||
int int_arg_idx = 0;
|
||||
int float_arg_idx = 0;
|
||||
|
||||
for (Argument* arg : F->getArguments()) {
|
||||
unsigned vreg = ISel->getVReg(arg);
|
||||
|
||||
if (arg->getType()->isFloat()) {
|
||||
if (float_arg_idx < 8) { // fa0-fa7
|
||||
auto preg = static_cast<PhysicalReg>(static_cast<int>(PhysicalReg::F10) + float_arg_idx);
|
||||
color_map[vreg] = preg;
|
||||
float_arg_idx++;
|
||||
if (optLevel > 0)
|
||||
{
|
||||
for (const auto& pair : vreg_to_value_map) {
|
||||
unsigned vreg = pair.first;
|
||||
Value* val = pair.second;
|
||||
|
||||
// 检查这个 Value* 是不是一个 Argument 对象
|
||||
if (auto arg = dynamic_cast<Argument*>(val)) {
|
||||
// 如果是,那么 vreg 就是最初分配给这个参数的 vreg
|
||||
int arg_idx = arg->getIndex();
|
||||
|
||||
if (arg->getType()->isFloat()) {
|
||||
if (arg_idx < 8) { // fa0-fa7
|
||||
auto preg = static_cast<PhysicalReg>(static_cast<int>(PhysicalReg::F10) + arg_idx);
|
||||
color_map[vreg] = preg;
|
||||
}
|
||||
} else { // 整数或指针
|
||||
if (arg_idx < 8) { // a0-a7
|
||||
auto preg = static_cast<PhysicalReg>(static_cast<int>(PhysicalReg::A0) + arg_idx);
|
||||
color_map[vreg] = preg;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else { // 整数或指针
|
||||
if (int_arg_idx < 8) { // a0-a7
|
||||
auto preg = static_cast<PhysicalReg>(static_cast<int>(PhysicalReg::A0) + int_arg_idx);
|
||||
color_map[vreg] = preg;
|
||||
int_arg_idx++;
|
||||
}
|
||||
} else {
|
||||
for (Argument* arg : F->getArguments()) {
|
||||
unsigned vreg = ISel->getVReg(arg);
|
||||
|
||||
if (arg->getType()->isFloat()) {
|
||||
if (float_arg_idx < 8) { // fa0-fa7
|
||||
auto preg = static_cast<PhysicalReg>(static_cast<int>(PhysicalReg::F10) + float_arg_idx);
|
||||
color_map[vreg] = preg;
|
||||
float_arg_idx++;
|
||||
}
|
||||
} else { // 整数或指针
|
||||
if (int_arg_idx < 8) { // a0-a7
|
||||
auto preg = static_cast<PhysicalReg>(static_cast<int>(PhysicalReg::A0) + int_arg_idx);
|
||||
color_map[vreg] = preg;
|
||||
int_arg_idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -471,16 +503,18 @@ void RISCv64RegAlloc::coalesce() {
|
||||
unsigned x = getAlias(*def.begin());
|
||||
unsigned y = getAlias(*use.begin());
|
||||
unsigned u, v;
|
||||
if (precolored.count(y)) { u = y; v = x; } else { u = x; v = y; }
|
||||
|
||||
// 进一步修正:标准化u和v的逻辑,必须同时考虑物理寄存器和已预着色的虚拟寄存器。
|
||||
// 目标是确保如果两个操作数中有一个是预着色的,它一定会被赋给 u。
|
||||
if (precolored.count(y) || coloredNodes.count(y)) {
|
||||
u = y; v = x;
|
||||
} else {
|
||||
u = x; v = y;
|
||||
}
|
||||
|
||||
// 防御性检查,处理物理寄存器之间的传送指令
|
||||
if (precolored.count(u) && precolored.count(v)) {
|
||||
// 如果 u 和 v 都是物理寄存器,我们不能合并它们。
|
||||
// 这通常是一条寄存器拷贝指令,例如 `mv a2, a1`。
|
||||
// 把它加入 constrainedMoves 列表,然后直接返回,不再处理。
|
||||
constrainedMoves.insert(move);
|
||||
// addWorklist(u) 和 addWorklist(v) 在这里也不需要调用,
|
||||
// 因为它们只对虚拟寄存器有意义。
|
||||
return;
|
||||
}
|
||||
|
||||
@ -492,7 +526,7 @@ void RISCv64RegAlloc::coalesce() {
|
||||
if (DEEPERDEBUG) std::cerr << " -> Trivial coalesce (u == v).\n";
|
||||
coalescedMoves.insert(move);
|
||||
addWorklist(u);
|
||||
return; // 处理完毕,提前返回
|
||||
return;
|
||||
}
|
||||
|
||||
if (isFPVReg(u) != isFPVReg(v)) {
|
||||
@ -502,10 +536,13 @@ void RISCv64RegAlloc::coalesce() {
|
||||
constrainedMoves.insert(move);
|
||||
addWorklist(u);
|
||||
addWorklist(v);
|
||||
return; // 立即返回,不再进行后续检查
|
||||
return;
|
||||
}
|
||||
|
||||
bool pre_interfere = adjList.at(v).count(u);
|
||||
// 注意:如果v已经是u的邻居, pre_interfere 会为true。
|
||||
// 但如果v不在adjList中(例如v是预着色节点),我们需要检查u是否在v的邻居中。
|
||||
// 为了简化,我们假设adjList包含了所有虚拟寄存器。对于(Phys, Virt)对,冲突信息存储在Virt节点的邻接表中。
|
||||
bool pre_interfere = (adjList.count(v) && adjList.at(v).count(u)) || (adjList.count(u) && adjList.at(u).count(v));
|
||||
|
||||
if (pre_interfere) {
|
||||
if (DEEPERDEBUG) std::cerr << " -> Constrained (nodes already interfere).\n";
|
||||
@ -515,63 +552,50 @@ void RISCv64RegAlloc::coalesce() {
|
||||
return;
|
||||
}
|
||||
|
||||
bool is_u_precolored = precolored.count(u);
|
||||
// 考虑物理寄存器和已预着色的虚拟寄存器
|
||||
bool u_is_effectively_precolored = precolored.count(u) || coloredNodes.count(u);
|
||||
bool can_coalesce = false;
|
||||
|
||||
if (is_u_precolored) {
|
||||
// --- 场景1:u是物理寄存器,使用 George 启发式 ---
|
||||
if (DEEPERDEBUG) std::cerr << " -> Trying George Heuristic (u is precolored)...\n";
|
||||
if (u_is_effectively_precolored) {
|
||||
// --- 场景1:u是物理寄存器或已预着色虚拟寄存器,使用 George 启发式 ---
|
||||
if (DEEPERDEBUG) std::cerr << " -> Trying George Heuristic (u is effectively precolored)...\n";
|
||||
|
||||
// 步骤 1: 独立调用 adjacent(v) 获取邻居集合
|
||||
VRegSet neighbors_of_v = adjacent(v);
|
||||
if (DEEPERDEBUG) {
|
||||
std::cerr << " - Neighbors of " << regIdToString(v) << " to check are (" << neighbors_of_v.size() << "): { ";
|
||||
for (unsigned id : neighbors_of_v) std::cerr << regIdToString(id) << " ";
|
||||
std::cerr << "}\n";
|
||||
}
|
||||
|
||||
// 步骤 2: 使用显式的 for 循环来代替 std::all_of
|
||||
bool george_ok = true; // 默认假设成功,任何一个邻居失败都会将此设为 false
|
||||
|
||||
bool george_ok = true;
|
||||
for (unsigned t : neighbors_of_v) {
|
||||
if (DEEPERDEBUG) {
|
||||
std::cerr << " - Checking neighbor " << regIdToString(t) << ":\n";
|
||||
}
|
||||
if (DEEPERDEBUG) std::cerr << " - Checking neighbor " << regIdToString(t) << ":\n";
|
||||
|
||||
// 步骤 3: 独立调用启发式函数
|
||||
bool heuristic_result = georgeHeuristic(t, u);
|
||||
unsigned u_phys_id = precolored.count(u) ? u : (static_cast<unsigned>(PhysicalReg::PHYS_REG_START_ID) + static_cast<unsigned>(color_map.at(u)));
|
||||
bool heuristic_result = georgeHeuristic(t, u_phys_id);
|
||||
|
||||
if (DEEPERDEBUG) {
|
||||
std::cerr << " - georgeHeuristic(" << regIdToString(t) << ", " << regIdToString(u) << ") -> " << (heuristic_result ? "OK" : "FAIL") << "\n";
|
||||
std::cerr << " - georgeHeuristic(" << regIdToString(t) << ", " << regIdToString(u_phys_id) << ") -> " << (heuristic_result ? "OK" : "FAIL") << "\n";
|
||||
}
|
||||
|
||||
if (!heuristic_result) {
|
||||
george_ok = false; // 只要有一个邻居不满足条件,整个检查就失败
|
||||
break; // 并且可以立即停止检查其他邻居
|
||||
george_ok = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (DEEPERDEBUG) {
|
||||
std::cerr << " -> George Heuristic final result: " << (george_ok ? "OK" : "FAIL") << "\n";
|
||||
}
|
||||
|
||||
if (george_ok) {
|
||||
can_coalesce = true;
|
||||
}
|
||||
if (DEEPERDEBUG) std::cerr << " -> George Heuristic final result: " << (george_ok ? "OK" : "FAIL") << "\n";
|
||||
if (george_ok) can_coalesce = true;
|
||||
|
||||
} else {
|
||||
// --- 场景2:u和v都是虚拟寄存器,使用 Briggs 启发式 ---
|
||||
// --- 场景2:u和v都是未着色的虚拟寄存器,使用 Briggs 启发式 ---
|
||||
if (DEEPERDEBUG) std::cerr << " -> Trying Briggs Heuristic (u and v are virtual)...\n";
|
||||
|
||||
bool briggs_ok = briggsHeuristic(u, v);
|
||||
if (DEEPERDEBUG) std::cerr << " - briggsHeuristic(" << regIdToString(u) << ", " << regIdToString(v) << ") -> " << (briggs_ok ? "OK" : "FAIL") << "\n";
|
||||
|
||||
if (briggs_ok) {
|
||||
can_coalesce = true;
|
||||
}
|
||||
if (briggs_ok) can_coalesce = true;
|
||||
}
|
||||
|
||||
// --- 根据启发式结果进行最终决策 ---
|
||||
|
||||
if (can_coalesce) {
|
||||
if (DEEPERDEBUG) std::cerr << " -> Heuristic OK. Combining " << regIdToString(v) << " into " << regIdToString(u) << ".\n";
|
||||
coalescedMoves.insert(move);
|
||||
@ -1127,7 +1151,7 @@ unsigned RISCv64RegAlloc::getAlias(unsigned n) {
|
||||
}
|
||||
|
||||
void RISCv64RegAlloc::addWorklist(unsigned u) {
|
||||
if (precolored.count(u)) return;
|
||||
if (precolored.count(u) || color_map.count(u)) return;
|
||||
|
||||
int K = isFPVReg(u) ? K_fp : K_int;
|
||||
if (!moveRelated(u) && degree.at(u) < K) {
|
||||
@ -1202,8 +1226,12 @@ bool RISCv64RegAlloc::georgeHeuristic(unsigned t, unsigned u) {
|
||||
}
|
||||
|
||||
int K = isFPVReg(t) ? K_fp : K_int;
|
||||
// adjList.at(t) 现在是安全的,因为 degree.count(t) > 0 保证了 adjList.count(t) > 0
|
||||
return degree.at(t) < K || precolored.count(u) || adjList.at(t).count(u);
|
||||
|
||||
// 缺陷 #2 修正: 移除了致命的 || precolored.count(u) 条件。
|
||||
// 在此函数的上下文中,u 总是预着色的物理寄存器ID,导致旧的条件永远为true,使整个启发式失效。
|
||||
// 正确的逻辑是检查:邻居t的度数是否小于K,或者t是否已经与u冲突。
|
||||
// return degree.at(t) < K || adjList.at(t).count(u);
|
||||
return degree.at(t) < K || !adjList.at(t).count(u);
|
||||
}
|
||||
|
||||
void RISCv64RegAlloc::combine(unsigned u, unsigned v) {
|
||||
@ -1251,7 +1279,7 @@ void RISCv64RegAlloc::freezeMoves(unsigned u) {
|
||||
activeMoves.erase(move);
|
||||
frozenMoves.insert(move);
|
||||
|
||||
if (!precolored.count(v_alias) && nodeMoves(v_alias).empty() && degree.at(v_alias) < (isFPVReg(v_alias) ? K_fp : K_int)) {
|
||||
if (!precolored.count(v_alias) && !coloredNodes.count(v_alias) && nodeMoves(v_alias).empty() && degree.at(v_alias) < (isFPVReg(v_alias) ? K_fp : K_int)) {
|
||||
freezeWorklist.erase(v_alias);
|
||||
simplifyWorklist.insert(v_alias);
|
||||
if (DEEPERDEBUG) {
|
||||
|
||||
716
src/backend/RISCv64/RISCv64SimpleRegAlloc.cpp
Normal file
716
src/backend/RISCv64/RISCv64SimpleRegAlloc.cpp
Normal file
@ -0,0 +1,716 @@
|
||||
#include "RISCv64SimpleRegAlloc.h"
|
||||
#include "RISCv64AsmPrinter.h"
|
||||
#include "RISCv64Info.h"
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <cassert>
|
||||
|
||||
// 外部调试级别控制变量的定义
|
||||
// 假设这些变量在其他地方定义,例如主程序或一个通用的cpp文件
|
||||
extern int DEBUG;
|
||||
extern int DEEPDEBUG;
|
||||
|
||||
namespace sysy {
|
||||
|
||||
RISCv64SimpleRegAlloc::RISCv64SimpleRegAlloc(MachineFunction* mfunc) : MFunc(mfunc), ISel(mfunc->getISel()) {
|
||||
// 1. 初始化可分配的整数寄存器池
|
||||
// T5 被大立即数传送逻辑保留
|
||||
// T2, T3, T4 被本分配器保留为专用的溢出/临时寄存器
|
||||
allocable_int_regs = {
|
||||
PhysicalReg::T0, PhysicalReg::T1, /* T2,T3,T4,T5,T6 reserved */
|
||||
PhysicalReg::A0, PhysicalReg::A1, PhysicalReg::A2, PhysicalReg::A3, PhysicalReg::A4, PhysicalReg::A5, PhysicalReg::A6, PhysicalReg::A7,
|
||||
PhysicalReg::S1, PhysicalReg::S2, PhysicalReg::S3, PhysicalReg::S4, PhysicalReg::S5, PhysicalReg::S6, PhysicalReg::S7,
|
||||
PhysicalReg::S8, PhysicalReg::S9, PhysicalReg::S10, PhysicalReg::S11,
|
||||
};
|
||||
|
||||
// 2. 初始化可分配的浮点寄存器池
|
||||
// F0, F1, F2 被本分配器保留为专用的溢出/临时寄存器
|
||||
allocable_fp_regs = {
|
||||
/* F0,F1,F2 reserved */ PhysicalReg::F3, PhysicalReg::F4, PhysicalReg::F5, PhysicalReg::F6, PhysicalReg::F7,
|
||||
PhysicalReg::F10, PhysicalReg::F11, PhysicalReg::F12, PhysicalReg::F13, PhysicalReg::F14, PhysicalReg::F15, PhysicalReg::F16, PhysicalReg::F17,
|
||||
PhysicalReg::F8, PhysicalReg::F9, PhysicalReg::F18, PhysicalReg::F19, PhysicalReg::F20, PhysicalReg::F21, PhysicalReg::F22,
|
||||
PhysicalReg::F23, PhysicalReg::F24, PhysicalReg::F25, PhysicalReg::F26, PhysicalReg::F27,
|
||||
PhysicalReg::F28, PhysicalReg::F29, PhysicalReg::F30, PhysicalReg::F31,
|
||||
};
|
||||
|
||||
// 3. 映射所有物理寄存器到特殊的虚拟寄存器ID (保持不变)
|
||||
const unsigned offset = static_cast<unsigned>(PhysicalReg::PHYS_REG_START_ID);
|
||||
for (unsigned i = 0; i < static_cast<unsigned>(PhysicalReg::INVALID); ++i) {
|
||||
auto preg = static_cast<PhysicalReg>(i);
|
||||
preg_to_vreg_id_map[preg] = offset + i;
|
||||
}
|
||||
}
|
||||
|
||||
// 寄存器分配的主入口点
|
||||
void RISCv64SimpleRegAlloc::run() {
|
||||
if (DEBUG) std::cerr << "===== Running Simple Graph Coloring Allocator for function: " << MFunc->getName() << " =====\n";
|
||||
|
||||
// 实例化一个AsmPrinter用于调试输出,避免重复创建
|
||||
RISCv64AsmPrinter printer(MFunc);
|
||||
printer.setStream(std::cerr);
|
||||
|
||||
if (DEBUG) {
|
||||
std::cerr << "\n===== LLIR after VReg Unification =====\n";
|
||||
printer.run(std::cerr, true);
|
||||
std::cerr << "===== End of Unified LLIR =====\n\n";
|
||||
}
|
||||
|
||||
// 阶段 1: 处理函数调用约定(参数寄存器预着色)
|
||||
handleCallingConvention();
|
||||
if (DEBUG) {
|
||||
std::cerr << "--- After HandleCallingConvention ---\n";
|
||||
std::cerr << "Pre-colored vregs:\n";
|
||||
for (const auto& pair : color_map) {
|
||||
std::cerr << " %vreg" << pair.first << " -> " << printer.regToString(pair.second) << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
// 阶段 2: 活跃性分析
|
||||
analyzeLiveness();
|
||||
|
||||
// 阶段 3: 构建干扰图
|
||||
buildInterferenceGraph();
|
||||
|
||||
// 阶段 4: 图着色算法分配物理寄存器
|
||||
colorGraph();
|
||||
if (DEBUG) {
|
||||
std::cerr << "\n--- After GraphColoring ---\n";
|
||||
std::cerr << "Assigned colors:\n";
|
||||
for (const auto& pair : color_map) {
|
||||
std::cerr << " %vreg" << pair.first << " -> " << printer.regToString(pair.second) << "\n";
|
||||
}
|
||||
std::cerr << "Spilled vregs:\n";
|
||||
if (spilled_vregs.empty()) {
|
||||
std::cerr << " (None)\n";
|
||||
} else {
|
||||
for (unsigned vreg : spilled_vregs) {
|
||||
std::cerr << " %vreg" << vreg << "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 阶段 5: 重写函数(插入溢出/填充代码,替换虚拟寄存器为物理寄存器)
|
||||
rewriteFunction();
|
||||
|
||||
// 将最终的寄存器分配结果保存到MachineFunction的帧信息中,供后续Pass使用
|
||||
MFunc->getFrameInfo().vreg_to_preg_map = this->color_map;
|
||||
|
||||
if (DEBUG) {
|
||||
std::cerr << "\n===== Final LLIR after Simple Register Allocation =====\n";
|
||||
printer.run(std::cerr, false); // 使用false来打印最终的物理寄存器
|
||||
std::cerr << "===== Finished Simple Graph Coloring Allocator =====\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief [新增] 虚拟寄存器统一预处理
|
||||
* 扫描函数,找到通过栈帧传递的参数,并将后续从该栈帧加载的VReg统一为原始的参数VReg。
|
||||
*/
|
||||
void RISCv64SimpleRegAlloc::unifyArgumentVRegs() {
|
||||
if (MFunc->getBlocks().size() < 2) return; // 至少需要入口和函数体两个块
|
||||
|
||||
std::map<int, unsigned> stack_slot_to_vreg; // 映射: <栈偏移, 原始参数vreg>
|
||||
MachineBasicBlock* entry_block = MFunc->getBlocks().front().get();
|
||||
|
||||
// 步骤 1: 扫描入口块,找到所有参数的“家(home)”在栈上的位置
|
||||
for (const auto& instr : entry_block->getInstructions()) {
|
||||
// 我们寻找 sw %vreg_arg, 0(%vreg_addr) 的模式
|
||||
if (instr->getOpcode() == RVOpcodes::SW || instr->getOpcode() == RVOpcodes::SD || instr->getOpcode() == RVOpcodes::FSW) {
|
||||
auto& operands = instr->getOperands();
|
||||
if (operands.size() == 2 && operands[0]->getKind() == MachineOperand::KIND_REG && operands[1]->getKind() == MachineOperand::KIND_MEM) {
|
||||
auto src_reg_op = static_cast<RegOperand*>(operands[0].get());
|
||||
auto mem_op = static_cast<MemOperand*>(operands[1].get());
|
||||
unsigned addr_vreg = mem_op->getBase()->getVRegNum();
|
||||
|
||||
// 查找定义这个地址vreg的addi指令,以获取偏移量
|
||||
for (const auto& prev_instr : entry_block->getInstructions()) {
|
||||
if (prev_instr->getOpcode() == RVOpcodes::ADDI && prev_instr->getOperands().front()->getKind() == MachineOperand::KIND_REG) {
|
||||
auto def_op = static_cast<RegOperand*>(prev_instr->getOperands().front().get());
|
||||
if (def_op->isVirtual() && def_op->getVRegNum() == addr_vreg) {
|
||||
int offset = static_cast<ImmOperand*>(prev_instr->getOperands()[2].get())->getValue();
|
||||
stack_slot_to_vreg[offset] = src_reg_op->getVRegNum();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (stack_slot_to_vreg.empty()) return; // 没有找到参数存储,无需处理
|
||||
|
||||
// 步骤 2: 扫描函数体,构建本地vreg到参数vreg的重映射表
|
||||
std::map<unsigned, unsigned> vreg_remap; // 映射: <本地vreg, 原始参数vreg>
|
||||
MachineBasicBlock* body_block = MFunc->getBlocks()[1].get();
|
||||
|
||||
for (const auto& instr : body_block->getInstructions()) {
|
||||
if (instr->getOpcode() == RVOpcodes::LW || instr->getOpcode() == RVOpcodes::LD || instr->getOpcode() == RVOpcodes::FLW) {
|
||||
auto& operands = instr->getOperands();
|
||||
if (operands.size() == 2 && operands[0]->getKind() == MachineOperand::KIND_REG && operands[1]->getKind() == MachineOperand::KIND_MEM) {
|
||||
auto dest_reg_op = static_cast<RegOperand*>(operands[0].get());
|
||||
auto mem_op = static_cast<MemOperand*>(operands[1].get());
|
||||
unsigned addr_vreg = mem_op->getBase()->getVRegNum();
|
||||
|
||||
// 同样地,查找定义地址的addi指令
|
||||
for (const auto& prev_instr : body_block->getInstructions()) {
|
||||
if (prev_instr->getOpcode() == RVOpcodes::ADDI && prev_instr->getOperands().front()->getKind() == MachineOperand::KIND_REG) {
|
||||
auto def_op = static_cast<RegOperand*>(prev_instr->getOperands().front().get());
|
||||
if (def_op->isVirtual() && def_op->getVRegNum() == addr_vreg) {
|
||||
int offset = static_cast<ImmOperand*>(prev_instr->getOperands()[2].get())->getValue();
|
||||
if (stack_slot_to_vreg.count(offset)) {
|
||||
unsigned old_vreg = dest_reg_op->getVRegNum();
|
||||
unsigned new_vreg = stack_slot_to_vreg.at(offset);
|
||||
vreg_remap[old_vreg] = new_vreg;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (vreg_remap.empty()) return;
|
||||
|
||||
// 步骤 3: 遍历所有指令,应用重映射
|
||||
// 定义一个lambda函数来替换vreg,避免代码重复
|
||||
auto replace_vreg_in_operand = [&](MachineOperand* op) {
|
||||
if (op->getKind() == MachineOperand::KIND_REG) {
|
||||
auto reg_op = static_cast<RegOperand*>(op);
|
||||
if (reg_op->isVirtual() && vreg_remap.count(reg_op->getVRegNum())) {
|
||||
reg_op->setVRegNum(vreg_remap.at(reg_op->getVRegNum()));
|
||||
}
|
||||
} else if (op->getKind() == MachineOperand::KIND_MEM) {
|
||||
auto base_reg_op = static_cast<MemOperand*>(op)->getBase();
|
||||
if (base_reg_op->isVirtual() && vreg_remap.count(base_reg_op->getVRegNum())) {
|
||||
base_reg_op->setVRegNum(vreg_remap.at(base_reg_op->getVRegNum()));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for (auto& mbb : MFunc->getBlocks()) {
|
||||
for (auto& instr : mbb->getInstructions()) {
|
||||
for (auto& op : instr->getOperands()) {
|
||||
replace_vreg_in_operand(op.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void RISCv64SimpleRegAlloc::handleCallingConvention() {
|
||||
Function* F = MFunc->getFunc();
|
||||
if (!F) return;
|
||||
|
||||
// --- 1. 处理函数传入参数的预着色 ---
|
||||
int int_arg_idx = 0;
|
||||
int float_arg_idx = 0;
|
||||
|
||||
for (Argument* arg : F->getArguments()) {
|
||||
unsigned vreg = ISel->getVReg(arg);
|
||||
if (arg->getType()->isFloat()) {
|
||||
if (float_arg_idx < 8) { // fa0-fa7
|
||||
auto preg = static_cast<PhysicalReg>(static_cast<int>(PhysicalReg::F10) + float_arg_idx);
|
||||
color_map[vreg] = preg;
|
||||
}
|
||||
float_arg_idx++;
|
||||
} else {
|
||||
if (int_arg_idx < 8) { // a0-a7
|
||||
auto preg = static_cast<PhysicalReg>(static_cast<int>(PhysicalReg::A0) + int_arg_idx);
|
||||
color_map[vreg] = preg;
|
||||
}
|
||||
int_arg_idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void RISCv64SimpleRegAlloc::analyzeLiveness() {
|
||||
if (DEBUG) std::cerr << "\n--- Starting Liveness Analysis ---\n";
|
||||
|
||||
// === 阶段 1: 预计算每个基本块的 use 和 def 集合 ===
|
||||
std::map<const MachineBasicBlock*, LiveSet> block_uses;
|
||||
std::map<const MachineBasicBlock*, LiveSet> block_defs;
|
||||
for (const auto& mbb_ptr : MFunc->getBlocks()) {
|
||||
const MachineBasicBlock* mbb = mbb_ptr.get();
|
||||
LiveSet uses, defs;
|
||||
for (const auto& instr_ptr : mbb->getInstructions()) {
|
||||
LiveSet instr_use, instr_def;
|
||||
getInstrUseDef_Liveness(instr_ptr.get(), instr_use, instr_def);
|
||||
// use[B] = use[B] U (instr_use - def[B])
|
||||
for (unsigned u : instr_use) {
|
||||
if (defs.find(u) == defs.end()) {
|
||||
uses.insert(u);
|
||||
}
|
||||
}
|
||||
// def[B] = def[B] U instr_def
|
||||
defs.insert(instr_def.begin(), instr_def.end());
|
||||
}
|
||||
block_uses[mbb] = uses;
|
||||
block_defs[mbb] = defs;
|
||||
}
|
||||
|
||||
// === 阶段 2: 在“块”粒度上进行迭代数据流分析,直到收敛 ===
|
||||
std::map<const MachineBasicBlock*, LiveSet> block_live_in;
|
||||
std::map<const MachineBasicBlock*, LiveSet> block_live_out;
|
||||
bool changed = true;
|
||||
while (changed) {
|
||||
changed = false;
|
||||
// 逆序遍历基本块,加速收敛
|
||||
for (auto it = MFunc->getBlocks().rbegin(); it != MFunc->getBlocks().rend(); ++it) {
|
||||
const auto& mbb_ptr = *it;
|
||||
const MachineBasicBlock* mbb = mbb_ptr.get();
|
||||
|
||||
// 2.1 计算 live_out[B] = U_{S in succ(B)} live_in[S]
|
||||
LiveSet new_live_out;
|
||||
for (auto succ : mbb->successors) {
|
||||
new_live_out.insert(block_live_in[succ].begin(), block_live_in[succ].end());
|
||||
}
|
||||
|
||||
// 2.2 计算 live_in[B] = use[B] U (live_out[B] - def[B])
|
||||
LiveSet live_out_minus_def = new_live_out;
|
||||
for (unsigned d : block_defs.at(mbb)) {
|
||||
live_out_minus_def.erase(d);
|
||||
}
|
||||
LiveSet new_live_in = block_uses.at(mbb);
|
||||
new_live_in.insert(live_out_minus_def.begin(), live_out_minus_def.end());
|
||||
|
||||
// 2.3 检查是否达到不动点
|
||||
if (block_live_out[mbb] != new_live_out || block_live_in[mbb] != new_live_in) {
|
||||
changed = true;
|
||||
block_live_out[mbb] = new_live_out;
|
||||
block_live_in[mbb] = new_live_in;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// === 阶段 3: 进行一次指令粒度的遍历,填充最终的 live_in_map 和 live_out_map ===
|
||||
for (const auto& mbb_ptr : MFunc->getBlocks()) {
|
||||
const MachineBasicBlock* mbb = mbb_ptr.get();
|
||||
LiveSet live_out = block_live_out.at(mbb);
|
||||
|
||||
for (auto instr_it = mbb->getInstructions().rbegin(); instr_it != mbb->getInstructions().rend(); ++instr_it) {
|
||||
const MachineInstr* instr = instr_it->get();
|
||||
live_out_map[instr] = live_out;
|
||||
|
||||
LiveSet use, def;
|
||||
getInstrUseDef_Liveness(instr, use, def);
|
||||
|
||||
LiveSet live_in = use;
|
||||
LiveSet diff = live_out;
|
||||
for (auto vreg : def) {
|
||||
diff.erase(vreg);
|
||||
}
|
||||
live_in.insert(diff.begin(), diff.end());
|
||||
live_in_map[instr] = live_in;
|
||||
|
||||
// 更新 live_out,为块内的上一条指令做准备
|
||||
live_out = live_in;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void RISCv64SimpleRegAlloc::buildInterferenceGraph() {
|
||||
if (DEBUG) std::cerr << "\n--- Starting Interference Graph Construction ---\n";
|
||||
RISCv64AsmPrinter printer(MFunc);
|
||||
printer.setStream(std::cerr);
|
||||
|
||||
// 1. 收集所有图中需要出现的节点 (所有虚拟寄存器和物理寄存器)
|
||||
std::set<unsigned> all_nodes;
|
||||
for (const auto& mbb : MFunc->getBlocks()) {
|
||||
for(const auto& instr : mbb->getInstructions()) {
|
||||
LiveSet use, def;
|
||||
getInstrUseDef_Liveness(instr.get(), use, def);
|
||||
all_nodes.insert(use.begin(), use.end());
|
||||
all_nodes.insert(def.begin(), def.end());
|
||||
}
|
||||
}
|
||||
// 确保所有物理寄存器节点也存在
|
||||
for (const auto& pair : preg_to_vreg_id_map) {
|
||||
all_nodes.insert(pair.second);
|
||||
}
|
||||
|
||||
// 2. 初始化干扰图邻接表
|
||||
for (unsigned vreg : all_nodes) { interference_graph[vreg] = {}; }
|
||||
|
||||
// 3. 遍历指令,添加冲突边
|
||||
for (const auto& mbb : MFunc->getBlocks()) {
|
||||
if (DEEPDEBUG) std::cerr << "--- Building Graph for Basic Block: " << mbb->getName() << " ---\n";
|
||||
for (const auto& instr_ptr : mbb->getInstructions()) {
|
||||
const MachineInstr* instr = instr_ptr.get();
|
||||
if (DEEPDEBUG) {
|
||||
std::cerr << " Instr: ";
|
||||
printer.printInstruction(const_cast<MachineInstr*>(instr), true);
|
||||
}
|
||||
|
||||
LiveSet def, use;
|
||||
getInstrUseDef_Liveness(instr, def, use); // 注意Use/Def顺序
|
||||
const LiveSet& live_out = live_out_map.at(instr);
|
||||
|
||||
if (DEEPDEBUG) {
|
||||
printLiveSet(use, "Use ", std::cerr, printer);
|
||||
printLiveSet(def, "Def ", std::cerr, printer);
|
||||
printLiveSet(live_out, "Live_Out", std::cerr, printer);
|
||||
}
|
||||
|
||||
// 规则1: 指令的定义(def)与该指令之后的所有活跃变量(live_out)冲突
|
||||
for (unsigned d : def) {
|
||||
for (unsigned l : live_out) {
|
||||
if (d != l) {
|
||||
if (DEEPDEBUG && interference_graph.at(d).find(l) == interference_graph.at(d).end()) {
|
||||
std::cerr << " Edge (Def-LiveOut): " << regIdToString(d, printer) << " <-> " << regIdToString(l, printer) << "\n";
|
||||
}
|
||||
interference_graph[d].insert(l);
|
||||
interference_graph[l].insert(d);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 规则2: 对于非MV指令, def与use也冲突
|
||||
if (instr->getOpcode() != RVOpcodes::MV) {
|
||||
for (unsigned d : def) {
|
||||
for (unsigned u : use) {
|
||||
if (d != u) {
|
||||
if (DEEPDEBUG && interference_graph.at(d).find(u) == interference_graph.at(d).end()) {
|
||||
std::cerr << " Edge (Def-Use): " << regIdToString(d, printer) << " <-> " << regIdToString(u, printer) << "\n";
|
||||
}
|
||||
interference_graph[d].insert(u);
|
||||
interference_graph[u].insert(d);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 所有在某一点上同时活跃的寄存器(即live_out集合中的所有成员),
|
||||
// 它们之间必须两两互相干扰。
|
||||
std::vector<unsigned> live_out_vec(live_out.begin(), live_out.end());
|
||||
for (size_t i = 0; i < live_out_vec.size(); ++i) {
|
||||
for (size_t j = i + 1; j < live_out_vec.size(); ++j) {
|
||||
unsigned u = live_out_vec[i];
|
||||
unsigned v = live_out_vec[j];
|
||||
if (DEEPDEBUG && interference_graph[u].find(v) == interference_graph[u].end()) {
|
||||
std::cerr << " Edge (Live-Live): %vreg" << u << " <-> %vreg" << v << "\n";
|
||||
}
|
||||
interference_graph[u].insert(v);
|
||||
interference_graph[v].insert(u);
|
||||
}
|
||||
}
|
||||
|
||||
// 规则3: CALL指令会破坏所有调用者保存(caller-saved)寄存器
|
||||
if (instr->getOpcode() == RVOpcodes::CALL) {
|
||||
const auto& caller_saved_int = getCallerSavedIntRegs();
|
||||
const auto& caller_saved_fp = getCallerSavedFpRegs();
|
||||
|
||||
for (unsigned live_vreg : live_out) {
|
||||
auto [type, size] = getTypeAndSize(live_vreg);
|
||||
if (type == Type::kFloat) {
|
||||
for (PhysicalReg cs_reg : caller_saved_fp) {
|
||||
unsigned cs_vreg_id = preg_to_vreg_id_map.at(cs_reg);
|
||||
if (live_vreg != cs_vreg_id) {
|
||||
interference_graph[live_vreg].insert(cs_vreg_id);
|
||||
interference_graph[cs_vreg_id].insert(live_vreg);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (PhysicalReg cs_reg : caller_saved_int) {
|
||||
unsigned cs_vreg_id = preg_to_vreg_id_map.at(cs_reg);
|
||||
if (live_vreg != cs_vreg_id) {
|
||||
interference_graph[live_vreg].insert(cs_vreg_id);
|
||||
interference_graph[cs_vreg_id].insert(live_vreg);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // end if CALL
|
||||
if (DEEPDEBUG) std::cerr << " ----------------\n";
|
||||
} // end for instr
|
||||
} // end for mbb
|
||||
}
|
||||
|
||||
void RISCv64SimpleRegAlloc::colorGraph() {
|
||||
// 1. 收集所有需要着色的虚拟寄存器
|
||||
std::vector<unsigned> vregs_to_color;
|
||||
for (auto const& [vreg, neighbors] : interference_graph) {
|
||||
// 只为未预着色的、真正的虚拟寄存器进行着色
|
||||
if (color_map.find(vreg) == color_map.end() && vreg < static_cast<unsigned>(PhysicalReg::PHYS_REG_START_ID)) {
|
||||
vregs_to_color.push_back(vreg);
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 按冲突度从高到低排序,进行贪心着色
|
||||
std::sort(vregs_to_color.begin(), vregs_to_color.end(), [&](unsigned a, unsigned b) {
|
||||
return interference_graph.at(a).size() > interference_graph.at(b).size();
|
||||
});
|
||||
|
||||
// 3. 遍历并着色
|
||||
for (unsigned vreg : vregs_to_color) {
|
||||
std::set<PhysicalReg> used_colors;
|
||||
// 收集所有邻居的颜色
|
||||
for (unsigned neighbor_id : interference_graph.at(vreg)) {
|
||||
// A. 邻居是已着色的vreg
|
||||
if (color_map.count(neighbor_id)) {
|
||||
used_colors.insert(color_map.at(neighbor_id));
|
||||
}
|
||||
// B. 邻居是物理寄存器本身
|
||||
else if (neighbor_id >= static_cast<unsigned>(PhysicalReg::PHYS_REG_START_ID)) {
|
||||
PhysicalReg neighbor_preg = static_cast<PhysicalReg>(neighbor_id - static_cast<unsigned>(PhysicalReg::PHYS_REG_START_ID));
|
||||
used_colors.insert(neighbor_preg);
|
||||
}
|
||||
}
|
||||
|
||||
// 根据vreg类型选择寄存器池
|
||||
auto [type, size] = getTypeAndSize(vreg);
|
||||
const auto& allocable_regs = (type == Type::kFloat) ? allocable_fp_regs : allocable_int_regs;
|
||||
|
||||
bool colored = false;
|
||||
for (PhysicalReg preg : allocable_regs) {
|
||||
if (used_colors.find(preg) == used_colors.end()) {
|
||||
color_map[vreg] = preg;
|
||||
colored = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!colored) {
|
||||
spilled_vregs.insert(vreg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void RISCv64SimpleRegAlloc::rewriteFunction() {
|
||||
if (DEBUG) std::cerr << "\n--- Starting Function Rewrite (Spilling & Substitution) ---\n";
|
||||
StackFrameInfo& frame_info = MFunc->getFrameInfo();
|
||||
|
||||
// 步骤 1: 为所有溢出的vreg计算唯一的栈偏移量 (此部分逻辑正确,予以保留)
|
||||
int current_offset = frame_info.locals_end_offset;
|
||||
for (unsigned vreg : spilled_vregs) {
|
||||
if (frame_info.spill_offsets.count(vreg)) continue;
|
||||
auto [type, size] = getTypeAndSize(vreg);
|
||||
current_offset -= size;
|
||||
current_offset = current_offset & ~7;
|
||||
frame_info.spill_offsets[vreg] = current_offset;
|
||||
}
|
||||
frame_info.spill_size = -(current_offset - frame_info.locals_end_offset);
|
||||
|
||||
// 步骤 2: 遍历所有指令,对CALL指令做简化处理
|
||||
for (auto& mbb : MFunc->getBlocks()) {
|
||||
std::vector<std::unique_ptr<MachineInstr>> new_instructions;
|
||||
for (auto& instr_ptr : mbb->getInstructions()) {
|
||||
|
||||
if (instr_ptr->getOpcode() != RVOpcodes::CALL) {
|
||||
std::vector<PhysicalReg> int_spill_pool = {PhysicalReg::T2, PhysicalReg::T3, PhysicalReg::T4, /*PhysicalReg::T5,*/ PhysicalReg::T6};
|
||||
std::vector<PhysicalReg> fp_spill_pool = {PhysicalReg::F0, PhysicalReg::F1, PhysicalReg::F2, PhysicalReg::F3};
|
||||
std::map<unsigned, PhysicalReg> vreg_to_preg_map_for_this_instr;
|
||||
LiveSet use, def;
|
||||
getInstrUseDef(instr_ptr.get(), use, def);
|
||||
LiveSet all_vregs_in_instr = use;
|
||||
all_vregs_in_instr.insert(def.begin(), def.end());
|
||||
for(unsigned vreg : all_vregs_in_instr) {
|
||||
if (spilled_vregs.count(vreg)) {
|
||||
auto [type, size] = getTypeAndSize(vreg);
|
||||
if (type == Type::kFloat) {
|
||||
assert(!fp_spill_pool.empty() && "FP spill pool exhausted for generic instruction!");
|
||||
vreg_to_preg_map_for_this_instr[vreg] = fp_spill_pool.front();
|
||||
fp_spill_pool.erase(fp_spill_pool.begin());
|
||||
} else {
|
||||
assert(!int_spill_pool.empty() && "Int spill pool exhausted for generic instruction!");
|
||||
vreg_to_preg_map_for_this_instr[vreg] = int_spill_pool.front();
|
||||
int_spill_pool.erase(int_spill_pool.begin());
|
||||
}
|
||||
}
|
||||
}
|
||||
for (unsigned vreg : use) {
|
||||
if (spilled_vregs.count(vreg)) {
|
||||
PhysicalReg target_preg = vreg_to_preg_map_for_this_instr.at(vreg);
|
||||
auto [type, size] = getTypeAndSize(vreg);
|
||||
RVOpcodes load_op = (type == Type::kFloat) ? RVOpcodes::FLW : ((type == Type::kPointer) ? RVOpcodes::LD : RVOpcodes::LW);
|
||||
auto load = std::make_unique<MachineInstr>(load_op);
|
||||
load->addOperand(std::make_unique<RegOperand>(target_preg));
|
||||
load->addOperand(std::make_unique<MemOperand>(
|
||||
std::make_unique<RegOperand>(PhysicalReg::S0),
|
||||
std::make_unique<ImmOperand>(frame_info.spill_offsets.at(vreg))
|
||||
));
|
||||
new_instructions.push_back(std::move(load));
|
||||
}
|
||||
}
|
||||
auto new_instr = std::make_unique<MachineInstr>(instr_ptr->getOpcode());
|
||||
for (const auto& op : instr_ptr->getOperands()) {
|
||||
const RegOperand* reg_op = nullptr;
|
||||
if (op->getKind() == MachineOperand::KIND_REG) reg_op = static_cast<const RegOperand*>(op.get());
|
||||
else if (op->getKind() == MachineOperand::KIND_MEM) reg_op = static_cast<const MemOperand*>(op.get())->getBase();
|
||||
if (reg_op) {
|
||||
PhysicalReg final_preg;
|
||||
if (reg_op->isVirtual()) {
|
||||
unsigned vreg = reg_op->getVRegNum();
|
||||
if (spilled_vregs.count(vreg)) {
|
||||
final_preg = vreg_to_preg_map_for_this_instr.at(vreg);
|
||||
} else {
|
||||
assert(color_map.count(vreg));
|
||||
final_preg = color_map.at(vreg);
|
||||
}
|
||||
} else {
|
||||
final_preg = reg_op->getPReg();
|
||||
}
|
||||
auto new_reg_op = std::make_unique<RegOperand>(final_preg);
|
||||
if (op->getKind() == MachineOperand::KIND_REG) {
|
||||
new_instr->addOperand(std::move(new_reg_op));
|
||||
} else {
|
||||
auto mem_op = static_cast<const MemOperand*>(op.get());
|
||||
new_instr->addOperand(std::make_unique<MemOperand>(std::move(new_reg_op), std::make_unique<ImmOperand>(*mem_op->getOffset())));
|
||||
}
|
||||
} else {
|
||||
if(op->getKind() == MachineOperand::KIND_IMM) new_instr->addOperand(std::make_unique<ImmOperand>(*static_cast<const ImmOperand*>(op.get())));
|
||||
else if (op->getKind() == MachineOperand::KIND_LABEL) new_instr->addOperand(std::make_unique<LabelOperand>(*static_cast<const LabelOperand*>(op.get())));
|
||||
}
|
||||
}
|
||||
new_instructions.push_back(std::move(new_instr));
|
||||
for (unsigned vreg : def) {
|
||||
if (spilled_vregs.count(vreg)) {
|
||||
PhysicalReg src_preg = vreg_to_preg_map_for_this_instr.at(vreg);
|
||||
auto [type, size] = getTypeAndSize(vreg);
|
||||
RVOpcodes store_op = (type == Type::kFloat) ? RVOpcodes::FSW : ((type == Type::kPointer) ? RVOpcodes::SD : RVOpcodes::SW);
|
||||
auto store = std::make_unique<MachineInstr>(store_op);
|
||||
store->addOperand(std::make_unique<RegOperand>(src_preg));
|
||||
store->addOperand(std::make_unique<MemOperand>(
|
||||
std::make_unique<RegOperand>(PhysicalReg::S0),
|
||||
std::make_unique<ImmOperand>(frame_info.spill_offsets.at(vreg))
|
||||
));
|
||||
new_instructions.push_back(std::move(store));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// --- 对于CALL指令,只处理其自身和返回值,不再处理参数 ---
|
||||
const PhysicalReg INT_TEMP_REG = PhysicalReg::T6;
|
||||
const PhysicalReg FP_TEMP_REG = PhysicalReg::F7;
|
||||
|
||||
// 1. 克隆CALL指令本身,只保留标签操作数
|
||||
auto new_call = std::make_unique<MachineInstr>(RVOpcodes::CALL);
|
||||
for (const auto& op : instr_ptr->getOperands()) {
|
||||
if (op->getKind() == MachineOperand::KIND_LABEL) {
|
||||
new_call->addOperand(std::make_unique<LabelOperand>(*static_cast<const LabelOperand*>(op.get())));
|
||||
// 注意:只添加第一个标签,防止ISel的错误导致多个标签
|
||||
break;
|
||||
}
|
||||
}
|
||||
new_instructions.push_back(std::move(new_call));
|
||||
|
||||
// 2. 只处理返回值(def)的溢出和移动
|
||||
auto& operands = instr_ptr->getOperands();
|
||||
if (!operands.empty() && operands.front()->getKind() == MachineOperand::KIND_REG) {
|
||||
unsigned def_vreg = static_cast<RegOperand*>(operands.front().get())->getVRegNum();
|
||||
auto [type, size] = getTypeAndSize(def_vreg);
|
||||
PhysicalReg result_reg_abi = type == Type::kFloat ? PhysicalReg::F10 : PhysicalReg::A0;
|
||||
|
||||
if (spilled_vregs.count(def_vreg)) {
|
||||
// 返回值被溢出:a0/fa0 -> temp -> 溢出槽
|
||||
PhysicalReg temp_reg = type == Type::kFloat ? FP_TEMP_REG : INT_TEMP_REG;
|
||||
RVOpcodes store_op = (type == Type::kFloat) ? RVOpcodes::FSW : ((type == Type::kPointer) ? RVOpcodes::SD : RVOpcodes::SW);
|
||||
|
||||
auto mv_from_abi = std::make_unique<MachineInstr>(type == Type::kFloat ? RVOpcodes::FMV_S : RVOpcodes::MV);
|
||||
mv_from_abi->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||
mv_from_abi->addOperand(std::make_unique<RegOperand>(result_reg_abi));
|
||||
new_instructions.push_back(std::move(mv_from_abi));
|
||||
|
||||
auto store = std::make_unique<MachineInstr>(store_op);
|
||||
store->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||
store->addOperand(std::make_unique<MemOperand>(
|
||||
std::make_unique<RegOperand>(PhysicalReg::S0),
|
||||
std::make_unique<ImmOperand>(frame_info.spill_offsets.at(def_vreg))
|
||||
));
|
||||
new_instructions.push_back(std::move(store));
|
||||
} else {
|
||||
// 返回值未溢出:a0/fa0 -> 已着色的物理寄存器
|
||||
auto mv_to_dest = std::make_unique<MachineInstr>(type == Type::kFloat ? RVOpcodes::FMV_S : RVOpcodes::MV);
|
||||
mv_to_dest->addOperand(std::make_unique<RegOperand>(color_map.at(def_vreg)));
|
||||
mv_to_dest->addOperand(std::make_unique<RegOperand>(result_reg_abi));
|
||||
new_instructions.push_back(std::move(mv_to_dest));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
mbb->getInstructions() = std::move(new_instructions);
|
||||
}
|
||||
}
|
||||
|
||||
// --- 辅助函数实现 ---
|
||||
|
||||
void RISCv64SimpleRegAlloc::getInstrUseDef_Liveness(const MachineInstr* instr, LiveSet& use, LiveSet& def) {
|
||||
auto opcode = instr->getOpcode();
|
||||
const auto& operands = instr->getOperands();
|
||||
const unsigned offset = static_cast<unsigned>(PhysicalReg::PHYS_REG_START_ID);
|
||||
|
||||
auto get_any_reg_id = [&](const MachineOperand* op) -> unsigned {
|
||||
if (op->getKind() == MachineOperand::KIND_REG) {
|
||||
auto reg_op = static_cast<const RegOperand*>(op);
|
||||
return reg_op->isVirtual() ? reg_op->getVRegNum() : (offset + static_cast<unsigned>(reg_op->getPReg()));
|
||||
} else if (op->getKind() == MachineOperand::KIND_MEM) {
|
||||
auto reg_op = static_cast<const MemOperand*>(op)->getBase();
|
||||
return reg_op->isVirtual() ? reg_op->getVRegNum() : (offset + static_cast<unsigned>(reg_op->getPReg()));
|
||||
}
|
||||
return (unsigned)-1;
|
||||
};
|
||||
|
||||
if (op_info.count(opcode)) {
|
||||
const auto& info = op_info.at(opcode);
|
||||
for (int idx : info.first) if (idx < operands.size()) {
|
||||
unsigned reg_id = get_any_reg_id(operands[idx].get());
|
||||
if (reg_id != (unsigned)-1) def.insert(reg_id);
|
||||
}
|
||||
for (int idx : info.second) if (idx < operands.size()) {
|
||||
unsigned reg_id = get_any_reg_id(operands[idx].get());
|
||||
if (reg_id != (unsigned)-1) use.insert(reg_id);
|
||||
}
|
||||
for (const auto& op : operands) {
|
||||
if (op->getKind() == MachineOperand::KIND_MEM) {
|
||||
unsigned reg_id = get_any_reg_id(op.get());
|
||||
if (reg_id != (unsigned)-1) use.insert(reg_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (opcode == RVOpcodes::CALL) {
|
||||
if (!operands.empty() && operands[0]->getKind() == MachineOperand::KIND_REG) {
|
||||
def.insert(get_any_reg_id(operands[0].get()));
|
||||
}
|
||||
for (size_t i = 1; i < operands.size(); ++i) {
|
||||
if (operands[i]->getKind() == MachineOperand::KIND_REG) {
|
||||
use.insert(get_any_reg_id(operands[i].get()));
|
||||
}
|
||||
}
|
||||
for (auto preg : getCallerSavedIntRegs()) def.insert(offset + static_cast<unsigned>(preg));
|
||||
for (auto preg : getCallerSavedFpRegs()) def.insert(offset + static_cast<unsigned>(preg));
|
||||
def.insert(offset + static_cast<unsigned>(PhysicalReg::RA));
|
||||
}
|
||||
else if (opcode == RVOpcodes::RET) {
|
||||
use.insert(offset + static_cast<unsigned>(PhysicalReg::A0));
|
||||
use.insert(offset + static_cast<unsigned>(PhysicalReg::F10)); // fa0
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<Type::Kind, unsigned> RISCv64SimpleRegAlloc::getTypeAndSize(unsigned vreg) {
|
||||
const auto& vreg_type_map = ISel->getVRegTypeMap();
|
||||
if (vreg_type_map.count(vreg)) {
|
||||
Type* type = vreg_type_map.at(vreg);
|
||||
if (type->isFloat()) return {Type::kFloat, 4};
|
||||
if (type->isPointer()) return {Type::kPointer, 8};
|
||||
}
|
||||
// 默认或未知类型按32位整数处理
|
||||
return {Type::kInt, 4};
|
||||
}
|
||||
|
||||
std::string RISCv64SimpleRegAlloc::regIdToString(unsigned id, const RISCv64AsmPrinter& printer) const {
|
||||
const unsigned offset = static_cast<unsigned>(PhysicalReg::PHYS_REG_START_ID);
|
||||
if (id >= offset) {
|
||||
PhysicalReg reg = static_cast<PhysicalReg>(id - offset);
|
||||
return printer.regToString(reg);
|
||||
} else {
|
||||
return "%vreg" + std::to_string(id);
|
||||
}
|
||||
}
|
||||
|
||||
void RISCv64SimpleRegAlloc::printLiveSet(const LiveSet& s, const std::string& name, std::ostream& os, const RISCv64AsmPrinter& printer) {
|
||||
os << " " << name << " (" << s.size() << "): { ";
|
||||
for (unsigned vreg : s) {
|
||||
os << regIdToString(vreg, printer) << " ";
|
||||
}
|
||||
os << "}\n";
|
||||
}
|
||||
|
||||
} // namespace sysy
|
||||
@ -19,7 +19,7 @@ public:
|
||||
// 辅助函数
|
||||
void setStream(std::ostream& os) { OS = &os; }
|
||||
// 辅助函数
|
||||
std::string regToString(PhysicalReg reg);
|
||||
std::string regToString(PhysicalReg reg) const;
|
||||
std::string formatInstr(const MachineInstr *instr);
|
||||
|
||||
private:
|
||||
|
||||
@ -26,7 +26,7 @@ private:
|
||||
unsigned getTypeSizeInBytes(Type* type);
|
||||
|
||||
Module* module;
|
||||
bool gc_failed = false;
|
||||
bool irc_failed = false;
|
||||
};
|
||||
|
||||
} // namespace sysy
|
||||
|
||||
@ -11,6 +11,7 @@ namespace sysy {
|
||||
|
||||
extern int DEBUG;
|
||||
extern int DEEPDEBUG;
|
||||
extern int optLevel;
|
||||
|
||||
namespace sysy {
|
||||
|
||||
|
||||
@ -12,6 +12,7 @@ extern int DEBUG;
|
||||
extern int DEEPDEBUG;
|
||||
extern int DEBUGLENGTH; // 用于限制调试输出的长度
|
||||
extern int DEEPERDEBUG; // 用于更深层次的调试输出
|
||||
extern int optLevel;
|
||||
|
||||
namespace sysy {
|
||||
|
||||
@ -20,7 +21,7 @@ public:
|
||||
RISCv64RegAlloc(MachineFunction* mfunc);
|
||||
|
||||
// 模块主入口
|
||||
bool run();
|
||||
bool run(std::shared_ptr<std::atomic<bool>> stop_flag);
|
||||
|
||||
private:
|
||||
// 类型定义,与Python版本对应
|
||||
|
||||
107
src/include/backend/RISCv64/RISCv64SimpleRegAlloc.h
Normal file
107
src/include/backend/RISCv64/RISCv64SimpleRegAlloc.h
Normal file
@ -0,0 +1,107 @@
|
||||
#ifndef RISCV64_SIMPLE_REGALLOC_H
|
||||
#define RISCV64_SIMPLE_REGALLOC_H
|
||||
|
||||
#include "RISCv64LLIR.h"
|
||||
#include "RISCv64ISel.h"
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
// 外部调试级别控制变量的声明
|
||||
extern int DEBUG;
|
||||
extern int DEEPDEBUG;
|
||||
|
||||
namespace sysy {
|
||||
|
||||
class RISCv64AsmPrinter; // 前向声明
|
||||
|
||||
/**
|
||||
* @class RISCv64SimpleRegAlloc
|
||||
* @brief 一个简单的一次性图着色寄存器分配器。
|
||||
* * 该分配器遵循一个线性的、非迭代的流程:
|
||||
* 1. 活跃性分析
|
||||
* 2. 构建冲突图
|
||||
* 3. 贪心图着色
|
||||
* 4. 重写函数代码,插入溢出指令
|
||||
* * 它与新版后端流水线兼容,但保留了旧版分配器的核心逻辑。
|
||||
* 溢出处理使用硬编码的物理寄存器。
|
||||
*/
|
||||
class RISCv64SimpleRegAlloc {
|
||||
public:
|
||||
RISCv64SimpleRegAlloc(MachineFunction* mfunc);
|
||||
|
||||
/**
|
||||
* @brief 运行寄存器分配的主函数。
|
||||
*/
|
||||
void run();
|
||||
|
||||
private:
|
||||
using LiveSet = std::set<unsigned>;
|
||||
using InterferenceGraph = std::map<unsigned, LiveSet>;
|
||||
|
||||
// --- 分配流程的各个阶段 ---
|
||||
void unifyArgumentVRegs();
|
||||
void handleCallingConvention();
|
||||
void analyzeLiveness();
|
||||
void buildInterferenceGraph();
|
||||
void colorGraph();
|
||||
void rewriteFunction();
|
||||
|
||||
// --- 辅助函数 ---
|
||||
|
||||
/**
|
||||
* @brief 获取指令的Use/Def集合,包含物理寄存器,用于活跃性分析。
|
||||
* @param instr 机器指令。
|
||||
* @param use 输出参数,存储使用的寄存器ID。
|
||||
* @param def 输出参数,存储定义的寄存器ID。
|
||||
*/
|
||||
void getInstrUseDef_Liveness(const MachineInstr* instr, LiveSet& use, LiveSet& def);
|
||||
|
||||
/**
|
||||
* @brief 根据vreg的类型信息返回其大小和类型种类。
|
||||
* @param vreg 虚拟寄存器号。
|
||||
* @return 一个包含类型信息和大小(字节)的pair。
|
||||
*/
|
||||
std::pair<Type::Kind, unsigned> getTypeAndSize(unsigned vreg);
|
||||
|
||||
/**
|
||||
* @brief 打印调试用的活跃集信息。
|
||||
*/
|
||||
void printLiveSet(const LiveSet& s, const std::string& name, std::ostream& os, const RISCv64AsmPrinter& printer);
|
||||
|
||||
/**
|
||||
* @brief 将寄存器ID(虚拟或物理)转换为可读字符串。
|
||||
*/
|
||||
std::string regIdToString(unsigned id, const RISCv64AsmPrinter& printer) const;
|
||||
|
||||
// --- 成员变量 ---
|
||||
MachineFunction* MFunc;
|
||||
RISCv64ISel* ISel;
|
||||
|
||||
// 可分配的寄存器池
|
||||
std::vector<PhysicalReg> allocable_int_regs;
|
||||
std::vector<PhysicalReg> allocable_fp_regs;
|
||||
|
||||
// 硬编码的溢出专用物理寄存器
|
||||
const PhysicalReg INT_SPILL_REG = PhysicalReg::T2; // 用于 32-bit int
|
||||
const PhysicalReg PTR_SPILL_REG = PhysicalReg::T3; // 用于 64-bit pointer
|
||||
const PhysicalReg FP_SPILL_REG = PhysicalReg::F4; // 用于 32-bit float (ft4)
|
||||
|
||||
// 活跃性分析结果
|
||||
std::map<const MachineInstr*, LiveSet> live_in_map;
|
||||
std::map<const MachineInstr*, LiveSet> live_out_map;
|
||||
|
||||
// 冲突图
|
||||
InterferenceGraph interference_graph;
|
||||
|
||||
// 着色结果和溢出列表
|
||||
std::map<unsigned, PhysicalReg> color_map;
|
||||
std::set<unsigned> spilled_vregs;
|
||||
|
||||
// 映射:将物理寄存器ID映射到它们在冲突图中的特殊虚拟ID
|
||||
std::map<PhysicalReg, unsigned> preg_to_vreg_id_map;
|
||||
};
|
||||
|
||||
} // namespace sysy
|
||||
|
||||
#endif // RISCV64_SIMPLE_REGALLOC_H
|
||||
@ -350,7 +350,11 @@ private:
|
||||
std::set<Value*>& visited
|
||||
);
|
||||
bool isBasicInductionVariable(Value* val, Loop* loop);
|
||||
bool hasSimpleMemoryPattern(Loop* loop); // 简单的内存模式检查
|
||||
// ========== 循环不变量分析辅助方法 ==========
|
||||
bool isInvariantOperands(Instruction* inst, Loop* loop, const std::unordered_set<Value*>& invariants);
|
||||
bool isMemoryLocationModifiedInLoop(Value* ptr, Loop* loop);
|
||||
bool isMemoryLocationLoadedInLoop(Value* ptr, Loop* loop, Instruction* excludeInst = nullptr);
|
||||
bool isPureFunction(Function* calledFunc);
|
||||
};
|
||||
|
||||
} // namespace sysy
|
||||
|
||||
87
src/include/midend/Pass/Optimize/GVN.h
Normal file
87
src/include/midend/Pass/Optimize/GVN.h
Normal file
@ -0,0 +1,87 @@
|
||||
#pragma once
|
||||
|
||||
#include "Pass.h"
|
||||
#include "IR.h"
|
||||
#include "Dom.h"
|
||||
#include "SideEffectAnalysis.h"
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
|
||||
namespace sysy {
|
||||
|
||||
// GVN优化遍的核心逻辑封装类
|
||||
class GVNContext {
|
||||
public:
|
||||
// 运行GVN优化的主要方法
|
||||
void run(Function* func, AnalysisManager* AM, bool& changed);
|
||||
|
||||
private:
|
||||
// 新的值编号系统
|
||||
std::unordered_map<Value*, unsigned> valueToNumber; // Value -> 值编号
|
||||
std::unordered_map<unsigned, Value*> numberToValue; // 值编号 -> 代表值
|
||||
std::unordered_map<std::string, unsigned> expressionToNumber; // 表达式 -> 值编号
|
||||
unsigned nextValueNumber = 1;
|
||||
|
||||
// 已访问的基本块集合
|
||||
std::unordered_set<BasicBlock*> visited;
|
||||
|
||||
// 逆后序遍历的基本块列表
|
||||
std::vector<BasicBlock*> rpoBlocks;
|
||||
|
||||
// 需要删除的指令集合
|
||||
std::unordered_set<Instruction*> needRemove;
|
||||
|
||||
// 分析结果
|
||||
DominatorTree* domTree = nullptr;
|
||||
SideEffectAnalysisResult* sideEffectAnalysis = nullptr;
|
||||
|
||||
// 计算逆后序遍历
|
||||
void computeRPO(Function* func);
|
||||
void dfs(BasicBlock* bb);
|
||||
|
||||
// 新的值编号方法
|
||||
unsigned getValueNumber(Value* value);
|
||||
unsigned assignValueNumber(Value* value);
|
||||
|
||||
// 基本块处理
|
||||
void processBasicBlock(BasicBlock* bb, bool& changed);
|
||||
|
||||
// 指令处理
|
||||
bool processInstruction(Instruction* inst);
|
||||
|
||||
// 表达式构建和查找
|
||||
std::string buildExpressionKey(Instruction* inst);
|
||||
Value* findExistingValue(const std::string& exprKey, Instruction* inst);
|
||||
|
||||
// 支配关系和安全性检查
|
||||
bool dominates(Instruction* a, Instruction* b);
|
||||
bool isMemorySafe(LoadInst* earlierLoad, LoadInst* laterLoad);
|
||||
|
||||
// 清理方法
|
||||
void eliminateRedundantInstructions(bool& changed);
|
||||
void invalidateMemoryValues(StoreInst* store);
|
||||
};
|
||||
|
||||
// GVN优化遍类
|
||||
class GVN : public OptimizationPass {
|
||||
public:
|
||||
// 静态成员,作为该遍的唯一ID
|
||||
static void* ID;
|
||||
|
||||
GVN() : OptimizationPass("GVN", Granularity::Function) {}
|
||||
|
||||
// 在函数上运行优化
|
||||
bool runOnFunction(Function* func, AnalysisManager& AM) override;
|
||||
|
||||
// 返回该遍的唯一ID
|
||||
void* getPassID() const override { return ID; }
|
||||
|
||||
// 声明分析依赖
|
||||
void getAnalysisUsage(std::set<void*>& analysisDependencies,
|
||||
std::set<void*>& analysisInvalidations) const override;
|
||||
};
|
||||
|
||||
} // namespace sysy
|
||||
107
src/include/midend/Pass/Optimize/GlobalStrengthReduction.h
Normal file
107
src/include/midend/Pass/Optimize/GlobalStrengthReduction.h
Normal file
@ -0,0 +1,107 @@
|
||||
#pragma once
|
||||
|
||||
#include "Pass.h"
|
||||
#include "IR.h"
|
||||
#include "SideEffectAnalysis.h"
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include <cstdint>
|
||||
|
||||
namespace sysy {
|
||||
|
||||
// 魔数乘法结构,用于除法优化
|
||||
struct MagicNumber {
|
||||
uint32_t multiplier;
|
||||
int shift;
|
||||
bool needAdd;
|
||||
|
||||
MagicNumber(uint32_t m, int s, bool add = false)
|
||||
: multiplier(m), shift(s), needAdd(add) {}
|
||||
};
|
||||
|
||||
// 全局强度削弱优化遍的核心逻辑封装类
|
||||
class GlobalStrengthReductionContext {
|
||||
public:
|
||||
// 构造函数,接受IRBuilder参数
|
||||
explicit GlobalStrengthReductionContext(IRBuilder* builder) : builder(builder) {}
|
||||
|
||||
// 运行优化的主要方法
|
||||
void run(Function* func, AnalysisManager* AM, bool& changed);
|
||||
|
||||
private:
|
||||
IRBuilder* builder; // IR构建器
|
||||
|
||||
// 分析结果
|
||||
SideEffectAnalysisResult* sideEffectAnalysis = nullptr;
|
||||
|
||||
// 优化计数
|
||||
int algebraicOptCount = 0;
|
||||
int strengthReductionCount = 0;
|
||||
int divisionOptCount = 0;
|
||||
|
||||
// 主要优化方法
|
||||
bool processBasicBlock(BasicBlock* bb);
|
||||
bool processInstruction(Instruction* inst);
|
||||
|
||||
// 代数优化方法
|
||||
bool tryAlgebraicOptimization(Instruction* inst);
|
||||
bool optimizeAddition(BinaryInst* inst);
|
||||
bool optimizeSubtraction(BinaryInst* inst);
|
||||
bool optimizeMultiplication(BinaryInst* inst);
|
||||
bool optimizeDivision(BinaryInst* inst);
|
||||
bool optimizeComparison(BinaryInst* inst);
|
||||
bool optimizeLogical(BinaryInst* inst);
|
||||
|
||||
// 强度削弱方法
|
||||
bool tryStrengthReduction(Instruction* inst);
|
||||
bool reduceMultiplication(BinaryInst* inst);
|
||||
bool reduceDivision(BinaryInst* inst);
|
||||
bool reducePower(CallInst* inst);
|
||||
|
||||
// 复杂乘法强度削弱方法
|
||||
bool tryComplexMultiplication(BinaryInst* inst, Value* variable, int constant);
|
||||
bool findOptimalShiftDecomposition(int constant, std::vector<int>& shifts);
|
||||
Value* createShiftDecomposition(BinaryInst* inst, Value* variable, const std::vector<int>& shifts);
|
||||
|
||||
// 魔数乘法相关方法
|
||||
MagicNumber computeMagicNumber(uint32_t divisor);
|
||||
std::pair<int, int> computeMulhMagicNumbers(int divisor);
|
||||
Value* createMagicDivision(BinaryInst* divInst, uint32_t divisor, const MagicNumber& magic);
|
||||
Value* createMagicDivisionLibdivide(BinaryInst* divInst, int divisor);
|
||||
bool isPowerOfTwo(uint32_t n);
|
||||
int log2OfPowerOfTwo(uint32_t n);
|
||||
|
||||
// 辅助方法
|
||||
bool isConstantInt(Value* val, int& constVal);
|
||||
bool isConstantInt(Value* val, uint32_t& constVal);
|
||||
ConstantInteger* getConstantInt(int val);
|
||||
bool hasOnlyLocalUses(Instruction* inst);
|
||||
void replaceWithOptimized(Instruction* original, Value* replacement);
|
||||
};
|
||||
|
||||
// 全局强度削弱优化遍类
|
||||
class GlobalStrengthReduction : public OptimizationPass {
|
||||
private:
|
||||
IRBuilder* builder; // IR构建器,用于创建新指令
|
||||
|
||||
public:
|
||||
// 静态成员,作为该遍的唯一ID
|
||||
static void* ID;
|
||||
|
||||
// 构造函数,接受IRBuilder参数
|
||||
explicit GlobalStrengthReduction(IRBuilder* builder)
|
||||
: OptimizationPass("GlobalStrengthReduction", Granularity::Function), builder(builder) {}
|
||||
|
||||
// 在函数上运行优化
|
||||
bool runOnFunction(Function* func, AnalysisManager& AM) override;
|
||||
|
||||
// 返回该遍的唯一ID
|
||||
void* getPassID() const override { return ID; }
|
||||
|
||||
// 声明分析依赖
|
||||
void getAnalysisUsage(std::set<void*>& analysisDependencies,
|
||||
std::set<void*>& analysisInvalidations) const override;
|
||||
};
|
||||
|
||||
} // namespace sysy
|
||||
@ -1,24 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "../Pass.h"
|
||||
|
||||
namespace sysy {
|
||||
|
||||
class LargeArrayToGlobalPass : public OptimizationPass {
|
||||
public:
|
||||
static void *ID;
|
||||
|
||||
LargeArrayToGlobalPass() : OptimizationPass("LargeArrayToGlobal", Granularity::Module) {}
|
||||
|
||||
bool runOnModule(Module *M, AnalysisManager &AM) override;
|
||||
void *getPassID() const override {
|
||||
return &ID;
|
||||
}
|
||||
|
||||
private:
|
||||
unsigned calculateTypeSize(Type *type);
|
||||
void convertAllocaToGlobal(AllocaInst *alloca, Function *F, Module *M);
|
||||
std::string generateUniqueGlobalName(AllocaInst *alloca, Function *F);
|
||||
};
|
||||
|
||||
} // namespace sysy
|
||||
@ -127,13 +127,6 @@ private:
|
||||
*/
|
||||
bool analyzeInductionVariableRange(const InductionVarInfo* ivInfo, Loop* loop) const;
|
||||
|
||||
/**
|
||||
* 计算用于除法优化的魔数和移位量
|
||||
* @param divisor 除数
|
||||
* @return {魔数, 移位量}
|
||||
*/
|
||||
std::pair<int, int> computeMulhMagicNumbers(int divisor) const;
|
||||
|
||||
/**
|
||||
* 生成除法替换代码
|
||||
* @param candidate 优化候选项
|
||||
|
||||
@ -107,6 +107,190 @@ public:
|
||||
// 所以当AllocaInst的basetype是PointerType时(一维数组)或者是指向ArrayType的PointerType(多位数组)时,返回true
|
||||
return aval && (baseType->isPointer() || baseType->as<PointerType>()->getBaseType()->isArray());
|
||||
}
|
||||
|
||||
|
||||
//该实现参考了libdivide的算法
|
||||
static std::pair<int, int> computeMulhMagicNumbers(int divisor) {
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "\n[SR] ===== Computing magic numbers for divisor " << divisor << " (libdivide algorithm) =====" << std::endl;
|
||||
}
|
||||
|
||||
if (divisor == 0) {
|
||||
if (DEBUG) std::cout << "[SR] Error: divisor must be != 0" << std::endl;
|
||||
return {-1, -1};
|
||||
}
|
||||
|
||||
// libdivide 常数
|
||||
const uint8_t LIBDIVIDE_ADD_MARKER = 0x40;
|
||||
const uint8_t LIBDIVIDE_NEGATIVE_DIVISOR = 0x80;
|
||||
|
||||
// 辅助函数:计算前导零个数
|
||||
auto count_leading_zeros32 = [](uint32_t val) -> uint32_t {
|
||||
if (val == 0) return 32;
|
||||
return __builtin_clz(val);
|
||||
};
|
||||
|
||||
// 辅助函数:64位除法返回32位商和余数
|
||||
auto div_64_32 = [](uint32_t high, uint32_t low, uint32_t divisor, uint32_t* rem) -> uint32_t {
|
||||
uint64_t dividend = ((uint64_t)high << 32) | low;
|
||||
uint32_t quotient = dividend / divisor;
|
||||
*rem = dividend % divisor;
|
||||
return quotient;
|
||||
};
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Input divisor: " << divisor << std::endl;
|
||||
}
|
||||
|
||||
// libdivide_internal_s32_gen 算法实现
|
||||
int32_t d = divisor;
|
||||
uint32_t ud = (uint32_t)d;
|
||||
uint32_t absD = (d < 0) ? -ud : ud;
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] absD = " << absD << std::endl;
|
||||
}
|
||||
|
||||
uint32_t floor_log_2_d = 31 - count_leading_zeros32(absD);
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] floor_log_2_d = " << floor_log_2_d << std::endl;
|
||||
}
|
||||
|
||||
// 检查 absD 是否为2的幂
|
||||
if ((absD & (absD - 1)) == 0) {
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] " << absD << " 是2的幂,使用移位方法" << std::endl;
|
||||
}
|
||||
|
||||
// 对于2的幂,我们只使用移位,不需要魔数
|
||||
int shift = floor_log_2_d;
|
||||
if (d < 0) shift |= 0x80; // 标记负数
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Power of 2 result: magic=0, shift=" << shift << std::endl;
|
||||
std::cout << "[SR] ===== End magic computation =====" << std::endl;
|
||||
}
|
||||
|
||||
// 对于我们的目的,我们将在IR生成中以不同方式处理2的幂
|
||||
// 返回特殊标记
|
||||
return {0, shift};
|
||||
}
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] " << absD << " is not a power of 2, computing magic number" << std::endl;
|
||||
}
|
||||
|
||||
// 非2的幂除数的魔数计算
|
||||
uint8_t more;
|
||||
uint32_t rem, proposed_m;
|
||||
|
||||
// 计算 proposed_m = floor(2^(floor_log_2_d + 31) / absD)
|
||||
proposed_m = div_64_32((uint32_t)1 << (floor_log_2_d - 1), 0, absD, &rem);
|
||||
const uint32_t e = absD - rem;
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] proposed_m = " << proposed_m << ", rem = " << rem << ", e = " << e << std::endl;
|
||||
}
|
||||
|
||||
// 确定是否需要"加法"版本
|
||||
const bool branchfree = false; // 使用分支版本
|
||||
|
||||
if (!branchfree && e < ((uint32_t)1 << floor_log_2_d)) {
|
||||
// 这个幂次有效
|
||||
more = (uint8_t)(floor_log_2_d - 1);
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Using basic algorithm, shift = " << (int)more << std::endl;
|
||||
}
|
||||
} else {
|
||||
// 我们需要上升一个等级
|
||||
proposed_m += proposed_m;
|
||||
const uint32_t twice_rem = rem + rem;
|
||||
if (twice_rem >= absD || twice_rem < rem) {
|
||||
proposed_m += 1;
|
||||
}
|
||||
more = (uint8_t)(floor_log_2_d | LIBDIVIDE_ADD_MARKER);
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Using add algorithm, proposed_m = " << proposed_m << ", more = " << (int)more << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
proposed_m += 1;
|
||||
int32_t magic = (int32_t)proposed_m;
|
||||
|
||||
// 处理负除数
|
||||
if (d < 0) {
|
||||
more |= LIBDIVIDE_NEGATIVE_DIVISOR;
|
||||
if (!branchfree) {
|
||||
magic = -magic;
|
||||
}
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Negative divisor, magic = " << magic << ", more = " << (int)more << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// 为我们的IR生成提取移位量和标志
|
||||
int shift = more & 0x3F; // 移除标志,保留移位量(位0-5)
|
||||
bool need_add = (more & LIBDIVIDE_ADD_MARKER) != 0;
|
||||
bool is_negative = (more & LIBDIVIDE_NEGATIVE_DIVISOR) != 0;
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Final result: magic = " << magic << ", more = " << (int)more
|
||||
<< " (0x" << std::hex << (int)more << std::dec << ")" << std::endl;
|
||||
std::cout << "[SR] Shift = " << shift << ", need_add = " << need_add
|
||||
<< ", is_negative = " << is_negative << std::endl;
|
||||
|
||||
// Test the magic number using the correct libdivide algorithm
|
||||
std::cout << "[SR] Testing magic number (libdivide algorithm):" << std::endl;
|
||||
int test_values[] = {1, 7, 37, 100, 999, -1, -7, -37, -100};
|
||||
|
||||
for (int test_val : test_values) {
|
||||
int64_t quotient;
|
||||
|
||||
// 实现正确的libdivide算法
|
||||
int64_t product = (int64_t)test_val * magic;
|
||||
int64_t high_bits = product >> 32;
|
||||
|
||||
if (need_add) {
|
||||
// ADD_MARKER情况:移位前加上被除数
|
||||
// 这是libdivide的关键洞察!
|
||||
high_bits += test_val;
|
||||
quotient = high_bits >> shift;
|
||||
} else {
|
||||
// 正常情况:只是移位
|
||||
quotient = high_bits >> shift;
|
||||
}
|
||||
|
||||
// 符号修正:这是libdivide有符号除法的关键部分!
|
||||
// 如果被除数为负,商需要加1来匹配C语言的截断除法语义
|
||||
if (test_val < 0) {
|
||||
quotient += 1;
|
||||
}
|
||||
|
||||
int expected = test_val / divisor;
|
||||
|
||||
bool correct = (quotient == expected);
|
||||
std::cout << "[SR] " << test_val << " / " << divisor << " = " << quotient
|
||||
<< " (expected " << expected << ") " << (correct ? "✓" : "✗") << std::endl;
|
||||
}
|
||||
|
||||
std::cout << "[SR] ===== End magic computation =====" << std::endl;
|
||||
}
|
||||
|
||||
// 返回魔数、移位量,并在移位中编码ADD_MARKER标志
|
||||
// 我们将使用移位的第6位表示ADD_MARKER,第7位表示负数(如果需要)
|
||||
int encoded_shift = shift;
|
||||
if (need_add) {
|
||||
encoded_shift |= 0x40; // 设置第6位表示ADD_MARKER
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Encoding ADD_MARKER in shift: " << encoded_shift << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
return {magic, encoded_shift};
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
}// namespace sysy
|
||||
39
src/include/midend/Pass/Optimize/TailCallOpt.h
Normal file
39
src/include/midend/Pass/Optimize/TailCallOpt.h
Normal file
@ -0,0 +1,39 @@
|
||||
#pragma once
|
||||
|
||||
#include "Pass.h"
|
||||
#include "Dom.h"
|
||||
#include "Loop.h"
|
||||
|
||||
namespace sysy {
|
||||
|
||||
/**
|
||||
* @class TailCallOpt
|
||||
* @brief 优化尾调用的中端优化通道。
|
||||
*
|
||||
* 该类实现了一个针对函数级别的尾调用优化的优化通道(OptimizationPass)。
|
||||
* 通过分析和转换 IR(中间表示),将可优化的尾调用转换为更高效的形式,
|
||||
* 以减少函数调用的开销,提升程序性能。
|
||||
*
|
||||
* @note 需要传入 IRBuilder 指针用于 IR 构建和修改。
|
||||
*
|
||||
* @method runOnFunction
|
||||
* 对指定函数进行尾调用优化。
|
||||
*
|
||||
* @method getPassID
|
||||
* 获取当前优化通道的唯一标识符。
|
||||
*
|
||||
* @method getAnalysisUsage
|
||||
* 指定该优化通道所依赖和失效的分析集合。
|
||||
*/
|
||||
class TailCallOpt : public OptimizationPass {
|
||||
private:
|
||||
IRBuilder* builder;
|
||||
public:
|
||||
TailCallOpt(IRBuilder* builder) : OptimizationPass("TailCallOpt", Granularity::Function), builder(builder) {}
|
||||
static void *ID;
|
||||
bool runOnFunction(Function *F, AnalysisManager &AM) override;
|
||||
void *getPassID() const override { return &ID; }
|
||||
void getAnalysisUsage(std::set<void *> &analysisDependencies, std::set<void *> &analysisInvalidations) const override;
|
||||
};
|
||||
|
||||
} // namespace sysy
|
||||
@ -15,14 +15,16 @@ add_library(midend_lib STATIC
|
||||
Pass/Optimize/DCE.cpp
|
||||
Pass/Optimize/Mem2Reg.cpp
|
||||
Pass/Optimize/Reg2Mem.cpp
|
||||
Pass/Optimize/GVN.cpp
|
||||
Pass/Optimize/SysYIRCFGOpt.cpp
|
||||
Pass/Optimize/SCCP.cpp
|
||||
Pass/Optimize/LoopNormalization.cpp
|
||||
Pass/Optimize/LICM.cpp
|
||||
Pass/Optimize/LoopStrengthReduction.cpp
|
||||
Pass/Optimize/InductionVariableElimination.cpp
|
||||
Pass/Optimize/GlobalStrengthReduction.cpp
|
||||
Pass/Optimize/BuildCFG.cpp
|
||||
Pass/Optimize/LargeArrayToGlobal.cpp
|
||||
Pass/Optimize/TailCallOpt.cpp
|
||||
)
|
||||
|
||||
# 包含中端模块所需的头文件路径
|
||||
|
||||
@ -847,7 +847,7 @@ void CondBrInst::print(std::ostream &os) const {
|
||||
|
||||
os << "%tmp_cond_" << condName << "_" << uniqueSuffix << " = icmp ne i32 ";
|
||||
printOperand(os, condition);
|
||||
os << ", 0\n br i1 %tmp_cond_" << condName << "_" << uniqueSuffix;
|
||||
os << ", 0\n br i1 %tmp_cond_" << condName << "_" << uniqueSuffix;
|
||||
|
||||
os << ", label %";
|
||||
printBlockName(os, getThenBlock());
|
||||
@ -886,7 +886,7 @@ void MemsetInst::print(std::ostream &os) const {
|
||||
// This is done at print time to avoid modifying the IR structure
|
||||
os << "%tmp_bitcast_" << ptr->getName() << " = bitcast " << *ptr->getType() << " ";
|
||||
printOperand(os, ptr);
|
||||
os << " to i8*\n ";
|
||||
os << " to i8*\n ";
|
||||
|
||||
// Now call memset with the bitcast result
|
||||
os << "call void @llvm.memset.p0i8.i32(i8* %tmp_bitcast_" << ptr->getName() << ", i8 ";
|
||||
|
||||
@ -776,38 +776,324 @@ void LoopCharacteristicsPass::findDerivedInductionVars(
|
||||
}
|
||||
}
|
||||
|
||||
// 递归/推进式判定
|
||||
bool LoopCharacteristicsPass::isClassicLoopInvariant(Value* val, Loop* loop, const std::unordered_set<Value*>& invariants) {
|
||||
// 1. 常量
|
||||
if (auto* constval = dynamic_cast<ConstantValue*>(val)) return true;
|
||||
|
||||
// 2. 参数(函数参数)通常不在任何BasicBlock内,直接判定为不变量
|
||||
if (auto* arg = dynamic_cast<Argument*>(val)) return true;
|
||||
|
||||
// 3. 指令且定义在循环外
|
||||
if (auto* inst = dynamic_cast<Instruction*>(val)) {
|
||||
if (!loop->contains(inst->getParent()))
|
||||
return true;
|
||||
|
||||
// 4. 跳转 phi指令 副作用 不外提
|
||||
if (inst->isTerminator() || inst->isPhi() || sideEffectAnalysis->hasSideEffect(inst))
|
||||
// 检查操作数是否都是不变量
|
||||
bool LoopCharacteristicsPass::isInvariantOperands(Instruction* inst, Loop* loop, const std::unordered_set<Value*>& invariants) {
|
||||
for (size_t i = 0; i < inst->getNumOperands(); ++i) {
|
||||
Value* op = inst->getOperand(i);
|
||||
if (!isClassicLoopInvariant(op, loop, invariants) && !invariants.count(op)) {
|
||||
return false;
|
||||
|
||||
// 5. 所有操作数都是不变量
|
||||
for (size_t i = 0; i < inst->getNumOperands(); ++i) {
|
||||
Value* op = inst->getOperand(i);
|
||||
if (!isClassicLoopInvariant(op, loop, invariants) && !invariants.count(op))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
// 其它情况
|
||||
return true;
|
||||
}
|
||||
|
||||
// 检查内存位置是否在循环中被修改
|
||||
bool LoopCharacteristicsPass::isMemoryLocationModifiedInLoop(Value* ptr, Loop* loop) {
|
||||
// 遍历循环中的所有指令,检查是否有对该内存位置的写入
|
||||
for (BasicBlock* bb : loop->getBlocks()) {
|
||||
for (auto& inst : bb->getInstructions()) {
|
||||
// 1. 检查直接的Store指令
|
||||
if (auto* storeInst = dynamic_cast<StoreInst*>(inst.get())) {
|
||||
Value* storeTar = storeInst->getPointer();
|
||||
|
||||
// 使用别名分析检查是否可能别名
|
||||
if (aliasAnalysis) {
|
||||
auto aliasType = aliasAnalysis->queryAlias(ptr, storeTar);
|
||||
if (aliasType != AliasType::NO_ALIAS) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Memory location " << ptr->getName()
|
||||
<< " may be modified by store to " << storeTar->getName() << std::endl;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
// 如果没有别名分析,保守处理 - 只检查精确匹配
|
||||
if (ptr == storeTar) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 检查函数调用是否可能修改该内存位置
|
||||
else if (auto* callInst = dynamic_cast<CallInst*>(inst.get())) {
|
||||
Function* calledFunc = callInst->getCallee();
|
||||
|
||||
// 如果是纯函数,不会修改内存
|
||||
if (isPureFunction(calledFunc)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// 检查函数参数中是否有该内存位置的指针
|
||||
for (size_t i = 1; i < callInst->getNumOperands(); ++i) { // 跳过函数指针
|
||||
Value* arg = callInst->getOperand(i);
|
||||
|
||||
// 检查参数是否是指针类型且可能指向该内存位置
|
||||
if (auto* ptrType = dynamic_cast<PointerType*>(arg->getType())) {
|
||||
// 使用别名分析检查
|
||||
if (aliasAnalysis) {
|
||||
auto aliasType = aliasAnalysis->queryAlias(ptr, arg);
|
||||
if (aliasType != AliasType::NO_ALIAS) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Memory location " << ptr->getName()
|
||||
<< " may be modified by function call " << calledFunc->getName()
|
||||
<< " through parameter " << arg->getName() << std::endl;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
// 没有别名分析,检查精确匹配
|
||||
if (ptr == arg) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Memory location " << ptr->getName()
|
||||
<< " may be modified by function call " << calledFunc->getName()
|
||||
<< " (exact match)" << std::endl;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool LoopCharacteristicsPass::hasSimpleMemoryPattern(Loop* loop) {
|
||||
// 检查是否有简单的内存访问模式
|
||||
return true; // 暂时简化处理
|
||||
// 检查内存位置是否在循环中被读取
|
||||
bool LoopCharacteristicsPass::isMemoryLocationLoadedInLoop(Value* ptr, Loop* loop, Instruction* excludeInst) {
|
||||
// 遍历循环中的所有Load指令,检查是否有对该内存位置的读取
|
||||
for (BasicBlock* bb : loop->getBlocks()) {
|
||||
for (auto& inst : bb->getInstructions()) {
|
||||
if (inst.get() == excludeInst) continue; // 排除当前指令本身
|
||||
|
||||
if (auto* loadInst = dynamic_cast<LoadInst*>(inst.get())) {
|
||||
Value* loadSrc = loadInst->getPointer();
|
||||
|
||||
// 使用别名分析检查是否可能别名
|
||||
if (aliasAnalysis) {
|
||||
auto aliasType = aliasAnalysis->queryAlias(ptr, loadSrc);
|
||||
if (aliasType != AliasType::NO_ALIAS) {
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
// 如果没有别名分析,保守处理 - 只检查精确匹配
|
||||
if (ptr == loadSrc) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// 检查函数调用是否为纯函数
|
||||
bool LoopCharacteristicsPass::isPureFunction(Function* calledFunc) {
|
||||
if (!calledFunc) return false;
|
||||
|
||||
// 使用副作用分析检查函数是否为纯函数
|
||||
if (sideEffectAnalysis && sideEffectAnalysis->isPureFunction(calledFunc)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// 检查是否为内置纯函数(如数学函数)
|
||||
std::string funcName = calledFunc->getName();
|
||||
static const std::set<std::string> pureFunctions = {
|
||||
"abs", "fabs", "sqrt", "sin", "cos", "tan", "exp", "log", "pow",
|
||||
"floor", "ceil", "round", "min", "max"
|
||||
};
|
||||
|
||||
return pureFunctions.count(funcName) > 0;
|
||||
}
|
||||
|
||||
// 递归/推进式判定 - 完善版本
|
||||
bool LoopCharacteristicsPass::isClassicLoopInvariant(Value* val, Loop* loop, const std::unordered_set<Value*>& invariants) {
|
||||
if (DEBUG >= 2) {
|
||||
std::cout << " Checking loop invariant for: " << val->getName() << std::endl;
|
||||
}
|
||||
|
||||
// 1. 常量
|
||||
if (auto* constval = dynamic_cast<ConstantValue*>(val)) {
|
||||
if (DEBUG >= 2) std::cout << " -> Constant: YES" << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
// 2. 参数(函数参数)通常不在任何BasicBlock内,直接判定为不变量
|
||||
// 在SSA形式下,参数不会被重新赋值
|
||||
if (auto* arg = dynamic_cast<Argument*>(val)) {
|
||||
if (DEBUG >= 2) std::cout << " -> Function argument: YES" << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
// 3. 指令且定义在循环外
|
||||
if (auto* inst = dynamic_cast<Instruction*>(val)) {
|
||||
if (!loop->contains(inst->getParent())) {
|
||||
if (DEBUG >= 2) std::cout << " -> Defined outside loop: YES" << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
// 4. 跳转指令、phi指令不能外提
|
||||
if (inst->isTerminator() || inst->isPhi()) {
|
||||
if (DEBUG >= 2) std::cout << " -> Terminator or PHI: NO" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// 5. 根据指令类型进行具体分析
|
||||
switch (inst->getKind()) {
|
||||
case Instruction::Kind::kStore: {
|
||||
// Store指令:检查循环内是否有对该内存的load
|
||||
auto* storeInst = dynamic_cast<StoreInst*>(inst);
|
||||
Value* storePtr = storeInst->getPointer();
|
||||
|
||||
// 首先检查操作数是否不变
|
||||
if (!isInvariantOperands(inst, loop, invariants)) {
|
||||
if (DEBUG >= 2) std::cout << " -> Store: operands not invariant: NO" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// 检查是否有对该内存位置的load
|
||||
if (isMemoryLocationLoadedInLoop(storePtr, loop, inst)) {
|
||||
if (DEBUG >= 2) std::cout << " -> Store: memory location loaded in loop: NO" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (DEBUG >= 2) std::cout << " -> Store: safe to hoist: YES" << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
case Instruction::Kind::kLoad: {
|
||||
// Load指令:检查循环内是否有对该内存的store
|
||||
auto* loadInst = dynamic_cast<LoadInst*>(inst);
|
||||
Value* loadPtr = loadInst->getPointer();
|
||||
|
||||
// 首先检查指针操作数是否不变
|
||||
if (!isInvariantOperands(inst, loop, invariants)) {
|
||||
if (DEBUG >= 2) std::cout << " -> Load: pointer not invariant: NO" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// 检查是否有对该内存位置的store
|
||||
if (isMemoryLocationModifiedInLoop(loadPtr, loop)) {
|
||||
if (DEBUG >= 2) std::cout << " -> Load: memory location modified in loop: NO" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (DEBUG >= 2) std::cout << " -> Load: safe to hoist: YES" << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
case Instruction::Kind::kCall: {
|
||||
// Call指令:检查是否为纯函数且参数不变
|
||||
auto* callInst = dynamic_cast<CallInst*>(inst);
|
||||
Function* calledFunc = callInst->getCallee();
|
||||
|
||||
// 检查是否为纯函数
|
||||
if (!isPureFunction(calledFunc)) {
|
||||
if (DEBUG >= 2) std::cout << " -> Call: not pure function: NO" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// 检查参数是否都不变
|
||||
if (!isInvariantOperands(inst, loop, invariants)) {
|
||||
if (DEBUG >= 2) std::cout << " -> Call: arguments not invariant: NO" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (DEBUG >= 2) std::cout << " -> Call: pure function with invariant args: YES" << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
case Instruction::Kind::kGetElementPtr: {
|
||||
// GEP指令:检查基址和索引是否都不变
|
||||
if (!isInvariantOperands(inst, loop, invariants)) {
|
||||
if (DEBUG >= 2) std::cout << " -> GEP: base or indices not invariant: NO" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (DEBUG >= 2) std::cout << " -> GEP: base and indices invariant: YES" << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
// 一元运算指令
|
||||
case Instruction::Kind::kNeg:
|
||||
case Instruction::Kind::kNot:
|
||||
case Instruction::Kind::kFNeg:
|
||||
case Instruction::Kind::kFNot:
|
||||
case Instruction::Kind::kFtoI:
|
||||
case Instruction::Kind::kItoF:
|
||||
case Instruction::Kind::kBitItoF:
|
||||
case Instruction::Kind::kBitFtoI: {
|
||||
// 检查操作数是否不变
|
||||
if (!isInvariantOperands(inst, loop, invariants)) {
|
||||
if (DEBUG >= 2) std::cout << " -> Unary op: operand not invariant: NO" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (DEBUG >= 2) std::cout << " -> Unary op: operand invariant: YES" << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
// 二元运算指令
|
||||
case Instruction::Kind::kAdd:
|
||||
case Instruction::Kind::kSub:
|
||||
case Instruction::Kind::kMul:
|
||||
case Instruction::Kind::kDiv:
|
||||
case Instruction::Kind::kRem:
|
||||
case Instruction::Kind::kSll:
|
||||
case Instruction::Kind::kSrl:
|
||||
case Instruction::Kind::kSra:
|
||||
case Instruction::Kind::kAnd:
|
||||
case Instruction::Kind::kOr:
|
||||
case Instruction::Kind::kFAdd:
|
||||
case Instruction::Kind::kFSub:
|
||||
case Instruction::Kind::kFMul:
|
||||
case Instruction::Kind::kFDiv:
|
||||
case Instruction::Kind::kICmpEQ:
|
||||
case Instruction::Kind::kICmpNE:
|
||||
case Instruction::Kind::kICmpLT:
|
||||
case Instruction::Kind::kICmpGT:
|
||||
case Instruction::Kind::kICmpLE:
|
||||
case Instruction::Kind::kICmpGE:
|
||||
case Instruction::Kind::kFCmpEQ:
|
||||
case Instruction::Kind::kFCmpNE:
|
||||
case Instruction::Kind::kFCmpLT:
|
||||
case Instruction::Kind::kFCmpGT:
|
||||
case Instruction::Kind::kFCmpLE:
|
||||
case Instruction::Kind::kFCmpGE:
|
||||
case Instruction::Kind::kMulh: {
|
||||
// 检查所有操作数是否不变
|
||||
if (!isInvariantOperands(inst, loop, invariants)) {
|
||||
if (DEBUG >= 2) std::cout << " -> Binary op: operands not invariant: NO" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (DEBUG >= 2) std::cout << " -> Binary op: operands invariant: YES" << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
default: {
|
||||
// 其他指令:使用副作用分析
|
||||
if (sideEffectAnalysis && sideEffectAnalysis->hasSideEffect(inst)) {
|
||||
if (DEBUG >= 2) std::cout << " -> Other inst: has side effect: NO" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// 检查操作数是否都不变
|
||||
if (!isInvariantOperands(inst, loop, invariants)) {
|
||||
if (DEBUG >= 2) std::cout << " -> Other inst: operands not invariant: NO" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (DEBUG >= 2) std::cout << " -> Other inst: no side effect, operands invariant: YES" << std::endl;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 其它情况
|
||||
if (DEBUG >= 2) std::cout << " -> Other value type: NO" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
} // namespace sysy
|
||||
|
||||
@ -26,10 +26,21 @@ const SideEffectInfo &SideEffectAnalysisResult::getInstructionSideEffect(Instruc
|
||||
}
|
||||
|
||||
const SideEffectInfo &SideEffectAnalysisResult::getFunctionSideEffect(Function *func) const {
|
||||
// 首先检查分析过的用户定义函数
|
||||
auto it = functionSideEffects.find(func);
|
||||
if (it != functionSideEffects.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// 如果没有找到,检查是否为已知的库函数
|
||||
if (func) {
|
||||
std::string funcName = func->getName();
|
||||
const SideEffectInfo *knownInfo = getKnownFunctionSideEffect(funcName);
|
||||
if (knownInfo) {
|
||||
return *knownInfo;
|
||||
}
|
||||
}
|
||||
|
||||
// 返回默认的无副作用信息
|
||||
static SideEffectInfo noEffect;
|
||||
return noEffect;
|
||||
|
||||
492
src/midend/Pass/Optimize/GVN.cpp
Normal file
492
src/midend/Pass/Optimize/GVN.cpp
Normal file
@ -0,0 +1,492 @@
|
||||
#include "GVN.h"
|
||||
#include "Dom.h"
|
||||
#include "SysYIROptUtils.h"
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
extern int DEBUG;
|
||||
|
||||
namespace sysy {
|
||||
|
||||
// GVN 遍的静态 ID
|
||||
void *GVN::ID = (void *)&GVN::ID;
|
||||
|
||||
// ======================================================================
|
||||
// GVN 类的实现
|
||||
// ======================================================================
|
||||
|
||||
bool GVN::runOnFunction(Function *func, AnalysisManager &AM) {
|
||||
if (func->getBasicBlocks().empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "\n=== Running GVN on function: " << func->getName() << " ===" << std::endl;
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
GVNContext context;
|
||||
context.run(func, &AM, changed);
|
||||
|
||||
if (DEBUG) {
|
||||
if (changed) {
|
||||
std::cout << "GVN: Function " << func->getName() << " was modified" << std::endl;
|
||||
} else {
|
||||
std::cout << "GVN: Function " << func->getName() << " was not modified" << std::endl;
|
||||
}
|
||||
std::cout << "=== GVN completed for function: " << func->getName() << " ===" << std::endl;
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
void GVN::getAnalysisUsage(std::set<void *> &analysisDependencies, std::set<void *> &analysisInvalidations) const {
|
||||
// GVN依赖以下分析:
|
||||
// 1. 支配树分析 - 用于检查指令的支配关系,确保替换的安全性
|
||||
analysisDependencies.insert(&DominatorTreeAnalysisPass::ID);
|
||||
|
||||
// 2. 副作用分析 - 用于判断函数调用是否可以进行GVN
|
||||
analysisDependencies.insert(&SysYSideEffectAnalysisPass::ID);
|
||||
|
||||
// GVN不会使任何分析失效,因为:
|
||||
// - GVN只删除冗余计算,不改变CFG结构
|
||||
// - GVN不修改程序的语义,只是消除重复计算
|
||||
// - 支配关系保持不变
|
||||
// - 副作用分析结果保持不变
|
||||
// analysisInvalidations 保持为空
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "GVN: Declared analysis dependencies (DominatorTree, SideEffectAnalysis)" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// ======================================================================
|
||||
// GVNContext 类的实现 - 重构版本
|
||||
// ======================================================================
|
||||
|
||||
// 简单的表达式哈希结构
|
||||
struct ExpressionKey {
|
||||
enum Type { BINARY, UNARY, LOAD, GEP, CALL } type;
|
||||
int opcode;
|
||||
std::vector<Value*> operands;
|
||||
Type* resultType;
|
||||
|
||||
bool operator==(const ExpressionKey& other) const {
|
||||
return type == other.type && opcode == other.opcode &&
|
||||
operands == other.operands && resultType == other.resultType;
|
||||
}
|
||||
};
|
||||
|
||||
struct ExpressionKeyHash {
|
||||
size_t operator()(const ExpressionKey& key) const {
|
||||
size_t hash = std::hash<int>()(static_cast<int>(key.type)) ^
|
||||
std::hash<int>()(key.opcode);
|
||||
for (auto op : key.operands) {
|
||||
hash ^= std::hash<Value*>()(op) + 0x9e3779b9 + (hash << 6) + (hash >> 2);
|
||||
}
|
||||
return hash;
|
||||
}
|
||||
};
|
||||
|
||||
void GVNContext::run(Function *func, AnalysisManager *AM, bool &changed) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Starting GVN analysis for function: " << func->getName() << std::endl;
|
||||
}
|
||||
|
||||
// 获取分析结果
|
||||
if (AM) {
|
||||
domTree = AM->getAnalysisResult<DominatorTree, DominatorTreeAnalysisPass>(func);
|
||||
sideEffectAnalysis = AM->getAnalysisResult<SideEffectAnalysisResult, SysYSideEffectAnalysisPass>();
|
||||
|
||||
if (DEBUG) {
|
||||
if (domTree) {
|
||||
std::cout << " GVN: Using dominator tree analysis" << std::endl;
|
||||
} else {
|
||||
std::cout << " GVN: Warning - dominator tree analysis not available" << std::endl;
|
||||
}
|
||||
if (sideEffectAnalysis) {
|
||||
std::cout << " GVN: Using side effect analysis" << std::endl;
|
||||
} else {
|
||||
std::cout << " GVN: Warning - side effect analysis not available" << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 清空状态
|
||||
valueToNumber.clear();
|
||||
numberToValue.clear();
|
||||
expressionToNumber.clear();
|
||||
nextValueNumber = 1;
|
||||
visited.clear();
|
||||
rpoBlocks.clear();
|
||||
needRemove.clear();
|
||||
|
||||
// 计算逆后序遍历
|
||||
computeRPO(func);
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << " Computed RPO with " << rpoBlocks.size() << " blocks" << std::endl;
|
||||
}
|
||||
|
||||
// 按逆后序遍历基本块进行GVN
|
||||
int blockCount = 0;
|
||||
for (auto bb : rpoBlocks) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Processing block " << ++blockCount << "/" << rpoBlocks.size()
|
||||
<< ": " << bb->getName() << std::endl;
|
||||
}
|
||||
|
||||
processBasicBlock(bb, changed);
|
||||
}
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << " Found " << needRemove.size() << " redundant instructions to remove" << std::endl;
|
||||
}
|
||||
|
||||
// 删除冗余指令
|
||||
eliminateRedundantInstructions(changed);
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << " GVN analysis completed for function: " << func->getName() << std::endl;
|
||||
std::cout << " Total values numbered: " << valueToNumber.size() << std::endl;
|
||||
std::cout << " Instructions eliminated: " << needRemove.size() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void GVNContext::computeRPO(Function *func) {
|
||||
rpoBlocks.clear();
|
||||
visited.clear();
|
||||
|
||||
auto entry = func->getEntryBlock();
|
||||
if (entry) {
|
||||
dfs(entry);
|
||||
std::reverse(rpoBlocks.begin(), rpoBlocks.end());
|
||||
}
|
||||
}
|
||||
|
||||
void GVNContext::dfs(BasicBlock *bb) {
|
||||
if (!bb || visited.count(bb)) {
|
||||
return;
|
||||
}
|
||||
|
||||
visited.insert(bb);
|
||||
|
||||
// 访问所有后继基本块
|
||||
for (auto succ : bb->getSuccessors()) {
|
||||
if (visited.find(succ) == visited.end()) {
|
||||
dfs(succ);
|
||||
}
|
||||
}
|
||||
|
||||
rpoBlocks.push_back(bb);
|
||||
}
|
||||
|
||||
unsigned GVNContext::getValueNumber(Value* value) {
|
||||
// 如果已经有值编号,直接返回
|
||||
auto it = valueToNumber.find(value);
|
||||
if (it != valueToNumber.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// 为新值分配编号
|
||||
return assignValueNumber(value);
|
||||
}
|
||||
|
||||
unsigned GVNContext::assignValueNumber(Value* value) {
|
||||
unsigned number = nextValueNumber++;
|
||||
valueToNumber[value] = number;
|
||||
numberToValue[number] = value;
|
||||
|
||||
if (DEBUG >= 2) {
|
||||
std::cout << " Assigned value number " << number
|
||||
<< " to " << value->getName() << std::endl;
|
||||
}
|
||||
|
||||
return number;
|
||||
}
|
||||
|
||||
void GVNContext::processBasicBlock(BasicBlock* bb, bool& changed) {
|
||||
int instCount = 0;
|
||||
for (auto &instPtr : bb->getInstructions()) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Processing instruction " << ++instCount
|
||||
<< ": " << instPtr->getName() << std::endl;
|
||||
}
|
||||
|
||||
if (processInstruction(instPtr.get())) {
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool GVNContext::processInstruction(Instruction* inst) {
|
||||
// 跳过分支指令和其他不可优化的指令
|
||||
if (inst->isBranch() || dynamic_cast<ReturnInst*>(inst) ||
|
||||
dynamic_cast<AllocaInst*>(inst) || dynamic_cast<StoreInst*>(inst)) {
|
||||
|
||||
// 如果是store指令,需要使相关的内存值失效
|
||||
if (auto store = dynamic_cast<StoreInst*>(inst)) {
|
||||
invalidateMemoryValues(store);
|
||||
}
|
||||
|
||||
// 为这些指令分配值编号但不尝试优化
|
||||
getValueNumber(inst);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << " Processing optimizable instruction: " << inst->getName()
|
||||
<< " (kind: " << static_cast<int>(inst->getKind()) << ")" << std::endl;
|
||||
}
|
||||
|
||||
// 构建表达式键
|
||||
std::string exprKey = buildExpressionKey(inst);
|
||||
if (exprKey.empty()) {
|
||||
// 不可优化的指令,只分配值编号
|
||||
getValueNumber(inst);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (DEBUG >= 2) {
|
||||
std::cout << " Expression key: " << exprKey << std::endl;
|
||||
}
|
||||
|
||||
// 查找已存在的等价值
|
||||
Value* existing = findExistingValue(exprKey, inst);
|
||||
if (existing && existing != inst) {
|
||||
// 检查支配关系
|
||||
if (auto existingInst = dynamic_cast<Instruction*>(existing)) {
|
||||
if (dominates(existingInst, inst)) {
|
||||
if (DEBUG) {
|
||||
std::cout << " GVN: Replacing " << inst->getName()
|
||||
<< " with existing " << existing->getName() << std::endl;
|
||||
}
|
||||
|
||||
// 用已存在的值替换当前指令
|
||||
inst->replaceAllUsesWith(existing);
|
||||
needRemove.insert(inst);
|
||||
|
||||
// 将当前指令的值编号指向已存在的值
|
||||
unsigned existingNumber = getValueNumber(existing);
|
||||
valueToNumber[inst] = existingNumber;
|
||||
|
||||
return true;
|
||||
} else {
|
||||
if (DEBUG) {
|
||||
std::cout << " Found equivalent but dominance check failed" << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 没有找到等价值,为这个表达式分配新的值编号
|
||||
unsigned number = assignValueNumber(inst);
|
||||
expressionToNumber[exprKey] = number;
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << " Instruction " << inst->getName() << " is unique" << std::endl;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string GVNContext::buildExpressionKey(Instruction* inst) {
|
||||
std::ostringstream oss;
|
||||
|
||||
if (auto binary = dynamic_cast<BinaryInst*>(inst)) {
|
||||
oss << "binary_" << static_cast<int>(binary->getKind()) << "_";
|
||||
oss << getValueNumber(binary->getLhs()) << "_" << getValueNumber(binary->getRhs());
|
||||
|
||||
// 对于可交换操作,确保操作数顺序一致
|
||||
if (binary->isCommutative()) {
|
||||
unsigned lhsNum = getValueNumber(binary->getLhs());
|
||||
unsigned rhsNum = getValueNumber(binary->getRhs());
|
||||
if (lhsNum > rhsNum) {
|
||||
oss.str("");
|
||||
oss << "binary_" << static_cast<int>(binary->getKind()) << "_";
|
||||
oss << rhsNum << "_" << lhsNum;
|
||||
}
|
||||
}
|
||||
} else if (auto unary = dynamic_cast<UnaryInst*>(inst)) {
|
||||
oss << "unary_" << static_cast<int>(unary->getKind()) << "_";
|
||||
oss << getValueNumber(unary->getOperand());
|
||||
} else if (auto gep = dynamic_cast<GetElementPtrInst*>(inst)) {
|
||||
oss << "gep_" << getValueNumber(gep->getBasePointer());
|
||||
for (unsigned i = 0; i < gep->getNumIndices(); ++i) {
|
||||
oss << "_" << getValueNumber(gep->getIndex(i));
|
||||
}
|
||||
} else if (auto load = dynamic_cast<LoadInst*>(inst)) {
|
||||
oss << "load_" << getValueNumber(load->getPointer());
|
||||
oss << "_" << reinterpret_cast<uintptr_t>(load->getType()); // 类型区分
|
||||
} else if (auto call = dynamic_cast<CallInst*>(inst)) {
|
||||
// 只为无副作用的函数调用建立表达式
|
||||
if (sideEffectAnalysis && sideEffectAnalysis->isPureFunction(call->getCallee())) {
|
||||
oss << "call_" << call->getCallee()->getName();
|
||||
for (size_t i = 1; i < call->getNumOperands(); ++i) { // 跳过函数指针
|
||||
oss << "_" << getValueNumber(call->getOperand(i));
|
||||
}
|
||||
} else {
|
||||
return ""; // 有副作用的函数调用不可优化
|
||||
}
|
||||
} else {
|
||||
return ""; // 不支持的指令类型
|
||||
}
|
||||
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
Value* GVNContext::findExistingValue(const std::string& exprKey, Instruction* inst) {
|
||||
auto it = expressionToNumber.find(exprKey);
|
||||
if (it != expressionToNumber.end()) {
|
||||
unsigned number = it->second;
|
||||
auto valueIt = numberToValue.find(number);
|
||||
if (valueIt != numberToValue.end()) {
|
||||
Value* existing = valueIt->second;
|
||||
|
||||
// 对于load指令,需要额外检查内存安全性
|
||||
if (auto loadInst = dynamic_cast<LoadInst*>(inst)) {
|
||||
if (auto existingLoad = dynamic_cast<LoadInst*>(existing)) {
|
||||
if (!isMemorySafe(existingLoad, loadInst)) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return existing;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool GVNContext::dominates(Instruction* a, Instruction* b) {
|
||||
auto aBB = a->getParent();
|
||||
auto bBB = b->getParent();
|
||||
|
||||
// 同一基本块内的情况
|
||||
if (aBB == bBB) {
|
||||
auto &insts = aBB->getInstructions();
|
||||
auto aIt = std::find_if(insts.begin(), insts.end(),
|
||||
[a](const auto &ptr) { return ptr.get() == a; });
|
||||
auto bIt = std::find_if(insts.begin(), insts.end(),
|
||||
[b](const auto &ptr) { return ptr.get() == b; });
|
||||
|
||||
if (aIt == insts.end() || bIt == insts.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return std::distance(insts.begin(), aIt) < std::distance(insts.begin(), bIt);
|
||||
}
|
||||
|
||||
// 不同基本块的情况,使用支配树
|
||||
if (domTree) {
|
||||
auto dominators = domTree->getDominators(bBB);
|
||||
return dominators && dominators->count(aBB);
|
||||
}
|
||||
|
||||
return false; // 保守做法
|
||||
}
|
||||
|
||||
bool GVNContext::isMemorySafe(LoadInst* earlierLoad, LoadInst* laterLoad) {
|
||||
// 检查两个load是否访问相同的内存位置
|
||||
unsigned earlierPtr = getValueNumber(earlierLoad->getPointer());
|
||||
unsigned laterPtr = getValueNumber(laterLoad->getPointer());
|
||||
|
||||
if (earlierPtr != laterPtr) {
|
||||
return false; // 不同的内存位置
|
||||
}
|
||||
|
||||
// 检查类型是否匹配
|
||||
if (earlierLoad->getType() != laterLoad->getType()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 简单情况:如果在同一个基本块且没有中间的store,则安全
|
||||
auto earlierBB = earlierLoad->getParent();
|
||||
auto laterBB = laterLoad->getParent();
|
||||
|
||||
if (earlierBB != laterBB) {
|
||||
// 跨基本块的情况需要更复杂的分析,暂时保守处理
|
||||
return false;
|
||||
}
|
||||
|
||||
// 同一基本块内检查是否有中间的store
|
||||
auto &insts = earlierBB->getInstructions();
|
||||
auto earlierIt = std::find_if(insts.begin(), insts.end(),
|
||||
[earlierLoad](const auto &ptr) { return ptr.get() == earlierLoad; });
|
||||
auto laterIt = std::find_if(insts.begin(), insts.end(),
|
||||
[laterLoad](const auto &ptr) { return ptr.get() == laterLoad; });
|
||||
|
||||
if (earlierIt == insts.end() || laterIt == insts.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 确保earlierLoad真的在laterLoad之前
|
||||
if (std::distance(insts.begin(), earlierIt) >= std::distance(insts.begin(), laterIt)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 检查中间是否有store指令修改了相同的内存位置
|
||||
for (auto it = std::next(earlierIt); it != laterIt; ++it) {
|
||||
if (auto store = dynamic_cast<StoreInst*>(it->get())) {
|
||||
unsigned storePtr = getValueNumber(store->getPointer());
|
||||
if (storePtr == earlierPtr) {
|
||||
return false; // 找到中间的store
|
||||
}
|
||||
}
|
||||
|
||||
// 检查函数调用是否可能修改内存
|
||||
if (auto call = dynamic_cast<CallInst*>(it->get())) {
|
||||
if (sideEffectAnalysis && !sideEffectAnalysis->isPureFunction(call->getCallee())) {
|
||||
// 保守处理:有副作用的函数可能修改内存
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true; // 安全
|
||||
}
|
||||
|
||||
void GVNContext::invalidateMemoryValues(StoreInst* store) {
|
||||
unsigned storePtr = getValueNumber(store->getPointer());
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << " Invalidating memory values affected by store" << std::endl;
|
||||
}
|
||||
|
||||
// 找到所有可能被这个store影响的load表达式
|
||||
std::vector<std::string> toRemove;
|
||||
|
||||
for (auto& [exprKey, number] : expressionToNumber) {
|
||||
if (exprKey.find("load_" + std::to_string(storePtr)) == 0) {
|
||||
toRemove.push_back(exprKey);
|
||||
if (DEBUG) {
|
||||
std::cout << " Invalidating expression: " << exprKey << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 移除失效的表达式
|
||||
for (const auto& key : toRemove) {
|
||||
expressionToNumber.erase(key);
|
||||
}
|
||||
}
|
||||
|
||||
void GVNContext::eliminateRedundantInstructions(bool& changed) {
|
||||
int removeCount = 0;
|
||||
for (auto inst : needRemove) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Removing redundant instruction " << ++removeCount
|
||||
<< "/" << needRemove.size() << ": " << inst->getName() << std::endl;
|
||||
}
|
||||
|
||||
// 删除指令前先断开所有使用关系
|
||||
// inst->replaceAllUsesWith 已在 processInstruction 中调用
|
||||
SysYIROptUtils::usedelete(inst);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace sysy
|
||||
897
src/midend/Pass/Optimize/GlobalStrengthReduction.cpp
Normal file
897
src/midend/Pass/Optimize/GlobalStrengthReduction.cpp
Normal file
@ -0,0 +1,897 @@
|
||||
#include "GlobalStrengthReduction.h"
|
||||
#include "SysYIROptUtils.h"
|
||||
#include "IRBuilder.h"
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <cmath>
|
||||
|
||||
extern int DEBUG;
|
||||
|
||||
namespace sysy {
|
||||
|
||||
// 全局强度削弱优化遍的静态 ID
|
||||
void *GlobalStrengthReduction::ID = (void *)&GlobalStrengthReduction::ID;
|
||||
|
||||
// ======================================================================
|
||||
// GlobalStrengthReduction 类的实现
|
||||
// ======================================================================
|
||||
|
||||
bool GlobalStrengthReduction::runOnFunction(Function *func, AnalysisManager &AM) {
|
||||
if (func->getBasicBlocks().empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "\n=== Running GlobalStrengthReduction on function: " << func->getName() << " ===" << std::endl;
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
GlobalStrengthReductionContext context(builder);
|
||||
context.run(func, &AM, changed);
|
||||
|
||||
if (DEBUG) {
|
||||
if (changed) {
|
||||
std::cout << "GlobalStrengthReduction: Function " << func->getName() << " was modified" << std::endl;
|
||||
} else {
|
||||
std::cout << "GlobalStrengthReduction: Function " << func->getName() << " was not modified" << std::endl;
|
||||
}
|
||||
std::cout << "=== GlobalStrengthReduction completed for function: " << func->getName() << " ===" << std::endl;
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
void GlobalStrengthReduction::getAnalysisUsage(std::set<void *> &analysisDependencies, std::set<void *> &analysisInvalidations) const {
|
||||
// 强度削弱依赖副作用分析来判断指令是否可以安全优化
|
||||
analysisDependencies.insert(&SysYSideEffectAnalysisPass::ID);
|
||||
|
||||
// 强度削弱不会使分析失效,因为:
|
||||
// - 只替换计算指令,不改变控制流
|
||||
// - 不修改内存,不影响别名分析
|
||||
// - 保持程序语义不变
|
||||
// analysisInvalidations 保持为空
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "GlobalStrengthReduction: Declared analysis dependencies (SideEffectAnalysis)" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// ======================================================================
|
||||
// GlobalStrengthReductionContext 类的实现
|
||||
// ======================================================================
|
||||
|
||||
void GlobalStrengthReductionContext::run(Function *func, AnalysisManager *AM, bool &changed) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Starting GlobalStrengthReduction analysis for function: " << func->getName() << std::endl;
|
||||
}
|
||||
|
||||
// 获取分析结果
|
||||
if (AM) {
|
||||
sideEffectAnalysis = AM->getAnalysisResult<SideEffectAnalysisResult, SysYSideEffectAnalysisPass>();
|
||||
|
||||
if (DEBUG) {
|
||||
if (sideEffectAnalysis) {
|
||||
std::cout << " GlobalStrengthReduction: Using side effect analysis" << std::endl;
|
||||
} else {
|
||||
std::cout << " GlobalStrengthReduction: Warning - side effect analysis not available" << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 重置计数器
|
||||
algebraicOptCount = 0;
|
||||
strengthReductionCount = 0;
|
||||
divisionOptCount = 0;
|
||||
|
||||
// 遍历所有基本块进行优化
|
||||
for (auto &bb_ptr : func->getBasicBlocks()) {
|
||||
if (processBasicBlock(bb_ptr.get())) {
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << " GlobalStrengthReduction completed for function: " << func->getName() << std::endl;
|
||||
std::cout << " Algebraic optimizations: " << algebraicOptCount << std::endl;
|
||||
std::cout << " Strength reductions: " << strengthReductionCount << std::endl;
|
||||
std::cout << " Division optimizations: " << divisionOptCount << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
bool GlobalStrengthReductionContext::processBasicBlock(BasicBlock *bb) {
|
||||
bool changed = false;
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << " Processing block: " << bb->getName() << std::endl;
|
||||
}
|
||||
|
||||
// 收集需要处理的指令(避免迭代器失效)
|
||||
std::vector<Instruction*> instructions;
|
||||
for (auto &inst_ptr : bb->getInstructions()) {
|
||||
instructions.push_back(inst_ptr.get());
|
||||
}
|
||||
|
||||
// 处理每条指令
|
||||
for (auto inst : instructions) {
|
||||
if (processInstruction(inst)) {
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool GlobalStrengthReductionContext::processInstruction(Instruction *inst) {
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << " Processing instruction: " << inst->getName() << std::endl;
|
||||
}
|
||||
|
||||
// 先尝试代数优化
|
||||
if (tryAlgebraicOptimization(inst)) {
|
||||
algebraicOptCount++;
|
||||
return true;
|
||||
}
|
||||
|
||||
// 再尝试强度削弱
|
||||
if (tryStrengthReduction(inst)) {
|
||||
strengthReductionCount++;
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// ======================================================================
|
||||
// 代数优化方法
|
||||
// ======================================================================
|
||||
|
||||
bool GlobalStrengthReductionContext::tryAlgebraicOptimization(Instruction *inst) {
|
||||
auto binary = dynamic_cast<BinaryInst*>(inst);
|
||||
if (!binary) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (binary->getKind()) {
|
||||
case Instruction::kAdd:
|
||||
return optimizeAddition(binary);
|
||||
case Instruction::kSub:
|
||||
return optimizeSubtraction(binary);
|
||||
case Instruction::kMul:
|
||||
return optimizeMultiplication(binary);
|
||||
case Instruction::kDiv:
|
||||
return optimizeDivision(binary);
|
||||
case Instruction::kICmpEQ:
|
||||
case Instruction::kICmpNE:
|
||||
case Instruction::kICmpLT:
|
||||
case Instruction::kICmpGT:
|
||||
case Instruction::kICmpLE:
|
||||
case Instruction::kICmpGE:
|
||||
return optimizeComparison(binary);
|
||||
case Instruction::kAnd:
|
||||
case Instruction::kOr:
|
||||
return optimizeLogical(binary);
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool GlobalStrengthReductionContext::optimizeAddition(BinaryInst *inst) {
|
||||
Value *lhs = inst->getLhs();
|
||||
Value *rhs = inst->getRhs();
|
||||
int constVal;
|
||||
|
||||
// x + 0 = x
|
||||
if (isConstantInt(rhs, constVal) && constVal == 0) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x + 0 -> x" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, lhs);
|
||||
return true;
|
||||
}
|
||||
|
||||
// 0 + x = x
|
||||
if (isConstantInt(lhs, constVal) && constVal == 0) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = 0 + x -> x" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, rhs);
|
||||
return true;
|
||||
}
|
||||
|
||||
// x + (-y) = x - y
|
||||
if (auto rhsInst = dynamic_cast<UnaryInst*>(rhs)) {
|
||||
if (rhsInst->getKind() == Instruction::kNeg) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x + (-y) -> x - y" << std::endl;
|
||||
}
|
||||
// 创建减法指令
|
||||
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||
auto subInst = builder->createSubInst(lhs, rhsInst->getOperand());
|
||||
replaceWithOptimized(inst, subInst);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool GlobalStrengthReductionContext::optimizeSubtraction(BinaryInst *inst) {
|
||||
Value *lhs = inst->getLhs();
|
||||
Value *rhs = inst->getRhs();
|
||||
int constVal;
|
||||
|
||||
// x - 0 = x
|
||||
if (isConstantInt(rhs, constVal) && constVal == 0) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x - 0 -> x" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, lhs);
|
||||
return true;
|
||||
}
|
||||
|
||||
// x - x = 0 (如果x没有副作用)
|
||||
if (lhs == rhs && hasOnlyLocalUses(dynamic_cast<Instruction*>(lhs))) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x - x -> 0" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, getConstantInt(0));
|
||||
return true;
|
||||
}
|
||||
|
||||
// x - (-y) = x + y
|
||||
if (auto rhsInst = dynamic_cast<UnaryInst*>(rhs)) {
|
||||
if (rhsInst->getKind() == Instruction::kNeg) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x - (-y) -> x + y" << std::endl;
|
||||
}
|
||||
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||
auto addInst = builder->createAddInst(lhs, rhsInst->getOperand());
|
||||
replaceWithOptimized(inst, addInst);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool GlobalStrengthReductionContext::optimizeMultiplication(BinaryInst *inst) {
|
||||
Value *lhs = inst->getLhs();
|
||||
Value *rhs = inst->getRhs();
|
||||
int constVal;
|
||||
|
||||
// x * 0 = 0
|
||||
if (isConstantInt(rhs, constVal) && constVal == 0) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x * 0 -> 0" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, getConstantInt(0));
|
||||
return true;
|
||||
}
|
||||
|
||||
// 0 * x = 0
|
||||
if (isConstantInt(lhs, constVal) && constVal == 0) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = 0 * x -> 0" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, getConstantInt(0));
|
||||
return true;
|
||||
}
|
||||
|
||||
// x * 1 = x
|
||||
if (isConstantInt(rhs, constVal) && constVal == 1) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x * 1 -> x" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, lhs);
|
||||
return true;
|
||||
}
|
||||
|
||||
// 1 * x = x
|
||||
if (isConstantInt(lhs, constVal) && constVal == 1) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = 1 * x -> x" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, rhs);
|
||||
return true;
|
||||
}
|
||||
|
||||
// x * (-1) = -x
|
||||
if (isConstantInt(rhs, constVal) && constVal == -1) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x * (-1) -> -x" << std::endl;
|
||||
}
|
||||
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||
auto negInst = builder->createNegInst(lhs);
|
||||
replaceWithOptimized(inst, negInst);
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool GlobalStrengthReductionContext::optimizeDivision(BinaryInst *inst) {
|
||||
Value *lhs = inst->getLhs();
|
||||
Value *rhs = inst->getRhs();
|
||||
int constVal;
|
||||
|
||||
// x / 1 = x
|
||||
if (isConstantInt(rhs, constVal) && constVal == 1) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x / 1 -> x" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, lhs);
|
||||
return true;
|
||||
}
|
||||
|
||||
// x / (-1) = -x
|
||||
if (isConstantInt(rhs, constVal) && constVal == -1) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x / (-1) -> -x" << std::endl;
|
||||
}
|
||||
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||
auto negInst = builder->createNegInst(lhs);
|
||||
replaceWithOptimized(inst, negInst);
|
||||
return true;
|
||||
}
|
||||
|
||||
// x / x = 1 (如果x != 0且没有副作用)
|
||||
if (lhs == rhs && hasOnlyLocalUses(dynamic_cast<Instruction*>(lhs))) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x / x -> 1" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, getConstantInt(1));
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool GlobalStrengthReductionContext::optimizeComparison(BinaryInst *inst) {
|
||||
Value *lhs = inst->getLhs();
|
||||
Value *rhs = inst->getRhs();
|
||||
|
||||
// x == x = true (如果x没有副作用)
|
||||
if (inst->getKind() == Instruction::kICmpEQ && lhs == rhs &&
|
||||
hasOnlyLocalUses(dynamic_cast<Instruction*>(lhs))) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x == x -> true" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, getConstantInt(1));
|
||||
return true;
|
||||
}
|
||||
|
||||
// x != x = false (如果x没有副作用)
|
||||
if (inst->getKind() == Instruction::kICmpNE && lhs == rhs &&
|
||||
hasOnlyLocalUses(dynamic_cast<Instruction*>(lhs))) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x != x -> false" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, getConstantInt(0));
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool GlobalStrengthReductionContext::optimizeLogical(BinaryInst *inst) {
|
||||
Value *lhs = inst->getLhs();
|
||||
Value *rhs = inst->getRhs();
|
||||
int constVal;
|
||||
|
||||
if (inst->getKind() == Instruction::kAnd) {
|
||||
// x && 0 = 0
|
||||
if (isConstantInt(rhs, constVal) && constVal == 0) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x && 0 -> 0" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, getConstantInt(0));
|
||||
return true;
|
||||
}
|
||||
|
||||
// x && -1 = x
|
||||
if (isConstantInt(rhs, constVal) && constVal == -1) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x && 1 -> x" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, lhs);
|
||||
return true;
|
||||
}
|
||||
|
||||
// x && x = x
|
||||
if (lhs == rhs) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x && x -> x" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, lhs);
|
||||
return true;
|
||||
}
|
||||
} else if (inst->getKind() == Instruction::kOr) {
|
||||
// x || 0 = x
|
||||
if (isConstantInt(rhs, constVal) && constVal == 0) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x || 0 -> x" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, lhs);
|
||||
return true;
|
||||
}
|
||||
|
||||
// x || x = x
|
||||
if (lhs == rhs) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x || x -> x" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, lhs);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// ======================================================================
|
||||
// 强度削弱方法
|
||||
// ======================================================================
|
||||
|
||||
bool GlobalStrengthReductionContext::tryStrengthReduction(Instruction *inst) {
|
||||
if (auto binary = dynamic_cast<BinaryInst*>(inst)) {
|
||||
switch (binary->getKind()) {
|
||||
case Instruction::kMul:
|
||||
return reduceMultiplication(binary);
|
||||
case Instruction::kDiv:
|
||||
return reduceDivision(binary);
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
} else if (auto call = dynamic_cast<CallInst*>(inst)) {
|
||||
return reducePower(call);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool GlobalStrengthReductionContext::reduceMultiplication(BinaryInst *inst) {
|
||||
Value *lhs = inst->getLhs();
|
||||
Value *rhs = inst->getRhs();
|
||||
int constVal;
|
||||
|
||||
// 尝试右操作数为常数
|
||||
Value* variable = lhs;
|
||||
if (isConstantInt(rhs, constVal) && constVal > 0) {
|
||||
return tryComplexMultiplication(inst, variable, constVal);
|
||||
}
|
||||
|
||||
// 尝试左操作数为常数
|
||||
if (isConstantInt(lhs, constVal) && constVal > 0) {
|
||||
variable = rhs;
|
||||
return tryComplexMultiplication(inst, variable, constVal);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool GlobalStrengthReductionContext::tryComplexMultiplication(BinaryInst* inst, Value* variable, int constant) {
|
||||
// 首先检查是否为2的幂,使用简单位移
|
||||
if (isPowerOfTwo(constant)) {
|
||||
int shiftAmount = log2OfPowerOfTwo(constant);
|
||||
if (DEBUG) {
|
||||
std::cout << " StrengthReduction: " << inst->getName()
|
||||
<< " = x * " << constant << " -> x << " << shiftAmount << std::endl;
|
||||
}
|
||||
|
||||
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||
auto shiftInst = builder->createBinaryInst(Instruction::kSll, Type::getIntType(), variable, getConstantInt(shiftAmount));
|
||||
replaceWithOptimized(inst, shiftInst);
|
||||
return true;
|
||||
}
|
||||
|
||||
// 尝试分解为位移和加法的组合
|
||||
std::vector<int> shifts;
|
||||
if (findOptimalShiftDecomposition(constant, shifts)) {
|
||||
if (DEBUG) {
|
||||
std::cout << " StrengthReduction: " << inst->getName()
|
||||
<< " = x * " << constant << " -> shift decomposition with " << shifts.size() << " terms" << std::endl;
|
||||
}
|
||||
|
||||
Value* result = createShiftDecomposition(inst, variable, shifts);
|
||||
if (result) {
|
||||
replaceWithOptimized(inst, result);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool GlobalStrengthReductionContext::findOptimalShiftDecomposition(int constant, std::vector<int>& shifts) {
|
||||
shifts.clear();
|
||||
|
||||
// 常见的有效分解模式
|
||||
switch (constant) {
|
||||
case 3: // 3 = 2^1 + 2^0 -> (x << 1) + x
|
||||
shifts = {1, 0};
|
||||
return true;
|
||||
case 5: // 5 = 2^2 + 2^0 -> (x << 2) + x
|
||||
shifts = {2, 0};
|
||||
return true;
|
||||
case 6: // 6 = 2^2 + 2^1 -> (x << 2) + (x << 1)
|
||||
shifts = {2, 1};
|
||||
return true;
|
||||
case 7: // 7 = 2^2 + 2^1 + 2^0 -> (x << 2) + (x << 1) + x
|
||||
shifts = {2, 1, 0};
|
||||
return true;
|
||||
case 9: // 9 = 2^3 + 2^0 -> (x << 3) + x
|
||||
shifts = {3, 0};
|
||||
return true;
|
||||
case 10: // 10 = 2^3 + 2^1 -> (x << 3) + (x << 1)
|
||||
shifts = {3, 1};
|
||||
return true;
|
||||
case 11: // 11 = 2^3 + 2^1 + 2^0 -> (x << 3) + (x << 1) + x
|
||||
shifts = {3, 1, 0};
|
||||
return true;
|
||||
case 12: // 12 = 2^3 + 2^2 -> (x << 3) + (x << 2)
|
||||
shifts = {3, 2};
|
||||
return true;
|
||||
case 13: // 13 = 2^3 + 2^2 + 2^0 -> (x << 3) + (x << 2) + x
|
||||
shifts = {3, 2, 0};
|
||||
return true;
|
||||
case 14: // 14 = 2^3 + 2^2 + 2^1 -> (x << 3) + (x << 2) + (x << 1)
|
||||
shifts = {3, 2, 1};
|
||||
return true;
|
||||
case 15: // 15 = 2^3 + 2^2 + 2^1 + 2^0 -> (x << 3) + (x << 2) + (x << 1) + x
|
||||
shifts = {3, 2, 1, 0};
|
||||
return true;
|
||||
case 17: // 17 = 2^4 + 2^0 -> (x << 4) + x
|
||||
shifts = {4, 0};
|
||||
return true;
|
||||
case 18: // 18 = 2^4 + 2^1 -> (x << 4) + (x << 1)
|
||||
shifts = {4, 1};
|
||||
return true;
|
||||
case 20: // 20 = 2^4 + 2^2 -> (x << 4) + (x << 2)
|
||||
shifts = {4, 2};
|
||||
return true;
|
||||
case 24: // 24 = 2^4 + 2^3 -> (x << 4) + (x << 3)
|
||||
shifts = {4, 3};
|
||||
return true;
|
||||
case 25: // 25 = 2^4 + 2^3 + 2^0 -> (x << 4) + (x << 3) + x
|
||||
shifts = {4, 3, 0};
|
||||
return true;
|
||||
case 100: // 100 = 2^6 + 2^5 + 2^2 -> (x << 6) + (x << 5) + (x << 2)
|
||||
shifts = {6, 5, 2};
|
||||
return true;
|
||||
}
|
||||
|
||||
// 通用二进制分解(最多4个项,避免过度复杂化)
|
||||
if (constant > 0 && constant < 256) {
|
||||
std::vector<int> binaryShifts;
|
||||
int temp = constant;
|
||||
int bit = 0;
|
||||
|
||||
while (temp > 0 && binaryShifts.size() < 4) {
|
||||
if (temp & 1) {
|
||||
binaryShifts.push_back(bit);
|
||||
}
|
||||
temp >>= 1;
|
||||
bit++;
|
||||
}
|
||||
|
||||
// 只有当项数不超过3个时才使用二进制分解(比直接乘法更有效)
|
||||
if (binaryShifts.size() <= 3 && binaryShifts.size() >= 2) {
|
||||
shifts = binaryShifts;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
Value* GlobalStrengthReductionContext::createShiftDecomposition(BinaryInst* inst, Value* variable, const std::vector<int>& shifts) {
|
||||
if (shifts.empty()) return nullptr;
|
||||
|
||||
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||
|
||||
Value* result = nullptr;
|
||||
|
||||
for (int shift : shifts) {
|
||||
Value* term;
|
||||
if (shift == 0) {
|
||||
// 0位移就是原变量
|
||||
term = variable;
|
||||
} else {
|
||||
// 创建位移指令
|
||||
term = builder->createBinaryInst(Instruction::kSll, Type::getIntType(), variable, getConstantInt(shift));
|
||||
}
|
||||
|
||||
if (result == nullptr) {
|
||||
result = term;
|
||||
} else {
|
||||
// 累加到结果中
|
||||
result = builder->createAddInst(result, term);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
bool GlobalStrengthReductionContext::reduceDivision(BinaryInst *inst) {
|
||||
Value *lhs = inst->getLhs();
|
||||
Value *rhs = inst->getRhs();
|
||||
uint32_t constVal;
|
||||
|
||||
// x / 2^n = x >> n (对于无符号除法或已知为正数的情况)
|
||||
if (isConstantInt(rhs, constVal) && constVal > 0 && isPowerOfTwo(constVal)) {
|
||||
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||
int shiftAmount = log2OfPowerOfTwo(constVal);
|
||||
// 有符号除法校正:(x + (x >> 31) & mask) >> k
|
||||
int maskValue = constVal - 1;
|
||||
|
||||
// x >> 31 (算术右移获取符号位)
|
||||
Value* signShift = ConstantInteger::get(31);
|
||||
Value* signBits = builder->createBinaryInst(
|
||||
Instruction::Kind::kSra, // 算术右移
|
||||
lhs->getType(),
|
||||
lhs,
|
||||
signShift
|
||||
);
|
||||
|
||||
// (x >> 31) & mask
|
||||
Value* mask = ConstantInteger::get(maskValue);
|
||||
Value* correction = builder->createBinaryInst(
|
||||
Instruction::Kind::kAnd,
|
||||
lhs->getType(),
|
||||
signBits,
|
||||
mask
|
||||
);
|
||||
|
||||
// x + correction
|
||||
Value* corrected = builder->createAddInst(lhs, correction);
|
||||
|
||||
// (x + correction) >> k
|
||||
Value* divShift = ConstantInteger::get(shiftAmount);
|
||||
Value* shiftInst = builder->createBinaryInst(
|
||||
Instruction::Kind::kSra, // 算术右移
|
||||
lhs->getType(),
|
||||
corrected,
|
||||
divShift
|
||||
);
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << " StrengthReduction: " << inst->getName()
|
||||
<< " = x / " << constVal << " -> (x + (x >> 31) & mask) >> " << shiftAmount << std::endl;
|
||||
}
|
||||
|
||||
// builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||
// Value* divisor_minus_1 = ConstantInteger::get(constVal - 1);
|
||||
// Value* adjusted = builder->createAddInst(lhs, divisor_minus_1);
|
||||
// Value* shiftInst = builder->createBinaryInst(Instruction::kSra, Type::getIntType(), adjusted, getConstantInt(shiftAmount));
|
||||
replaceWithOptimized(inst, shiftInst);
|
||||
strengthReductionCount++;
|
||||
return true;
|
||||
}
|
||||
|
||||
// x / c = x * magic_number (魔数乘法优化 - 使用libdivide算法)
|
||||
if (isConstantInt(rhs, constVal) && constVal > 1 && constVal != (uint32_t)(-1)) {
|
||||
// auto magicPair = computeMulhMagicNumbers(static_cast<int>(constVal));
|
||||
Value* magicResult = createMagicDivisionLibdivide(inst, static_cast<int>(constVal));
|
||||
replaceWithOptimized(inst, magicResult);
|
||||
divisionOptCount++;
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool GlobalStrengthReductionContext::reducePower(CallInst *inst) {
|
||||
// 检查是否是pow函数调用
|
||||
Function* callee = inst->getCallee();
|
||||
if (!callee || callee->getName() != "pow") {
|
||||
return false;
|
||||
}
|
||||
|
||||
// pow(x, 2) = x * x
|
||||
if (inst->getNumOperands() >= 2) {
|
||||
int exponent;
|
||||
if (isConstantInt(inst->getOperand(1), exponent)) {
|
||||
if (exponent == 2) {
|
||||
if (DEBUG) {
|
||||
std::cout << " StrengthReduction: pow(x, 2) -> x * x" << std::endl;
|
||||
}
|
||||
|
||||
Value* base = inst->getOperand(0);
|
||||
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||
auto mulInst = builder->createMulInst(base, base);
|
||||
replaceWithOptimized(inst, mulInst);
|
||||
strengthReductionCount++;
|
||||
return true;
|
||||
} else if (exponent >= 3 && exponent <= 8) {
|
||||
// 对于小的指数,展开为连续乘法
|
||||
if (DEBUG) {
|
||||
std::cout << " StrengthReduction: pow(x, " << exponent << ") -> repeated multiplication" << std::endl;
|
||||
}
|
||||
|
||||
Value* base = inst->getOperand(0);
|
||||
Value* result = base;
|
||||
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||
|
||||
for (int i = 1; i < exponent; i++) {
|
||||
result = builder->createMulInst(result, base);
|
||||
}
|
||||
|
||||
replaceWithOptimized(inst, result);
|
||||
strengthReductionCount++;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
Value* GlobalStrengthReductionContext::createMagicDivisionLibdivide(BinaryInst* divInst, int divisor) {
|
||||
builder->setPosition(divInst->getParent(), divInst->getParent()->findInstIterator(divInst));
|
||||
// 使用mulh指令优化任意常数除法
|
||||
auto [magic, shift] = SysYIROptUtils::computeMulhMagicNumbers(divisor);
|
||||
|
||||
// 检查是否无法优化(magic == -1, shift == -1 表示失败)
|
||||
if (magic == -1 && shift == -1) {
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Cannot optimize division by " << divisor
|
||||
<< ", keeping original division" << std::endl;
|
||||
}
|
||||
// 返回 nullptr 表示无法优化,调用方应该保持原始除法
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// 2的幂次方除法可以用移位优化(但这不是魔数法的情况)这种情况应该不会被分类到这里但是还是做一个保护措施
|
||||
if ((divisor & (divisor - 1)) == 0 && divisor > 0) {
|
||||
// 是2的幂次方,可以用移位
|
||||
int shift_amount = 0;
|
||||
int temp = divisor;
|
||||
while (temp > 1) {
|
||||
temp >>= 1;
|
||||
shift_amount++;
|
||||
}
|
||||
|
||||
Value* shiftConstant = ConstantInteger::get(shift_amount);
|
||||
// 对于有符号除法,需要先加上除数-1然后再移位(为了正确处理负数舍入)
|
||||
Value* divisor_minus_1 = ConstantInteger::get(divisor - 1);
|
||||
Value* adjusted = builder->createAddInst(divInst->getOperand(0), divisor_minus_1);
|
||||
return builder->createBinaryInst(
|
||||
Instruction::Kind::kSra, // 算术右移
|
||||
divInst->getOperand(0)->getType(),
|
||||
adjusted,
|
||||
shiftConstant
|
||||
);
|
||||
}
|
||||
|
||||
// 创建魔数常量
|
||||
// 检查魔数是否能放入32位,如果不能,则不进行优化
|
||||
if (magic > INT32_MAX || magic < INT32_MIN) {
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Magic number " << magic << " exceeds 32-bit range, skipping optimization" << std::endl;
|
||||
}
|
||||
return nullptr; // 无法优化,保持原始除法
|
||||
}
|
||||
|
||||
Value* magicConstant = ConstantInteger::get((int32_t)magic);
|
||||
|
||||
// 检查是否需要ADD_MARKER处理(加法调整)
|
||||
bool needAdd = (shift & 0x40) != 0;
|
||||
int actualShift = shift & 0x3F; // 提取真实的移位量
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] IR Generation: magic=" << magic << ", needAdd=" << needAdd
|
||||
<< ", actualShift=" << actualShift << std::endl;
|
||||
}
|
||||
|
||||
// 执行高位乘法:mulh(x, magic)
|
||||
Value* mulhResult = builder->createBinaryInst(
|
||||
Instruction::Kind::kMulh, // 高位乘法
|
||||
divInst->getOperand(0)->getType(),
|
||||
divInst->getOperand(0),
|
||||
magicConstant
|
||||
);
|
||||
|
||||
if (needAdd) {
|
||||
// ADD_MARKER 情况:需要在移位前加上被除数
|
||||
// 这对应于 libdivide 的加法调整算法
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Applying ADD_MARKER: adding dividend before shift" << std::endl;
|
||||
}
|
||||
mulhResult = builder->createAddInst(mulhResult, divInst->getOperand(0));
|
||||
}
|
||||
|
||||
if (actualShift > 0) {
|
||||
// 如果需要额外移位
|
||||
Value* shiftConstant = ConstantInteger::get(actualShift);
|
||||
mulhResult = builder->createBinaryInst(
|
||||
Instruction::Kind::kSra, // 算术右移
|
||||
divInst->getOperand(0)->getType(),
|
||||
mulhResult,
|
||||
shiftConstant
|
||||
);
|
||||
}
|
||||
|
||||
// 标准的有符号除法符号修正:如果被除数为负,商需要加1
|
||||
// 这对所有有符号除法都需要,不管是否可能有负数
|
||||
Value* isNegative = builder->createICmpLTInst(divInst->getOperand(0), ConstantInteger::get(0));
|
||||
// 将i1转换为i32:负数时为1,非负数时为0 ICmpLTInst的结果会默认转化为32位
|
||||
mulhResult = builder->createAddInst(mulhResult, isNegative);
|
||||
|
||||
return mulhResult;
|
||||
}
|
||||
|
||||
// ======================================================================
|
||||
// 辅助方法
|
||||
// ======================================================================
|
||||
|
||||
bool GlobalStrengthReductionContext::isPowerOfTwo(uint32_t n) {
|
||||
return n > 0 && (n & (n - 1)) == 0;
|
||||
}
|
||||
|
||||
int GlobalStrengthReductionContext::log2OfPowerOfTwo(uint32_t n) {
|
||||
int result = 0;
|
||||
while (n > 1) {
|
||||
n >>= 1;
|
||||
result++;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool GlobalStrengthReductionContext::isConstantInt(Value* val, int& constVal) {
|
||||
if (auto constInt = dynamic_cast<ConstantInteger*>(val)) {
|
||||
constVal = std::get<int>(constInt->getVal());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool GlobalStrengthReductionContext::isConstantInt(Value* val, uint32_t& constVal) {
|
||||
if (auto constInt = dynamic_cast<ConstantInteger*>(val)) {
|
||||
int signedVal = std::get<int>(constInt->getVal());
|
||||
if (signedVal >= 0) {
|
||||
constVal = static_cast<uint32_t>(signedVal);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
ConstantInteger* GlobalStrengthReductionContext::getConstantInt(int val) {
|
||||
return ConstantInteger::get(val);
|
||||
}
|
||||
|
||||
bool GlobalStrengthReductionContext::hasOnlyLocalUses(Instruction* inst) {
|
||||
if (!inst) return true;
|
||||
|
||||
// 简单检查:如果指令没有副作用,则认为是本地的
|
||||
if (sideEffectAnalysis) {
|
||||
auto sideEffect = sideEffectAnalysis->getInstructionSideEffect(inst);
|
||||
return sideEffect.type == SideEffectType::NO_SIDE_EFFECT;
|
||||
}
|
||||
|
||||
// 没有副作用分析时,保守处理
|
||||
return !inst->isCall() && !inst->isStore() && !inst->isLoad();
|
||||
}
|
||||
|
||||
void GlobalStrengthReductionContext::replaceWithOptimized(Instruction* original, Value* replacement) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Replacing " << original->getName()
|
||||
<< " with " << replacement->getName() << std::endl;
|
||||
}
|
||||
|
||||
original->replaceAllUsesWith(replacement);
|
||||
|
||||
// 如果替换值是新创建的指令,确保它有合适的名字
|
||||
// if (auto replInst = dynamic_cast<Instruction*>(replacement)) {
|
||||
// if (replInst->getName().empty()) {
|
||||
// replInst->setName(original->getName() + "_opt");
|
||||
// }
|
||||
// }
|
||||
|
||||
// 删除原指令,让调用者处理
|
||||
SysYIROptUtils::usedelete(original);
|
||||
}
|
||||
|
||||
} // namespace sysy
|
||||
@ -18,62 +18,214 @@ bool LICMContext::hoistInstructions() {
|
||||
// 1. 先收集所有可外提指令
|
||||
std::unordered_set<Instruction *> workSet(chars->invariantInsts.begin(), chars->invariantInsts.end());
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "LICM: Found " << workSet.size() << " candidate invariant instructions to hoist:" << std::endl;
|
||||
for (auto *inst : workSet) {
|
||||
std::cout << " - " << inst->getName() << " (kind: " << static_cast<int>(inst->getKind())
|
||||
<< ", in BB: " << inst->getParent()->getName() << ")" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 计算每个指令被依赖的次数(入度)
|
||||
std::unordered_map<Instruction *, int> indegree;
|
||||
std::unordered_map<Instruction *, std::vector<Instruction *>> dependencies; // 记录依赖关系
|
||||
std::unordered_map<Instruction *, std::vector<Instruction *>> dependents; // 记录被依赖关系
|
||||
|
||||
for (auto *inst : workSet) {
|
||||
indegree[inst] = 0;
|
||||
dependencies[inst] = {};
|
||||
dependents[inst] = {};
|
||||
}
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "LICM: Analyzing dependencies between invariant instructions..." << std::endl;
|
||||
}
|
||||
|
||||
for (auto *inst : workSet) {
|
||||
for (size_t i = 0; i < inst->getNumOperands(); ++i) {
|
||||
if (auto *dep = dynamic_cast<Instruction *>(inst->getOperand(i))) {
|
||||
if (workSet.count(dep)) {
|
||||
indegree[inst]++;
|
||||
dependencies[inst].push_back(dep);
|
||||
dependents[dep].push_back(inst);
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << " Dependency: " << inst->getName() << " depends on " << dep->getName() << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "LICM: Initial indegree analysis:" << std::endl;
|
||||
for (auto &[inst, deg] : indegree) {
|
||||
std::cout << " " << inst->getName() << ": indegree=" << deg;
|
||||
if (deg > 0) {
|
||||
std::cout << ", depends on: ";
|
||||
for (auto *dep : dependencies[inst]) {
|
||||
std::cout << dep->getName() << " ";
|
||||
}
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Kahn拓扑排序
|
||||
std::vector<Instruction *> sorted;
|
||||
std::queue<Instruction *> q;
|
||||
for (auto &[inst, deg] : indegree) {
|
||||
if (deg == 0)
|
||||
q.push(inst);
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "LICM: Starting topological sort..." << std::endl;
|
||||
}
|
||||
|
||||
for (auto &[inst, deg] : indegree) {
|
||||
if (deg == 0) {
|
||||
q.push(inst);
|
||||
if (DEBUG) {
|
||||
std::cout << " Initial zero-indegree instruction: " << inst->getName() << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int sortStep = 0;
|
||||
while (!q.empty()) {
|
||||
auto *inst = q.front();
|
||||
q.pop();
|
||||
sorted.push_back(inst);
|
||||
for (size_t i = 0; i < inst->getNumOperands(); ++i) {
|
||||
if (auto *dep = dynamic_cast<Instruction *>(inst->getOperand(i))) {
|
||||
if (workSet.count(dep)) {
|
||||
indegree[dep]--;
|
||||
if (indegree[dep] == 0)
|
||||
q.push(dep);
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << " Step " << (++sortStep) << ": Processing " << inst->getName() << std::endl;
|
||||
}
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << " Reducing indegree of dependents of " << inst->getName() << std::endl;
|
||||
}
|
||||
|
||||
// 正确的拓扑排序:当处理一个指令时,应该减少其所有使用者(dependents)的入度
|
||||
for (auto *dependent : dependents[inst]) {
|
||||
indegree[dependent]--;
|
||||
if (DEBUG) {
|
||||
std::cout << " Reducing indegree of " << dependent->getName() << " to " << indegree[dependent] << std::endl;
|
||||
}
|
||||
if (indegree[dependent] == 0) {
|
||||
q.push(dependent);
|
||||
if (DEBUG) {
|
||||
std::cout << " Adding " << dependent->getName() << " to queue (indegree=0)" << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否全部排序,若未全部排序,说明有环(理论上不会)
|
||||
// 检查是否全部排序,若未全部排序,打印错误信息
|
||||
// 这可能是因为存在循环依赖或其他问题导致无法完成拓扑排序
|
||||
if (sorted.size() != workSet.size()) {
|
||||
if (DEBUG)
|
||||
std::cerr << "LICM: Topological sort failed, possible dependency cycle." << std::endl;
|
||||
if (DEBUG) {
|
||||
std::cout << "LICM: Topological sort failed! Sorted " << sorted.size()
|
||||
<< " instructions out of " << workSet.size() << " total." << std::endl;
|
||||
|
||||
// 找出未被排序的指令(形成循环依赖的指令)
|
||||
std::unordered_set<Instruction *> remaining;
|
||||
for (auto *inst : workSet) {
|
||||
bool found = false;
|
||||
for (auto *sortedInst : sorted) {
|
||||
if (inst == sortedInst) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
remaining.insert(inst);
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "LICM: Instructions involved in dependency cycle:" << std::endl;
|
||||
for (auto *inst : remaining) {
|
||||
std::cout << " - " << inst->getName() << " (indegree=" << indegree[inst] << ")" << std::endl;
|
||||
std::cout << " Dependencies within cycle: ";
|
||||
for (auto *dep : dependencies[inst]) {
|
||||
if (remaining.count(dep)) {
|
||||
std::cout << dep->getName() << " ";
|
||||
}
|
||||
}
|
||||
std::cout << std::endl;
|
||||
std::cout << " Dependents within cycle: ";
|
||||
for (auto *dependent : dependents[inst]) {
|
||||
if (remaining.count(dependent)) {
|
||||
std::cout << dependent->getName() << " ";
|
||||
}
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
// 尝试找出一个具体的循环路径
|
||||
std::cout << "LICM: Attempting to trace a dependency cycle:" << std::endl;
|
||||
if (!remaining.empty()) {
|
||||
auto *start = *remaining.begin();
|
||||
std::unordered_set<Instruction *> visited;
|
||||
std::vector<Instruction *> path;
|
||||
|
||||
std::function<bool(Instruction *)> findCycle = [&](Instruction *current) -> bool {
|
||||
if (visited.count(current)) {
|
||||
// 找到环
|
||||
auto it = std::find(path.begin(), path.end(), current);
|
||||
if (it != path.end()) {
|
||||
std::cout << " Cycle found: ";
|
||||
for (auto cycleIt = it; cycleIt != path.end(); ++cycleIt) {
|
||||
std::cout << (*cycleIt)->getName() << " -> ";
|
||||
}
|
||||
std::cout << current->getName() << std::endl;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
visited.insert(current);
|
||||
path.push_back(current);
|
||||
|
||||
for (auto *dep : dependencies[current]) {
|
||||
if (remaining.count(dep)) {
|
||||
if (findCycle(dep)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
path.pop_back();
|
||||
return false;
|
||||
};
|
||||
|
||||
findCycle(start);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// 4. 按拓扑序外提
|
||||
if (DEBUG) {
|
||||
std::cout << "LICM: Successfully completed topological sort. Hoisting instructions in order:" << std::endl;
|
||||
}
|
||||
|
||||
for (auto *inst : sorted) {
|
||||
if (!inst)
|
||||
continue;
|
||||
BasicBlock *parent = inst->getParent();
|
||||
if (parent && loop->contains(parent)) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Hoisting " << inst->getName() << " from " << parent->getName()
|
||||
<< " to preheader " << preheader->getName() << std::endl;
|
||||
}
|
||||
auto sourcePos = parent->findInstIterator(inst);
|
||||
auto targetPos = preheader->terminator();
|
||||
parent->moveInst(sourcePos, targetPos, preheader);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (DEBUG && changed) {
|
||||
std::cout << "LICM: Successfully hoisted " << sorted.size() << " invariant instructions" << std::endl;
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
// ---- LICM Pass Implementation ----
|
||||
|
||||
@ -1,145 +0,0 @@
|
||||
#include "../../include/midend/Pass/Optimize/LargeArrayToGlobal.h"
|
||||
#include "../../IR.h"
|
||||
#include <unordered_map>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
namespace sysy {
|
||||
|
||||
// Helper function to convert type to string
|
||||
static std::string typeToString(Type *type) {
|
||||
if (!type) return "null";
|
||||
|
||||
switch (type->getKind()) {
|
||||
case Type::kInt:
|
||||
return "int";
|
||||
case Type::kFloat:
|
||||
return "float";
|
||||
case Type::kPointer:
|
||||
return "ptr";
|
||||
case Type::kArray: {
|
||||
auto *arrayType = type->as<ArrayType>();
|
||||
return "[" + std::to_string(arrayType->getNumElements()) + " x " +
|
||||
typeToString(arrayType->getElementType()) + "]";
|
||||
}
|
||||
default:
|
||||
return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
void *LargeArrayToGlobalPass::ID = &LargeArrayToGlobalPass::ID;
|
||||
|
||||
bool LargeArrayToGlobalPass::runOnModule(Module *M, AnalysisManager &AM) {
|
||||
bool changed = false;
|
||||
|
||||
if (!M) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Collect all alloca instructions from all functions
|
||||
std::vector<std::pair<AllocaInst*, Function*>> allocasToConvert;
|
||||
|
||||
for (auto &funcPair : M->getFunctions()) {
|
||||
Function *F = funcPair.second.get();
|
||||
if (!F || F->getBasicBlocks().begin() == F->getBasicBlocks().end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (auto &BB : F->getBasicBlocks()) {
|
||||
for (auto &inst : BB->getInstructions()) {
|
||||
if (auto *alloca = dynamic_cast<AllocaInst*>(inst.get())) {
|
||||
Type *allocatedType = alloca->getAllocatedType();
|
||||
|
||||
// Calculate the size of the allocated type
|
||||
unsigned size = calculateTypeSize(allocatedType);
|
||||
if(DEBUG){
|
||||
// Debug: print size information
|
||||
std::cout << "LargeArrayToGlobalPass: Found alloca with size " << size
|
||||
<< " for type " << typeToString(allocatedType) << std::endl;
|
||||
}
|
||||
|
||||
// Convert arrays of 1KB (1024 bytes) or larger to global variables
|
||||
if (size >= 1024) {
|
||||
if(DEBUG)
|
||||
std::cout << "LargeArrayToGlobalPass: Converting array of size " << size << " to global" << std::endl;
|
||||
allocasToConvert.emplace_back(alloca, F);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert the collected alloca instructions to global variables
|
||||
for (auto [alloca, F] : allocasToConvert) {
|
||||
convertAllocaToGlobal(alloca, F, M);
|
||||
changed = true;
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
unsigned LargeArrayToGlobalPass::calculateTypeSize(Type *type) {
|
||||
if (!type) return 0;
|
||||
|
||||
switch (type->getKind()) {
|
||||
case Type::kInt:
|
||||
case Type::kFloat:
|
||||
return 4;
|
||||
case Type::kPointer:
|
||||
return 8;
|
||||
case Type::kArray: {
|
||||
auto *arrayType = type->as<ArrayType>();
|
||||
return arrayType->getNumElements() * calculateTypeSize(arrayType->getElementType());
|
||||
}
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
void LargeArrayToGlobalPass::convertAllocaToGlobal(AllocaInst *alloca, Function *F, Module *M) {
|
||||
Type *allocatedType = alloca->getAllocatedType();
|
||||
|
||||
// Create a unique name for the global variable
|
||||
std::string globalName = generateUniqueGlobalName(alloca, F);
|
||||
|
||||
// Create the global variable - GlobalValue expects pointer type
|
||||
Type *pointerType = Type::getPointerType(allocatedType);
|
||||
GlobalValue *globalVar = M->createGlobalValue(globalName, pointerType);
|
||||
|
||||
if (!globalVar) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Replace all uses of the alloca with the global variable
|
||||
alloca->replaceAllUsesWith(globalVar);
|
||||
|
||||
// Remove the alloca instruction from its basic block
|
||||
for (auto &BB : F->getBasicBlocks()) {
|
||||
auto &instructions = BB->getInstructions();
|
||||
for (auto it = instructions.begin(); it != instructions.end(); ++it) {
|
||||
if (it->get() == alloca) {
|
||||
instructions.erase(it);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string LargeArrayToGlobalPass::generateUniqueGlobalName(AllocaInst *alloca, Function *F) {
|
||||
std::string baseName = alloca->getName();
|
||||
if (baseName.empty()) {
|
||||
baseName = "array";
|
||||
}
|
||||
|
||||
// Ensure uniqueness by appending function name and counter
|
||||
static std::unordered_map<std::string, int> nameCounter;
|
||||
std::string key = F->getName() + "." + baseName;
|
||||
|
||||
int counter = nameCounter[key]++;
|
||||
std::ostringstream oss;
|
||||
oss << key << "." << counter;
|
||||
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
} // namespace sysy
|
||||
@ -106,187 +106,6 @@ bool StrengthReductionContext::analyzeInductionVariableRange(
|
||||
return hasNegativePotential;
|
||||
}
|
||||
|
||||
//该实现参考了libdivide的算法
|
||||
std::pair<int, int> StrengthReductionContext::computeMulhMagicNumbers(int divisor) const {
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "\n[SR] ===== Computing magic numbers for divisor " << divisor << " (libdivide algorithm) =====" << std::endl;
|
||||
}
|
||||
|
||||
if (divisor == 0) {
|
||||
if (DEBUG) std::cout << "[SR] Error: divisor must be != 0" << std::endl;
|
||||
return {-1, -1};
|
||||
}
|
||||
|
||||
// libdivide 常数
|
||||
const uint8_t LIBDIVIDE_ADD_MARKER = 0x40;
|
||||
const uint8_t LIBDIVIDE_NEGATIVE_DIVISOR = 0x80;
|
||||
|
||||
// 辅助函数:计算前导零个数
|
||||
auto count_leading_zeros32 = [](uint32_t val) -> uint32_t {
|
||||
if (val == 0) return 32;
|
||||
return __builtin_clz(val);
|
||||
};
|
||||
|
||||
// 辅助函数:64位除法返回32位商和余数
|
||||
auto div_64_32 = [](uint32_t high, uint32_t low, uint32_t divisor, uint32_t* rem) -> uint32_t {
|
||||
uint64_t dividend = ((uint64_t)high << 32) | low;
|
||||
uint32_t quotient = dividend / divisor;
|
||||
*rem = dividend % divisor;
|
||||
return quotient;
|
||||
};
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Input divisor: " << divisor << std::endl;
|
||||
}
|
||||
|
||||
// libdivide_internal_s32_gen 算法实现
|
||||
int32_t d = divisor;
|
||||
uint32_t ud = (uint32_t)d;
|
||||
uint32_t absD = (d < 0) ? -ud : ud;
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] absD = " << absD << std::endl;
|
||||
}
|
||||
|
||||
uint32_t floor_log_2_d = 31 - count_leading_zeros32(absD);
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] floor_log_2_d = " << floor_log_2_d << std::endl;
|
||||
}
|
||||
|
||||
// 检查 absD 是否为2的幂
|
||||
if ((absD & (absD - 1)) == 0) {
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] " << absD << " 是2的幂,使用移位方法" << std::endl;
|
||||
}
|
||||
|
||||
// 对于2的幂,我们只使用移位,不需要魔数
|
||||
int shift = floor_log_2_d;
|
||||
if (d < 0) shift |= 0x80; // 标记负数
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Power of 2 result: magic=0, shift=" << shift << std::endl;
|
||||
std::cout << "[SR] ===== End magic computation =====" << std::endl;
|
||||
}
|
||||
|
||||
// 对于我们的目的,我们将在IR生成中以不同方式处理2的幂
|
||||
// 返回特殊标记
|
||||
return {0, shift};
|
||||
}
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] " << absD << " is not a power of 2, computing magic number" << std::endl;
|
||||
}
|
||||
|
||||
// 非2的幂除数的魔数计算
|
||||
uint8_t more;
|
||||
uint32_t rem, proposed_m;
|
||||
|
||||
// 计算 proposed_m = floor(2^(floor_log_2_d + 31) / absD)
|
||||
proposed_m = div_64_32((uint32_t)1 << (floor_log_2_d - 1), 0, absD, &rem);
|
||||
const uint32_t e = absD - rem;
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] proposed_m = " << proposed_m << ", rem = " << rem << ", e = " << e << std::endl;
|
||||
}
|
||||
|
||||
// 确定是否需要"加法"版本
|
||||
const bool branchfree = false; // 使用分支版本
|
||||
|
||||
if (!branchfree && e < ((uint32_t)1 << floor_log_2_d)) {
|
||||
// 这个幂次有效
|
||||
more = (uint8_t)(floor_log_2_d - 1);
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Using basic algorithm, shift = " << (int)more << std::endl;
|
||||
}
|
||||
} else {
|
||||
// 我们需要上升一个等级
|
||||
proposed_m += proposed_m;
|
||||
const uint32_t twice_rem = rem + rem;
|
||||
if (twice_rem >= absD || twice_rem < rem) {
|
||||
proposed_m += 1;
|
||||
}
|
||||
more = (uint8_t)(floor_log_2_d | LIBDIVIDE_ADD_MARKER);
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Using add algorithm, proposed_m = " << proposed_m << ", more = " << (int)more << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
proposed_m += 1;
|
||||
int32_t magic = (int32_t)proposed_m;
|
||||
|
||||
// 处理负除数
|
||||
if (d < 0) {
|
||||
more |= LIBDIVIDE_NEGATIVE_DIVISOR;
|
||||
if (!branchfree) {
|
||||
magic = -magic;
|
||||
}
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Negative divisor, magic = " << magic << ", more = " << (int)more << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// 为我们的IR生成提取移位量和标志
|
||||
int shift = more & 0x3F; // 移除标志,保留移位量(位0-5)
|
||||
bool need_add = (more & LIBDIVIDE_ADD_MARKER) != 0;
|
||||
bool is_negative = (more & LIBDIVIDE_NEGATIVE_DIVISOR) != 0;
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Final result: magic = " << magic << ", more = " << (int)more
|
||||
<< " (0x" << std::hex << (int)more << std::dec << ")" << std::endl;
|
||||
std::cout << "[SR] Shift = " << shift << ", need_add = " << need_add
|
||||
<< ", is_negative = " << is_negative << std::endl;
|
||||
|
||||
// Test the magic number using the correct libdivide algorithm
|
||||
std::cout << "[SR] Testing magic number (libdivide algorithm):" << std::endl;
|
||||
int test_values[] = {1, 7, 37, 100, 999, -1, -7, -37, -100};
|
||||
|
||||
for (int test_val : test_values) {
|
||||
int64_t quotient;
|
||||
|
||||
// 实现正确的libdivide算法
|
||||
int64_t product = (int64_t)test_val * magic;
|
||||
int64_t high_bits = product >> 32;
|
||||
|
||||
if (need_add) {
|
||||
// ADD_MARKER情况:移位前加上被除数
|
||||
// 这是libdivide的关键洞察!
|
||||
high_bits += test_val;
|
||||
quotient = high_bits >> shift;
|
||||
} else {
|
||||
// 正常情况:只是移位
|
||||
quotient = high_bits >> shift;
|
||||
}
|
||||
|
||||
// 符号修正:这是libdivide有符号除法的关键部分!
|
||||
// 如果被除数为负,商需要加1来匹配C语言的截断除法语义
|
||||
if (test_val < 0) {
|
||||
quotient += 1;
|
||||
}
|
||||
|
||||
int expected = test_val / divisor;
|
||||
|
||||
bool correct = (quotient == expected);
|
||||
std::cout << "[SR] " << test_val << " / " << divisor << " = " << quotient
|
||||
<< " (expected " << expected << ") " << (correct ? "✓" : "✗") << std::endl;
|
||||
}
|
||||
|
||||
std::cout << "[SR] ===== End magic computation =====" << std::endl;
|
||||
}
|
||||
|
||||
// 返回魔数、移位量,并在移位中编码ADD_MARKER标志
|
||||
// 我们将使用移位的第6位表示ADD_MARKER,第7位表示负数(如果需要)
|
||||
int encoded_shift = shift;
|
||||
if (need_add) {
|
||||
encoded_shift |= 0x40; // 设置第6位表示ADD_MARKER
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Encoding ADD_MARKER in shift: " << encoded_shift << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
return {magic, encoded_shift};
|
||||
}
|
||||
|
||||
bool LoopStrengthReduction::runOnFunction(Function* F, AnalysisManager& AM) {
|
||||
if (F->getBasicBlocks().empty()) {
|
||||
@ -1018,7 +837,7 @@ Value* StrengthReductionContext::generateConstantDivisionReplacement(
|
||||
IRBuilder* builder
|
||||
) const {
|
||||
// 使用mulh指令优化任意常数除法
|
||||
auto [magic, shift] = computeMulhMagicNumbers(candidate->multiplier);
|
||||
auto [magic, shift] = SysYIROptUtils::computeMulhMagicNumbers(candidate->multiplier);
|
||||
|
||||
// 检查是否无法优化(magic == -1, shift == -1 表示失败)
|
||||
if (magic == -1 && shift == -1) {
|
||||
|
||||
@ -70,20 +70,20 @@ void Reg2MemContext::allocateMemoryForSSAValues(Function *func) {
|
||||
|
||||
// 1. 为函数参数分配内存
|
||||
builder->setPosition(entryBlock, entryBlock->begin()); // 确保在入口块的开始位置插入
|
||||
for (auto arg : func->getArguments()) {
|
||||
// 默认情况下,将所有参数是提升到内存
|
||||
if (isPromotableToMemory(arg)) {
|
||||
// 参数的类型就是 AllocaInst 需要分配的类型
|
||||
AllocaInst *alloca = builder->createAllocaInst(Type::getPointerType(arg->getType()), arg->getName() + ".reg2mem");
|
||||
// 将参数值 store 到 alloca 中 (这是 Mem2Reg 逆转的关键一步)
|
||||
valueToAllocaMap[arg] = alloca;
|
||||
// for (auto arg : func->getArguments()) {
|
||||
// // 默认情况下,将所有参数是提升到内存
|
||||
// if (isPromotableToMemory(arg)) {
|
||||
// // 参数的类型就是 AllocaInst 需要分配的类型
|
||||
// AllocaInst *alloca = builder->createAllocaInst(Type::getPointerType(arg->getType()), arg->getName() + ".reg2mem");
|
||||
// // 将参数值 store 到 alloca 中 (这是 Mem2Reg 逆转的关键一步)
|
||||
// valueToAllocaMap[arg] = alloca;
|
||||
|
||||
// 确保 alloca 位于入口块的顶部,但在所有参数的 store 指令之前
|
||||
// 通常 alloca 都在 entry block 的最开始
|
||||
// 这里我们只是创建,并让 builder 决定插入位置 (通常在当前插入点)
|
||||
// 如果需要严格控制顺序,可能需要手动 insert 到 instruction list
|
||||
}
|
||||
}
|
||||
// // 确保 alloca 位于入口块的顶部,但在所有参数的 store 指令之前
|
||||
// // 通常 alloca 都在 entry block 的最开始
|
||||
// // 这里我们只是创建,并让 builder 决定插入位置 (通常在当前插入点)
|
||||
// // 如果需要严格控制顺序,可能需要手动 insert 到 instruction list
|
||||
// }
|
||||
// }
|
||||
|
||||
// 2. 为指令结果分配内存
|
||||
// 遍历所有基本块和指令,找出所有需要分配 Alloca 的指令结果
|
||||
@ -123,11 +123,11 @@ void Reg2MemContext::allocateMemoryForSSAValues(Function *func) {
|
||||
}
|
||||
|
||||
// 插入所有参数的初始 Store 指令
|
||||
for (auto arg : func->getArguments()) {
|
||||
if (valueToAllocaMap.count(arg)) { // 检查是否为其分配了 alloca
|
||||
builder->createStoreInst(arg, valueToAllocaMap[arg]);
|
||||
}
|
||||
}
|
||||
// for (auto arg : func->getArguments()) {
|
||||
// if (valueToAllocaMap.count(arg)) { // 检查是否为其分配了 alloca
|
||||
// builder->createStoreInst(arg, valueToAllocaMap[arg]);
|
||||
// }
|
||||
// }
|
||||
|
||||
builder->setPosition(entryBlock, entryBlock->terminator());
|
||||
}
|
||||
|
||||
125
src/midend/Pass/Optimize/TailCallOpt.cpp
Normal file
125
src/midend/Pass/Optimize/TailCallOpt.cpp
Normal file
@ -0,0 +1,125 @@
|
||||
#include "TailCallOpt.h"
|
||||
#include "IR.h"
|
||||
#include "IRBuilder.h"
|
||||
#include "SysYIROptUtils.h"
|
||||
#include <vector>
|
||||
// #include <iostream>
|
||||
#include <algorithm>
|
||||
|
||||
namespace sysy {
|
||||
|
||||
void *TailCallOpt::ID = (void *)&TailCallOpt::ID;
|
||||
|
||||
void TailCallOpt::getAnalysisUsage(std::set<void *> &analysisDependencies, std::set<void *> &analysisInvalidations) const {
|
||||
analysisInvalidations.insert(&DominatorTreeAnalysisPass::ID);
|
||||
analysisInvalidations.insert(&LoopAnalysisPass::ID);
|
||||
}
|
||||
|
||||
bool TailCallOpt::runOnFunction(Function *F, AnalysisManager &AM) {
|
||||
std::vector<CallInst *> tailCallInsts;
|
||||
// 遍历函数的所有基本块
|
||||
for (auto &bb_ptr : F->getBasicBlocks()) {
|
||||
auto BB = bb_ptr.get();
|
||||
if (BB->getInstructions().empty()) continue; // 跳过空基本块
|
||||
|
||||
auto term_iter = BB->terminator();
|
||||
if (term_iter == BB->getInstructions().end()) continue; // 没有终结指令则跳过
|
||||
auto term = (*term_iter).get();
|
||||
|
||||
if (!term || !term->isReturn()) continue; // 不是返回指令则跳过
|
||||
auto retInst = static_cast<ReturnInst *>(term);
|
||||
|
||||
Instruction *prevInst = nullptr;
|
||||
if (BB->getInstructions().size() > 1) {
|
||||
auto it = term_iter;
|
||||
--it; // 获取返回指令前的指令
|
||||
prevInst = (*it).get();
|
||||
}
|
||||
|
||||
if (!prevInst || !prevInst->isCall()) continue; // 前一条不是调用指令则跳过
|
||||
auto callInst = static_cast<CallInst *>(prevInst);
|
||||
|
||||
// 检查是否为尾递归调用:被调用函数与当前函数相同且返回值与调用结果匹配
|
||||
if (callInst->getCallee() == F) {
|
||||
// 对于尾递归,返回值应为调用结果或为 void 类型
|
||||
if (retInst->getReturnValue() == callInst ||
|
||||
(retInst->getReturnValue() == nullptr && callInst->getType()->isVoid())) {
|
||||
tailCallInsts.push_back(callInst);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (tailCallInsts.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 创建一个新的入口基本块,作为循环的前置块
|
||||
auto original_entry = F->getEntryBlock();
|
||||
auto new_entry = F->addBasicBlock("tco.entry." + F->getName());
|
||||
auto loop_header = F->addBasicBlock("tco.loop_header." + F->getName());
|
||||
|
||||
// 将原入口块中的所有指令移动到循环头块
|
||||
loop_header->getInstructions().splice(loop_header->end(), original_entry->getInstructions());
|
||||
original_entry->setName("tco.pre_header");
|
||||
|
||||
// 为函数参数创建 phi 节点
|
||||
builder->setPosition(loop_header, loop_header->begin());
|
||||
std::vector<PhiInst *> phis;
|
||||
auto original_args = F->getArguments();
|
||||
for (auto &arg : original_args) {
|
||||
auto phi = builder->createPhiInst(arg->getType(), {}, {}, "tco.phi."+arg->getName());
|
||||
phis.push_back(phi);
|
||||
}
|
||||
|
||||
// 用 phi 节点替换所有原始参数的使用
|
||||
for (size_t i = 0; i < original_args.size(); ++i) {
|
||||
original_args[i]->replaceAllUsesWith(phis[i]);
|
||||
}
|
||||
|
||||
// 设置 phi 节点的输入值
|
||||
for (size_t i = 0; i < phis.size(); ++i) {
|
||||
phis[i]->addIncoming(original_args[i], new_entry);
|
||||
}
|
||||
|
||||
// 连接各个基本块
|
||||
builder->setPosition(original_entry, original_entry->end());
|
||||
builder->createUncondBrInst(new_entry);
|
||||
original_entry->addSuccessor(new_entry);
|
||||
|
||||
builder->setPosition(new_entry, new_entry->end());
|
||||
builder->createUncondBrInst(loop_header);
|
||||
new_entry->addSuccessor(loop_header);
|
||||
loop_header->addPredecessor(new_entry);
|
||||
|
||||
// 处理每一个尾递归调用
|
||||
for (auto callInst : tailCallInsts) {
|
||||
auto tail_call_block = callInst->getParent();
|
||||
|
||||
// 收集尾递归调用的参数
|
||||
auto args_range = callInst->getArguments();
|
||||
std::vector<Value*> args;
|
||||
std::transform(args_range.begin(), args_range.end(), std::back_inserter(args),
|
||||
[](auto& use_ptr){ return use_ptr->getValue(); });
|
||||
|
||||
// 用新的参数值更新 phi 节点
|
||||
for (size_t i = 0; i < phis.size(); ++i) {
|
||||
phis[i]->addIncoming(args[i], tail_call_block);
|
||||
}
|
||||
|
||||
// 移除原有的调用和返回指令
|
||||
auto term_iter = tail_call_block->terminator();
|
||||
SysYIROptUtils::usedelete(term_iter);
|
||||
auto call_iter = tail_call_block->findInstIterator(callInst);
|
||||
SysYIROptUtils::usedelete(call_iter);
|
||||
|
||||
// 添加跳转回循环头块的分支指令
|
||||
builder->setPosition(tail_call_block, tail_call_block->end());
|
||||
builder->createUncondBrInst(loop_header);
|
||||
tail_call_block->addSuccessor(loop_header);
|
||||
loop_header->addPredecessor(tail_call_block);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace sysy
|
||||
@ -10,13 +10,15 @@
|
||||
#include "DCE.h"
|
||||
#include "Mem2Reg.h"
|
||||
#include "Reg2Mem.h"
|
||||
#include "GVN.h"
|
||||
#include "SCCP.h"
|
||||
#include "BuildCFG.h"
|
||||
#include "LargeArrayToGlobal.h"
|
||||
#include "LoopNormalization.h"
|
||||
#include "LICM.h"
|
||||
#include "LoopStrengthReduction.h"
|
||||
#include "InductionVariableElimination.h"
|
||||
#include "GlobalStrengthReduction.h"
|
||||
#include "TailCallOpt.h"
|
||||
#include "Pass.h"
|
||||
#include <iostream>
|
||||
#include <queue>
|
||||
@ -58,7 +60,7 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR
|
||||
|
||||
// 注册优化遍
|
||||
registerOptimizationPass<BuildCFG>();
|
||||
registerOptimizationPass<LargeArrayToGlobalPass>();
|
||||
registerOptimizationPass<GVN>();
|
||||
|
||||
registerOptimizationPass<SysYDelInstAfterBrPass>();
|
||||
registerOptimizationPass<SysYDelNoPreBLockPass>();
|
||||
@ -74,7 +76,10 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR
|
||||
registerOptimizationPass<LICM>(builderIR);
|
||||
registerOptimizationPass<LoopStrengthReduction>(builderIR);
|
||||
registerOptimizationPass<InductionVariableElimination>();
|
||||
|
||||
registerOptimizationPass<GlobalStrengthReduction>(builderIR);
|
||||
registerOptimizationPass<Reg2Mem>(builderIR);
|
||||
registerOptimizationPass<TailCallOpt>(builderIR);
|
||||
|
||||
registerOptimizationPass<SCCP>(builderIR);
|
||||
|
||||
@ -90,7 +95,6 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR
|
||||
|
||||
this->clearPasses();
|
||||
this->addPass(&BuildCFG::ID);
|
||||
this->addPass(&LargeArrayToGlobalPass::ID);
|
||||
this->run();
|
||||
|
||||
this->clearPasses();
|
||||
@ -129,6 +133,25 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR
|
||||
printPasses();
|
||||
}
|
||||
|
||||
this->clearPasses();
|
||||
this->addPass(&GVN::ID);
|
||||
this->run();
|
||||
|
||||
this->clearPasses();
|
||||
this->addPass(&TailCallOpt::ID);
|
||||
this->run();
|
||||
|
||||
if(DEBUG) {
|
||||
std::cout << "=== IR After TailCallOpt ===\n";
|
||||
SysYPrinter printer(moduleIR);
|
||||
printer.printIR();
|
||||
}
|
||||
|
||||
if(DEBUG) {
|
||||
std::cout << "=== IR After GVN Optimizations ===\n";
|
||||
printPasses();
|
||||
}
|
||||
|
||||
this->clearPasses();
|
||||
this->addPass(&SCCP::ID);
|
||||
this->run();
|
||||
@ -141,19 +164,46 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR
|
||||
this->clearPasses();
|
||||
this->addPass(&LoopNormalizationPass::ID);
|
||||
this->addPass(&InductionVariableElimination::ID);
|
||||
this->addPass(&LICM::ID);
|
||||
this->addPass(&LoopStrengthReduction::ID);
|
||||
this->run();
|
||||
|
||||
if(DEBUG) {
|
||||
std::cout << "=== IR After Loop Normalization, LICM, and Strength Reduction Optimizations ===\n";
|
||||
std::cout << "=== IR After Loop Normalization, Induction Variable Elimination ===\n";
|
||||
printPasses();
|
||||
}
|
||||
|
||||
|
||||
this->clearPasses();
|
||||
this->addPass(&LICM::ID);
|
||||
this->run();
|
||||
|
||||
if(DEBUG) {
|
||||
std::cout << "=== IR After LICM ===\n";
|
||||
printPasses();
|
||||
}
|
||||
|
||||
// this->clearPasses();
|
||||
// this->addPass(&LoopStrengthReduction::ID);
|
||||
// this->run();
|
||||
|
||||
if(DEBUG) {
|
||||
std::cout << "=== IR After Loop Normalization, and Strength Reduction Optimizations ===\n";
|
||||
printPasses();
|
||||
}
|
||||
|
||||
// // 全局强度削弱优化,包括代数优化和魔数除法
|
||||
// this->clearPasses();
|
||||
// this->addPass(&Reg2Mem::ID);
|
||||
// this->addPass(&GlobalStrengthReduction::ID);
|
||||
// this->run();
|
||||
|
||||
if(DEBUG) {
|
||||
std::cout << "=== IR After Global Strength Reduction Optimizations ===\n";
|
||||
printPasses();
|
||||
}
|
||||
|
||||
this->clearPasses();
|
||||
this->addPass(&Reg2Mem::ID);
|
||||
this->run();
|
||||
|
||||
if(DEBUG) {
|
||||
std::cout << "=== IR After Reg2Mem Optimizations ===\n";
|
||||
printPasses();
|
||||
|
||||
@ -28,7 +28,7 @@ static string argStopAfter;
|
||||
static string argInputFile;
|
||||
static bool argFormat = false; // 目前未使用,但保留
|
||||
static string argOutputFilename;
|
||||
static int optLevel = 0; // 优化级别,默认为0 (不加-O参数时)
|
||||
int optLevel = 0; // 优化级别,默认为0 (不加-O参数时)
|
||||
|
||||
void usage(int code) {
|
||||
const char *msg = "Usage: sysyc [options] inputfile\n\n"
|
||||
|
||||
Reference in New Issue
Block a user