Compare commits

...

1 Commits

Author SHA1 Message Date
a5318a2c5c 为中端加入常量传播Pass 2025-07-31 20:46:35 +08:00
4 changed files with 258 additions and 0 deletions

View File

@ -0,0 +1,14 @@
#pragma once
#include "Pass.h"
namespace sysy {
class ConstPropagation : public OptimizationPass {
public:
ConstPropagation() : OptimizationPass("ConstPropagation", Granularity::Function) {}
bool runOnFunction(Function *F, AnalysisManager& AM) override;
static char ID;
};
} // namespace sysy

View File

@ -10,6 +10,7 @@ add_library(midend_lib STATIC
Pass/Optimize/Mem2Reg.cpp
Pass/Optimize/Reg2Mem.cpp
Pass/Optimize/SysYIRCFGOpt.cpp
Pass/Optimize/ConstPropagation.cpp
)
# 包含中端模块所需的头文件路径

View File

@ -0,0 +1,241 @@
#include "Pass/Optimize/ConstPropagation.h"
#include "IR.h"
#include "Pass.h"
#include <climits>
#include <cmath>
namespace sysy {
char ConstPropagation::ID = 0;
bool ConstPropagation::runOnFunction(Function *func, AnalysisManager &am) {
bool changed = false;
bool localChanged = true;
while (localChanged) {
localChanged = false;
for (auto &bb : func->getBasicBlocks()) {
for (auto instIter = bb->getInstructions().begin();
instIter != bb->getInstructions().end();) {
auto &inst = *instIter;
bool shouldAdvanceIter = true;
// 处理二元运算指令
if (auto *binaryInst = dynamic_cast<BinaryInst *>(inst.get())) {
auto *lhs = binaryInst->getLhs();
auto *rhs = binaryInst->getRhs();
auto *lhsConst = dynamic_cast<ConstantValue *>(lhs);
auto *rhsConst = dynamic_cast<ConstantValue *>(rhs);
if (lhsConst && rhsConst) {
ConstantValue *newConst = nullptr;
try {
if (lhs->isInt() && rhs->isInt()) {
int l = lhsConst->getInt();
int r = rhsConst->getInt();
int result;
bool validOperation = true;
switch (binaryInst->getKind()) {
case Instruction::kAdd:
// 检查加法溢出
if ((r > 0 && l > INT_MAX - r) || (r < 0 && l < INT_MIN - r)) {
validOperation = false;
} else {
result = l + r;
}
break;
case Instruction::kSub:
// 检查减法溢出
if ((r < 0 && l > INT_MAX + r) || (r > 0 && l < INT_MIN + r)) {
validOperation = false;
} else {
result = l - r;
}
break;
case Instruction::kMul:
// 检查乘法溢出
if (l != 0 && r != 0 &&
(std::abs(l) > INT_MAX / std::abs(r))) {
validOperation = false;
} else {
result = l * r;
}
break;
case Instruction::kDiv:
if (r == 0) {
validOperation = false;
} else {
result = l / r;
}
break;
case Instruction::kRem:
if (r == 0) {
validOperation = false;
} else {
result = l % r;
}
break;
case Instruction::kICmpEQ: result = (l == r) ? 1 : 0; break;
case Instruction::kICmpNE: result = (l != r) ? 1 : 0; break;
case Instruction::kICmpLT: result = (l < r) ? 1 : 0; break;
case Instruction::kICmpGT: result = (l > r) ? 1 : 0; break;
case Instruction::kICmpLE: result = (l <= r) ? 1 : 0; break;
case Instruction::kICmpGE: result = (l >= r) ? 1 : 0; break;
case Instruction::kAnd: result = (l && r) ? 1 : 0; break;
case Instruction::kOr: result = (l || r) ? 1 : 0; break;
default:
validOperation = false;
}
if (validOperation) {
if (binaryInst->isCmp() || binaryInst->getKind() == Instruction::kAnd ||
binaryInst->getKind() == Instruction::kOr) {
newConst = ConstantInteger::get(Type::getIntType(), result);
} else {
newConst = ConstantInteger::get(result);
}
}
} else if (lhs->isFloat() && rhs->isFloat()) {
float l = lhsConst->getFloat();
float r = rhsConst->getFloat();
bool validOperation = true;
switch (binaryInst->getKind()) {
case Instruction::kFAdd: {
float result = l + r;
if (std::isfinite(result)) {
newConst = ConstantFloating::get(result);
} else {
validOperation = false;
}
break;
}
case Instruction::kFSub: {
float result = l - r;
if (std::isfinite(result)) {
newConst = ConstantFloating::get(result);
} else {
validOperation = false;
}
break;
}
case Instruction::kFMul: {
float result = l * r;
if (std::isfinite(result)) {
newConst = ConstantFloating::get(result);
} else {
validOperation = false;
}
break;
}
case Instruction::kFDiv: {
if (std::abs(r) < std::numeric_limits<float>::epsilon()) {
validOperation = false;
} else {
float result = l / r;
if (std::isfinite(result)) {
newConst = ConstantFloating::get(result);
} else {
validOperation = false;
}
}
break;
}
case Instruction::kFCmpEQ:
newConst = ConstantInteger::get(Type::getIntType(), (l == r) ? 1 : 0);
break;
case Instruction::kFCmpNE:
newConst = ConstantInteger::get(Type::getIntType(), (l != r) ? 1 : 0);
break;
case Instruction::kFCmpLT:
newConst = ConstantInteger::get(Type::getIntType(), (l < r) ? 1 : 0);
break;
case Instruction::kFCmpGT:
newConst = ConstantInteger::get(Type::getIntType(), (l > r) ? 1 : 0);
break;
case Instruction::kFCmpLE:
newConst = ConstantInteger::get(Type::getIntType(), (l <= r) ? 1 : 0);
break;
case Instruction::kFCmpGE:
newConst = ConstantInteger::get(Type::getIntType(), (l >= r) ? 1 : 0);
break;
default:
validOperation = false;
}
}
} catch (...) {
// 捕获可能的异常,跳过优化
newConst = nullptr;
}
if (newConst) {
binaryInst->replaceAllUsesWith(newConst);
instIter = bb->getInstructions().erase(instIter);
shouldAdvanceIter = false;
localChanged = true;
}
}
}
// 处理一元运算指令
else if (auto *unaryInst = dynamic_cast<UnaryInst *>(inst.get())) {
auto *operand = unaryInst->getOperand();
auto *operandConst = dynamic_cast<ConstantValue *>(operand);
if (operandConst) {
ConstantValue *newConst = nullptr;
if (operand->isInt()) {
int val = operandConst->getInt();
switch (unaryInst->getKind()) {
case Instruction::kNeg:
if (val != INT_MIN) { // 避免溢出
newConst = ConstantInteger::get(-val);
}
break;
case Instruction::kNot:
newConst = ConstantInteger::get(Type::getIntType(), (!val) ? 1 : 0);
break;
default:
break;
}
} else if (operand->isFloat()) {
float val = operandConst->getFloat();
switch (unaryInst->getKind()) {
case Instruction::kFNeg:
newConst = ConstantFloating::get(-val);
break;
default:
break;
}
}
if (newConst) {
unaryInst->replaceAllUsesWith(newConst);
instIter = bb->getInstructions().erase(instIter);
shouldAdvanceIter = false;
localChanged = true;
}
}
}
if (shouldAdvanceIter) {
++instIter;
}
}
}
if (localChanged) {
changed = true;
}
}
return changed;
}
} // namespace sysy

View File

@ -5,6 +5,7 @@
#include "DCE.h"
#include "Mem2Reg.h"
#include "Reg2Mem.h"
#include "ConstPropagation.h"
#include "Pass.h"
#include <iostream>
#include <queue>
@ -80,6 +81,7 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR
this->clearPasses();
this->addPass(&Mem2Reg::ID);
this->addPass(&ConstPropagation::ID);
this->run();
if(DEBUG) {