[midend-GSR]修复错误的代数简化

This commit is contained in:
rain2133
2025-08-18 21:55:57 +08:00
parent 5c34cbc7b8
commit ad74e435ba

View File

@ -390,8 +390,8 @@ bool GlobalStrengthReductionContext::optimizeLogical(BinaryInst *inst) {
return true;
}
// x && 1 = x
if (isConstantInt(rhs, constVal) && constVal == 1) {
// x && -1 = x
if (isConstantInt(rhs, constVal) && constVal == -1) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x && 1 -> x" << std::endl;
}
@ -416,15 +416,6 @@ bool GlobalStrengthReductionContext::optimizeLogical(BinaryInst *inst) {
replaceWithOptimized(inst, lhs);
return true;
}
// x || 1 = 1
if (isConstantInt(rhs, constVal) && constVal == 1) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x || 1 -> 1" << std::endl;
}
replaceWithOptimized(inst, getConstantInt(1));
return true;
}
// x || x = x
if (lhs == rhs) {
@ -630,16 +621,50 @@ bool GlobalStrengthReductionContext::reduceDivision(BinaryInst *inst) {
// x / 2^n = x >> n (对于无符号除法或已知为正数的情况)
if (isConstantInt(rhs, constVal) && constVal > 0 && isPowerOfTwo(constVal)) {
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
int shiftAmount = log2OfPowerOfTwo(constVal);
// 有符号除法校正:(x + (x >> 31) & mask) >> k
int maskValue = constVal - 1;
// x >> 31 (算术右移获取符号位)
Value* signShift = ConstantInteger::get(31);
Value* signBits = builder->createBinaryInst(
Instruction::Kind::kSra, // 算术右移
lhs->getType(),
lhs,
signShift
);
// (x >> 31) & mask
Value* mask = ConstantInteger::get(maskValue);
Value* correction = builder->createBinaryInst(
Instruction::Kind::kAnd,
lhs->getType(),
signBits,
mask
);
// x + correction
Value* corrected = builder->createAddInst(lhs, correction);
// (x + correction) >> k
Value* divShift = ConstantInteger::get(shiftAmount);
Value* shiftInst = builder->createBinaryInst(
Instruction::Kind::kSra, // 算术右移
lhs->getType(),
corrected,
divShift
);
if (DEBUG) {
std::cout << " StrengthReduction: " << inst->getName()
<< " = x / " << constVal << " -> x >> " << shiftAmount << std::endl;
<< " = x / " << constVal << " -> (x + (x >> 31) & mask) >> " << shiftAmount << std::endl;
}
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
Value* divisor_minus_1 = ConstantInteger::get(constVal - 1);
Value* adjusted = builder->createAddInst(lhs, divisor_minus_1);
Value* shiftInst = builder->createBinaryInst(Instruction::kSra, Type::getIntType(), adjusted, getConstantInt(shiftAmount));
// builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
// Value* divisor_minus_1 = ConstantInteger::get(constVal - 1);
// Value* adjusted = builder->createAddInst(lhs, divisor_minus_1);
// Value* shiftInst = builder->createBinaryInst(Instruction::kSra, Type::getIntType(), adjusted, getConstantInt(shiftAmount));
replaceWithOptimized(inst, shiftInst);
strengthReductionCount++;
return true;