[optimize]添加更为通用的除法强度削减Pass, 不受除数限制替换div指令,不影响当前分数

This commit is contained in:
2025-08-03 14:37:33 +08:00
parent f312792fe9
commit 0ce742a86e
7 changed files with 175 additions and 187 deletions

View File

@ -1,4 +1,6 @@
#include "DivStrengthReduction.h"
#include <cmath>
#include <cstdint>
namespace sysy {
@ -17,69 +19,49 @@ void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) {
if (debug)
std::cout << "Running DivStrengthReduction optimization..." << std::endl;
// 虚拟寄存器分配器
int next_temp_reg = 1000;
auto createTempReg = [&]() -> int {
return next_temp_reg++;
};
// Magic number 信息结构
struct MagicInfo {
int64_t magic;
int shift;
bool add_indicator; // 是否需要额外的加法修正
};
// 针对缺少MULH指令的简化magic number计算
auto computeMagicNumber = [](int64_t divisor, bool is_32bit) -> MagicInfo {
if (divisor == 0) return {0, 0, false};
if (divisor == 1) return {1, 0, false};
if (divisor == -1) return {-1, 0, false};
// 对于没有MULH的情况我们使用更简单但有效的算法
// 基于 2^n / divisor 的近似
bool neg = divisor < 0;
int64_t d = neg ? -divisor : divisor;
auto computeMagic = [](int64_t d, bool is_32bit) -> MagicInfo {
int word_size = is_32bit ? 32 : 64;
uint64_t ad = std::abs(d);
// 计算合适的移位量
int shift = word_size;
int64_t magic = ((1LL << shift) + d - 1) / d;
if (ad == 0) return {0, 0};
// 调整magic number以适应MUL指令
int l = std::floor(std::log2(ad));
if ((ad & (ad - 1)) == 0) { // power of 2
l = 0; // special case for power of 2, shift will be calculated differently
}
__int128_t one = 1;
__int128_t num;
int total_shift;
if (is_32bit) {
// 32位情况调整magic使其适合符号扩展后的乘法
shift = 32;
magic = ((1LL << shift) + d - 1) / d;
total_shift = 31 + l;
num = one << total_shift;
} else {
// 64位情况使用更保守的算法
shift = 32; // 使用32位作为基础移位
magic = ((1LL << shift) + d - 1) / d;
total_shift = 63 + l;
num = one << total_shift;
}
bool add_indicator = false;
__int128_t den = ad;
int64_t magic = (num / den) + 1;
// 检查是否需要加法修正
if (magic >= (1LL << (word_size - 1))) {
add_indicator = true;
magic -= (1LL << word_size);
}
if (neg) {
magic = -magic;
}
return {magic, shift, add_indicator};
return {magic, total_shift};
};
// 检查是否为2的幂次
auto isPowerOfTwo = [](int64_t n) -> bool {
return n > 0 && (n & (n - 1)) == 0;
};
// 获取2的幂次的指数
auto getPowerOfTwoExponent = [](int64_t n) -> int {
if (n <= 0 || (n & (n - 1)) != 0) return -1;
int shift = 0;
@ -90,9 +72,9 @@ void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) {
return shift;
};
// 收集需要替换的指令
struct InstructionReplacement {
size_t index;
size_t count_to_erase;
std::vector<std::unique_ptr<MachineInstr>> newInstrs;
};
@ -106,7 +88,6 @@ void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) {
bool is_32bit = (instr->getOpcode() == RVOpcodes::DIVW);
// 只处理 DIV 和 DIVW 指令
if (instr->getOpcode() != RVOpcodes::DIV && !is_32bit) {
continue;
}
@ -118,100 +99,74 @@ void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) {
auto *dst_op = instr->getOperands()[0].get();
auto *src1_op = instr->getOperands()[1].get();
auto *src2_op = instr->getOperands()[2].get();
// 检查操作数类型
if (dst_op->getKind() != MachineOperand::KIND_REG ||
src1_op->getKind() != MachineOperand::KIND_REG ||
src2_op->getKind() != MachineOperand::KIND_IMM) {
int64_t divisor = 0;
bool const_divisor_found = false;
size_t instructions_to_replace = 1;
if (src2_op->getKind() == MachineOperand::KIND_IMM) {
divisor = static_cast<ImmOperand *>(src2_op)->getValue();
const_divisor_found = true;
} else if (src2_op->getKind() == MachineOperand::KIND_REG) {
if (i > 0) {
auto *prev_instr = instrs[i - 1].get();
if (prev_instr->getOpcode() == RVOpcodes::LI && prev_instr->getOperands().size() == 2) {
auto *li_dst_op = prev_instr->getOperands()[0].get();
auto *li_imm_op = prev_instr->getOperands()[1].get();
if (li_dst_op->getKind() == MachineOperand::KIND_REG && li_imm_op->getKind() == MachineOperand::KIND_IMM) {
auto *div_reg_op = static_cast<RegOperand *>(src2_op);
auto *li_dst_reg_op = static_cast<RegOperand *>(li_dst_op);
if (div_reg_op->isVirtual() && li_dst_reg_op->isVirtual() &&
div_reg_op->getVRegNum() == li_dst_reg_op->getVRegNum()) {
divisor = static_cast<ImmOperand *>(li_imm_op)->getValue();
const_divisor_found = true;
instructions_to_replace = 2;
}
}
}
}
}
if (!const_divisor_found) {
continue;
}
auto *dst_reg = static_cast<RegOperand *>(dst_op);
auto *src1_reg = static_cast<RegOperand *>(src1_op);
auto *src2_imm = static_cast<ImmOperand *>(src2_op);
int64_t divisor = src2_imm->getValue();
// 跳过除数为0的情况
if (divisor == 0) continue;
std::vector<std::unique_ptr<MachineInstr>> newInstrs;
// 情况1: 除数为1
if (divisor == 1) {
// dst = src1 (直接复制)
auto moveInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
moveInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
moveInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
moveInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
newInstrs.push_back(std::move(moveInstr));
}
// 情况2: 除数为-1
else if (divisor == -1) {
// dst = -src1
auto negInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB);
negInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
negInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
negInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
newInstrs.push_back(std::move(negInstr));
}
// 情况3: 正的2的幂次除法
else if (isPowerOfTwo(divisor)) {
int shift = getPowerOfTwoExponent(divisor);
else if (isPowerOfTwo(std::abs(divisor))) {
int shift = getPowerOfTwoExponent(std::abs(divisor));
int temp_reg = createTempReg();
// 对于有符号除法,需要处理负数的舍入
// if (src1 < 0) src1 += (divisor - 1)
// 获取符号位temp = src1 >> (word_size - 1)
auto sraSignInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI);
sraSignInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraSignInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
sraSignInstr->addOperand(std::make_unique<ImmOperand>(is_32bit ? 31 : 63));
newInstrs.push_back(std::move(sraSignInstr));
// 计算偏移temp = temp >> (word_size - shift)
if (shift < (is_32bit ? 32 : 64)) {
auto srlInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRLIW : RVOpcodes::SRLI);
srlInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
srlInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
srlInstr->addOperand(std::make_unique<ImmOperand>((is_32bit ? 32 : 64) - shift));
newInstrs.push_back(std::move(srlInstr));
}
// 加上偏移temp = src1 + temp
auto addInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
addInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
addInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
addInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
newInstrs.push_back(std::move(addInstr));
// 最终右移dst = temp >> shift
auto sraInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI);
sraInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<ImmOperand>(shift));
newInstrs.push_back(std::move(sraInstr));
}
// 情况4: 负的2的幂次除法
else if (divisor < 0 && isPowerOfTwo(-divisor)) {
int shift = getPowerOfTwoExponent(-divisor);
int temp_reg = createTempReg();
// 先按正数处理
auto sraSignInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI);
sraSignInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraSignInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
sraSignInstr->addOperand(std::make_unique<ImmOperand>(is_32bit ? 31 : 63));
newInstrs.push_back(std::move(sraSignInstr));
if (shift < (is_32bit ? 32 : 64)) {
auto srlInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRLIW : RVOpcodes::SRLI);
srlInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
srlInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
srlInstr->addOperand(std::make_unique<ImmOperand>((is_32bit ? 32 : 64) - shift));
newInstrs.push_back(std::move(srlInstr));
}
auto srlInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRLIW : RVOpcodes::SRLI);
srlInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
srlInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
srlInstr->addOperand(std::make_unique<ImmOperand>((is_32bit ? 32 : 64) - shift));
newInstrs.push_back(std::move(srlInstr));
auto addInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
addInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
@ -224,101 +179,99 @@ void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) {
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<ImmOperand>(shift));
newInstrs.push_back(std::move(sraInstr));
// 然后取反
auto negInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB);
negInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
negInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
negInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
newInstrs.push_back(std::move(negInstr));
if (divisor < 0) {
auto negInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB);
negInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
negInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
negInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
newInstrs.push_back(std::move(negInstr));
} else {
auto moveInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
moveInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
moveInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
moveInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
newInstrs.push_back(std::move(moveInstr));
}
}
// 情况5: 通用magic number算法针对没有MULH的情况进行了简化
else {
// 对于一般除法在没有MULH的情况下我们采用更保守的策略
// 只处理一些简单的常数除法,复杂的情况保持原始除法指令
// 检查是否为小的常数(可以用简单乘法处理)
if (std::abs(divisor) <= 1024) { // 限制在较小的除数范围内
auto magic_info = computeMagicNumber(divisor, is_32bit);
auto magic_info = computeMagic(divisor, is_32bit);
int magic_reg = createTempReg();
int temp_reg = createTempReg();
auto loadInstr = std::make_unique<MachineInstr>(RVOpcodes::LI);
loadInstr->addOperand(std::make_unique<RegOperand>(magic_reg));
loadInstr->addOperand(std::make_unique<ImmOperand>(magic_info.magic));
newInstrs.push_back(std::move(loadInstr));
if (is_32bit) {
auto mulInstr = std::make_unique<MachineInstr>(RVOpcodes::MUL);
mulInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
mulInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
mulInstr->addOperand(std::make_unique<RegOperand>(magic_reg));
newInstrs.push_back(std::move(mulInstr));
auto sraInstr = std::make_unique<MachineInstr>(RVOpcodes::SRAI);
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<ImmOperand>(magic_info.shift));
newInstrs.push_back(std::move(sraInstr));
} else {
auto mulhInstr = std::make_unique<MachineInstr>(RVOpcodes::MULH);
mulhInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
mulhInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
mulhInstr->addOperand(std::make_unique<RegOperand>(magic_reg));
newInstrs.push_back(std::move(mulhInstr));
if (magic_info.magic == 0) continue;
int magic_reg = createTempReg();
int temp_reg = createTempReg();
// 加载magic number到寄存器
auto loadInstr = std::make_unique<MachineInstr>(RVOpcodes::LI);
loadInstr->addOperand(std::make_unique<RegOperand>(magic_reg));
loadInstr->addOperand(std::make_unique<ImmOperand>(magic_info.magic));
newInstrs.push_back(std::move(loadInstr));
// 使用普通乘法模拟高位乘法
if (is_32bit) {
// 32位使用MULW
auto mulInstr = std::make_unique<MachineInstr>(RVOpcodes::MULW);
mulInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
mulInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
mulInstr->addOperand(std::make_unique<RegOperand>(magic_reg));
newInstrs.push_back(std::move(mulInstr));
// 右移得到近似结果
auto sraInstr = std::make_unique<MachineInstr>(RVOpcodes::SRAIW);
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<ImmOperand>(magic_info.shift));
newInstrs.push_back(std::move(sraInstr));
} else {
// 64位使用MUL
auto mulInstr = std::make_unique<MachineInstr>(RVOpcodes::MUL);
mulInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
mulInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
mulInstr->addOperand(std::make_unique<RegOperand>(magic_reg));
newInstrs.push_back(std::move(mulInstr));
// 右移得到近似结果
int post_shift = magic_info.shift - 63;
if (post_shift > 0) {
auto sraInstr = std::make_unique<MachineInstr>(RVOpcodes::SRAI);
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<ImmOperand>(magic_info.shift));
sraInstr->addOperand(std::make_unique<ImmOperand>(post_shift));
newInstrs.push_back(std::move(sraInstr));
}
// 符号修正:处理负数被除数
int sign_reg = createTempReg();
// 获取被除数的符号位
auto sraSignInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI);
sraSignInstr->addOperand(std::make_unique<RegOperand>(sign_reg));
sraSignInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
sraSignInstr->addOperand(std::make_unique<ImmOperand>(is_32bit ? 31 : 63));
newInstrs.push_back(std::move(sraSignInstr));
// 最终结果dst = temp - sign对于正除数或 dst = temp + sign对于负除数
if (divisor > 0) {
auto finalSubInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB);
finalSubInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
finalSubInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
finalSubInstr->addOperand(std::make_unique<RegOperand>(sign_reg));
newInstrs.push_back(std::move(finalSubInstr));
} else {
auto finalAddInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
finalAddInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
finalAddInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
finalAddInstr->addOperand(std::make_unique<RegOperand>(sign_reg));
newInstrs.push_back(std::move(finalAddInstr));
}
}
// 对于大的除数或复杂情况,保持原始除法指令不变
int sign_reg = createTempReg();
auto sraSignInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI);
sraSignInstr->addOperand(std::make_unique<RegOperand>(sign_reg));
sraSignInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
sraSignInstr->addOperand(std::make_unique<ImmOperand>(is_32bit ? 31 : 63));
newInstrs.push_back(std::move(sraSignInstr));
auto subInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB);
subInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
subInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
subInstr->addOperand(std::make_unique<RegOperand>(sign_reg));
newInstrs.push_back(std::move(subInstr));
if (divisor < 0) {
auto negInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB);
negInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
negInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
negInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
newInstrs.push_back(std::move(negInstr));
} else {
auto moveInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
moveInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
moveInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
moveInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
newInstrs.push_back(std::move(moveInstr));
}
}
if (!newInstrs.empty()) {
replacements.push_back({i, std::move(newInstrs)});
size_t start_index = i;
if (instructions_to_replace == 2) {
start_index = i - 1;
}
replacements.push_back({start_index, instructions_to_replace, std::move(newInstrs)});
}
}
// 批量应用替换(从后往前处理避免索引问题)
for (auto it = replacements.rbegin(); it != replacements.rend(); ++it) {
instrs.erase(instrs.begin() + it->index);
instrs.erase(instrs.begin() + it->index, instrs.begin() + it->index + it->count_to_erase);
instrs.insert(instrs.begin() + it->index,
std::make_move_iterator(it->newInstrs.begin()),
std::make_move_iterator(it->newInstrs.end()));
@ -326,4 +279,4 @@ void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) {
}
}
} // namespace sysy
} // namespace sysy

