Compare commits
13 Commits
midend-Loo
...
midend
| Author | SHA1 | Date | |
|---|---|---|---|
| d72601d9db | |||
| 8094fd5705 | |||
| ad5f35c1a0 | |||
| 839791e862 | |||
| 751d3df2ac | |||
| 1d59e9e256 | |||
| db122cabbd | |||
| ce4d4b5f5b | |||
| 042b1a5d99 | |||
| 937833117e | |||
| ad74e435ba | |||
| 5c34cbc7b8 | |||
| c9a0c700e1 |
2
.gitignore
vendored
2
.gitignore
vendored
@ -36,7 +36,7 @@ doxygen
|
||||
|
||||
!/testdata/functional/*.out
|
||||
!/testdata/h_functional/*.out
|
||||
!/testdata/performance/*.out
|
||||
testdata/performance/
|
||||
build/
|
||||
.antlr
|
||||
.vscode/
|
||||
|
||||
@ -20,18 +20,19 @@ QEMU_RISCV64="qemu-riscv64"
|
||||
|
||||
# --- 初始化变量 ---
|
||||
EXECUTE_MODE=false
|
||||
IR_EXECUTE_MODE=false # 新增
|
||||
IR_EXECUTE_MODE=false
|
||||
CLEAN_MODE=false
|
||||
OPTIMIZE_FLAG=""
|
||||
SYSYC_TIMEOUT=30
|
||||
LLC_TIMEOUT=10 # 新增
|
||||
LLC_TIMEOUT=10
|
||||
GCC_TIMEOUT=10
|
||||
EXEC_TIMEOUT=30
|
||||
MAX_OUTPUT_LINES=20
|
||||
MAX_OUTPUT_CHARS=1000
|
||||
SY_FILES=()
|
||||
PASSED_CASES=0
|
||||
FAILED_CASES_LIST=""
|
||||
INTERRUPTED=false # 新增
|
||||
INTERRUPTED=false
|
||||
|
||||
# =================================================================
|
||||
# --- 函数定义 ---
|
||||
@ -50,22 +51,31 @@ show_help() {
|
||||
echo " -gct N 设置 gcc 交叉编译超时为 N 秒 (默认: 10)。"
|
||||
echo " -et N 设置 qemu 自动化执行超时为 N 秒 (默认: 30)。"
|
||||
echo " -ml N, --max-lines N 当输出对比失败时,最多显示 N 行内容 (默认: 20)。"
|
||||
echo " -mc N, --max-chars N 当输出对比失败时,最多显示 N 个字符 (默认: 1000)。"
|
||||
echo " -h, --help 显示此帮助信息并退出。"
|
||||
echo ""
|
||||
echo "可在任何时候按 Ctrl+C 来中断测试并显示当前已完成的测例总结。"
|
||||
}
|
||||
|
||||
# 显示文件内容并根据行数和字符数截断的函数
|
||||
display_file_content() {
|
||||
local file_path="$1"
|
||||
local title="$2"
|
||||
local max_lines="$3"
|
||||
local max_chars="$4" # 新增参数
|
||||
if [ ! -f "$file_path" ]; then return; fi
|
||||
echo -e "$title"
|
||||
local line_count
|
||||
local char_count
|
||||
line_count=$(wc -l < "$file_path")
|
||||
char_count=$(wc -c < "$file_path")
|
||||
|
||||
if [ "$line_count" -gt "$max_lines" ]; then
|
||||
head -n "$max_lines" "$file_path"
|
||||
echo -e "\e[33m[... 输出已截断,共 ${line_count} 行 ...]\e[0m"
|
||||
echo -e "\e[33m[... 输出因行数过多 (共 ${line_count} 行) 而截断 ...]\e[0m"
|
||||
elif [ "$char_count" -gt "$max_chars" ]; then
|
||||
head -c "$max_chars" "$file_path"
|
||||
echo -e "\n\e[33m[... 输出因字符数过多 (共 ${char_count} 字符) 而截断 ...]\e[0m"
|
||||
else
|
||||
cat "$file_path"
|
||||
fi
|
||||
@ -131,6 +141,7 @@ while [[ "$#" -gt 0 ]]; do
|
||||
-gct) if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then GCC_TIMEOUT="$2"; shift 2; else echo "错误: -gct 需要一个正整数参数。" >&2; exit 1; fi ;;
|
||||
-et) if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then EXEC_TIMEOUT="$2"; shift 2; else echo "错误: -et 需要一个正整数参数。" >&2; exit 1; fi ;;
|
||||
-ml|--max-lines) if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then MAX_OUTPUT_LINES="$2"; shift 2; else echo "错误: --max-lines 需要一个正整数参数。" >&2; exit 1; fi ;;
|
||||
-mc|--max-chars) if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then MAX_OUTPUT_CHARS="$2"; shift 2; else echo "错误: --max-chars 需要一个正整数参数。" >&2; exit 1; fi ;;
|
||||
-h|--help) show_help; exit 0 ;;
|
||||
-*) echo "未知选项: $1"; show_help; exit 1 ;;
|
||||
*)
|
||||
@ -180,6 +191,8 @@ TOTAL_CASES=${#SY_FILES[@]}
|
||||
echo "SysY 单例测试运行器启动..."
|
||||
if [ -n "$OPTIMIZE_FLAG" ]; then echo "优化等级: ${OPTIMIZE_FLAG}"; fi
|
||||
echo "超时设置: sysyc=${SYSYC_TIMEOUT}s, llc=${LLC_TIMEOUT}s, gcc=${GCC_TIMEOUT}s, qemu=${EXEC_TIMEOUT}s"
|
||||
echo "失败输出最大行数: ${MAX_OUTPUT_LINES}"
|
||||
echo "失败输出最大字符数: ${MAX_OUTPUT_CHARS}"
|
||||
echo ""
|
||||
|
||||
for sy_file in "${SY_FILES[@]}"; do
|
||||
@ -260,8 +273,8 @@ for sy_file in "${SY_FILES[@]}"; do
|
||||
out_ok=1
|
||||
if ! diff -q <(tr -d '[:space:]' < "${output_actual_file}") <(tr -d '[:space:]' < "${EXPECTED_STDOUT_FILE}") >/dev/null 2>&1; then
|
||||
echo -e "\e[31m 标准输出测试失败。\e[0m"; out_ok=0
|
||||
display_file_content "${EXPECTED_STDOUT_FILE}" " \e[36m--- 期望输出 ---\e[0m" "${MAX_OUTPUT_LINES}"
|
||||
display_file_content "${output_actual_file}" " \e[36m--- 实际输出 ---\e[0m" "${MAX_OUTPUT_LINES}"
|
||||
display_file_content "${EXPECTED_STDOUT_FILE}" " \e[36m--- 期望输出 ---\e[0m" "${MAX_OUTPUT_LINES}" "${MAX_OUTPUT_CHARS}"
|
||||
display_file_content "${output_actual_file}" " \e[36m--- 实际输出 ---\e[0m" "${MAX_OUTPUT_LINES}" "${MAX_OUTPUT_CHARS}"
|
||||
fi
|
||||
|
||||
if [ "$ret_ok" -eq 1 ] && [ "$out_ok" -eq 1 ]; then echo -e "\e[32m 返回码与标准输出测试成功。\e[0m"; else is_passed=0; fi
|
||||
@ -271,8 +284,8 @@ for sy_file in "${SY_FILES[@]}"; do
|
||||
echo -e "\e[32m 标准输出测试成功。\e[0m"
|
||||
else
|
||||
echo -e "\e[31m 标准输出测试失败。\e[0m"; is_passed=0
|
||||
display_file_content "${output_reference_file}" " \e[36m--- 期望输出 ---\e[0m" "${MAX_OUTPUT_LINES}"
|
||||
display_file_content "${output_actual_file}" " \e[36m--- 实际输出 ---\e[0m" "${MAX_OUTPUT_LINES}"
|
||||
display_file_content "${output_reference_file}" " \e[36m--- 期望输出 ---\e[0m" "${MAX_OUTPUT_LINES}" "${MAX_OUTPUT_CHARS}"
|
||||
display_file_content "${output_actual_file}" " \e[36m--- 实际输出 ---\e[0m" "${MAX_OUTPUT_LINES}" "${MAX_OUTPUT_CHARS}"
|
||||
fi
|
||||
fi
|
||||
else
|
||||
@ -301,4 +314,4 @@ for sy_file in "${SY_FILES[@]}"; do
|
||||
done
|
||||
|
||||
# --- 打印最终总结 ---
|
||||
print_summary
|
||||
print_summary
|
||||
|
||||
@ -27,11 +27,12 @@ LLC_TIMEOUT=10
|
||||
GCC_TIMEOUT=10
|
||||
EXEC_TIMEOUT=30
|
||||
MAX_OUTPUT_LINES=20
|
||||
MAX_OUTPUT_CHARS=1000
|
||||
TEST_SETS=()
|
||||
TOTAL_CASES=0
|
||||
PASSED_CASES=0
|
||||
FAILED_CASES_LIST=""
|
||||
INTERRUPTED=false # 新增:用于标记是否被中断
|
||||
INTERRUPTED=false
|
||||
|
||||
# =================================================================
|
||||
# --- 函数定义 ---
|
||||
@ -53,6 +54,7 @@ show_help() {
|
||||
echo " -gct N 设置 gcc 交叉编译超时为 N 秒 (默认: 10)。"
|
||||
echo " -et N 设置 qemu 执行超时为 N 秒 (默认: 30)。"
|
||||
echo " -ml N, --max-lines N 当输出对比失败时,最多显示 N 行内容 (默认: 20)。"
|
||||
echo " -mc N, --max-chars N 当输出对比失败时,最多显示 N 个字符 (默认: 1000)。"
|
||||
echo " -h, --help 显示此帮助信息并退出。"
|
||||
echo ""
|
||||
echo "注意: 默认行为 (无 -e 或 -eir) 是将 .sy 文件同时编译为 .s (汇编) 和 .ll (IR),不执行。"
|
||||
@ -60,18 +62,25 @@ show_help() {
|
||||
}
|
||||
|
||||
|
||||
# 显示文件内容并根据行数截断的函数
|
||||
# 显示文件内容并根据行数和字符数截断的函数
|
||||
display_file_content() {
|
||||
local file_path="$1"
|
||||
local title="$2"
|
||||
local max_lines="$3"
|
||||
local max_chars="$4" # 新增参数
|
||||
if [ ! -f "$file_path" ]; then return; fi
|
||||
echo -e "$title"
|
||||
local line_count
|
||||
local char_count
|
||||
line_count=$(wc -l < "$file_path")
|
||||
char_count=$(wc -c < "$file_path")
|
||||
|
||||
if [ "$line_count" -gt "$max_lines" ]; then
|
||||
head -n "$max_lines" "$file_path"
|
||||
echo -e "\e[33m[... 输出已截断,共 ${line_count} 行 ...]\e[0m"
|
||||
echo -e "\e[33m[... 输出因行数过多 (共 ${line_count} 行) 而截断 ...]\e[0m"
|
||||
elif [ "$char_count" -gt "$max_chars" ]; then
|
||||
head -c "$max_chars" "$file_path"
|
||||
echo -e "\n\e[33m[... 输出因字符数过多 (共 ${char_count} 字符) 而截断 ...]\e[0m"
|
||||
else
|
||||
cat "$file_path"
|
||||
fi
|
||||
@ -151,6 +160,7 @@ while [[ "$#" -gt 0 ]]; do
|
||||
-gct) if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then GCC_TIMEOUT="$2"; shift 2; else echo "错误: -gct 需要一个正整数参数。" >&2; exit 1; fi ;;
|
||||
-et) if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then EXEC_TIMEOUT="$2"; shift 2; else echo "错误: -et 需要一个正整数参数。" >&2; exit 1; fi ;;
|
||||
-ml|--max-lines) if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then MAX_OUTPUT_LINES="$2"; shift 2; else echo "错误: --max-lines 需要一个正整数参数。" >&2; exit 1; fi ;;
|
||||
-mc|--max-chars) if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then MAX_OUTPUT_CHARS="$2"; shift 2; else echo "错误: --max-chars 需要一个正整数参数。" >&2; exit 1; fi ;;
|
||||
-h|--help) show_help; exit 0 ;;
|
||||
*) echo "未知选项: $1"; show_help; exit 1 ;;
|
||||
esac
|
||||
@ -204,6 +214,7 @@ echo "运行模式: ${RUN_MODE_INFO}"
|
||||
echo "${TIMEOUT_INFO}"
|
||||
if ${EXECUTE_MODE} || ${IR_EXECUTE_MODE}; then
|
||||
echo "失败输出最大行数: ${MAX_OUTPUT_LINES}"
|
||||
echo "失败输出最大字符数: ${MAX_OUTPUT_CHARS}"
|
||||
fi
|
||||
echo ""
|
||||
|
||||
@ -298,8 +309,8 @@ while IFS= read -r sy_file; do
|
||||
[ "$test_logic_passed" -eq 1 ] && echo -e "\e[32m 标准输出测试成功\e[0m"
|
||||
else
|
||||
echo -e "\e[31m 标准输出测试失败\e[0m"
|
||||
display_file_content "${EXPECTED_STDOUT_FILE}" " \e[36m---------- 期望输出 ----------\e[0m" "${MAX_OUTPUT_LINES}"
|
||||
display_file_content "${output_actual_file_from_ir}" " \e[36m---------- 实际输出 ----------\e[0m" "${MAX_OUTPUT_LINES}"
|
||||
display_file_content "${EXPECTED_STDOUT_FILE}" " \e[36m---------- 期望输出 ----------\e[0m" "${MAX_OUTPUT_LINES}" "${MAX_OUTPUT_CHARS}"
|
||||
display_file_content "${output_actual_file_from_ir}" " \e[36m---------- 实际输出 ----------\e[0m" "${MAX_OUTPUT_LINES}" "${MAX_OUTPUT_CHARS}"
|
||||
test_logic_passed=0
|
||||
fi
|
||||
else
|
||||
@ -308,8 +319,8 @@ while IFS= read -r sy_file; do
|
||||
echo -e "\e[32m 成功: 输出与参考输出匹配\e[0m"
|
||||
else
|
||||
echo -e "\e[31m 失败: 输出不匹配\e[0m"
|
||||
display_file_content "${output_reference_file}" " \e[36m---------- 期望输出 ----------\e[0m" "${MAX_OUTPUT_LINES}"
|
||||
display_file_content "${output_actual_file_from_ir}" " \e[36m---------- 实际输出 ----------\e[0m" "${MAX_OUTPUT_LINES}"
|
||||
display_file_content "${output_reference_file}" " \e[36m---------- 期望输出 ----------\e[0m" "${MAX_OUTPUT_LINES}" "${MAX_OUTPUT_CHARS}"
|
||||
display_file_content "${output_actual_file_from_ir}" " \e[36m---------- 实际输出 ----------\e[0m" "${MAX_OUTPUT_LINES}" "${MAX_OUTPUT_CHARS}"
|
||||
test_logic_passed=0
|
||||
fi
|
||||
fi
|
||||
@ -375,8 +386,8 @@ while IFS= read -r sy_file; do
|
||||
[ "$test_logic_passed" -eq 1 ] && echo -e "\e[32m 标准输出测试成功\e[0m"
|
||||
else
|
||||
echo -e "\e[31m 标准输出测试失败\e[0m"
|
||||
display_file_content "${EXPECTED_STDOUT_FILE}" " \e[36m---------- 期望输出 ----------\e[0m" "${MAX_OUTPUT_LINES}"
|
||||
display_file_content "${output_actual_file_S}" " \e[36m---------- 实际输出 ----------\e[0m" "${MAX_OUTPUT_LINES}"
|
||||
display_file_content "${EXPECTED_STDOUT_FILE}" " \e[36m---------- 期望输出 ----------\e[0m" "${MAX_OUTPUT_LINES}" "${MAX_OUTPUT_CHARS}"
|
||||
display_file_content "${output_actual_file_S}" " \e[36m---------- 实际输出 ----------\e[0m" "${MAX_OUTPUT_LINES}" "${MAX_OUTPUT_CHARS}"
|
||||
test_logic_passed=0
|
||||
fi
|
||||
else
|
||||
@ -385,8 +396,8 @@ while IFS= read -r sy_file; do
|
||||
echo -e "\e[32m 成功: 输出与参考输出匹配\e[0m"
|
||||
else
|
||||
echo -e "\e[31m 失败: 输出不匹配\e[0m"
|
||||
display_file_content "${output_reference_file}" " \e[36m---------- 期望输出 ----------\e[0m" "${MAX_OUTPUT_LINES}"
|
||||
display_file_content "${output_actual_file_S}" " \e[36m---------- 实际输出 ----------\e[0m" "${MAX_OUTPUT_LINES}"
|
||||
display_file_content "${output_reference_file}" " \e[36m---------- 期望输出 ----------\e[0m" "${MAX_OUTPUT_LINES}" "${MAX_OUTPUT_CHARS}"
|
||||
display_file_content "${output_actual_file_S}" " \e[36m---------- 实际输出 ----------\e[0m" "${MAX_OUTPUT_LINES}" "${MAX_OUTPUT_CHARS}"
|
||||
test_logic_passed=0
|
||||
fi
|
||||
fi
|
||||
|
||||
@ -634,6 +634,22 @@ void PeepholeOptimizer::runOnMachineFunction(MachineFunction *mfunc) {
|
||||
}
|
||||
}
|
||||
}
|
||||
// 8. 消除无用移动指令: mv a, a -> (删除)
|
||||
else if (mi1->getOpcode() == RVOpcodes::MV &&
|
||||
mi1->getOperands().size() == 2) {
|
||||
if (mi1->getOperands()[0]->getKind() == MachineOperand::KIND_REG &&
|
||||
mi1->getOperands()[1]->getKind() == MachineOperand::KIND_REG) {
|
||||
auto *dst = static_cast<RegOperand *>(mi1->getOperands()[0].get());
|
||||
auto *src = static_cast<RegOperand *>(mi1->getOperands()[1].get());
|
||||
|
||||
// 检查源和目标寄存器是否相同
|
||||
if (areRegsEqual(dst, src)) {
|
||||
// 删除这条无用指令
|
||||
instrs.erase(instrs.begin() + i);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 根据是否发生变化调整遍历索引
|
||||
if (!changed) {
|
||||
|
||||
@ -1007,6 +1007,7 @@ class PhiInst : public Instruction {
|
||||
void replaceIncomingBlock(BasicBlock *oldBlock, BasicBlock *newBlock, Value *newValue);
|
||||
void refreshMap() {
|
||||
blk2val.clear();
|
||||
vsize = getNumOperands() / 2;
|
||||
for (unsigned i = 0; i < vsize; ++i) {
|
||||
blk2val[getIncomingBlock(i)] = getIncomingValue(i);
|
||||
}
|
||||
|
||||
107
src/include/midend/Pass/Optimize/GlobalStrengthReduction.h
Normal file
107
src/include/midend/Pass/Optimize/GlobalStrengthReduction.h
Normal file
@ -0,0 +1,107 @@
|
||||
#pragma once
|
||||
|
||||
#include "Pass.h"
|
||||
#include "IR.h"
|
||||
#include "SideEffectAnalysis.h"
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include <cstdint>
|
||||
|
||||
namespace sysy {
|
||||
|
||||
// 魔数乘法结构,用于除法优化
|
||||
struct MagicNumber {
|
||||
uint32_t multiplier;
|
||||
int shift;
|
||||
bool needAdd;
|
||||
|
||||
MagicNumber(uint32_t m, int s, bool add = false)
|
||||
: multiplier(m), shift(s), needAdd(add) {}
|
||||
};
|
||||
|
||||
// 全局强度削弱优化遍的核心逻辑封装类
|
||||
class GlobalStrengthReductionContext {
|
||||
public:
|
||||
// 构造函数,接受IRBuilder参数
|
||||
explicit GlobalStrengthReductionContext(IRBuilder* builder) : builder(builder) {}
|
||||
|
||||
// 运行优化的主要方法
|
||||
void run(Function* func, AnalysisManager* AM, bool& changed);
|
||||
|
||||
private:
|
||||
IRBuilder* builder; // IR构建器
|
||||
|
||||
// 分析结果
|
||||
SideEffectAnalysisResult* sideEffectAnalysis = nullptr;
|
||||
|
||||
// 优化计数
|
||||
int algebraicOptCount = 0;
|
||||
int strengthReductionCount = 0;
|
||||
int divisionOptCount = 0;
|
||||
|
||||
// 主要优化方法
|
||||
bool processBasicBlock(BasicBlock* bb);
|
||||
bool processInstruction(Instruction* inst);
|
||||
|
||||
// 代数优化方法
|
||||
bool tryAlgebraicOptimization(Instruction* inst);
|
||||
bool optimizeAddition(BinaryInst* inst);
|
||||
bool optimizeSubtraction(BinaryInst* inst);
|
||||
bool optimizeMultiplication(BinaryInst* inst);
|
||||
bool optimizeDivision(BinaryInst* inst);
|
||||
bool optimizeComparison(BinaryInst* inst);
|
||||
bool optimizeLogical(BinaryInst* inst);
|
||||
|
||||
// 强度削弱方法
|
||||
bool tryStrengthReduction(Instruction* inst);
|
||||
bool reduceMultiplication(BinaryInst* inst);
|
||||
bool reduceDivision(BinaryInst* inst);
|
||||
bool reducePower(CallInst* inst);
|
||||
|
||||
// 复杂乘法强度削弱方法
|
||||
bool tryComplexMultiplication(BinaryInst* inst, Value* variable, int constant);
|
||||
bool findOptimalShiftDecomposition(int constant, std::vector<int>& shifts);
|
||||
Value* createShiftDecomposition(BinaryInst* inst, Value* variable, const std::vector<int>& shifts);
|
||||
|
||||
// 魔数乘法相关方法
|
||||
MagicNumber computeMagicNumber(uint32_t divisor);
|
||||
std::pair<int, int> computeMulhMagicNumbers(int divisor);
|
||||
Value* createMagicDivision(BinaryInst* divInst, uint32_t divisor, const MagicNumber& magic);
|
||||
Value* createMagicDivisionLibdivide(BinaryInst* divInst, int divisor);
|
||||
bool isPowerOfTwo(uint32_t n);
|
||||
int log2OfPowerOfTwo(uint32_t n);
|
||||
|
||||
// 辅助方法
|
||||
bool isConstantInt(Value* val, int& constVal);
|
||||
bool isConstantInt(Value* val, uint32_t& constVal);
|
||||
ConstantInteger* getConstantInt(int val);
|
||||
bool hasOnlyLocalUses(Instruction* inst);
|
||||
void replaceWithOptimized(Instruction* original, Value* replacement);
|
||||
};
|
||||
|
||||
// 全局强度削弱优化遍类
|
||||
class GlobalStrengthReduction : public OptimizationPass {
|
||||
private:
|
||||
IRBuilder* builder; // IR构建器,用于创建新指令
|
||||
|
||||
public:
|
||||
// 静态成员,作为该遍的唯一ID
|
||||
static void* ID;
|
||||
|
||||
// 构造函数,接受IRBuilder参数
|
||||
explicit GlobalStrengthReduction(IRBuilder* builder)
|
||||
: OptimizationPass("GlobalStrengthReduction", Granularity::Function), builder(builder) {}
|
||||
|
||||
// 在函数上运行优化
|
||||
bool runOnFunction(Function* func, AnalysisManager& AM) override;
|
||||
|
||||
// 返回该遍的唯一ID
|
||||
void* getPassID() const override { return ID; }
|
||||
|
||||
// 声明分析依赖
|
||||
void getAnalysisUsage(std::set<void*>& analysisDependencies,
|
||||
std::set<void*>& analysisInvalidations) const override;
|
||||
};
|
||||
|
||||
} // namespace sysy
|
||||
@ -127,13 +127,6 @@ private:
|
||||
*/
|
||||
bool analyzeInductionVariableRange(const InductionVarInfo* ivInfo, Loop* loop) const;
|
||||
|
||||
/**
|
||||
* 计算用于除法优化的魔数和移位量
|
||||
* @param divisor 除数
|
||||
* @return {魔数, 移位量}
|
||||
*/
|
||||
std::pair<int, int> computeMulhMagicNumbers(int divisor) const;
|
||||
|
||||
/**
|
||||
* 生成除法替换代码
|
||||
* @param candidate 优化候选项
|
||||
|
||||
@ -107,6 +107,218 @@ public:
|
||||
// 所以当AllocaInst的basetype是PointerType时(一维数组)或者是指向ArrayType的PointerType(多位数组)时,返回true
|
||||
return aval && (baseType->isPointer() || baseType->as<PointerType>()->getBaseType()->isArray());
|
||||
}
|
||||
|
||||
|
||||
// PHI指令消除相关方法
|
||||
static bool eliminateRedundantPhisInFunction(Function* func){
|
||||
bool changed = false;
|
||||
std::vector<Instruction *> toDelete;
|
||||
for (auto &bb : func->getBasicBlocks()) {
|
||||
for (auto &inst : bb->getInstructions()) {
|
||||
if (auto phi = dynamic_cast<PhiInst *>(inst.get())) {
|
||||
auto incoming = phi->getIncomingValues();
|
||||
if(DEBUG){
|
||||
std::cout << "Checking Phi: " << phi->getName() << " with " << incoming.size() << " incoming values." << std::endl;
|
||||
}
|
||||
if (incoming.size() == 1) {
|
||||
Value *singleVal = incoming[0].second;
|
||||
inst->replaceAllUsesWith(singleVal);
|
||||
toDelete.push_back(inst.get());
|
||||
}
|
||||
}
|
||||
else
|
||||
break; // 只处理Phi指令
|
||||
}
|
||||
}
|
||||
for (auto *phi : toDelete) {
|
||||
usedelete(phi);
|
||||
changed = true; // 标记为已更改
|
||||
}
|
||||
return changed; // 返回是否有删除发生
|
||||
}
|
||||
|
||||
//该实现参考了libdivide的算法
|
||||
static std::pair<int, int> computeMulhMagicNumbers(int divisor) {
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "\n[SR] ===== Computing magic numbers for divisor " << divisor << " (libdivide algorithm) =====" << std::endl;
|
||||
}
|
||||
|
||||
if (divisor == 0) {
|
||||
if (DEBUG) std::cout << "[SR] Error: divisor must be != 0" << std::endl;
|
||||
return {-1, -1};
|
||||
}
|
||||
|
||||
// libdivide 常数
|
||||
const uint8_t LIBDIVIDE_ADD_MARKER = 0x40;
|
||||
const uint8_t LIBDIVIDE_NEGATIVE_DIVISOR = 0x80;
|
||||
|
||||
// 辅助函数:计算前导零个数
|
||||
auto count_leading_zeros32 = [](uint32_t val) -> uint32_t {
|
||||
if (val == 0) return 32;
|
||||
return __builtin_clz(val);
|
||||
};
|
||||
|
||||
// 辅助函数:64位除法返回32位商和余数
|
||||
auto div_64_32 = [](uint32_t high, uint32_t low, uint32_t divisor, uint32_t* rem) -> uint32_t {
|
||||
uint64_t dividend = ((uint64_t)high << 32) | low;
|
||||
uint32_t quotient = dividend / divisor;
|
||||
*rem = dividend % divisor;
|
||||
return quotient;
|
||||
};
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Input divisor: " << divisor << std::endl;
|
||||
}
|
||||
|
||||
// libdivide_internal_s32_gen 算法实现
|
||||
int32_t d = divisor;
|
||||
uint32_t ud = (uint32_t)d;
|
||||
uint32_t absD = (d < 0) ? -ud : ud;
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] absD = " << absD << std::endl;
|
||||
}
|
||||
|
||||
uint32_t floor_log_2_d = 31 - count_leading_zeros32(absD);
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] floor_log_2_d = " << floor_log_2_d << std::endl;
|
||||
}
|
||||
|
||||
// 检查 absD 是否为2的幂
|
||||
if ((absD & (absD - 1)) == 0) {
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] " << absD << " 是2的幂,使用移位方法" << std::endl;
|
||||
}
|
||||
|
||||
// 对于2的幂,我们只使用移位,不需要魔数
|
||||
int shift = floor_log_2_d;
|
||||
if (d < 0) shift |= 0x80; // 标记负数
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Power of 2 result: magic=0, shift=" << shift << std::endl;
|
||||
std::cout << "[SR] ===== End magic computation =====" << std::endl;
|
||||
}
|
||||
|
||||
// 对于我们的目的,我们将在IR生成中以不同方式处理2的幂
|
||||
// 返回特殊标记
|
||||
return {0, shift};
|
||||
}
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] " << absD << " is not a power of 2, computing magic number" << std::endl;
|
||||
}
|
||||
|
||||
// 非2的幂除数的魔数计算
|
||||
uint8_t more;
|
||||
uint32_t rem, proposed_m;
|
||||
|
||||
// 计算 proposed_m = floor(2^(floor_log_2_d + 31) / absD)
|
||||
proposed_m = div_64_32((uint32_t)1 << (floor_log_2_d - 1), 0, absD, &rem);
|
||||
const uint32_t e = absD - rem;
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] proposed_m = " << proposed_m << ", rem = " << rem << ", e = " << e << std::endl;
|
||||
}
|
||||
|
||||
// 确定是否需要"加法"版本
|
||||
const bool branchfree = false; // 使用分支版本
|
||||
|
||||
if (!branchfree && e < ((uint32_t)1 << floor_log_2_d)) {
|
||||
// 这个幂次有效
|
||||
more = (uint8_t)(floor_log_2_d - 1);
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Using basic algorithm, shift = " << (int)more << std::endl;
|
||||
}
|
||||
} else {
|
||||
// 我们需要上升一个等级
|
||||
proposed_m += proposed_m;
|
||||
const uint32_t twice_rem = rem + rem;
|
||||
if (twice_rem >= absD || twice_rem < rem) {
|
||||
proposed_m += 1;
|
||||
}
|
||||
more = (uint8_t)(floor_log_2_d | LIBDIVIDE_ADD_MARKER);
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Using add algorithm, proposed_m = " << proposed_m << ", more = " << (int)more << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
proposed_m += 1;
|
||||
int32_t magic = (int32_t)proposed_m;
|
||||
|
||||
// 处理负除数
|
||||
if (d < 0) {
|
||||
more |= LIBDIVIDE_NEGATIVE_DIVISOR;
|
||||
if (!branchfree) {
|
||||
magic = -magic;
|
||||
}
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Negative divisor, magic = " << magic << ", more = " << (int)more << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// 为我们的IR生成提取移位量和标志
|
||||
int shift = more & 0x3F; // 移除标志,保留移位量(位0-5)
|
||||
bool need_add = (more & LIBDIVIDE_ADD_MARKER) != 0;
|
||||
bool is_negative = (more & LIBDIVIDE_NEGATIVE_DIVISOR) != 0;
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Final result: magic = " << magic << ", more = " << (int)more
|
||||
<< " (0x" << std::hex << (int)more << std::dec << ")" << std::endl;
|
||||
std::cout << "[SR] Shift = " << shift << ", need_add = " << need_add
|
||||
<< ", is_negative = " << is_negative << std::endl;
|
||||
|
||||
// Test the magic number using the correct libdivide algorithm
|
||||
std::cout << "[SR] Testing magic number (libdivide algorithm):" << std::endl;
|
||||
int test_values[] = {1, 7, 37, 100, 999, -1, -7, -37, -100};
|
||||
|
||||
for (int test_val : test_values) {
|
||||
int64_t quotient;
|
||||
|
||||
// 实现正确的libdivide算法
|
||||
int64_t product = (int64_t)test_val * magic;
|
||||
int64_t high_bits = product >> 32;
|
||||
|
||||
if (need_add) {
|
||||
// ADD_MARKER情况:移位前加上被除数
|
||||
// 这是libdivide的关键洞察!
|
||||
high_bits += test_val;
|
||||
quotient = high_bits >> shift;
|
||||
} else {
|
||||
// 正常情况:只是移位
|
||||
quotient = high_bits >> shift;
|
||||
}
|
||||
|
||||
// 符号修正:这是libdivide有符号除法的关键部分!
|
||||
// 如果被除数为负,商需要加1来匹配C语言的截断除法语义
|
||||
if (test_val < 0) {
|
||||
quotient += 1;
|
||||
}
|
||||
|
||||
int expected = test_val / divisor;
|
||||
|
||||
bool correct = (quotient == expected);
|
||||
std::cout << "[SR] " << test_val << " / " << divisor << " = " << quotient
|
||||
<< " (expected " << expected << ") " << (correct ? "✓" : "✗") << std::endl;
|
||||
}
|
||||
|
||||
std::cout << "[SR] ===== End magic computation =====" << std::endl;
|
||||
}
|
||||
|
||||
// 返回魔数、移位量,并在移位中编码ADD_MARKER标志
|
||||
// 我们将使用移位的第6位表示ADD_MARKER,第7位表示负数(如果需要)
|
||||
int encoded_shift = shift;
|
||||
if (need_add) {
|
||||
encoded_shift |= 0x40; // 设置第6位表示ADD_MARKER
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Encoding ADD_MARKER in shift: " << encoded_shift << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
return {magic, encoded_shift};
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
}// namespace sysy
|
||||
39
src/include/midend/Pass/Optimize/TailCallOpt.h
Normal file
39
src/include/midend/Pass/Optimize/TailCallOpt.h
Normal file
@ -0,0 +1,39 @@
|
||||
#pragma once
|
||||
|
||||
#include "Pass.h"
|
||||
#include "Dom.h"
|
||||
#include "Loop.h"
|
||||
|
||||
namespace sysy {
|
||||
|
||||
/**
|
||||
* @class TailCallOpt
|
||||
* @brief 优化尾调用的中端优化通道。
|
||||
*
|
||||
* 该类实现了一个针对函数级别的尾调用优化的优化通道(OptimizationPass)。
|
||||
* 通过分析和转换 IR(中间表示),将可优化的尾调用转换为更高效的形式,
|
||||
* 以减少函数调用的开销,提升程序性能。
|
||||
*
|
||||
* @note 需要传入 IRBuilder 指针用于 IR 构建和修改。
|
||||
*
|
||||
* @method runOnFunction
|
||||
* 对指定函数进行尾调用优化。
|
||||
*
|
||||
* @method getPassID
|
||||
* 获取当前优化通道的唯一标识符。
|
||||
*
|
||||
* @method getAnalysisUsage
|
||||
* 指定该优化通道所依赖和失效的分析集合。
|
||||
*/
|
||||
class TailCallOpt : public OptimizationPass {
|
||||
private:
|
||||
IRBuilder* builder;
|
||||
public:
|
||||
TailCallOpt(IRBuilder* builder) : OptimizationPass("TailCallOpt", Granularity::Function), builder(builder) {}
|
||||
static void *ID;
|
||||
bool runOnFunction(Function *F, AnalysisManager &AM) override;
|
||||
void *getPassID() const override { return &ID; }
|
||||
void getAnalysisUsage(std::set<void *> &analysisDependencies, std::set<void *> &analysisInvalidations) const override;
|
||||
};
|
||||
|
||||
} // namespace sysy
|
||||
@ -22,8 +22,10 @@ add_library(midend_lib STATIC
|
||||
Pass/Optimize/LICM.cpp
|
||||
Pass/Optimize/LoopStrengthReduction.cpp
|
||||
Pass/Optimize/InductionVariableElimination.cpp
|
||||
Pass/Optimize/GlobalStrengthReduction.cpp
|
||||
Pass/Optimize/BuildCFG.cpp
|
||||
Pass/Optimize/LargeArrayToGlobal.cpp
|
||||
Pass/Optimize/TailCallOpt.cpp
|
||||
)
|
||||
|
||||
# 包含中端模块所需的头文件路径
|
||||
|
||||
@ -757,7 +757,7 @@ void BinaryInst::print(std::ostream &os) const {
|
||||
auto lhs_hash = std::hash<const void*>{}(static_cast<const void*>(getLhs()));
|
||||
auto rhs_hash = std::hash<const void*>{}(static_cast<const void*>(getRhs()));
|
||||
size_t combined_hash = inst_hash ^ (lhs_hash << 1) ^ (rhs_hash << 2);
|
||||
std::string tmpName = "tmp_icmp_" + std::to_string(combined_hash % 1000000);
|
||||
std::string tmpName = "tmp_icmp_" + std::to_string(combined_hash % 1000000007);
|
||||
os << "%" << tmpName << " = " << getKindString() << " " << *getLhs()->getType() << " ";
|
||||
printOperand(os, getLhs());
|
||||
os << ", ";
|
||||
@ -772,7 +772,7 @@ void BinaryInst::print(std::ostream &os) const {
|
||||
auto lhs_hash = std::hash<const void*>{}(static_cast<const void*>(getLhs()));
|
||||
auto rhs_hash = std::hash<const void*>{}(static_cast<const void*>(getRhs()));
|
||||
size_t combined_hash = inst_hash ^ (lhs_hash << 1) ^ (rhs_hash << 2);
|
||||
std::string tmpName = "tmp_fcmp_" + std::to_string(combined_hash % 1000000);
|
||||
std::string tmpName = "tmp_fcmp_" + std::to_string(combined_hash % 1000000007);
|
||||
os << "%" << tmpName << " = " << getKindString() << " " << *getLhs()->getType() << " ";
|
||||
printOperand(os, getLhs());
|
||||
os << ", ";
|
||||
@ -834,7 +834,7 @@ void CondBrInst::print(std::ostream &os) const {
|
||||
if (condName.empty()) {
|
||||
// 使用条件值地址的哈希值作为唯一标识
|
||||
auto ptr_hash = std::hash<const void*>{}(static_cast<const void*>(condition));
|
||||
condName = "const_" + std::to_string(ptr_hash % 100000);
|
||||
condName = "const_" + std::to_string(ptr_hash % 1000000007);
|
||||
}
|
||||
|
||||
// 组合指令地址、条件地址和目标块地址的哈希来确保唯一性
|
||||
@ -843,7 +843,7 @@ void CondBrInst::print(std::ostream &os) const {
|
||||
auto then_hash = std::hash<const void*>{}(static_cast<const void*>(getThenBlock()));
|
||||
auto else_hash = std::hash<const void*>{}(static_cast<const void*>(getElseBlock()));
|
||||
size_t combined_hash = inst_hash ^ (cond_hash << 1) ^ (then_hash << 2) ^ (else_hash << 3);
|
||||
std::string uniqueSuffix = std::to_string(combined_hash % 1000000);
|
||||
std::string uniqueSuffix = std::to_string(combined_hash % 1000000007);
|
||||
|
||||
os << "%tmp_cond_" << condName << "_" << uniqueSuffix << " = icmp ne i32 ";
|
||||
printOperand(os, condition);
|
||||
|
||||
@ -74,6 +74,7 @@ void DCEContext::run(Function *func, AnalysisManager *AM, bool &changed) {
|
||||
}
|
||||
}
|
||||
}
|
||||
changed |= SysYIROptUtils::eliminateRedundantPhisInFunction(func); // 如果有活跃指令,则标记为已更改
|
||||
}
|
||||
|
||||
// 判断指令是否是"天然活跃"的实现
|
||||
|
||||
@ -39,7 +39,7 @@ bool GVN::runOnFunction(Function *func, AnalysisManager &AM) {
|
||||
}
|
||||
std::cout << "=== GVN completed for function: " << func->getName() << " ===" << std::endl;
|
||||
}
|
||||
|
||||
changed |= SysYIROptUtils::eliminateRedundantPhisInFunction(func);
|
||||
return changed;
|
||||
}
|
||||
|
||||
|
||||
897
src/midend/Pass/Optimize/GlobalStrengthReduction.cpp
Normal file
897
src/midend/Pass/Optimize/GlobalStrengthReduction.cpp
Normal file
@ -0,0 +1,897 @@
|
||||
#include "GlobalStrengthReduction.h"
|
||||
#include "SysYIROptUtils.h"
|
||||
#include "IRBuilder.h"
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <cmath>
|
||||
|
||||
extern int DEBUG;
|
||||
|
||||
namespace sysy {
|
||||
|
||||
// 全局强度削弱优化遍的静态 ID
|
||||
void *GlobalStrengthReduction::ID = (void *)&GlobalStrengthReduction::ID;
|
||||
|
||||
// ======================================================================
|
||||
// GlobalStrengthReduction 类的实现
|
||||
// ======================================================================
|
||||
|
||||
bool GlobalStrengthReduction::runOnFunction(Function *func, AnalysisManager &AM) {
|
||||
if (func->getBasicBlocks().empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "\n=== Running GlobalStrengthReduction on function: " << func->getName() << " ===" << std::endl;
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
GlobalStrengthReductionContext context(builder);
|
||||
context.run(func, &AM, changed);
|
||||
|
||||
if (DEBUG) {
|
||||
if (changed) {
|
||||
std::cout << "GlobalStrengthReduction: Function " << func->getName() << " was modified" << std::endl;
|
||||
} else {
|
||||
std::cout << "GlobalStrengthReduction: Function " << func->getName() << " was not modified" << std::endl;
|
||||
}
|
||||
std::cout << "=== GlobalStrengthReduction completed for function: " << func->getName() << " ===" << std::endl;
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
void GlobalStrengthReduction::getAnalysisUsage(std::set<void *> &analysisDependencies, std::set<void *> &analysisInvalidations) const {
|
||||
// 强度削弱依赖副作用分析来判断指令是否可以安全优化
|
||||
analysisDependencies.insert(&SysYSideEffectAnalysisPass::ID);
|
||||
|
||||
// 强度削弱不会使分析失效,因为:
|
||||
// - 只替换计算指令,不改变控制流
|
||||
// - 不修改内存,不影响别名分析
|
||||
// - 保持程序语义不变
|
||||
// analysisInvalidations 保持为空
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "GlobalStrengthReduction: Declared analysis dependencies (SideEffectAnalysis)" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// ======================================================================
|
||||
// GlobalStrengthReductionContext 类的实现
|
||||
// ======================================================================
|
||||
|
||||
void GlobalStrengthReductionContext::run(Function *func, AnalysisManager *AM, bool &changed) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Starting GlobalStrengthReduction analysis for function: " << func->getName() << std::endl;
|
||||
}
|
||||
|
||||
// 获取分析结果
|
||||
if (AM) {
|
||||
sideEffectAnalysis = AM->getAnalysisResult<SideEffectAnalysisResult, SysYSideEffectAnalysisPass>();
|
||||
|
||||
if (DEBUG) {
|
||||
if (sideEffectAnalysis) {
|
||||
std::cout << " GlobalStrengthReduction: Using side effect analysis" << std::endl;
|
||||
} else {
|
||||
std::cout << " GlobalStrengthReduction: Warning - side effect analysis not available" << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 重置计数器
|
||||
algebraicOptCount = 0;
|
||||
strengthReductionCount = 0;
|
||||
divisionOptCount = 0;
|
||||
|
||||
// 遍历所有基本块进行优化
|
||||
for (auto &bb_ptr : func->getBasicBlocks()) {
|
||||
if (processBasicBlock(bb_ptr.get())) {
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << " GlobalStrengthReduction completed for function: " << func->getName() << std::endl;
|
||||
std::cout << " Algebraic optimizations: " << algebraicOptCount << std::endl;
|
||||
std::cout << " Strength reductions: " << strengthReductionCount << std::endl;
|
||||
std::cout << " Division optimizations: " << divisionOptCount << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
bool GlobalStrengthReductionContext::processBasicBlock(BasicBlock *bb) {
|
||||
bool changed = false;
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << " Processing block: " << bb->getName() << std::endl;
|
||||
}
|
||||
|
||||
// 收集需要处理的指令(避免迭代器失效)
|
||||
std::vector<Instruction*> instructions;
|
||||
for (auto &inst_ptr : bb->getInstructions()) {
|
||||
instructions.push_back(inst_ptr.get());
|
||||
}
|
||||
|
||||
// 处理每条指令
|
||||
for (auto inst : instructions) {
|
||||
if (processInstruction(inst)) {
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool GlobalStrengthReductionContext::processInstruction(Instruction *inst) {
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << " Processing instruction: " << inst->getName() << std::endl;
|
||||
}
|
||||
|
||||
// 先尝试代数优化
|
||||
if (tryAlgebraicOptimization(inst)) {
|
||||
algebraicOptCount++;
|
||||
return true;
|
||||
}
|
||||
|
||||
// 再尝试强度削弱
|
||||
if (tryStrengthReduction(inst)) {
|
||||
strengthReductionCount++;
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// ======================================================================
|
||||
// 代数优化方法
|
||||
// ======================================================================
|
||||
|
||||
bool GlobalStrengthReductionContext::tryAlgebraicOptimization(Instruction *inst) {
|
||||
auto binary = dynamic_cast<BinaryInst*>(inst);
|
||||
if (!binary) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (binary->getKind()) {
|
||||
case Instruction::kAdd:
|
||||
return optimizeAddition(binary);
|
||||
case Instruction::kSub:
|
||||
return optimizeSubtraction(binary);
|
||||
case Instruction::kMul:
|
||||
return optimizeMultiplication(binary);
|
||||
case Instruction::kDiv:
|
||||
return optimizeDivision(binary);
|
||||
case Instruction::kICmpEQ:
|
||||
case Instruction::kICmpNE:
|
||||
case Instruction::kICmpLT:
|
||||
case Instruction::kICmpGT:
|
||||
case Instruction::kICmpLE:
|
||||
case Instruction::kICmpGE:
|
||||
return optimizeComparison(binary);
|
||||
case Instruction::kAnd:
|
||||
case Instruction::kOr:
|
||||
return optimizeLogical(binary);
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool GlobalStrengthReductionContext::optimizeAddition(BinaryInst *inst) {
|
||||
Value *lhs = inst->getLhs();
|
||||
Value *rhs = inst->getRhs();
|
||||
int constVal;
|
||||
|
||||
// x + 0 = x
|
||||
if (isConstantInt(rhs, constVal) && constVal == 0) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x + 0 -> x" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, lhs);
|
||||
return true;
|
||||
}
|
||||
|
||||
// 0 + x = x
|
||||
if (isConstantInt(lhs, constVal) && constVal == 0) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = 0 + x -> x" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, rhs);
|
||||
return true;
|
||||
}
|
||||
|
||||
// x + (-y) = x - y
|
||||
if (auto rhsInst = dynamic_cast<UnaryInst*>(rhs)) {
|
||||
if (rhsInst->getKind() == Instruction::kNeg) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x + (-y) -> x - y" << std::endl;
|
||||
}
|
||||
// 创建减法指令
|
||||
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||
auto subInst = builder->createSubInst(lhs, rhsInst->getOperand());
|
||||
replaceWithOptimized(inst, subInst);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool GlobalStrengthReductionContext::optimizeSubtraction(BinaryInst *inst) {
|
||||
Value *lhs = inst->getLhs();
|
||||
Value *rhs = inst->getRhs();
|
||||
int constVal;
|
||||
|
||||
// x - 0 = x
|
||||
if (isConstantInt(rhs, constVal) && constVal == 0) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x - 0 -> x" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, lhs);
|
||||
return true;
|
||||
}
|
||||
|
||||
// x - x = 0 (如果x没有副作用)
|
||||
if (lhs == rhs && hasOnlyLocalUses(dynamic_cast<Instruction*>(lhs))) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x - x -> 0" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, getConstantInt(0));
|
||||
return true;
|
||||
}
|
||||
|
||||
// x - (-y) = x + y
|
||||
if (auto rhsInst = dynamic_cast<UnaryInst*>(rhs)) {
|
||||
if (rhsInst->getKind() == Instruction::kNeg) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x - (-y) -> x + y" << std::endl;
|
||||
}
|
||||
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||
auto addInst = builder->createAddInst(lhs, rhsInst->getOperand());
|
||||
replaceWithOptimized(inst, addInst);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool GlobalStrengthReductionContext::optimizeMultiplication(BinaryInst *inst) {
|
||||
Value *lhs = inst->getLhs();
|
||||
Value *rhs = inst->getRhs();
|
||||
int constVal;
|
||||
|
||||
// x * 0 = 0
|
||||
if (isConstantInt(rhs, constVal) && constVal == 0) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x * 0 -> 0" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, getConstantInt(0));
|
||||
return true;
|
||||
}
|
||||
|
||||
// 0 * x = 0
|
||||
if (isConstantInt(lhs, constVal) && constVal == 0) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = 0 * x -> 0" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, getConstantInt(0));
|
||||
return true;
|
||||
}
|
||||
|
||||
// x * 1 = x
|
||||
if (isConstantInt(rhs, constVal) && constVal == 1) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x * 1 -> x" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, lhs);
|
||||
return true;
|
||||
}
|
||||
|
||||
// 1 * x = x
|
||||
if (isConstantInt(lhs, constVal) && constVal == 1) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = 1 * x -> x" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, rhs);
|
||||
return true;
|
||||
}
|
||||
|
||||
// x * (-1) = -x
|
||||
if (isConstantInt(rhs, constVal) && constVal == -1) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x * (-1) -> -x" << std::endl;
|
||||
}
|
||||
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||
auto negInst = builder->createNegInst(lhs);
|
||||
replaceWithOptimized(inst, negInst);
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool GlobalStrengthReductionContext::optimizeDivision(BinaryInst *inst) {
|
||||
Value *lhs = inst->getLhs();
|
||||
Value *rhs = inst->getRhs();
|
||||
int constVal;
|
||||
|
||||
// x / 1 = x
|
||||
if (isConstantInt(rhs, constVal) && constVal == 1) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x / 1 -> x" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, lhs);
|
||||
return true;
|
||||
}
|
||||
|
||||
// x / (-1) = -x
|
||||
if (isConstantInt(rhs, constVal) && constVal == -1) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x / (-1) -> -x" << std::endl;
|
||||
}
|
||||
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||
auto negInst = builder->createNegInst(lhs);
|
||||
replaceWithOptimized(inst, negInst);
|
||||
return true;
|
||||
}
|
||||
|
||||
// x / x = 1 (如果x != 0且没有副作用)
|
||||
if (lhs == rhs && hasOnlyLocalUses(dynamic_cast<Instruction*>(lhs))) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x / x -> 1" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, getConstantInt(1));
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool GlobalStrengthReductionContext::optimizeComparison(BinaryInst *inst) {
|
||||
Value *lhs = inst->getLhs();
|
||||
Value *rhs = inst->getRhs();
|
||||
|
||||
// x == x = true (如果x没有副作用)
|
||||
if (inst->getKind() == Instruction::kICmpEQ && lhs == rhs &&
|
||||
hasOnlyLocalUses(dynamic_cast<Instruction*>(lhs))) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x == x -> true" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, getConstantInt(1));
|
||||
return true;
|
||||
}
|
||||
|
||||
// x != x = false (如果x没有副作用)
|
||||
if (inst->getKind() == Instruction::kICmpNE && lhs == rhs &&
|
||||
hasOnlyLocalUses(dynamic_cast<Instruction*>(lhs))) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x != x -> false" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, getConstantInt(0));
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool GlobalStrengthReductionContext::optimizeLogical(BinaryInst *inst) {
|
||||
Value *lhs = inst->getLhs();
|
||||
Value *rhs = inst->getRhs();
|
||||
int constVal;
|
||||
|
||||
if (inst->getKind() == Instruction::kAnd) {
|
||||
// x && 0 = 0
|
||||
if (isConstantInt(rhs, constVal) && constVal == 0) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x && 0 -> 0" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, getConstantInt(0));
|
||||
return true;
|
||||
}
|
||||
|
||||
// x && -1 = x
|
||||
if (isConstantInt(rhs, constVal) && constVal == -1) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x && 1 -> x" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, lhs);
|
||||
return true;
|
||||
}
|
||||
|
||||
// x && x = x
|
||||
if (lhs == rhs) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x && x -> x" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, lhs);
|
||||
return true;
|
||||
}
|
||||
} else if (inst->getKind() == Instruction::kOr) {
|
||||
// x || 0 = x
|
||||
if (isConstantInt(rhs, constVal) && constVal == 0) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x || 0 -> x" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, lhs);
|
||||
return true;
|
||||
}
|
||||
|
||||
// x || x = x
|
||||
if (lhs == rhs) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x || x -> x" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, lhs);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// ======================================================================
|
||||
// 强度削弱方法
|
||||
// ======================================================================
|
||||
|
||||
bool GlobalStrengthReductionContext::tryStrengthReduction(Instruction *inst) {
|
||||
if (auto binary = dynamic_cast<BinaryInst*>(inst)) {
|
||||
switch (binary->getKind()) {
|
||||
case Instruction::kMul:
|
||||
return reduceMultiplication(binary);
|
||||
case Instruction::kDiv:
|
||||
return reduceDivision(binary);
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
} else if (auto call = dynamic_cast<CallInst*>(inst)) {
|
||||
return reducePower(call);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool GlobalStrengthReductionContext::reduceMultiplication(BinaryInst *inst) {
|
||||
Value *lhs = inst->getLhs();
|
||||
Value *rhs = inst->getRhs();
|
||||
int constVal;
|
||||
|
||||
// 尝试右操作数为常数
|
||||
Value* variable = lhs;
|
||||
if (isConstantInt(rhs, constVal) && constVal > 0) {
|
||||
return tryComplexMultiplication(inst, variable, constVal);
|
||||
}
|
||||
|
||||
// 尝试左操作数为常数
|
||||
if (isConstantInt(lhs, constVal) && constVal > 0) {
|
||||
variable = rhs;
|
||||
return tryComplexMultiplication(inst, variable, constVal);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool GlobalStrengthReductionContext::tryComplexMultiplication(BinaryInst* inst, Value* variable, int constant) {
|
||||
// 首先检查是否为2的幂,使用简单位移
|
||||
if (isPowerOfTwo(constant)) {
|
||||
int shiftAmount = log2OfPowerOfTwo(constant);
|
||||
if (DEBUG) {
|
||||
std::cout << " StrengthReduction: " << inst->getName()
|
||||
<< " = x * " << constant << " -> x << " << shiftAmount << std::endl;
|
||||
}
|
||||
|
||||
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||
auto shiftInst = builder->createBinaryInst(Instruction::kSll, Type::getIntType(), variable, getConstantInt(shiftAmount));
|
||||
replaceWithOptimized(inst, shiftInst);
|
||||
return true;
|
||||
}
|
||||
|
||||
// 尝试分解为位移和加法的组合
|
||||
std::vector<int> shifts;
|
||||
if (findOptimalShiftDecomposition(constant, shifts)) {
|
||||
if (DEBUG) {
|
||||
std::cout << " StrengthReduction: " << inst->getName()
|
||||
<< " = x * " << constant << " -> shift decomposition with " << shifts.size() << " terms" << std::endl;
|
||||
}
|
||||
|
||||
Value* result = createShiftDecomposition(inst, variable, shifts);
|
||||
if (result) {
|
||||
replaceWithOptimized(inst, result);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool GlobalStrengthReductionContext::findOptimalShiftDecomposition(int constant, std::vector<int>& shifts) {
|
||||
shifts.clear();
|
||||
|
||||
// 常见的有效分解模式
|
||||
switch (constant) {
|
||||
case 3: // 3 = 2^1 + 2^0 -> (x << 1) + x
|
||||
shifts = {1, 0};
|
||||
return true;
|
||||
case 5: // 5 = 2^2 + 2^0 -> (x << 2) + x
|
||||
shifts = {2, 0};
|
||||
return true;
|
||||
case 6: // 6 = 2^2 + 2^1 -> (x << 2) + (x << 1)
|
||||
shifts = {2, 1};
|
||||
return true;
|
||||
case 7: // 7 = 2^2 + 2^1 + 2^0 -> (x << 2) + (x << 1) + x
|
||||
shifts = {2, 1, 0};
|
||||
return true;
|
||||
case 9: // 9 = 2^3 + 2^0 -> (x << 3) + x
|
||||
shifts = {3, 0};
|
||||
return true;
|
||||
case 10: // 10 = 2^3 + 2^1 -> (x << 3) + (x << 1)
|
||||
shifts = {3, 1};
|
||||
return true;
|
||||
case 11: // 11 = 2^3 + 2^1 + 2^0 -> (x << 3) + (x << 1) + x
|
||||
shifts = {3, 1, 0};
|
||||
return true;
|
||||
case 12: // 12 = 2^3 + 2^2 -> (x << 3) + (x << 2)
|
||||
shifts = {3, 2};
|
||||
return true;
|
||||
case 13: // 13 = 2^3 + 2^2 + 2^0 -> (x << 3) + (x << 2) + x
|
||||
shifts = {3, 2, 0};
|
||||
return true;
|
||||
case 14: // 14 = 2^3 + 2^2 + 2^1 -> (x << 3) + (x << 2) + (x << 1)
|
||||
shifts = {3, 2, 1};
|
||||
return true;
|
||||
case 15: // 15 = 2^3 + 2^2 + 2^1 + 2^0 -> (x << 3) + (x << 2) + (x << 1) + x
|
||||
shifts = {3, 2, 1, 0};
|
||||
return true;
|
||||
case 17: // 17 = 2^4 + 2^0 -> (x << 4) + x
|
||||
shifts = {4, 0};
|
||||
return true;
|
||||
case 18: // 18 = 2^4 + 2^1 -> (x << 4) + (x << 1)
|
||||
shifts = {4, 1};
|
||||
return true;
|
||||
case 20: // 20 = 2^4 + 2^2 -> (x << 4) + (x << 2)
|
||||
shifts = {4, 2};
|
||||
return true;
|
||||
case 24: // 24 = 2^4 + 2^3 -> (x << 4) + (x << 3)
|
||||
shifts = {4, 3};
|
||||
return true;
|
||||
case 25: // 25 = 2^4 + 2^3 + 2^0 -> (x << 4) + (x << 3) + x
|
||||
shifts = {4, 3, 0};
|
||||
return true;
|
||||
case 100: // 100 = 2^6 + 2^5 + 2^2 -> (x << 6) + (x << 5) + (x << 2)
|
||||
shifts = {6, 5, 2};
|
||||
return true;
|
||||
}
|
||||
|
||||
// 通用二进制分解(最多4个项,避免过度复杂化)
|
||||
if (constant > 0 && constant < 256) {
|
||||
std::vector<int> binaryShifts;
|
||||
int temp = constant;
|
||||
int bit = 0;
|
||||
|
||||
while (temp > 0 && binaryShifts.size() < 4) {
|
||||
if (temp & 1) {
|
||||
binaryShifts.push_back(bit);
|
||||
}
|
||||
temp >>= 1;
|
||||
bit++;
|
||||
}
|
||||
|
||||
// 只有当项数不超过3个时才使用二进制分解(比直接乘法更有效)
|
||||
if (binaryShifts.size() <= 3 && binaryShifts.size() >= 2) {
|
||||
shifts = binaryShifts;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
Value* GlobalStrengthReductionContext::createShiftDecomposition(BinaryInst* inst, Value* variable, const std::vector<int>& shifts) {
|
||||
if (shifts.empty()) return nullptr;
|
||||
|
||||
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||
|
||||
Value* result = nullptr;
|
||||
|
||||
for (int shift : shifts) {
|
||||
Value* term;
|
||||
if (shift == 0) {
|
||||
// 0位移就是原变量
|
||||
term = variable;
|
||||
} else {
|
||||
// 创建位移指令
|
||||
term = builder->createBinaryInst(Instruction::kSll, Type::getIntType(), variable, getConstantInt(shift));
|
||||
}
|
||||
|
||||
if (result == nullptr) {
|
||||
result = term;
|
||||
} else {
|
||||
// 累加到结果中
|
||||
result = builder->createAddInst(result, term);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
bool GlobalStrengthReductionContext::reduceDivision(BinaryInst *inst) {
|
||||
Value *lhs = inst->getLhs();
|
||||
Value *rhs = inst->getRhs();
|
||||
uint32_t constVal;
|
||||
|
||||
// x / 2^n = x >> n (对于无符号除法或已知为正数的情况)
|
||||
if (isConstantInt(rhs, constVal) && constVal > 0 && isPowerOfTwo(constVal)) {
|
||||
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||
int shiftAmount = log2OfPowerOfTwo(constVal);
|
||||
// 有符号除法校正:(x + (x >> 31) & mask) >> k
|
||||
int maskValue = constVal - 1;
|
||||
|
||||
// x >> 31 (算术右移获取符号位)
|
||||
Value* signShift = ConstantInteger::get(31);
|
||||
Value* signBits = builder->createBinaryInst(
|
||||
Instruction::Kind::kSra, // 算术右移
|
||||
lhs->getType(),
|
||||
lhs,
|
||||
signShift
|
||||
);
|
||||
|
||||
// (x >> 31) & mask
|
||||
Value* mask = ConstantInteger::get(maskValue);
|
||||
Value* correction = builder->createBinaryInst(
|
||||
Instruction::Kind::kAnd,
|
||||
lhs->getType(),
|
||||
signBits,
|
||||
mask
|
||||
);
|
||||
|
||||
// x + correction
|
||||
Value* corrected = builder->createAddInst(lhs, correction);
|
||||
|
||||
// (x + correction) >> k
|
||||
Value* divShift = ConstantInteger::get(shiftAmount);
|
||||
Value* shiftInst = builder->createBinaryInst(
|
||||
Instruction::Kind::kSra, // 算术右移
|
||||
lhs->getType(),
|
||||
corrected,
|
||||
divShift
|
||||
);
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << " StrengthReduction: " << inst->getName()
|
||||
<< " = x / " << constVal << " -> (x + (x >> 31) & mask) >> " << shiftAmount << std::endl;
|
||||
}
|
||||
|
||||
// builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||
// Value* divisor_minus_1 = ConstantInteger::get(constVal - 1);
|
||||
// Value* adjusted = builder->createAddInst(lhs, divisor_minus_1);
|
||||
// Value* shiftInst = builder->createBinaryInst(Instruction::kSra, Type::getIntType(), adjusted, getConstantInt(shiftAmount));
|
||||
replaceWithOptimized(inst, shiftInst);
|
||||
strengthReductionCount++;
|
||||
return true;
|
||||
}
|
||||
|
||||
// x / c = x * magic_number (魔数乘法优化 - 使用libdivide算法)
|
||||
// if (isConstantInt(rhs, constVal) && constVal > 1 && constVal != (uint32_t)(-1)) {
|
||||
// // auto magicPair = computeMulhMagicNumbers(static_cast<int>(constVal));
|
||||
// Value* magicResult = createMagicDivisionLibdivide(inst, static_cast<int>(constVal));
|
||||
// replaceWithOptimized(inst, magicResult);
|
||||
// divisionOptCount++;
|
||||
// return true;
|
||||
// }
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool GlobalStrengthReductionContext::reducePower(CallInst *inst) {
|
||||
// 检查是否是pow函数调用
|
||||
Function* callee = inst->getCallee();
|
||||
if (!callee || callee->getName() != "pow") {
|
||||
return false;
|
||||
}
|
||||
|
||||
// pow(x, 2) = x * x
|
||||
if (inst->getNumOperands() >= 2) {
|
||||
int exponent;
|
||||
if (isConstantInt(inst->getOperand(1), exponent)) {
|
||||
if (exponent == 2) {
|
||||
if (DEBUG) {
|
||||
std::cout << " StrengthReduction: pow(x, 2) -> x * x" << std::endl;
|
||||
}
|
||||
|
||||
Value* base = inst->getOperand(0);
|
||||
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||
auto mulInst = builder->createMulInst(base, base);
|
||||
replaceWithOptimized(inst, mulInst);
|
||||
strengthReductionCount++;
|
||||
return true;
|
||||
} else if (exponent >= 3 && exponent <= 8) {
|
||||
// 对于小的指数,展开为连续乘法
|
||||
if (DEBUG) {
|
||||
std::cout << " StrengthReduction: pow(x, " << exponent << ") -> repeated multiplication" << std::endl;
|
||||
}
|
||||
|
||||
Value* base = inst->getOperand(0);
|
||||
Value* result = base;
|
||||
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||
|
||||
for (int i = 1; i < exponent; i++) {
|
||||
result = builder->createMulInst(result, base);
|
||||
}
|
||||
|
||||
replaceWithOptimized(inst, result);
|
||||
strengthReductionCount++;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
Value* GlobalStrengthReductionContext::createMagicDivisionLibdivide(BinaryInst* divInst, int divisor) {
|
||||
builder->setPosition(divInst->getParent(), divInst->getParent()->findInstIterator(divInst));
|
||||
// 使用mulh指令优化任意常数除法
|
||||
auto [magic, shift] = SysYIROptUtils::computeMulhMagicNumbers(divisor);
|
||||
|
||||
// 检查是否无法优化(magic == -1, shift == -1 表示失败)
|
||||
if (magic == -1 && shift == -1) {
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Cannot optimize division by " << divisor
|
||||
<< ", keeping original division" << std::endl;
|
||||
}
|
||||
// 返回 nullptr 表示无法优化,调用方应该保持原始除法
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// 2的幂次方除法可以用移位优化(但这不是魔数法的情况)这种情况应该不会被分类到这里但是还是做一个保护措施
|
||||
if ((divisor & (divisor - 1)) == 0 && divisor > 0) {
|
||||
// 是2的幂次方,可以用移位
|
||||
int shift_amount = 0;
|
||||
int temp = divisor;
|
||||
while (temp > 1) {
|
||||
temp >>= 1;
|
||||
shift_amount++;
|
||||
}
|
||||
|
||||
Value* shiftConstant = ConstantInteger::get(shift_amount);
|
||||
// 对于有符号除法,需要先加上除数-1然后再移位(为了正确处理负数舍入)
|
||||
Value* divisor_minus_1 = ConstantInteger::get(divisor - 1);
|
||||
Value* adjusted = builder->createAddInst(divInst->getOperand(0), divisor_minus_1);
|
||||
return builder->createBinaryInst(
|
||||
Instruction::Kind::kSra, // 算术右移
|
||||
divInst->getOperand(0)->getType(),
|
||||
adjusted,
|
||||
shiftConstant
|
||||
);
|
||||
}
|
||||
|
||||
// 创建魔数常量
|
||||
// 检查魔数是否能放入32位,如果不能,则不进行优化
|
||||
if (magic > INT32_MAX || magic < INT32_MIN) {
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Magic number " << magic << " exceeds 32-bit range, skipping optimization" << std::endl;
|
||||
}
|
||||
return nullptr; // 无法优化,保持原始除法
|
||||
}
|
||||
|
||||
Value* magicConstant = ConstantInteger::get((int32_t)magic);
|
||||
|
||||
// 检查是否需要ADD_MARKER处理(加法调整)
|
||||
bool needAdd = (shift & 0x40) != 0;
|
||||
int actualShift = shift & 0x3F; // 提取真实的移位量
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] IR Generation: magic=" << magic << ", needAdd=" << needAdd
|
||||
<< ", actualShift=" << actualShift << std::endl;
|
||||
}
|
||||
|
||||
// 执行高位乘法:mulh(x, magic)
|
||||
Value* mulhResult = builder->createBinaryInst(
|
||||
Instruction::Kind::kMulh, // 高位乘法
|
||||
divInst->getOperand(0)->getType(),
|
||||
divInst->getOperand(0),
|
||||
magicConstant
|
||||
);
|
||||
|
||||
if (needAdd) {
|
||||
// ADD_MARKER 情况:需要在移位前加上被除数
|
||||
// 这对应于 libdivide 的加法调整算法
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Applying ADD_MARKER: adding dividend before shift" << std::endl;
|
||||
}
|
||||
mulhResult = builder->createAddInst(mulhResult, divInst->getOperand(0));
|
||||
}
|
||||
|
||||
if (actualShift > 0) {
|
||||
// 如果需要额外移位
|
||||
Value* shiftConstant = ConstantInteger::get(actualShift);
|
||||
mulhResult = builder->createBinaryInst(
|
||||
Instruction::Kind::kSra, // 算术右移
|
||||
divInst->getOperand(0)->getType(),
|
||||
mulhResult,
|
||||
shiftConstant
|
||||
);
|
||||
}
|
||||
|
||||
// 标准的有符号除法符号修正:如果被除数为负,商需要加1
|
||||
// 这对所有有符号除法都需要,不管是否可能有负数
|
||||
Value* isNegative = builder->createICmpLTInst(divInst->getOperand(0), ConstantInteger::get(0));
|
||||
// 将i1转换为i32:负数时为1,非负数时为0 ICmpLTInst的结果会默认转化为32位
|
||||
mulhResult = builder->createAddInst(mulhResult, isNegative);
|
||||
|
||||
return mulhResult;
|
||||
}
|
||||
|
||||
// ======================================================================
|
||||
// 辅助方法
|
||||
// ======================================================================
|
||||
|
||||
bool GlobalStrengthReductionContext::isPowerOfTwo(uint32_t n) {
|
||||
return n > 0 && (n & (n - 1)) == 0;
|
||||
}
|
||||
|
||||
int GlobalStrengthReductionContext::log2OfPowerOfTwo(uint32_t n) {
|
||||
int result = 0;
|
||||
while (n > 1) {
|
||||
n >>= 1;
|
||||
result++;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool GlobalStrengthReductionContext::isConstantInt(Value* val, int& constVal) {
|
||||
if (auto constInt = dynamic_cast<ConstantInteger*>(val)) {
|
||||
constVal = std::get<int>(constInt->getVal());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool GlobalStrengthReductionContext::isConstantInt(Value* val, uint32_t& constVal) {
|
||||
if (auto constInt = dynamic_cast<ConstantInteger*>(val)) {
|
||||
int signedVal = std::get<int>(constInt->getVal());
|
||||
if (signedVal >= 0) {
|
||||
constVal = static_cast<uint32_t>(signedVal);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
ConstantInteger* GlobalStrengthReductionContext::getConstantInt(int val) {
|
||||
return ConstantInteger::get(val);
|
||||
}
|
||||
|
||||
bool GlobalStrengthReductionContext::hasOnlyLocalUses(Instruction* inst) {
|
||||
if (!inst) return true;
|
||||
|
||||
// 简单检查:如果指令没有副作用,则认为是本地的
|
||||
if (sideEffectAnalysis) {
|
||||
auto sideEffect = sideEffectAnalysis->getInstructionSideEffect(inst);
|
||||
return sideEffect.type == SideEffectType::NO_SIDE_EFFECT;
|
||||
}
|
||||
|
||||
// 没有副作用分析时,保守处理
|
||||
return !inst->isCall() && !inst->isStore() && !inst->isLoad();
|
||||
}
|
||||
|
||||
void GlobalStrengthReductionContext::replaceWithOptimized(Instruction* original, Value* replacement) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Replacing " << original->getName()
|
||||
<< " with " << replacement->getName() << std::endl;
|
||||
}
|
||||
|
||||
original->replaceAllUsesWith(replacement);
|
||||
|
||||
// 如果替换值是新创建的指令,确保它有合适的名字
|
||||
// if (auto replInst = dynamic_cast<Instruction*>(replacement)) {
|
||||
// if (replInst->getName().empty()) {
|
||||
// replInst->setName(original->getName() + "_opt");
|
||||
// }
|
||||
// }
|
||||
|
||||
// 删除原指令,让调用者处理
|
||||
SysYIROptUtils::usedelete(original);
|
||||
}
|
||||
|
||||
} // namespace sysy
|
||||
@ -133,6 +133,7 @@ bool InductionVariableEliminationContext::run(Function* F, AnalysisManager& AM)
|
||||
printDebugInfo();
|
||||
}
|
||||
|
||||
modified |= SysYIROptUtils::eliminateRedundantPhisInFunction(F);
|
||||
return modified;
|
||||
}
|
||||
|
||||
|
||||
@ -106,187 +106,6 @@ bool StrengthReductionContext::analyzeInductionVariableRange(
|
||||
return hasNegativePotential;
|
||||
}
|
||||
|
||||
//该实现参考了libdivide的算法
|
||||
std::pair<int, int> StrengthReductionContext::computeMulhMagicNumbers(int divisor) const {
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "\n[SR] ===== Computing magic numbers for divisor " << divisor << " (libdivide algorithm) =====" << std::endl;
|
||||
}
|
||||
|
||||
if (divisor == 0) {
|
||||
if (DEBUG) std::cout << "[SR] Error: divisor must be != 0" << std::endl;
|
||||
return {-1, -1};
|
||||
}
|
||||
|
||||
// libdivide 常数
|
||||
const uint8_t LIBDIVIDE_ADD_MARKER = 0x40;
|
||||
const uint8_t LIBDIVIDE_NEGATIVE_DIVISOR = 0x80;
|
||||
|
||||
// 辅助函数:计算前导零个数
|
||||
auto count_leading_zeros32 = [](uint32_t val) -> uint32_t {
|
||||
if (val == 0) return 32;
|
||||
return __builtin_clz(val);
|
||||
};
|
||||
|
||||
// 辅助函数:64位除法返回32位商和余数
|
||||
auto div_64_32 = [](uint32_t high, uint32_t low, uint32_t divisor, uint32_t* rem) -> uint32_t {
|
||||
uint64_t dividend = ((uint64_t)high << 32) | low;
|
||||
uint32_t quotient = dividend / divisor;
|
||||
*rem = dividend % divisor;
|
||||
return quotient;
|
||||
};
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Input divisor: " << divisor << std::endl;
|
||||
}
|
||||
|
||||
// libdivide_internal_s32_gen 算法实现
|
||||
int32_t d = divisor;
|
||||
uint32_t ud = (uint32_t)d;
|
||||
uint32_t absD = (d < 0) ? -ud : ud;
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] absD = " << absD << std::endl;
|
||||
}
|
||||
|
||||
uint32_t floor_log_2_d = 31 - count_leading_zeros32(absD);
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] floor_log_2_d = " << floor_log_2_d << std::endl;
|
||||
}
|
||||
|
||||
// 检查 absD 是否为2的幂
|
||||
if ((absD & (absD - 1)) == 0) {
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] " << absD << " 是2的幂,使用移位方法" << std::endl;
|
||||
}
|
||||
|
||||
// 对于2的幂,我们只使用移位,不需要魔数
|
||||
int shift = floor_log_2_d;
|
||||
if (d < 0) shift |= 0x80; // 标记负数
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Power of 2 result: magic=0, shift=" << shift << std::endl;
|
||||
std::cout << "[SR] ===== End magic computation =====" << std::endl;
|
||||
}
|
||||
|
||||
// 对于我们的目的,我们将在IR生成中以不同方式处理2的幂
|
||||
// 返回特殊标记
|
||||
return {0, shift};
|
||||
}
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] " << absD << " is not a power of 2, computing magic number" << std::endl;
|
||||
}
|
||||
|
||||
// 非2的幂除数的魔数计算
|
||||
uint8_t more;
|
||||
uint32_t rem, proposed_m;
|
||||
|
||||
// 计算 proposed_m = floor(2^(floor_log_2_d + 31) / absD)
|
||||
proposed_m = div_64_32((uint32_t)1 << (floor_log_2_d - 1), 0, absD, &rem);
|
||||
const uint32_t e = absD - rem;
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] proposed_m = " << proposed_m << ", rem = " << rem << ", e = " << e << std::endl;
|
||||
}
|
||||
|
||||
// 确定是否需要"加法"版本
|
||||
const bool branchfree = false; // 使用分支版本
|
||||
|
||||
if (!branchfree && e < ((uint32_t)1 << floor_log_2_d)) {
|
||||
// 这个幂次有效
|
||||
more = (uint8_t)(floor_log_2_d - 1);
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Using basic algorithm, shift = " << (int)more << std::endl;
|
||||
}
|
||||
} else {
|
||||
// 我们需要上升一个等级
|
||||
proposed_m += proposed_m;
|
||||
const uint32_t twice_rem = rem + rem;
|
||||
if (twice_rem >= absD || twice_rem < rem) {
|
||||
proposed_m += 1;
|
||||
}
|
||||
more = (uint8_t)(floor_log_2_d | LIBDIVIDE_ADD_MARKER);
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Using add algorithm, proposed_m = " << proposed_m << ", more = " << (int)more << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
proposed_m += 1;
|
||||
int32_t magic = (int32_t)proposed_m;
|
||||
|
||||
// 处理负除数
|
||||
if (d < 0) {
|
||||
more |= LIBDIVIDE_NEGATIVE_DIVISOR;
|
||||
if (!branchfree) {
|
||||
magic = -magic;
|
||||
}
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Negative divisor, magic = " << magic << ", more = " << (int)more << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// 为我们的IR生成提取移位量和标志
|
||||
int shift = more & 0x3F; // 移除标志,保留移位量(位0-5)
|
||||
bool need_add = (more & LIBDIVIDE_ADD_MARKER) != 0;
|
||||
bool is_negative = (more & LIBDIVIDE_NEGATIVE_DIVISOR) != 0;
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Final result: magic = " << magic << ", more = " << (int)more
|
||||
<< " (0x" << std::hex << (int)more << std::dec << ")" << std::endl;
|
||||
std::cout << "[SR] Shift = " << shift << ", need_add = " << need_add
|
||||
<< ", is_negative = " << is_negative << std::endl;
|
||||
|
||||
// Test the magic number using the correct libdivide algorithm
|
||||
std::cout << "[SR] Testing magic number (libdivide algorithm):" << std::endl;
|
||||
int test_values[] = {1, 7, 37, 100, 999, -1, -7, -37, -100};
|
||||
|
||||
for (int test_val : test_values) {
|
||||
int64_t quotient;
|
||||
|
||||
// 实现正确的libdivide算法
|
||||
int64_t product = (int64_t)test_val * magic;
|
||||
int64_t high_bits = product >> 32;
|
||||
|
||||
if (need_add) {
|
||||
// ADD_MARKER情况:移位前加上被除数
|
||||
// 这是libdivide的关键洞察!
|
||||
high_bits += test_val;
|
||||
quotient = high_bits >> shift;
|
||||
} else {
|
||||
// 正常情况:只是移位
|
||||
quotient = high_bits >> shift;
|
||||
}
|
||||
|
||||
// 符号修正:这是libdivide有符号除法的关键部分!
|
||||
// 如果被除数为负,商需要加1来匹配C语言的截断除法语义
|
||||
if (test_val < 0) {
|
||||
quotient += 1;
|
||||
}
|
||||
|
||||
int expected = test_val / divisor;
|
||||
|
||||
bool correct = (quotient == expected);
|
||||
std::cout << "[SR] " << test_val << " / " << divisor << " = " << quotient
|
||||
<< " (expected " << expected << ") " << (correct ? "✓" : "✗") << std::endl;
|
||||
}
|
||||
|
||||
std::cout << "[SR] ===== End magic computation =====" << std::endl;
|
||||
}
|
||||
|
||||
// 返回魔数、移位量,并在移位中编码ADD_MARKER标志
|
||||
// 我们将使用移位的第6位表示ADD_MARKER,第7位表示负数(如果需要)
|
||||
int encoded_shift = shift;
|
||||
if (need_add) {
|
||||
encoded_shift |= 0x40; // 设置第6位表示ADD_MARKER
|
||||
if (DEBUG) {
|
||||
std::cout << "[SR] Encoding ADD_MARKER in shift: " << encoded_shift << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
return {magic, encoded_shift};
|
||||
}
|
||||
|
||||
bool LoopStrengthReduction::runOnFunction(Function* F, AnalysisManager& AM) {
|
||||
if (F->getBasicBlocks().empty()) {
|
||||
@ -842,9 +661,9 @@ bool StrengthReductionContext::replaceOriginalInstruction(StrengthReductionCandi
|
||||
|
||||
case StrengthReductionCandidate::DIVIDE_CONST: {
|
||||
// 任意常数除法
|
||||
builder->setPosition(candidate->containingBlock,
|
||||
candidate->containingBlock->findInstIterator(candidate->originalInst));
|
||||
replacementValue = generateConstantDivisionReplacement(candidate, builder);
|
||||
// builder->setPosition(candidate->containingBlock,
|
||||
// candidate->containingBlock->findInstIterator(candidate->originalInst));
|
||||
// replacementValue = generateConstantDivisionReplacement(candidate, builder);
|
||||
break;
|
||||
}
|
||||
|
||||
@ -864,17 +683,19 @@ bool StrengthReductionContext::replaceOriginalInstruction(StrengthReductionCandi
|
||||
);
|
||||
|
||||
// 检查原值是否为负数
|
||||
Value* zero = ConstantInteger::get(0);
|
||||
Value* isNegative = builder->createICmpLTInst(candidate->inductionVar, zero);
|
||||
Value* shift31condidata = builder->createBinaryInst(
|
||||
Instruction::Kind::kSra, candidate->inductionVar->getType(),
|
||||
candidate->inductionVar, ConstantInteger::get(31)
|
||||
);
|
||||
|
||||
// 如果为负数,需要调整结果
|
||||
Value* adjustment = ConstantInteger::get(candidate->multiplier);
|
||||
Value* adjustedTemp = builder->createAddInst(temp, adjustment);
|
||||
|
||||
// 使用条件分支来模拟select操作
|
||||
// 为简化起见,这里先用一个更复杂但可工作的方式
|
||||
// 实际应该创建条件分支,但这里先简化处理
|
||||
replacementValue = temp; // 简化版本,假设大多数情况下不是负数
|
||||
Value* adjustment = builder->createAndInst(shift31condidata, maskConstant);
|
||||
Value* adjustedTemp = builder->createAddInst(candidate->inductionVar, adjustment);
|
||||
Value* adjustedResult = builder->createBinaryInst(
|
||||
Instruction::Kind::kAnd, candidate->inductionVar->getType(),
|
||||
adjustedTemp, maskConstant
|
||||
);
|
||||
replacementValue = adjustedResult;
|
||||
} else {
|
||||
// 非负数的取模,直接使用位与
|
||||
replacementValue = builder->createBinaryInst(
|
||||
@ -1018,7 +839,7 @@ Value* StrengthReductionContext::generateConstantDivisionReplacement(
|
||||
IRBuilder* builder
|
||||
) const {
|
||||
// 使用mulh指令优化任意常数除法
|
||||
auto [magic, shift] = computeMulhMagicNumbers(candidate->multiplier);
|
||||
auto [magic, shift] = SysYIROptUtils::computeMulhMagicNumbers(candidate->multiplier);
|
||||
|
||||
// 检查是否无法优化(magic == -1, shift == -1 表示失败)
|
||||
if (magic == -1 && shift == -1) {
|
||||
|
||||
@ -1357,9 +1357,8 @@ void SCCPContext::run(Function *func, AnalysisManager &AM) {
|
||||
bool changed_control_flow = SimplifyControlFlow(func);
|
||||
|
||||
// 如果任何一个阶段修改了 IR,标记分析结果为失效
|
||||
if (changed_constant_propagation || changed_control_flow) {
|
||||
// AM.invalidate(); // 假设有这样的方法来使所有分析结果失效
|
||||
}
|
||||
bool changed = changed_constant_propagation || changed_control_flow;
|
||||
changed |= SysYIROptUtils::eliminateRedundantPhisInFunction(func);
|
||||
}
|
||||
|
||||
// SCCP Pass methods
|
||||
|
||||
125
src/midend/Pass/Optimize/TailCallOpt.cpp
Normal file
125
src/midend/Pass/Optimize/TailCallOpt.cpp
Normal file
@ -0,0 +1,125 @@
|
||||
#include "TailCallOpt.h"
|
||||
#include "IR.h"
|
||||
#include "IRBuilder.h"
|
||||
#include "SysYIROptUtils.h"
|
||||
#include <vector>
|
||||
// #include <iostream>
|
||||
#include <algorithm>
|
||||
|
||||
namespace sysy {
|
||||
|
||||
void *TailCallOpt::ID = (void *)&TailCallOpt::ID;
|
||||
|
||||
void TailCallOpt::getAnalysisUsage(std::set<void *> &analysisDependencies, std::set<void *> &analysisInvalidations) const {
|
||||
analysisInvalidations.insert(&DominatorTreeAnalysisPass::ID);
|
||||
analysisInvalidations.insert(&LoopAnalysisPass::ID);
|
||||
}
|
||||
|
||||
bool TailCallOpt::runOnFunction(Function *F, AnalysisManager &AM) {
|
||||
std::vector<CallInst *> tailCallInsts;
|
||||
// 遍历函数的所有基本块
|
||||
for (auto &bb_ptr : F->getBasicBlocks()) {
|
||||
auto BB = bb_ptr.get();
|
||||
if (BB->getInstructions().empty()) continue; // 跳过空基本块
|
||||
|
||||
auto term_iter = BB->terminator();
|
||||
if (term_iter == BB->getInstructions().end()) continue; // 没有终结指令则跳过
|
||||
auto term = (*term_iter).get();
|
||||
|
||||
if (!term || !term->isReturn()) continue; // 不是返回指令则跳过
|
||||
auto retInst = static_cast<ReturnInst *>(term);
|
||||
|
||||
Instruction *prevInst = nullptr;
|
||||
if (BB->getInstructions().size() > 1) {
|
||||
auto it = term_iter;
|
||||
--it; // 获取返回指令前的指令
|
||||
prevInst = (*it).get();
|
||||
}
|
||||
|
||||
if (!prevInst || !prevInst->isCall()) continue; // 前一条不是调用指令则跳过
|
||||
auto callInst = static_cast<CallInst *>(prevInst);
|
||||
|
||||
// 检查是否为尾递归调用:被调用函数与当前函数相同且返回值与调用结果匹配
|
||||
if (callInst->getCallee() == F) {
|
||||
// 对于尾递归,返回值应为调用结果或为 void 类型
|
||||
if (retInst->getReturnValue() == callInst ||
|
||||
(retInst->getReturnValue() == nullptr && callInst->getType()->isVoid())) {
|
||||
tailCallInsts.push_back(callInst);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (tailCallInsts.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 创建一个新的入口基本块,作为循环的前置块
|
||||
auto original_entry = F->getEntryBlock();
|
||||
auto new_entry = F->addBasicBlock("tco.entry." + F->getName());
|
||||
auto loop_header = F->addBasicBlock("tco.loop_header." + F->getName());
|
||||
|
||||
// 将原入口块中的所有指令移动到循环头块
|
||||
loop_header->getInstructions().splice(loop_header->end(), original_entry->getInstructions());
|
||||
original_entry->setName("tco.pre_header");
|
||||
|
||||
// 为函数参数创建 phi 节点
|
||||
builder->setPosition(loop_header, loop_header->begin());
|
||||
std::vector<PhiInst *> phis;
|
||||
auto original_args = F->getArguments();
|
||||
for (auto &arg : original_args) {
|
||||
auto phi = builder->createPhiInst(arg->getType(), {}, {}, "tco.phi."+arg->getName());
|
||||
phis.push_back(phi);
|
||||
}
|
||||
|
||||
// 用 phi 节点替换所有原始参数的使用
|
||||
for (size_t i = 0; i < original_args.size(); ++i) {
|
||||
original_args[i]->replaceAllUsesWith(phis[i]);
|
||||
}
|
||||
|
||||
// 设置 phi 节点的输入值
|
||||
for (size_t i = 0; i < phis.size(); ++i) {
|
||||
phis[i]->addIncoming(original_args[i], new_entry);
|
||||
}
|
||||
|
||||
// 连接各个基本块
|
||||
builder->setPosition(original_entry, original_entry->end());
|
||||
builder->createUncondBrInst(new_entry);
|
||||
original_entry->addSuccessor(new_entry);
|
||||
|
||||
builder->setPosition(new_entry, new_entry->end());
|
||||
builder->createUncondBrInst(loop_header);
|
||||
new_entry->addSuccessor(loop_header);
|
||||
loop_header->addPredecessor(new_entry);
|
||||
|
||||
// 处理每一个尾递归调用
|
||||
for (auto callInst : tailCallInsts) {
|
||||
auto tail_call_block = callInst->getParent();
|
||||
|
||||
// 收集尾递归调用的参数
|
||||
auto args_range = callInst->getArguments();
|
||||
std::vector<Value*> args;
|
||||
std::transform(args_range.begin(), args_range.end(), std::back_inserter(args),
|
||||
[](auto& use_ptr){ return use_ptr->getValue(); });
|
||||
|
||||
// 用新的参数值更新 phi 节点
|
||||
for (size_t i = 0; i < phis.size(); ++i) {
|
||||
phis[i]->addIncoming(args[i], tail_call_block);
|
||||
}
|
||||
|
||||
// 移除原有的调用和返回指令
|
||||
auto term_iter = tail_call_block->terminator();
|
||||
SysYIROptUtils::usedelete(term_iter);
|
||||
auto call_iter = tail_call_block->findInstIterator(callInst);
|
||||
SysYIROptUtils::usedelete(call_iter);
|
||||
|
||||
// 添加跳转回循环头块的分支指令
|
||||
builder->setPosition(tail_call_block, tail_call_block->end());
|
||||
builder->createUncondBrInst(loop_header);
|
||||
tail_call_block->addSuccessor(loop_header);
|
||||
loop_header->addPredecessor(tail_call_block);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace sysy
|
||||
@ -18,6 +18,8 @@
|
||||
#include "LICM.h"
|
||||
#include "LoopStrengthReduction.h"
|
||||
#include "InductionVariableElimination.h"
|
||||
#include "GlobalStrengthReduction.h"
|
||||
#include "TailCallOpt.h"
|
||||
#include "Pass.h"
|
||||
#include <iostream>
|
||||
#include <queue>
|
||||
@ -77,7 +79,10 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR
|
||||
registerOptimizationPass<LICM>(builderIR);
|
||||
registerOptimizationPass<LoopStrengthReduction>(builderIR);
|
||||
registerOptimizationPass<InductionVariableElimination>();
|
||||
|
||||
registerOptimizationPass<GlobalStrengthReduction>(builderIR);
|
||||
registerOptimizationPass<Reg2Mem>(builderIR);
|
||||
registerOptimizationPass<TailCallOpt>(builderIR);
|
||||
|
||||
registerOptimizationPass<SCCP>(builderIR);
|
||||
|
||||
@ -136,6 +141,16 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR
|
||||
this->addPass(&GVN::ID);
|
||||
this->run();
|
||||
|
||||
this->clearPasses();
|
||||
this->addPass(&TailCallOpt::ID);
|
||||
this->run();
|
||||
|
||||
if(DEBUG) {
|
||||
std::cout << "=== IR After TailCallOpt ===\n";
|
||||
SysYPrinter printer(moduleIR);
|
||||
printer.printIR();
|
||||
}
|
||||
|
||||
if(DEBUG) {
|
||||
std::cout << "=== IR After GVN Optimizations ===\n";
|
||||
printPasses();
|
||||
@ -179,9 +194,19 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR
|
||||
printPasses();
|
||||
}
|
||||
|
||||
// this->clearPasses();
|
||||
// this->addPass(&Reg2Mem::ID);
|
||||
// this->run();
|
||||
// 全局强度削弱优化,包括代数优化和魔数除法
|
||||
this->clearPasses();
|
||||
this->addPass(&GlobalStrengthReduction::ID);
|
||||
this->run();
|
||||
|
||||
if(DEBUG) {
|
||||
std::cout << "=== IR After Global Strength Reduction Optimizations ===\n";
|
||||
printPasses();
|
||||
}
|
||||
|
||||
this->clearPasses();
|
||||
this->addPass(&Reg2Mem::ID);
|
||||
this->run();
|
||||
|
||||
if(DEBUG) {
|
||||
std::cout << "=== IR After Reg2Mem Optimizations ===\n";
|
||||
|
||||
Reference in New Issue
Block a user