diff --git a/src/backend/RISCv64/CMakeLists.txt b/src/backend/RISCv64/CMakeLists.txt index eb9f37f..f1e8f55 100644 --- a/src/backend/RISCv64/CMakeLists.txt +++ b/src/backend/RISCv64/CMakeLists.txt @@ -11,6 +11,7 @@ add_library(riscv64_backend_lib STATIC Optimize/Peephole.cpp Optimize/PostRA_Scheduler.cpp Optimize/PreRA_Scheduler.cpp + Optimize/DivStrengthReduction.cpp ) # 包含后端模块所需的头文件路径 diff --git a/src/backend/RISCv64/Optimize/DivStrengthReduction.cpp b/src/backend/RISCv64/Optimize/DivStrengthReduction.cpp new file mode 100644 index 0000000..c052c5d --- /dev/null +++ b/src/backend/RISCv64/Optimize/DivStrengthReduction.cpp @@ -0,0 +1,329 @@ +#include "DivStrengthReduction.h" + +namespace sysy { + +char DivStrengthReduction::ID = 0; + +bool DivStrengthReduction::runOnFunction(Function *F, AnalysisManager& AM) { + // This pass works on MachineFunction level, not IR level + return false; +} + +void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) { + if (!mfunc) + return; + + bool debug = false; // Set to true for debugging + if (debug) + std::cout << "Running DivStrengthReduction optimization..." << std::endl; + + // 虚拟寄存器分配器 + int next_temp_reg = 1000; + auto createTempReg = [&]() -> int { + return next_temp_reg++; + }; + + // Magic number 信息结构 + struct MagicInfo { + int64_t magic; + int shift; + bool add_indicator; // 是否需要额外的加法修正 + }; + + // 针对缺少MULH指令的简化magic number计算 + auto computeMagicNumber = [](int64_t divisor, bool is_32bit) -> MagicInfo { + if (divisor == 0) return {0, 0, false}; + if (divisor == 1) return {1, 0, false}; + if (divisor == -1) return {-1, 0, false}; + + // 对于没有MULH的情况,我们使用更简单但有效的算法 + // 基于 2^n / divisor 的近似 + + bool neg = divisor < 0; + int64_t d = neg ? -divisor : divisor; + + int word_size = is_32bit ? 32 : 64; + + // 计算合适的移位量 + int shift = word_size; + int64_t magic = ((1LL << shift) + d - 1) / d; + + // 调整magic number以适应MUL指令 + if (is_32bit) { + // 32位情况:调整magic使其适合符号扩展后的乘法 + shift = 32; + magic = ((1LL << shift) + d - 1) / d; + } else { + // 64位情况:使用更保守的算法 + shift = 32; // 使用32位作为基础移位 + magic = ((1LL << shift) + d - 1) / d; + } + + bool add_indicator = false; + + // 检查是否需要加法修正 + if (magic >= (1LL << (word_size - 1))) { + add_indicator = true; + magic -= (1LL << word_size); + } + + if (neg) { + magic = -magic; + } + + return {magic, shift, add_indicator}; + }; + + // 检查是否为2的幂次 + auto isPowerOfTwo = [](int64_t n) -> bool { + return n > 0 && (n & (n - 1)) == 0; + }; + + // 获取2的幂次的指数 + auto getPowerOfTwoExponent = [](int64_t n) -> int { + if (n <= 0 || (n & (n - 1)) != 0) return -1; + int shift = 0; + while (n > 1) { + n >>= 1; + shift++; + } + return shift; + }; + + // 收集需要替换的指令 + struct InstructionReplacement { + size_t index; + std::vector> newInstrs; + }; + + for (auto &mbb_uptr : mfunc->getBlocks()) { + auto &mbb = *mbb_uptr; + auto &instrs = mbb.getInstructions(); + std::vector replacements; + + for (size_t i = 0; i < instrs.size(); ++i) { + auto *instr = instrs[i].get(); + + bool is_32bit = (instr->getOpcode() == RVOpcodes::DIVW); + + // 只处理 DIV 和 DIVW 指令 + if (instr->getOpcode() != RVOpcodes::DIV && !is_32bit) { + continue; + } + + if (instr->getOperands().size() != 3) { + continue; + } + + auto *dst_op = instr->getOperands()[0].get(); + auto *src1_op = instr->getOperands()[1].get(); + auto *src2_op = instr->getOperands()[2].get(); + + // 检查操作数类型 + if (dst_op->getKind() != MachineOperand::KIND_REG || + src1_op->getKind() != MachineOperand::KIND_REG || + src2_op->getKind() != MachineOperand::KIND_IMM) { + continue; + } + + auto *dst_reg = static_cast(dst_op); + auto *src1_reg = static_cast(src1_op); + auto *src2_imm = static_cast(src2_op); + + int64_t divisor = src2_imm->getValue(); + + // 跳过除数为0的情况 + if (divisor == 0) continue; + + std::vector> newInstrs; + + // 情况1: 除数为1 + if (divisor == 1) { + // dst = src1 (直接复制) + auto moveInstr = std::make_unique(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD); + moveInstr->addOperand(std::make_unique(*dst_reg)); + moveInstr->addOperand(std::make_unique(*src1_reg)); + moveInstr->addOperand(std::make_unique(PhysicalReg::ZERO)); + newInstrs.push_back(std::move(moveInstr)); + } + // 情况2: 除数为-1 + else if (divisor == -1) { + // dst = -src1 + auto negInstr = std::make_unique(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB); + negInstr->addOperand(std::make_unique(*dst_reg)); + negInstr->addOperand(std::make_unique(PhysicalReg::ZERO)); + negInstr->addOperand(std::make_unique(*src1_reg)); + newInstrs.push_back(std::move(negInstr)); + } + // 情况3: 正的2的幂次除法 + else if (isPowerOfTwo(divisor)) { + int shift = getPowerOfTwoExponent(divisor); + int temp_reg = createTempReg(); + + // 对于有符号除法,需要处理负数的舍入 + // if (src1 < 0) src1 += (divisor - 1) + + // 获取符号位:temp = src1 >> (word_size - 1) + auto sraSignInstr = std::make_unique(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI); + sraSignInstr->addOperand(std::make_unique(temp_reg)); + sraSignInstr->addOperand(std::make_unique(*src1_reg)); + sraSignInstr->addOperand(std::make_unique(is_32bit ? 31 : 63)); + newInstrs.push_back(std::move(sraSignInstr)); + + // 计算偏移:temp = temp >> (word_size - shift) + if (shift < (is_32bit ? 32 : 64)) { + auto srlInstr = std::make_unique(is_32bit ? RVOpcodes::SRLIW : RVOpcodes::SRLI); + srlInstr->addOperand(std::make_unique(temp_reg)); + srlInstr->addOperand(std::make_unique(temp_reg)); + srlInstr->addOperand(std::make_unique((is_32bit ? 32 : 64) - shift)); + newInstrs.push_back(std::move(srlInstr)); + } + + // 加上偏移:temp = src1 + temp + auto addInstr = std::make_unique(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD); + addInstr->addOperand(std::make_unique(temp_reg)); + addInstr->addOperand(std::make_unique(*src1_reg)); + addInstr->addOperand(std::make_unique(temp_reg)); + newInstrs.push_back(std::move(addInstr)); + + // 最终右移:dst = temp >> shift + auto sraInstr = std::make_unique(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI); + sraInstr->addOperand(std::make_unique(*dst_reg)); + sraInstr->addOperand(std::make_unique(temp_reg)); + sraInstr->addOperand(std::make_unique(shift)); + newInstrs.push_back(std::move(sraInstr)); + } + // 情况4: 负的2的幂次除法 + else if (divisor < 0 && isPowerOfTwo(-divisor)) { + int shift = getPowerOfTwoExponent(-divisor); + int temp_reg = createTempReg(); + + // 先按正数处理 + auto sraSignInstr = std::make_unique(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI); + sraSignInstr->addOperand(std::make_unique(temp_reg)); + sraSignInstr->addOperand(std::make_unique(*src1_reg)); + sraSignInstr->addOperand(std::make_unique(is_32bit ? 31 : 63)); + newInstrs.push_back(std::move(sraSignInstr)); + + if (shift < (is_32bit ? 32 : 64)) { + auto srlInstr = std::make_unique(is_32bit ? RVOpcodes::SRLIW : RVOpcodes::SRLI); + srlInstr->addOperand(std::make_unique(temp_reg)); + srlInstr->addOperand(std::make_unique(temp_reg)); + srlInstr->addOperand(std::make_unique((is_32bit ? 32 : 64) - shift)); + newInstrs.push_back(std::move(srlInstr)); + } + + auto addInstr = std::make_unique(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD); + addInstr->addOperand(std::make_unique(temp_reg)); + addInstr->addOperand(std::make_unique(*src1_reg)); + addInstr->addOperand(std::make_unique(temp_reg)); + newInstrs.push_back(std::move(addInstr)); + + auto sraInstr = std::make_unique(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI); + sraInstr->addOperand(std::make_unique(temp_reg)); + sraInstr->addOperand(std::make_unique(temp_reg)); + sraInstr->addOperand(std::make_unique(shift)); + newInstrs.push_back(std::move(sraInstr)); + + // 然后取反 + auto negInstr = std::make_unique(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB); + negInstr->addOperand(std::make_unique(*dst_reg)); + negInstr->addOperand(std::make_unique(PhysicalReg::ZERO)); + negInstr->addOperand(std::make_unique(temp_reg)); + newInstrs.push_back(std::move(negInstr)); + } + // 情况5: 通用magic number算法(针对没有MULH的情况进行了简化) + else { + // 对于一般除法,在没有MULH的情况下,我们采用更保守的策略 + // 只处理一些简单的常数除法,复杂的情况保持原始除法指令 + + // 检查是否为小的常数(可以用简单乘法处理) + if (std::abs(divisor) <= 1024) { // 限制在较小的除数范围内 + auto magic_info = computeMagicNumber(divisor, is_32bit); + + if (magic_info.magic == 0) continue; + + int magic_reg = createTempReg(); + int temp_reg = createTempReg(); + + // 加载magic number到寄存器 + auto loadInstr = std::make_unique(RVOpcodes::LI); + loadInstr->addOperand(std::make_unique(magic_reg)); + loadInstr->addOperand(std::make_unique(magic_info.magic)); + newInstrs.push_back(std::move(loadInstr)); + + // 使用普通乘法模拟高位乘法 + if (is_32bit) { + // 32位:使用MULW + auto mulInstr = std::make_unique(RVOpcodes::MULW); + mulInstr->addOperand(std::make_unique(temp_reg)); + mulInstr->addOperand(std::make_unique(*src1_reg)); + mulInstr->addOperand(std::make_unique(magic_reg)); + newInstrs.push_back(std::move(mulInstr)); + + // 右移得到近似结果 + auto sraInstr = std::make_unique(RVOpcodes::SRAIW); + sraInstr->addOperand(std::make_unique(temp_reg)); + sraInstr->addOperand(std::make_unique(temp_reg)); + sraInstr->addOperand(std::make_unique(magic_info.shift)); + newInstrs.push_back(std::move(sraInstr)); + } else { + // 64位:使用MUL + auto mulInstr = std::make_unique(RVOpcodes::MUL); + mulInstr->addOperand(std::make_unique(temp_reg)); + mulInstr->addOperand(std::make_unique(*src1_reg)); + mulInstr->addOperand(std::make_unique(magic_reg)); + newInstrs.push_back(std::move(mulInstr)); + + // 右移得到近似结果 + auto sraInstr = std::make_unique(RVOpcodes::SRAI); + sraInstr->addOperand(std::make_unique(temp_reg)); + sraInstr->addOperand(std::make_unique(temp_reg)); + sraInstr->addOperand(std::make_unique(magic_info.shift)); + newInstrs.push_back(std::move(sraInstr)); + } + + // 符号修正:处理负数被除数 + int sign_reg = createTempReg(); + + // 获取被除数的符号位 + auto sraSignInstr = std::make_unique(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI); + sraSignInstr->addOperand(std::make_unique(sign_reg)); + sraSignInstr->addOperand(std::make_unique(*src1_reg)); + sraSignInstr->addOperand(std::make_unique(is_32bit ? 31 : 63)); + newInstrs.push_back(std::move(sraSignInstr)); + + // 最终结果:dst = temp - sign(对于正除数)或 dst = temp + sign(对于负除数) + if (divisor > 0) { + auto finalSubInstr = std::make_unique(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB); + finalSubInstr->addOperand(std::make_unique(*dst_reg)); + finalSubInstr->addOperand(std::make_unique(temp_reg)); + finalSubInstr->addOperand(std::make_unique(sign_reg)); + newInstrs.push_back(std::move(finalSubInstr)); + } else { + auto finalAddInstr = std::make_unique(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD); + finalAddInstr->addOperand(std::make_unique(*dst_reg)); + finalAddInstr->addOperand(std::make_unique(temp_reg)); + finalAddInstr->addOperand(std::make_unique(sign_reg)); + newInstrs.push_back(std::move(finalAddInstr)); + } + } + // 对于大的除数或复杂情况,保持原始除法指令不变 + } + + if (!newInstrs.empty()) { + replacements.push_back({i, std::move(newInstrs)}); + } + } + + // 批量应用替换(从后往前处理避免索引问题) + for (auto it = replacements.rbegin(); it != replacements.rend(); ++it) { + instrs.erase(instrs.begin() + it->index); + instrs.insert(instrs.begin() + it->index, + std::make_move_iterator(it->newInstrs.begin()), + std::make_move_iterator(it->newInstrs.end())); + } + } +} + +} // namespace sysy \ No newline at end of file diff --git a/src/backend/RISCv64/RISCv64Backend.cpp b/src/backend/RISCv64/RISCv64Backend.cpp index 2797eb7..ae45d65 100644 --- a/src/backend/RISCv64/RISCv64Backend.cpp +++ b/src/backend/RISCv64/RISCv64Backend.cpp @@ -119,7 +119,11 @@ std::string RISCv64CodeGen::function_gen(Function* func) { RISCv64AsmPrinter printer1(mfunc.get()); printer1.run(ss1, true); - // 阶段 2: 指令调度 (Instruction Scheduling) + // 阶段 2: 除法强度削弱优化 (Division Strength Reduction) + DivStrengthReduction div_strength_reduction; + div_strength_reduction.runOnMachineFunction(mfunc.get()); + + // 阶段 2.1: 指令调度 (Instruction Scheduling) PreRA_Scheduler scheduler; scheduler.runOnMachineFunction(mfunc.get()); diff --git a/src/backend/RISCv64/RISCv64ISel.cpp b/src/backend/RISCv64/RISCv64ISel.cpp index 871a3e7..e6b0929 100644 --- a/src/backend/RISCv64/RISCv64ISel.cpp +++ b/src/backend/RISCv64/RISCv64ISel.cpp @@ -539,6 +539,15 @@ void RISCv64ISel::selectNode(DAGNode* node) { CurMBB->addInstruction(std::move(instr)); break; } + case Instruction::kSRA: { + auto rhs_const = dynamic_cast(rhs); + auto instr = std::make_unique(RVOpcodes::SRAIW); + instr->addOperand(std::make_unique(dest_vreg)); + instr->addOperand(std::make_unique(lhs_vreg)); + instr->addOperand(std::make_unique(rhs_const->getInt())); + CurMBB->addInstruction(std::move(instr)); + break; + } case BinaryInst::kICmpEQ: { // 等于 (a == b) -> (subw; seqz) auto sub = std::make_unique(RVOpcodes::SUBW); sub->addOperand(std::make_unique(dest_vreg)); @@ -1473,7 +1482,7 @@ std::vector> RISCv64ISel::build_dag(BasicB } } } - if (bin->getKind() >= Instruction::kFAdd) { // 假设浮点指令枚举值更大 + if (bin->isFPBinary()) { // 假设浮点指令枚举值更大 auto fbin_node = create_node(DAGNode::FBINARY, bin, value_to_node, nodes_storage); fbin_node->operands.push_back(get_operand_node(bin->getLhs(), value_to_node, nodes_storage)); fbin_node->operands.push_back(get_operand_node(bin->getRhs(), value_to_node, nodes_storage)); diff --git a/src/include/backend/RISCv64/Optimize/DivStrengthReduction.h b/src/include/backend/RISCv64/Optimize/DivStrengthReduction.h new file mode 100644 index 0000000..685bb19 --- /dev/null +++ b/src/include/backend/RISCv64/Optimize/DivStrengthReduction.h @@ -0,0 +1,30 @@ +#ifndef RISCV64_DIV_STRENGTH_REDUCTION_H +#define RISCV64_DIV_STRENGTH_REDUCTION_H + +#include "RISCv64LLIR.h" +#include "Pass.h" + +namespace sysy { + +/** + * @class DivStrengthReduction + * @brief 除法强度削弱优化器 + * * 将除法运算转换为乘法运算,使用magic number算法 + * 适用于除数为常数的情况,可以显著提高性能 + */ +class DivStrengthReduction : public Pass { +public: + static char ID; + + DivStrengthReduction() : Pass("div-strength-reduction", Granularity::Function, PassKind::Optimization) {} + + void *getPassID() const override { return &ID; } + + bool runOnFunction(Function *F, AnalysisManager& AM) override; + + void runOnMachineFunction(MachineFunction* mfunc); +}; + +} // namespace sysy + +#endif // RISCV64_DIV_STRENGTH_REDUCTION_H \ No newline at end of file diff --git a/src/include/backend/RISCv64/RISCv64Passes.h b/src/include/backend/RISCv64/RISCv64Passes.h index de08882..b456994 100644 --- a/src/include/backend/RISCv64/RISCv64Passes.h +++ b/src/include/backend/RISCv64/RISCv64Passes.h @@ -9,6 +9,8 @@ #include "LegalizeImmediates.h" #include "PrologueEpilogueInsertion.h" #include "Pass.h" +#include "DivStrengthReduction.h" + namespace sysy { diff --git a/src/include/midend/IR.h b/src/include/midend/IR.h index 901a3b7..2e4d72b 100644 --- a/src/include/midend/IR.h +++ b/src/include/midend/IR.h @@ -708,6 +708,8 @@ class Instruction : public User { kPhi = 0x1UL << 39, kBitItoF = 0x1UL << 40, kBitFtoI = 0x1UL << 41, + kSRA = 0x1UL << 42, + kMulh = 0x1UL << 43, }; protected: @@ -804,6 +806,12 @@ public: return "Memset"; case kPhi: return "Phi"; + case kBitItoF: + return "BitItoF"; + case kBitFtoI: + return "BitFtoI"; + case kSRA: + return "SRA"; default: return "Unknown"; } @@ -815,11 +823,15 @@ public: bool isBinary() const { static constexpr uint64_t BinaryOpMask = - (kAdd | kSub | kMul | kDiv | kRem | kAnd | kOr) | - (kICmpEQ | kICmpNE | kICmpLT | kICmpGT | kICmpLE | kICmpGE) | + (kAdd | kSub | kMul | kDiv | kRem | kAnd | kOr | kSRA) | + (kICmpEQ | kICmpNE | kICmpLT | kICmpGT | kICmpLE | kICmpGE); + return kind & BinaryOpMask; + } + bool isFPBinary() const { + static constexpr uint64_t FPBinaryOpMask = (kFAdd | kFSub | kFMul | kFDiv) | (kFCmpEQ | kFCmpNE | kFCmpLT | kFCmpGT | kFCmpLE | kFCmpGE); - return kind & BinaryOpMask; + return kind & FPBinaryOpMask; } bool isUnary() const { static constexpr uint64_t UnaryOpMask = diff --git a/src/include/midend/IRBuilder.h b/src/include/midend/IRBuilder.h index 73e40e5..c232578 100644 --- a/src/include/midend/IRBuilder.h +++ b/src/include/midend/IRBuilder.h @@ -217,6 +217,9 @@ class IRBuilder { BinaryInst * createOrInst(Value *lhs, Value *rhs, const std::string &name = "") { return createBinaryInst(Instruction::kOr, Type::getIntType(), lhs, rhs, name); } ///< 创建按位或指令 + BinaryInst * createSRAInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kSRA, Type::getIntType(), lhs, rhs, name); + } ///< 创建算术右移指令 CallInst * createCallInst(Function *callee, const std::vector &args, const std::string &name = "") { std::string newName; if (name.empty() && callee->getReturnType() != Type::getVoidType()) { diff --git a/src/midend/SysYIRGenerator.cpp b/src/midend/SysYIRGenerator.cpp index 812bd39..0b2d751 100644 --- a/src/midend/SysYIRGenerator.cpp +++ b/src/midend/SysYIRGenerator.cpp @@ -249,7 +249,26 @@ void SysYIRGenerator::compute() { case BinaryOp::ADD: resultValue = builder.createAddInst(lhs, rhs); break; case BinaryOp::SUB: resultValue = builder.createSubInst(lhs, rhs); break; case BinaryOp::MUL: resultValue = builder.createMulInst(lhs, rhs); break; - case BinaryOp::DIV: resultValue = builder.createDivInst(lhs, rhs); break; + case BinaryOp::DIV: { + ConstantInteger *rhsConst = dynamic_cast(rhs); + if (rhsConst) { + int divisor = rhsConst->getInt(); + if (divisor > 0 && (divisor & (divisor - 1)) == 0) { + int shift = 0; + int temp = divisor; + while (temp > 1) { + temp >>= 1; + shift++; + } + resultValue = builder.createSRAInst(lhs, ConstantInteger::get(shift)); + } else { + resultValue = builder.createDivInst(lhs, rhs); + } + } else { + resultValue = builder.createDivInst(lhs, rhs); + } + break; + } case BinaryOp::MOD: resultValue = builder.createRemInst(lhs, rhs); break; } } else if (commonType == Type::getFloatType()) { diff --git a/src/midend/SysYIRPrinter.cpp b/src/midend/SysYIRPrinter.cpp index 0a024d2..7891bca 100644 --- a/src/midend/SysYIRPrinter.cpp +++ b/src/midend/SysYIRPrinter.cpp @@ -240,6 +240,8 @@ void SysYPrinter::printInst(Instruction *pInst) { case Kind::kMul: case Kind::kDiv: case Kind::kRem: + case Kind::kSRA: + case Kind::kMulh: case Kind::kFAdd: case Kind::kFSub: case Kind::kFMul: @@ -272,6 +274,8 @@ void SysYPrinter::printInst(Instruction *pInst) { case Kind::kMul: std::cout << "mul"; break; case Kind::kDiv: std::cout << "sdiv"; break; case Kind::kRem: std::cout << "srem"; break; + case Kind::kSRA: std::cout << "ashr"; break; + case Kind::kMulh: std::cout << "mulh"; break; case Kind::kFAdd: std::cout << "fadd"; break; case Kind::kFSub: std::cout << "fsub"; break; case Kind::kFMul: std::cout << "fmul"; break;