View File

@ -60,7 +60,7 @@ void RISCv64AsmPrinter::printInstruction(MachineInstr* instr, bool debug) {
case RVOpcodes::ADD: *OS << "add "; break; case RVOpcodes::ADDI: *OS << "addi "; break;
case RVOpcodes::ADDW: *OS << "addw "; break; case RVOpcodes::ADDIW: *OS << "addiw "; break;
case RVOpcodes::SUB: *OS << "sub "; break; case RVOpcodes::SUBW: *OS << "subw "; break;
case RVOpcodes::MUL: *OS << "mul "; break; case RVOpcodes::MULW: *OS << "mulw "; break;
case RVOpcodes::MUL: *OS << "mul "; break; case RVOpcodes::MULW: *OS << "mulw "; break; case RVOpcodes::MULH: *OS << "mulh "; break;
case RVOpcodes::DIV: *OS << "div "; break; case RVOpcodes::DIVW: *OS << "divw "; break;
case RVOpcodes::REM: *OS << "rem "; break; case RVOpcodes::REMW: *OS << "remw "; break;
case RVOpcodes::XOR: *OS << "xor "; break; case RVOpcodes::XORI: *OS << "xori "; break;

View File

@ -45,7 +45,7 @@ enum class PhysicalReg {
// RISC-V 指令操作码枚举
enum class RVOpcodes {
// 算术指令
ADD, ADDI, ADDW, ADDIW, SUB, SUBW, MUL, MULW, DIV, DIVW, REM, REMW,
ADD, ADDI, ADDW, ADDIW, SUB, SUBW, MUL, MULW, MULH, DIV, DIVW, REM, REMW,
// 逻辑指令
XOR, XORI, OR, ORI, AND, ANDI,
// 移位指令

View File

@ -709,7 +709,7 @@ class Instruction : public User {
kBitItoF = 0x1UL << 40,
kBitFtoI = 0x1UL << 41,
kSRA = 0x1UL << 42,
kMulh = 0x1UL << 43,
kMulh = 0x1UL << 43
};
protected:
@ -823,7 +823,7 @@ public:
bool isBinary() const {
static constexpr uint64_t BinaryOpMask =
(kAdd | kSub | kMul | kDiv | kRem | kAnd | kOr | kSRA) |
(kAdd | kSub | kMul | kDiv | kRem | kAnd | kOr | kSRA | kMulh) |
(kICmpEQ | kICmpNE | kICmpLT | kICmpGT | kICmpLE | kICmpGE);
return kind & BinaryOpMask;
}

View File

@ -220,6 +220,9 @@ class IRBuilder {
BinaryInst * createSRAInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kSRA, Type::getIntType(), lhs, rhs, name);
} ///< 创建算术右移指令
BinaryInst * createMulhInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kMulh, Type::getIntType(), lhs, rhs, name);
} ///< 创建高位乘法指令
CallInst * createCallInst(Function *callee, const std::vector<Value *> &args, const std::string &name = "") {
std::string newName;
if (name.empty() && callee->getReturnType() != Type::getVoidType()) {

View File

@ -15,6 +15,29 @@
using namespace std;
namespace sysy {
std::pair<long long, int> calculate_signed_magic(int d) {
if (d == 0) throw std::runtime_error("Division by zero");
if (d == 1 || d == -1) return {0, 0}; // Not used by strength reduction
int k = 0;
unsigned int ad = (d > 0) ? d : -d;
unsigned int temp = ad;
while (temp > 0) {
temp >>= 1;
k++;
}
if ((ad & (ad - 1)) == 0) { // if power of 2
k--;
}
unsigned __int128 m_val = 1;
m_val <<= (32 + k - 1);
unsigned __int128 m_prime = m_val / ad;
long long m = m_prime + 1;
return {m, k};
}
// std::vector<Value*> BinaryValueStack; ///< 用于存储value的栈
// std::vector<int> BinaryOpStack; ///< 用于存储二元表达式的操作符栈

9
test_div_optimization.sy Normal file
View File

@ -0,0 +1,9 @@
int main() {
int a = 100;
int b = a / 4;
int c = a / 8;
int d = a / 16;
int e = a / 7;
int f = a / 3;
return b + c + d + e;
}