From 0ce742a86eca84071c32c4215a3099586326772f Mon Sep 17 00:00:00 2001 From: CGH0S7 <776459475@qq.com> Date: Sun, 3 Aug 2025 14:37:33 +0800 Subject: [PATCH] =?UTF-8?q?[optimize]=E6=B7=BB=E5=8A=A0=E6=9B=B4=E4=B8=BA?= =?UTF-8?q?=E9=80=9A=E7=94=A8=E7=9A=84=E9=99=A4=E6=B3=95=E5=BC=BA=E5=BA=A6?= =?UTF-8?q?=E5=89=8A=E5=87=8FPass,=20=E4=B8=8D=E5=8F=97=E9=99=A4=E6=95=B0?= =?UTF-8?q?=E9=99=90=E5=88=B6=E6=9B=BF=E6=8D=A2div=E6=8C=87=E4=BB=A4?= =?UTF-8?q?=EF=BC=8C=E4=B8=8D=E5=BD=B1=E5=93=8D=E5=BD=93=E5=89=8D=E5=88=86?= =?UTF-8?q?=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../RISCv64/Optimize/DivStrengthReduction.cpp | 319 ++++++++---------- src/backend/RISCv64/RISCv64AsmPrinter.cpp | 2 +- src/include/backend/RISCv64/RISCv64LLIR.h | 2 +- src/include/midend/IR.h | 4 +- src/include/midend/IRBuilder.h | 3 + src/midend/SysYIRGenerator.cpp | 23 ++ test_div_optimization.sy | 9 + 7 files changed, 175 insertions(+), 187 deletions(-) create mode 100644 test_div_optimization.sy diff --git a/src/backend/RISCv64/Optimize/DivStrengthReduction.cpp b/src/backend/RISCv64/Optimize/DivStrengthReduction.cpp index c052c5d..bcd7bda 100644 --- a/src/backend/RISCv64/Optimize/DivStrengthReduction.cpp +++ b/src/backend/RISCv64/Optimize/DivStrengthReduction.cpp @@ -1,4 +1,6 @@ #include "DivStrengthReduction.h" +#include +#include namespace sysy { @@ -17,69 +19,49 @@ void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) { if (debug) std::cout << "Running DivStrengthReduction optimization..." << std::endl; - // 虚拟寄存器分配器 int next_temp_reg = 1000; auto createTempReg = [&]() -> int { return next_temp_reg++; }; - // Magic number 信息结构 struct MagicInfo { int64_t magic; int shift; - bool add_indicator; // 是否需要额外的加法修正 }; - // 针对缺少MULH指令的简化magic number计算 - auto computeMagicNumber = [](int64_t divisor, bool is_32bit) -> MagicInfo { - if (divisor == 0) return {0, 0, false}; - if (divisor == 1) return {1, 0, false}; - if (divisor == -1) return {-1, 0, false}; - - // 对于没有MULH的情况,我们使用更简单但有效的算法 - // 基于 2^n / divisor 的近似 - - bool neg = divisor < 0; - int64_t d = neg ? -divisor : divisor; - + auto computeMagic = [](int64_t d, bool is_32bit) -> MagicInfo { int word_size = is_32bit ? 32 : 64; + uint64_t ad = std::abs(d); - // 计算合适的移位量 - int shift = word_size; - int64_t magic = ((1LL << shift) + d - 1) / d; + if (ad == 0) return {0, 0}; - // 调整magic number以适应MUL指令 + int l = std::floor(std::log2(ad)); + if ((ad & (ad - 1)) == 0) { // power of 2 + l = 0; // special case for power of 2, shift will be calculated differently + } + + __int128_t one = 1; + __int128_t num; + int total_shift; + if (is_32bit) { - // 32位情况:调整magic使其适合符号扩展后的乘法 - shift = 32; - magic = ((1LL << shift) + d - 1) / d; + total_shift = 31 + l; + num = one << total_shift; } else { - // 64位情况:使用更保守的算法 - shift = 32; // 使用32位作为基础移位 - magic = ((1LL << shift) + d - 1) / d; + total_shift = 63 + l; + num = one << total_shift; } - bool add_indicator = false; + __int128_t den = ad; + int64_t magic = (num / den) + 1; - // 检查是否需要加法修正 - if (magic >= (1LL << (word_size - 1))) { - add_indicator = true; - magic -= (1LL << word_size); - } - - if (neg) { - magic = -magic; - } - - return {magic, shift, add_indicator}; + return {magic, total_shift}; }; - // 检查是否为2的幂次 auto isPowerOfTwo = [](int64_t n) -> bool { return n > 0 && (n & (n - 1)) == 0; }; - // 获取2的幂次的指数 auto getPowerOfTwoExponent = [](int64_t n) -> int { if (n <= 0 || (n & (n - 1)) != 0) return -1; int shift = 0; @@ -90,9 +72,9 @@ void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) { return shift; }; - // 收集需要替换的指令 struct InstructionReplacement { size_t index; + size_t count_to_erase; std::vector> newInstrs; }; @@ -106,7 +88,6 @@ void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) { bool is_32bit = (instr->getOpcode() == RVOpcodes::DIVW); - // 只处理 DIV 和 DIVW 指令 if (instr->getOpcode() != RVOpcodes::DIV && !is_32bit) { continue; } @@ -118,100 +99,74 @@ void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) { auto *dst_op = instr->getOperands()[0].get(); auto *src1_op = instr->getOperands()[1].get(); auto *src2_op = instr->getOperands()[2].get(); - - // 检查操作数类型 - if (dst_op->getKind() != MachineOperand::KIND_REG || - src1_op->getKind() != MachineOperand::KIND_REG || - src2_op->getKind() != MachineOperand::KIND_IMM) { + + int64_t divisor = 0; + bool const_divisor_found = false; + size_t instructions_to_replace = 1; + + if (src2_op->getKind() == MachineOperand::KIND_IMM) { + divisor = static_cast(src2_op)->getValue(); + const_divisor_found = true; + } else if (src2_op->getKind() == MachineOperand::KIND_REG) { + if (i > 0) { + auto *prev_instr = instrs[i - 1].get(); + if (prev_instr->getOpcode() == RVOpcodes::LI && prev_instr->getOperands().size() == 2) { + auto *li_dst_op = prev_instr->getOperands()[0].get(); + auto *li_imm_op = prev_instr->getOperands()[1].get(); + if (li_dst_op->getKind() == MachineOperand::KIND_REG && li_imm_op->getKind() == MachineOperand::KIND_IMM) { + auto *div_reg_op = static_cast(src2_op); + auto *li_dst_reg_op = static_cast(li_dst_op); + if (div_reg_op->isVirtual() && li_dst_reg_op->isVirtual() && + div_reg_op->getVRegNum() == li_dst_reg_op->getVRegNum()) { + divisor = static_cast(li_imm_op)->getValue(); + const_divisor_found = true; + instructions_to_replace = 2; + } + } + } + } + } + + if (!const_divisor_found) { continue; } auto *dst_reg = static_cast(dst_op); auto *src1_reg = static_cast(src1_op); - auto *src2_imm = static_cast(src2_op); - int64_t divisor = src2_imm->getValue(); - - // 跳过除数为0的情况 if (divisor == 0) continue; std::vector> newInstrs; - // 情况1: 除数为1 if (divisor == 1) { - // dst = src1 (直接复制) auto moveInstr = std::make_unique(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD); moveInstr->addOperand(std::make_unique(*dst_reg)); moveInstr->addOperand(std::make_unique(*src1_reg)); moveInstr->addOperand(std::make_unique(PhysicalReg::ZERO)); newInstrs.push_back(std::move(moveInstr)); } - // 情况2: 除数为-1 else if (divisor == -1) { - // dst = -src1 auto negInstr = std::make_unique(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB); negInstr->addOperand(std::make_unique(*dst_reg)); negInstr->addOperand(std::make_unique(PhysicalReg::ZERO)); negInstr->addOperand(std::make_unique(*src1_reg)); newInstrs.push_back(std::move(negInstr)); } - // 情况3: 正的2的幂次除法 - else if (isPowerOfTwo(divisor)) { - int shift = getPowerOfTwoExponent(divisor); + else if (isPowerOfTwo(std::abs(divisor))) { + int shift = getPowerOfTwoExponent(std::abs(divisor)); int temp_reg = createTempReg(); - // 对于有符号除法,需要处理负数的舍入 - // if (src1 < 0) src1 += (divisor - 1) - - // 获取符号位:temp = src1 >> (word_size - 1) auto sraSignInstr = std::make_unique(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI); sraSignInstr->addOperand(std::make_unique(temp_reg)); sraSignInstr->addOperand(std::make_unique(*src1_reg)); sraSignInstr->addOperand(std::make_unique(is_32bit ? 31 : 63)); newInstrs.push_back(std::move(sraSignInstr)); - // 计算偏移:temp = temp >> (word_size - shift) - if (shift < (is_32bit ? 32 : 64)) { - auto srlInstr = std::make_unique(is_32bit ? RVOpcodes::SRLIW : RVOpcodes::SRLI); - srlInstr->addOperand(std::make_unique(temp_reg)); - srlInstr->addOperand(std::make_unique(temp_reg)); - srlInstr->addOperand(std::make_unique((is_32bit ? 32 : 64) - shift)); - newInstrs.push_back(std::move(srlInstr)); - } - - // 加上偏移:temp = src1 + temp - auto addInstr = std::make_unique(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD); - addInstr->addOperand(std::make_unique(temp_reg)); - addInstr->addOperand(std::make_unique(*src1_reg)); - addInstr->addOperand(std::make_unique(temp_reg)); - newInstrs.push_back(std::move(addInstr)); - - // 最终右移:dst = temp >> shift - auto sraInstr = std::make_unique(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI); - sraInstr->addOperand(std::make_unique(*dst_reg)); - sraInstr->addOperand(std::make_unique(temp_reg)); - sraInstr->addOperand(std::make_unique(shift)); - newInstrs.push_back(std::move(sraInstr)); - } - // 情况4: 负的2的幂次除法 - else if (divisor < 0 && isPowerOfTwo(-divisor)) { - int shift = getPowerOfTwoExponent(-divisor); - int temp_reg = createTempReg(); - - // 先按正数处理 - auto sraSignInstr = std::make_unique(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI); - sraSignInstr->addOperand(std::make_unique(temp_reg)); - sraSignInstr->addOperand(std::make_unique(*src1_reg)); - sraSignInstr->addOperand(std::make_unique(is_32bit ? 31 : 63)); - newInstrs.push_back(std::move(sraSignInstr)); - - if (shift < (is_32bit ? 32 : 64)) { - auto srlInstr = std::make_unique(is_32bit ? RVOpcodes::SRLIW : RVOpcodes::SRLI); - srlInstr->addOperand(std::make_unique(temp_reg)); - srlInstr->addOperand(std::make_unique(temp_reg)); - srlInstr->addOperand(std::make_unique((is_32bit ? 32 : 64) - shift)); - newInstrs.push_back(std::move(srlInstr)); - } + auto srlInstr = std::make_unique(is_32bit ? RVOpcodes::SRLIW : RVOpcodes::SRLI); + srlInstr->addOperand(std::make_unique(temp_reg)); + srlInstr->addOperand(std::make_unique(temp_reg)); + srlInstr->addOperand(std::make_unique((is_32bit ? 32 : 64) - shift)); + newInstrs.push_back(std::move(srlInstr)); auto addInstr = std::make_unique(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD); addInstr->addOperand(std::make_unique(temp_reg)); @@ -224,101 +179,99 @@ void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) { sraInstr->addOperand(std::make_unique(temp_reg)); sraInstr->addOperand(std::make_unique(shift)); newInstrs.push_back(std::move(sraInstr)); - - // 然后取反 - auto negInstr = std::make_unique(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB); - negInstr->addOperand(std::make_unique(*dst_reg)); - negInstr->addOperand(std::make_unique(PhysicalReg::ZERO)); - negInstr->addOperand(std::make_unique(temp_reg)); - newInstrs.push_back(std::move(negInstr)); + + if (divisor < 0) { + auto negInstr = std::make_unique(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB); + negInstr->addOperand(std::make_unique(*dst_reg)); + negInstr->addOperand(std::make_unique(PhysicalReg::ZERO)); + negInstr->addOperand(std::make_unique(temp_reg)); + newInstrs.push_back(std::move(negInstr)); + } else { + auto moveInstr = std::make_unique(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD); + moveInstr->addOperand(std::make_unique(*dst_reg)); + moveInstr->addOperand(std::make_unique(temp_reg)); + moveInstr->addOperand(std::make_unique(PhysicalReg::ZERO)); + newInstrs.push_back(std::move(moveInstr)); + } } - // 情况5: 通用magic number算法(针对没有MULH的情况进行了简化) else { - // 对于一般除法,在没有MULH的情况下,我们采用更保守的策略 - // 只处理一些简单的常数除法,复杂的情况保持原始除法指令 - - // 检查是否为小的常数(可以用简单乘法处理) - if (std::abs(divisor) <= 1024) { // 限制在较小的除数范围内 - auto magic_info = computeMagicNumber(divisor, is_32bit); + auto magic_info = computeMagic(divisor, is_32bit); + int magic_reg = createTempReg(); + int temp_reg = createTempReg(); + + auto loadInstr = std::make_unique(RVOpcodes::LI); + loadInstr->addOperand(std::make_unique(magic_reg)); + loadInstr->addOperand(std::make_unique(magic_info.magic)); + newInstrs.push_back(std::move(loadInstr)); + + if (is_32bit) { + auto mulInstr = std::make_unique(RVOpcodes::MUL); + mulInstr->addOperand(std::make_unique(temp_reg)); + mulInstr->addOperand(std::make_unique(*src1_reg)); + mulInstr->addOperand(std::make_unique(magic_reg)); + newInstrs.push_back(std::move(mulInstr)); + + auto sraInstr = std::make_unique(RVOpcodes::SRAI); + sraInstr->addOperand(std::make_unique(temp_reg)); + sraInstr->addOperand(std::make_unique(temp_reg)); + sraInstr->addOperand(std::make_unique(magic_info.shift)); + newInstrs.push_back(std::move(sraInstr)); + } else { + auto mulhInstr = std::make_unique(RVOpcodes::MULH); + mulhInstr->addOperand(std::make_unique(temp_reg)); + mulhInstr->addOperand(std::make_unique(*src1_reg)); + mulhInstr->addOperand(std::make_unique(magic_reg)); + newInstrs.push_back(std::move(mulhInstr)); - if (magic_info.magic == 0) continue; - - int magic_reg = createTempReg(); - int temp_reg = createTempReg(); - - // 加载magic number到寄存器 - auto loadInstr = std::make_unique(RVOpcodes::LI); - loadInstr->addOperand(std::make_unique(magic_reg)); - loadInstr->addOperand(std::make_unique(magic_info.magic)); - newInstrs.push_back(std::move(loadInstr)); - - // 使用普通乘法模拟高位乘法 - if (is_32bit) { - // 32位:使用MULW - auto mulInstr = std::make_unique(RVOpcodes::MULW); - mulInstr->addOperand(std::make_unique(temp_reg)); - mulInstr->addOperand(std::make_unique(*src1_reg)); - mulInstr->addOperand(std::make_unique(magic_reg)); - newInstrs.push_back(std::move(mulInstr)); - - // 右移得到近似结果 - auto sraInstr = std::make_unique(RVOpcodes::SRAIW); - sraInstr->addOperand(std::make_unique(temp_reg)); - sraInstr->addOperand(std::make_unique(temp_reg)); - sraInstr->addOperand(std::make_unique(magic_info.shift)); - newInstrs.push_back(std::move(sraInstr)); - } else { - // 64位:使用MUL - auto mulInstr = std::make_unique(RVOpcodes::MUL); - mulInstr->addOperand(std::make_unique(temp_reg)); - mulInstr->addOperand(std::make_unique(*src1_reg)); - mulInstr->addOperand(std::make_unique(magic_reg)); - newInstrs.push_back(std::move(mulInstr)); - - // 右移得到近似结果 + int post_shift = magic_info.shift - 63; + if (post_shift > 0) { auto sraInstr = std::make_unique(RVOpcodes::SRAI); sraInstr->addOperand(std::make_unique(temp_reg)); sraInstr->addOperand(std::make_unique(temp_reg)); - sraInstr->addOperand(std::make_unique(magic_info.shift)); + sraInstr->addOperand(std::make_unique(post_shift)); newInstrs.push_back(std::move(sraInstr)); } - - // 符号修正:处理负数被除数 - int sign_reg = createTempReg(); - - // 获取被除数的符号位 - auto sraSignInstr = std::make_unique(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI); - sraSignInstr->addOperand(std::make_unique(sign_reg)); - sraSignInstr->addOperand(std::make_unique(*src1_reg)); - sraSignInstr->addOperand(std::make_unique(is_32bit ? 31 : 63)); - newInstrs.push_back(std::move(sraSignInstr)); - - // 最终结果:dst = temp - sign(对于正除数)或 dst = temp + sign(对于负除数) - if (divisor > 0) { - auto finalSubInstr = std::make_unique(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB); - finalSubInstr->addOperand(std::make_unique(*dst_reg)); - finalSubInstr->addOperand(std::make_unique(temp_reg)); - finalSubInstr->addOperand(std::make_unique(sign_reg)); - newInstrs.push_back(std::move(finalSubInstr)); - } else { - auto finalAddInstr = std::make_unique(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD); - finalAddInstr->addOperand(std::make_unique(*dst_reg)); - finalAddInstr->addOperand(std::make_unique(temp_reg)); - finalAddInstr->addOperand(std::make_unique(sign_reg)); - newInstrs.push_back(std::move(finalAddInstr)); - } } - // 对于大的除数或复杂情况,保持原始除法指令不变 + + int sign_reg = createTempReg(); + auto sraSignInstr = std::make_unique(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI); + sraSignInstr->addOperand(std::make_unique(sign_reg)); + sraSignInstr->addOperand(std::make_unique(*src1_reg)); + sraSignInstr->addOperand(std::make_unique(is_32bit ? 31 : 63)); + newInstrs.push_back(std::move(sraSignInstr)); + + auto subInstr = std::make_unique(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB); + subInstr->addOperand(std::make_unique(temp_reg)); + subInstr->addOperand(std::make_unique(temp_reg)); + subInstr->addOperand(std::make_unique(sign_reg)); + newInstrs.push_back(std::move(subInstr)); + + if (divisor < 0) { + auto negInstr = std::make_unique(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB); + negInstr->addOperand(std::make_unique(*dst_reg)); + negInstr->addOperand(std::make_unique(PhysicalReg::ZERO)); + negInstr->addOperand(std::make_unique(temp_reg)); + newInstrs.push_back(std::move(negInstr)); + } else { + auto moveInstr = std::make_unique(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD); + moveInstr->addOperand(std::make_unique(*dst_reg)); + moveInstr->addOperand(std::make_unique(temp_reg)); + moveInstr->addOperand(std::make_unique(PhysicalReg::ZERO)); + newInstrs.push_back(std::move(moveInstr)); + } } if (!newInstrs.empty()) { - replacements.push_back({i, std::move(newInstrs)}); + size_t start_index = i; + if (instructions_to_replace == 2) { + start_index = i - 1; + } + replacements.push_back({start_index, instructions_to_replace, std::move(newInstrs)}); } } - // 批量应用替换(从后往前处理避免索引问题) for (auto it = replacements.rbegin(); it != replacements.rend(); ++it) { - instrs.erase(instrs.begin() + it->index); + instrs.erase(instrs.begin() + it->index, instrs.begin() + it->index + it->count_to_erase); instrs.insert(instrs.begin() + it->index, std::make_move_iterator(it->newInstrs.begin()), std::make_move_iterator(it->newInstrs.end())); @@ -326,4 +279,4 @@ void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) { } } -} // namespace sysy \ No newline at end of file +} // namespace sysy diff --git a/src/backend/RISCv64/RISCv64AsmPrinter.cpp b/src/backend/RISCv64/RISCv64AsmPrinter.cpp index 183003c..5e73cd2 100644 --- a/src/backend/RISCv64/RISCv64AsmPrinter.cpp +++ b/src/backend/RISCv64/RISCv64AsmPrinter.cpp @@ -60,7 +60,7 @@ void RISCv64AsmPrinter::printInstruction(MachineInstr* instr, bool debug) { case RVOpcodes::ADD: *OS << "add "; break; case RVOpcodes::ADDI: *OS << "addi "; break; case RVOpcodes::ADDW: *OS << "addw "; break; case RVOpcodes::ADDIW: *OS << "addiw "; break; case RVOpcodes::SUB: *OS << "sub "; break; case RVOpcodes::SUBW: *OS << "subw "; break; - case RVOpcodes::MUL: *OS << "mul "; break; case RVOpcodes::MULW: *OS << "mulw "; break; + case RVOpcodes::MUL: *OS << "mul "; break; case RVOpcodes::MULW: *OS << "mulw "; break; case RVOpcodes::MULH: *OS << "mulh "; break; case RVOpcodes::DIV: *OS << "div "; break; case RVOpcodes::DIVW: *OS << "divw "; break; case RVOpcodes::REM: *OS << "rem "; break; case RVOpcodes::REMW: *OS << "remw "; break; case RVOpcodes::XOR: *OS << "xor "; break; case RVOpcodes::XORI: *OS << "xori "; break; diff --git a/src/include/backend/RISCv64/RISCv64LLIR.h b/src/include/backend/RISCv64/RISCv64LLIR.h index be11528..3c8710e 100644 --- a/src/include/backend/RISCv64/RISCv64LLIR.h +++ b/src/include/backend/RISCv64/RISCv64LLIR.h @@ -45,7 +45,7 @@ enum class PhysicalReg { // RISC-V 指令操作码枚举 enum class RVOpcodes { // 算术指令 - ADD, ADDI, ADDW, ADDIW, SUB, SUBW, MUL, MULW, DIV, DIVW, REM, REMW, + ADD, ADDI, ADDW, ADDIW, SUB, SUBW, MUL, MULW, MULH, DIV, DIVW, REM, REMW, // 逻辑指令 XOR, XORI, OR, ORI, AND, ANDI, // 移位指令 diff --git a/src/include/midend/IR.h b/src/include/midend/IR.h index 2e4d72b..9423a8c 100644 --- a/src/include/midend/IR.h +++ b/src/include/midend/IR.h @@ -709,7 +709,7 @@ class Instruction : public User { kBitItoF = 0x1UL << 40, kBitFtoI = 0x1UL << 41, kSRA = 0x1UL << 42, - kMulh = 0x1UL << 43, + kMulh = 0x1UL << 43 }; protected: @@ -823,7 +823,7 @@ public: bool isBinary() const { static constexpr uint64_t BinaryOpMask = - (kAdd | kSub | kMul | kDiv | kRem | kAnd | kOr | kSRA) | + (kAdd | kSub | kMul | kDiv | kRem | kAnd | kOr | kSRA | kMulh) | (kICmpEQ | kICmpNE | kICmpLT | kICmpGT | kICmpLE | kICmpGE); return kind & BinaryOpMask; } diff --git a/src/include/midend/IRBuilder.h b/src/include/midend/IRBuilder.h index c232578..ab50173 100644 --- a/src/include/midend/IRBuilder.h +++ b/src/include/midend/IRBuilder.h @@ -220,6 +220,9 @@ class IRBuilder { BinaryInst * createSRAInst(Value *lhs, Value *rhs, const std::string &name = "") { return createBinaryInst(Instruction::kSRA, Type::getIntType(), lhs, rhs, name); } ///< 创建算术右移指令 + BinaryInst * createMulhInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kMulh, Type::getIntType(), lhs, rhs, name); + } ///< 创建高位乘法指令 CallInst * createCallInst(Function *callee, const std::vector &args, const std::string &name = "") { std::string newName; if (name.empty() && callee->getReturnType() != Type::getVoidType()) { diff --git a/src/midend/SysYIRGenerator.cpp b/src/midend/SysYIRGenerator.cpp index 0b2d751..3228257 100644 --- a/src/midend/SysYIRGenerator.cpp +++ b/src/midend/SysYIRGenerator.cpp @@ -15,6 +15,29 @@ using namespace std; namespace sysy { +std::pair calculate_signed_magic(int d) { + if (d == 0) throw std::runtime_error("Division by zero"); + if (d == 1 || d == -1) return {0, 0}; // Not used by strength reduction + + int k = 0; + unsigned int ad = (d > 0) ? d : -d; + unsigned int temp = ad; + while (temp > 0) { + temp >>= 1; + k++; + } + if ((ad & (ad - 1)) == 0) { // if power of 2 + k--; + } + + unsigned __int128 m_val = 1; + m_val <<= (32 + k - 1); + unsigned __int128 m_prime = m_val / ad; + long long m = m_prime + 1; + + return {m, k}; +} + // std::vector BinaryValueStack; ///< 用于存储value的栈 // std::vector BinaryOpStack; ///< 用于存储二元表达式的操作符栈 diff --git a/test_div_optimization.sy b/test_div_optimization.sy new file mode 100644 index 0000000..30c0bb1 --- /dev/null +++ b/test_div_optimization.sy @@ -0,0 +1,9 @@ +int main() { + int a = 100; + int b = a / 4; + int c = a / 8; + int d = a / 16; + int e = a / 7; + int f = a / 3; + return b + c + d + e; +}