Compare commits
1 Commits
deploy-202
...
constPropa
| Author | SHA1 | Date | |
|---|---|---|---|
| a5318a2c5c |
14
src/include/midend/Pass/Optimize/ConstPropagation.h
Normal file
14
src/include/midend/Pass/Optimize/ConstPropagation.h
Normal file
@ -0,0 +1,14 @@
|
||||
#pragma once
|
||||
|
||||
#include "Pass.h"
|
||||
|
||||
namespace sysy {
|
||||
|
||||
class ConstPropagation : public OptimizationPass {
|
||||
public:
|
||||
ConstPropagation() : OptimizationPass("ConstPropagation", Granularity::Function) {}
|
||||
bool runOnFunction(Function *F, AnalysisManager& AM) override;
|
||||
static char ID;
|
||||
};
|
||||
|
||||
} // namespace sysy
|
||||
@ -10,6 +10,7 @@ add_library(midend_lib STATIC
|
||||
Pass/Optimize/Mem2Reg.cpp
|
||||
Pass/Optimize/Reg2Mem.cpp
|
||||
Pass/Optimize/SysYIRCFGOpt.cpp
|
||||
Pass/Optimize/ConstPropagation.cpp
|
||||
)
|
||||
|
||||
# 包含中端模块所需的头文件路径
|
||||
|
||||
241
src/midend/Pass/Optimize/ConstPropagation.cpp
Normal file
241
src/midend/Pass/Optimize/ConstPropagation.cpp
Normal file
@ -0,0 +1,241 @@
|
||||
#include "Pass/Optimize/ConstPropagation.h"
|
||||
#include "IR.h"
|
||||
#include "Pass.h"
|
||||
#include <climits>
|
||||
#include <cmath>
|
||||
|
||||
namespace sysy {
|
||||
|
||||
char ConstPropagation::ID = 0;
|
||||
|
||||
bool ConstPropagation::runOnFunction(Function *func, AnalysisManager &am) {
|
||||
bool changed = false;
|
||||
bool localChanged = true;
|
||||
|
||||
while (localChanged) {
|
||||
localChanged = false;
|
||||
|
||||
for (auto &bb : func->getBasicBlocks()) {
|
||||
for (auto instIter = bb->getInstructions().begin();
|
||||
instIter != bb->getInstructions().end();) {
|
||||
auto &inst = *instIter;
|
||||
bool shouldAdvanceIter = true;
|
||||
|
||||
// 处理二元运算指令
|
||||
if (auto *binaryInst = dynamic_cast<BinaryInst *>(inst.get())) {
|
||||
auto *lhs = binaryInst->getLhs();
|
||||
auto *rhs = binaryInst->getRhs();
|
||||
|
||||
auto *lhsConst = dynamic_cast<ConstantValue *>(lhs);
|
||||
auto *rhsConst = dynamic_cast<ConstantValue *>(rhs);
|
||||
|
||||
if (lhsConst && rhsConst) {
|
||||
ConstantValue *newConst = nullptr;
|
||||
|
||||
try {
|
||||
if (lhs->isInt() && rhs->isInt()) {
|
||||
int l = lhsConst->getInt();
|
||||
int r = rhsConst->getInt();
|
||||
int result;
|
||||
bool validOperation = true;
|
||||
|
||||
switch (binaryInst->getKind()) {
|
||||
case Instruction::kAdd:
|
||||
// 检查加法溢出
|
||||
if ((r > 0 && l > INT_MAX - r) || (r < 0 && l < INT_MIN - r)) {
|
||||
validOperation = false;
|
||||
} else {
|
||||
result = l + r;
|
||||
}
|
||||
break;
|
||||
case Instruction::kSub:
|
||||
// 检查减法溢出
|
||||
if ((r < 0 && l > INT_MAX + r) || (r > 0 && l < INT_MIN + r)) {
|
||||
validOperation = false;
|
||||
} else {
|
||||
result = l - r;
|
||||
}
|
||||
break;
|
||||
case Instruction::kMul:
|
||||
// 检查乘法溢出
|
||||
if (l != 0 && r != 0 &&
|
||||
(std::abs(l) > INT_MAX / std::abs(r))) {
|
||||
validOperation = false;
|
||||
} else {
|
||||
result = l * r;
|
||||
}
|
||||
break;
|
||||
case Instruction::kDiv:
|
||||
if (r == 0) {
|
||||
validOperation = false;
|
||||
} else {
|
||||
result = l / r;
|
||||
}
|
||||
break;
|
||||
case Instruction::kRem:
|
||||
if (r == 0) {
|
||||
validOperation = false;
|
||||
} else {
|
||||
result = l % r;
|
||||
}
|
||||
break;
|
||||
case Instruction::kICmpEQ: result = (l == r) ? 1 : 0; break;
|
||||
case Instruction::kICmpNE: result = (l != r) ? 1 : 0; break;
|
||||
case Instruction::kICmpLT: result = (l < r) ? 1 : 0; break;
|
||||
case Instruction::kICmpGT: result = (l > r) ? 1 : 0; break;
|
||||
case Instruction::kICmpLE: result = (l <= r) ? 1 : 0; break;
|
||||
case Instruction::kICmpGE: result = (l >= r) ? 1 : 0; break;
|
||||
case Instruction::kAnd: result = (l && r) ? 1 : 0; break;
|
||||
case Instruction::kOr: result = (l || r) ? 1 : 0; break;
|
||||
default:
|
||||
validOperation = false;
|
||||
}
|
||||
|
||||
if (validOperation) {
|
||||
if (binaryInst->isCmp() || binaryInst->getKind() == Instruction::kAnd ||
|
||||
binaryInst->getKind() == Instruction::kOr) {
|
||||
newConst = ConstantInteger::get(Type::getIntType(), result);
|
||||
} else {
|
||||
newConst = ConstantInteger::get(result);
|
||||
}
|
||||
}
|
||||
} else if (lhs->isFloat() && rhs->isFloat()) {
|
||||
float l = lhsConst->getFloat();
|
||||
float r = rhsConst->getFloat();
|
||||
bool validOperation = true;
|
||||
|
||||
switch (binaryInst->getKind()) {
|
||||
case Instruction::kFAdd: {
|
||||
float result = l + r;
|
||||
if (std::isfinite(result)) {
|
||||
newConst = ConstantFloating::get(result);
|
||||
} else {
|
||||
validOperation = false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Instruction::kFSub: {
|
||||
float result = l - r;
|
||||
if (std::isfinite(result)) {
|
||||
newConst = ConstantFloating::get(result);
|
||||
} else {
|
||||
validOperation = false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Instruction::kFMul: {
|
||||
float result = l * r;
|
||||
if (std::isfinite(result)) {
|
||||
newConst = ConstantFloating::get(result);
|
||||
} else {
|
||||
validOperation = false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Instruction::kFDiv: {
|
||||
if (std::abs(r) < std::numeric_limits<float>::epsilon()) {
|
||||
validOperation = false;
|
||||
} else {
|
||||
float result = l / r;
|
||||
if (std::isfinite(result)) {
|
||||
newConst = ConstantFloating::get(result);
|
||||
} else {
|
||||
validOperation = false;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Instruction::kFCmpEQ:
|
||||
newConst = ConstantInteger::get(Type::getIntType(), (l == r) ? 1 : 0);
|
||||
break;
|
||||
case Instruction::kFCmpNE:
|
||||
newConst = ConstantInteger::get(Type::getIntType(), (l != r) ? 1 : 0);
|
||||
break;
|
||||
case Instruction::kFCmpLT:
|
||||
newConst = ConstantInteger::get(Type::getIntType(), (l < r) ? 1 : 0);
|
||||
break;
|
||||
case Instruction::kFCmpGT:
|
||||
newConst = ConstantInteger::get(Type::getIntType(), (l > r) ? 1 : 0);
|
||||
break;
|
||||
case Instruction::kFCmpLE:
|
||||
newConst = ConstantInteger::get(Type::getIntType(), (l <= r) ? 1 : 0);
|
||||
break;
|
||||
case Instruction::kFCmpGE:
|
||||
newConst = ConstantInteger::get(Type::getIntType(), (l >= r) ? 1 : 0);
|
||||
break;
|
||||
default:
|
||||
validOperation = false;
|
||||
}
|
||||
}
|
||||
} catch (...) {
|
||||
// 捕获可能的异常,跳过优化
|
||||
newConst = nullptr;
|
||||
}
|
||||
|
||||
if (newConst) {
|
||||
binaryInst->replaceAllUsesWith(newConst);
|
||||
instIter = bb->getInstructions().erase(instIter);
|
||||
shouldAdvanceIter = false;
|
||||
localChanged = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
// 处理一元运算指令
|
||||
else if (auto *unaryInst = dynamic_cast<UnaryInst *>(inst.get())) {
|
||||
auto *operand = unaryInst->getOperand();
|
||||
auto *operandConst = dynamic_cast<ConstantValue *>(operand);
|
||||
|
||||
if (operandConst) {
|
||||
ConstantValue *newConst = nullptr;
|
||||
|
||||
if (operand->isInt()) {
|
||||
int val = operandConst->getInt();
|
||||
|
||||
switch (unaryInst->getKind()) {
|
||||
case Instruction::kNeg:
|
||||
if (val != INT_MIN) { // 避免溢出
|
||||
newConst = ConstantInteger::get(-val);
|
||||
}
|
||||
break;
|
||||
case Instruction::kNot:
|
||||
newConst = ConstantInteger::get(Type::getIntType(), (!val) ? 1 : 0);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
} else if (operand->isFloat()) {
|
||||
float val = operandConst->getFloat();
|
||||
|
||||
switch (unaryInst->getKind()) {
|
||||
case Instruction::kFNeg:
|
||||
newConst = ConstantFloating::get(-val);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (newConst) {
|
||||
unaryInst->replaceAllUsesWith(newConst);
|
||||
instIter = bb->getInstructions().erase(instIter);
|
||||
shouldAdvanceIter = false;
|
||||
localChanged = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (shouldAdvanceIter) {
|
||||
++instIter;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (localChanged) {
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace sysy
|
||||
@ -5,6 +5,7 @@
|
||||
#include "DCE.h"
|
||||
#include "Mem2Reg.h"
|
||||
#include "Reg2Mem.h"
|
||||
#include "ConstPropagation.h"
|
||||
#include "Pass.h"
|
||||
#include <iostream>
|
||||
#include <queue>
|
||||
@ -80,6 +81,7 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR
|
||||
|
||||
this->clearPasses();
|
||||
this->addPass(&Mem2Reg::ID);
|
||||
this->addPass(&ConstPropagation::ID);
|
||||
this->run();
|
||||
|
||||
if(DEBUG) {
|
||||
|
||||
Reference in New Issue
Block a user