[midend-GSR]修复错误的代数简化
This commit is contained in:
@ -390,8 +390,8 @@ bool GlobalStrengthReductionContext::optimizeLogical(BinaryInst *inst) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// x && 1 = x
|
||||
if (isConstantInt(rhs, constVal) && constVal == 1) {
|
||||
// x && -1 = x
|
||||
if (isConstantInt(rhs, constVal) && constVal == -1) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x && 1 -> x" << std::endl;
|
||||
}
|
||||
@ -416,15 +416,6 @@ bool GlobalStrengthReductionContext::optimizeLogical(BinaryInst *inst) {
|
||||
replaceWithOptimized(inst, lhs);
|
||||
return true;
|
||||
}
|
||||
|
||||
// x || 1 = 1
|
||||
if (isConstantInt(rhs, constVal) && constVal == 1) {
|
||||
if (DEBUG) {
|
||||
std::cout << " Algebraic: " << inst->getName() << " = x || 1 -> 1" << std::endl;
|
||||
}
|
||||
replaceWithOptimized(inst, getConstantInt(1));
|
||||
return true;
|
||||
}
|
||||
|
||||
// x || x = x
|
||||
if (lhs == rhs) {
|
||||
@ -630,16 +621,50 @@ bool GlobalStrengthReductionContext::reduceDivision(BinaryInst *inst) {
|
||||
|
||||
// x / 2^n = x >> n (对于无符号除法或已知为正数的情况)
|
||||
if (isConstantInt(rhs, constVal) && constVal > 0 && isPowerOfTwo(constVal)) {
|
||||
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||
int shiftAmount = log2OfPowerOfTwo(constVal);
|
||||
// 有符号除法校正:(x + (x >> 31) & mask) >> k
|
||||
int maskValue = constVal - 1;
|
||||
|
||||
// x >> 31 (算术右移获取符号位)
|
||||
Value* signShift = ConstantInteger::get(31);
|
||||
Value* signBits = builder->createBinaryInst(
|
||||
Instruction::Kind::kSra, // 算术右移
|
||||
lhs->getType(),
|
||||
lhs,
|
||||
signShift
|
||||
);
|
||||
|
||||
// (x >> 31) & mask
|
||||
Value* mask = ConstantInteger::get(maskValue);
|
||||
Value* correction = builder->createBinaryInst(
|
||||
Instruction::Kind::kAnd,
|
||||
lhs->getType(),
|
||||
signBits,
|
||||
mask
|
||||
);
|
||||
|
||||
// x + correction
|
||||
Value* corrected = builder->createAddInst(lhs, correction);
|
||||
|
||||
// (x + correction) >> k
|
||||
Value* divShift = ConstantInteger::get(shiftAmount);
|
||||
Value* shiftInst = builder->createBinaryInst(
|
||||
Instruction::Kind::kSra, // 算术右移
|
||||
lhs->getType(),
|
||||
corrected,
|
||||
divShift
|
||||
);
|
||||
|
||||
if (DEBUG) {
|
||||
std::cout << " StrengthReduction: " << inst->getName()
|
||||
<< " = x / " << constVal << " -> x >> " << shiftAmount << std::endl;
|
||||
<< " = x / " << constVal << " -> (x + (x >> 31) & mask) >> " << shiftAmount << std::endl;
|
||||
}
|
||||
|
||||
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||
Value* divisor_minus_1 = ConstantInteger::get(constVal - 1);
|
||||
Value* adjusted = builder->createAddInst(lhs, divisor_minus_1);
|
||||
Value* shiftInst = builder->createBinaryInst(Instruction::kSra, Type::getIntType(), adjusted, getConstantInt(shiftAmount));
|
||||
// builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||
// Value* divisor_minus_1 = ConstantInteger::get(constVal - 1);
|
||||
// Value* adjusted = builder->createAddInst(lhs, divisor_minus_1);
|
||||
// Value* shiftInst = builder->createBinaryInst(Instruction::kSra, Type::getIntType(), adjusted, getConstantInt(shiftAmount));
|
||||
replaceWithOptimized(inst, shiftInst);
|
||||
strengthReductionCount++;
|
||||
return true;
|
||||
|
||||
Reference in New Issue
Block a user