Compare commits
1 Commits
backend-IR
...
constPropa
| Author | SHA1 | Date | |
|---|---|---|---|
| a5318a2c5c |
14
src/include/midend/Pass/Optimize/ConstPropagation.h
Normal file
14
src/include/midend/Pass/Optimize/ConstPropagation.h
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "Pass.h"
|
||||||
|
|
||||||
|
namespace sysy {
|
||||||
|
|
||||||
|
class ConstPropagation : public OptimizationPass {
|
||||||
|
public:
|
||||||
|
ConstPropagation() : OptimizationPass("ConstPropagation", Granularity::Function) {}
|
||||||
|
bool runOnFunction(Function *F, AnalysisManager& AM) override;
|
||||||
|
static char ID;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sysy
|
||||||
@ -10,6 +10,7 @@ add_library(midend_lib STATIC
|
|||||||
Pass/Optimize/Mem2Reg.cpp
|
Pass/Optimize/Mem2Reg.cpp
|
||||||
Pass/Optimize/Reg2Mem.cpp
|
Pass/Optimize/Reg2Mem.cpp
|
||||||
Pass/Optimize/SysYIRCFGOpt.cpp
|
Pass/Optimize/SysYIRCFGOpt.cpp
|
||||||
|
Pass/Optimize/ConstPropagation.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
# 包含中端模块所需的头文件路径
|
# 包含中端模块所需的头文件路径
|
||||||
|
|||||||
241
src/midend/Pass/Optimize/ConstPropagation.cpp
Normal file
241
src/midend/Pass/Optimize/ConstPropagation.cpp
Normal file
@ -0,0 +1,241 @@
|
|||||||
|
#include "Pass/Optimize/ConstPropagation.h"
|
||||||
|
#include "IR.h"
|
||||||
|
#include "Pass.h"
|
||||||
|
#include <climits>
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
namespace sysy {
|
||||||
|
|
||||||
|
char ConstPropagation::ID = 0;
|
||||||
|
|
||||||
|
bool ConstPropagation::runOnFunction(Function *func, AnalysisManager &am) {
|
||||||
|
bool changed = false;
|
||||||
|
bool localChanged = true;
|
||||||
|
|
||||||
|
while (localChanged) {
|
||||||
|
localChanged = false;
|
||||||
|
|
||||||
|
for (auto &bb : func->getBasicBlocks()) {
|
||||||
|
for (auto instIter = bb->getInstructions().begin();
|
||||||
|
instIter != bb->getInstructions().end();) {
|
||||||
|
auto &inst = *instIter;
|
||||||
|
bool shouldAdvanceIter = true;
|
||||||
|
|
||||||
|
// 处理二元运算指令
|
||||||
|
if (auto *binaryInst = dynamic_cast<BinaryInst *>(inst.get())) {
|
||||||
|
auto *lhs = binaryInst->getLhs();
|
||||||
|
auto *rhs = binaryInst->getRhs();
|
||||||
|
|
||||||
|
auto *lhsConst = dynamic_cast<ConstantValue *>(lhs);
|
||||||
|
auto *rhsConst = dynamic_cast<ConstantValue *>(rhs);
|
||||||
|
|
||||||
|
if (lhsConst && rhsConst) {
|
||||||
|
ConstantValue *newConst = nullptr;
|
||||||
|
|
||||||
|
try {
|
||||||
|
if (lhs->isInt() && rhs->isInt()) {
|
||||||
|
int l = lhsConst->getInt();
|
||||||
|
int r = rhsConst->getInt();
|
||||||
|
int result;
|
||||||
|
bool validOperation = true;
|
||||||
|
|
||||||
|
switch (binaryInst->getKind()) {
|
||||||
|
case Instruction::kAdd:
|
||||||
|
// 检查加法溢出
|
||||||
|
if ((r > 0 && l > INT_MAX - r) || (r < 0 && l < INT_MIN - r)) {
|
||||||
|
validOperation = false;
|
||||||
|
} else {
|
||||||
|
result = l + r;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case Instruction::kSub:
|
||||||
|
// 检查减法溢出
|
||||||
|
if ((r < 0 && l > INT_MAX + r) || (r > 0 && l < INT_MIN + r)) {
|
||||||
|
validOperation = false;
|
||||||
|
} else {
|
||||||
|
result = l - r;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case Instruction::kMul:
|
||||||
|
// 检查乘法溢出
|
||||||
|
if (l != 0 && r != 0 &&
|
||||||
|
(std::abs(l) > INT_MAX / std::abs(r))) {
|
||||||
|
validOperation = false;
|
||||||
|
} else {
|
||||||
|
result = l * r;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case Instruction::kDiv:
|
||||||
|
if (r == 0) {
|
||||||
|
validOperation = false;
|
||||||
|
} else {
|
||||||
|
result = l / r;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case Instruction::kRem:
|
||||||
|
if (r == 0) {
|
||||||
|
validOperation = false;
|
||||||
|
} else {
|
||||||
|
result = l % r;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case Instruction::kICmpEQ: result = (l == r) ? 1 : 0; break;
|
||||||
|
case Instruction::kICmpNE: result = (l != r) ? 1 : 0; break;
|
||||||
|
case Instruction::kICmpLT: result = (l < r) ? 1 : 0; break;
|
||||||
|
case Instruction::kICmpGT: result = (l > r) ? 1 : 0; break;
|
||||||
|
case Instruction::kICmpLE: result = (l <= r) ? 1 : 0; break;
|
||||||
|
case Instruction::kICmpGE: result = (l >= r) ? 1 : 0; break;
|
||||||
|
case Instruction::kAnd: result = (l && r) ? 1 : 0; break;
|
||||||
|
case Instruction::kOr: result = (l || r) ? 1 : 0; break;
|
||||||
|
default:
|
||||||
|
validOperation = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (validOperation) {
|
||||||
|
if (binaryInst->isCmp() || binaryInst->getKind() == Instruction::kAnd ||
|
||||||
|
binaryInst->getKind() == Instruction::kOr) {
|
||||||
|
newConst = ConstantInteger::get(Type::getIntType(), result);
|
||||||
|
} else {
|
||||||
|
newConst = ConstantInteger::get(result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (lhs->isFloat() && rhs->isFloat()) {
|
||||||
|
float l = lhsConst->getFloat();
|
||||||
|
float r = rhsConst->getFloat();
|
||||||
|
bool validOperation = true;
|
||||||
|
|
||||||
|
switch (binaryInst->getKind()) {
|
||||||
|
case Instruction::kFAdd: {
|
||||||
|
float result = l + r;
|
||||||
|
if (std::isfinite(result)) {
|
||||||
|
newConst = ConstantFloating::get(result);
|
||||||
|
} else {
|
||||||
|
validOperation = false;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case Instruction::kFSub: {
|
||||||
|
float result = l - r;
|
||||||
|
if (std::isfinite(result)) {
|
||||||
|
newConst = ConstantFloating::get(result);
|
||||||
|
} else {
|
||||||
|
validOperation = false;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case Instruction::kFMul: {
|
||||||
|
float result = l * r;
|
||||||
|
if (std::isfinite(result)) {
|
||||||
|
newConst = ConstantFloating::get(result);
|
||||||
|
} else {
|
||||||
|
validOperation = false;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case Instruction::kFDiv: {
|
||||||
|
if (std::abs(r) < std::numeric_limits<float>::epsilon()) {
|
||||||
|
validOperation = false;
|
||||||
|
} else {
|
||||||
|
float result = l / r;
|
||||||
|
if (std::isfinite(result)) {
|
||||||
|
newConst = ConstantFloating::get(result);
|
||||||
|
} else {
|
||||||
|
validOperation = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case Instruction::kFCmpEQ:
|
||||||
|
newConst = ConstantInteger::get(Type::getIntType(), (l == r) ? 1 : 0);
|
||||||
|
break;
|
||||||
|
case Instruction::kFCmpNE:
|
||||||
|
newConst = ConstantInteger::get(Type::getIntType(), (l != r) ? 1 : 0);
|
||||||
|
break;
|
||||||
|
case Instruction::kFCmpLT:
|
||||||
|
newConst = ConstantInteger::get(Type::getIntType(), (l < r) ? 1 : 0);
|
||||||
|
break;
|
||||||
|
case Instruction::kFCmpGT:
|
||||||
|
newConst = ConstantInteger::get(Type::getIntType(), (l > r) ? 1 : 0);
|
||||||
|
break;
|
||||||
|
case Instruction::kFCmpLE:
|
||||||
|
newConst = ConstantInteger::get(Type::getIntType(), (l <= r) ? 1 : 0);
|
||||||
|
break;
|
||||||
|
case Instruction::kFCmpGE:
|
||||||
|
newConst = ConstantInteger::get(Type::getIntType(), (l >= r) ? 1 : 0);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
validOperation = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (...) {
|
||||||
|
// 捕获可能的异常,跳过优化
|
||||||
|
newConst = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (newConst) {
|
||||||
|
binaryInst->replaceAllUsesWith(newConst);
|
||||||
|
instIter = bb->getInstructions().erase(instIter);
|
||||||
|
shouldAdvanceIter = false;
|
||||||
|
localChanged = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 处理一元运算指令
|
||||||
|
else if (auto *unaryInst = dynamic_cast<UnaryInst *>(inst.get())) {
|
||||||
|
auto *operand = unaryInst->getOperand();
|
||||||
|
auto *operandConst = dynamic_cast<ConstantValue *>(operand);
|
||||||
|
|
||||||
|
if (operandConst) {
|
||||||
|
ConstantValue *newConst = nullptr;
|
||||||
|
|
||||||
|
if (operand->isInt()) {
|
||||||
|
int val = operandConst->getInt();
|
||||||
|
|
||||||
|
switch (unaryInst->getKind()) {
|
||||||
|
case Instruction::kNeg:
|
||||||
|
if (val != INT_MIN) { // 避免溢出
|
||||||
|
newConst = ConstantInteger::get(-val);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case Instruction::kNot:
|
||||||
|
newConst = ConstantInteger::get(Type::getIntType(), (!val) ? 1 : 0);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else if (operand->isFloat()) {
|
||||||
|
float val = operandConst->getFloat();
|
||||||
|
|
||||||
|
switch (unaryInst->getKind()) {
|
||||||
|
case Instruction::kFNeg:
|
||||||
|
newConst = ConstantFloating::get(-val);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (newConst) {
|
||||||
|
unaryInst->replaceAllUsesWith(newConst);
|
||||||
|
instIter = bb->getInstructions().erase(instIter);
|
||||||
|
shouldAdvanceIter = false;
|
||||||
|
localChanged = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (shouldAdvanceIter) {
|
||||||
|
++instIter;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (localChanged) {
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return changed;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sysy
|
||||||
@ -5,6 +5,7 @@
|
|||||||
#include "DCE.h"
|
#include "DCE.h"
|
||||||
#include "Mem2Reg.h"
|
#include "Mem2Reg.h"
|
||||||
#include "Reg2Mem.h"
|
#include "Reg2Mem.h"
|
||||||
|
#include "ConstPropagation.h"
|
||||||
#include "Pass.h"
|
#include "Pass.h"
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
@ -80,6 +81,7 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR
|
|||||||
|
|
||||||
this->clearPasses();
|
this->clearPasses();
|
||||||
this->addPass(&Mem2Reg::ID);
|
this->addPass(&Mem2Reg::ID);
|
||||||
|
this->addPass(&ConstPropagation::ID);
|
||||||
this->run();
|
this->run();
|
||||||
|
|
||||||
if(DEBUG) {
|
if(DEBUG) {
|
||||||
|
|||||||
Reference in New Issue
Block a user