[optimze]添加基础的除法指令优化,目前只对除以2的幂数生效

This commit is contained in:
2025-08-03 13:46:42 +08:00
parent e8699d6d25
commit f312792fe9
10 changed files with 419 additions and 6 deletions

View File

@ -11,6 +11,7 @@ add_library(riscv64_backend_lib STATIC
Optimize/Peephole.cpp
Optimize/PostRA_Scheduler.cpp
Optimize/PreRA_Scheduler.cpp
Optimize/DivStrengthReduction.cpp
)
# 包含后端模块所需的头文件路径

View File

@ -0,0 +1,329 @@
#include "DivStrengthReduction.h"
namespace sysy {
char DivStrengthReduction::ID = 0;
bool DivStrengthReduction::runOnFunction(Function *F, AnalysisManager& AM) {
// This pass works on MachineFunction level, not IR level
return false;
}
void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) {
if (!mfunc)
return;
bool debug = false; // Set to true for debugging
if (debug)
std::cout << "Running DivStrengthReduction optimization..." << std::endl;
// 虚拟寄存器分配器
int next_temp_reg = 1000;
auto createTempReg = [&]() -> int {
return next_temp_reg++;
};
// Magic number 信息结构
struct MagicInfo {
int64_t magic;
int shift;
bool add_indicator; // 是否需要额外的加法修正
};
// 针对缺少MULH指令的简化magic number计算
auto computeMagicNumber = [](int64_t divisor, bool is_32bit) -> MagicInfo {
if (divisor == 0) return {0, 0, false};
if (divisor == 1) return {1, 0, false};
if (divisor == -1) return {-1, 0, false};
// 对于没有MULH的情况我们使用更简单但有效的算法
// 基于 2^n / divisor 的近似
bool neg = divisor < 0;
int64_t d = neg ? -divisor : divisor;
int word_size = is_32bit ? 32 : 64;
// 计算合适的移位量
int shift = word_size;
int64_t magic = ((1LL << shift) + d - 1) / d;
// 调整magic number以适应MUL指令
if (is_32bit) {
// 32位情况调整magic使其适合符号扩展后的乘法
shift = 32;
magic = ((1LL << shift) + d - 1) / d;
} else {
// 64位情况使用更保守的算法
shift = 32; // 使用32位作为基础移位
magic = ((1LL << shift) + d - 1) / d;
}
bool add_indicator = false;
// 检查是否需要加法修正
if (magic >= (1LL << (word_size - 1))) {
add_indicator = true;
magic -= (1LL << word_size);
}
if (neg) {
magic = -magic;
}
return {magic, shift, add_indicator};
};
// 检查是否为2的幂次
auto isPowerOfTwo = [](int64_t n) -> bool {
return n > 0 && (n & (n - 1)) == 0;
};
// 获取2的幂次的指数
auto getPowerOfTwoExponent = [](int64_t n) -> int {
if (n <= 0 || (n & (n - 1)) != 0) return -1;
int shift = 0;
while (n > 1) {
n >>= 1;
shift++;
}
return shift;
};
// 收集需要替换的指令
struct InstructionReplacement {
size_t index;
std::vector<std::unique_ptr<MachineInstr>> newInstrs;
};
for (auto &mbb_uptr : mfunc->getBlocks()) {
auto &mbb = *mbb_uptr;
auto &instrs = mbb.getInstructions();
std::vector<InstructionReplacement> replacements;
for (size_t i = 0; i < instrs.size(); ++i) {
auto *instr = instrs[i].get();
bool is_32bit = (instr->getOpcode() == RVOpcodes::DIVW);
// 只处理 DIV 和 DIVW 指令
if (instr->getOpcode() != RVOpcodes::DIV && !is_32bit) {
continue;
}
if (instr->getOperands().size() != 3) {
continue;
}
auto *dst_op = instr->getOperands()[0].get();
auto *src1_op = instr->getOperands()[1].get();
auto *src2_op = instr->getOperands()[2].get();
// 检查操作数类型
if (dst_op->getKind() != MachineOperand::KIND_REG ||
src1_op->getKind() != MachineOperand::KIND_REG ||
src2_op->getKind() != MachineOperand::KIND_IMM) {
continue;
}
auto *dst_reg = static_cast<RegOperand *>(dst_op);
auto *src1_reg = static_cast<RegOperand *>(src1_op);
auto *src2_imm = static_cast<ImmOperand *>(src2_op);
int64_t divisor = src2_imm->getValue();
// 跳过除数为0的情况
if (divisor == 0) continue;
std::vector<std::unique_ptr<MachineInstr>> newInstrs;
// 情况1: 除数为1
if (divisor == 1) {
// dst = src1 (直接复制)
auto moveInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
moveInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
moveInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
moveInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
newInstrs.push_back(std::move(moveInstr));
}
// 情况2: 除数为-1
else if (divisor == -1) {
// dst = -src1
auto negInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB);
negInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
negInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
negInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
newInstrs.push_back(std::move(negInstr));
}
// 情况3: 正的2的幂次除法
else if (isPowerOfTwo(divisor)) {
int shift = getPowerOfTwoExponent(divisor);
int temp_reg = createTempReg();
// 对于有符号除法,需要处理负数的舍入
// if (src1 < 0) src1 += (divisor - 1)
// 获取符号位temp = src1 >> (word_size - 1)
auto sraSignInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI);
sraSignInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraSignInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
sraSignInstr->addOperand(std::make_unique<ImmOperand>(is_32bit ? 31 : 63));
newInstrs.push_back(std::move(sraSignInstr));
// 计算偏移temp = temp >> (word_size - shift)
if (shift < (is_32bit ? 32 : 64)) {
auto srlInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRLIW : RVOpcodes::SRLI);
srlInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
srlInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
srlInstr->addOperand(std::make_unique<ImmOperand>((is_32bit ? 32 : 64) - shift));
newInstrs.push_back(std::move(srlInstr));
}
// 加上偏移temp = src1 + temp
auto addInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
addInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
addInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
addInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
newInstrs.push_back(std::move(addInstr));
// 最终右移dst = temp >> shift
auto sraInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI);
sraInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<ImmOperand>(shift));
newInstrs.push_back(std::move(sraInstr));
}
// 情况4: 负的2的幂次除法
else if (divisor < 0 && isPowerOfTwo(-divisor)) {
int shift = getPowerOfTwoExponent(-divisor);
int temp_reg = createTempReg();
// 先按正数处理
auto sraSignInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI);
sraSignInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraSignInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
sraSignInstr->addOperand(std::make_unique<ImmOperand>(is_32bit ? 31 : 63));
newInstrs.push_back(std::move(sraSignInstr));
if (shift < (is_32bit ? 32 : 64)) {
auto srlInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRLIW : RVOpcodes::SRLI);
srlInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
srlInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
srlInstr->addOperand(std::make_unique<ImmOperand>((is_32bit ? 32 : 64) - shift));
newInstrs.push_back(std::move(srlInstr));
}
auto addInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
addInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
addInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
addInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
newInstrs.push_back(std::move(addInstr));
auto sraInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI);
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<ImmOperand>(shift));
newInstrs.push_back(std::move(sraInstr));
// 然后取反
auto negInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB);
negInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
negInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
negInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
newInstrs.push_back(std::move(negInstr));
}
// 情况5: 通用magic number算法针对没有MULH的情况进行了简化
else {
// 对于一般除法在没有MULH的情况下我们采用更保守的策略
// 只处理一些简单的常数除法,复杂的情况保持原始除法指令
// 检查是否为小的常数(可以用简单乘法处理)
if (std::abs(divisor) <= 1024) { // 限制在较小的除数范围内
auto magic_info = computeMagicNumber(divisor, is_32bit);
if (magic_info.magic == 0) continue;
int magic_reg = createTempReg();
int temp_reg = createTempReg();
// 加载magic number到寄存器
auto loadInstr = std::make_unique<MachineInstr>(RVOpcodes::LI);
loadInstr->addOperand(std::make_unique<RegOperand>(magic_reg));
loadInstr->addOperand(std::make_unique<ImmOperand>(magic_info.magic));
newInstrs.push_back(std::move(loadInstr));
// 使用普通乘法模拟高位乘法
if (is_32bit) {
// 32位使用MULW
auto mulInstr = std::make_unique<MachineInstr>(RVOpcodes::MULW);
mulInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
mulInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
mulInstr->addOperand(std::make_unique<RegOperand>(magic_reg));
newInstrs.push_back(std::move(mulInstr));
// 右移得到近似结果
auto sraInstr = std::make_unique<MachineInstr>(RVOpcodes::SRAIW);
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<ImmOperand>(magic_info.shift));
newInstrs.push_back(std::move(sraInstr));
} else {
// 64位使用MUL
auto mulInstr = std::make_unique<MachineInstr>(RVOpcodes::MUL);
mulInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
mulInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
mulInstr->addOperand(std::make_unique<RegOperand>(magic_reg));
newInstrs.push_back(std::move(mulInstr));
// 右移得到近似结果
auto sraInstr = std::make_unique<MachineInstr>(RVOpcodes::SRAI);
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<ImmOperand>(magic_info.shift));
newInstrs.push_back(std::move(sraInstr));
}
// 符号修正:处理负数被除数
int sign_reg = createTempReg();
// 获取被除数的符号位
auto sraSignInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI);
sraSignInstr->addOperand(std::make_unique<RegOperand>(sign_reg));
sraSignInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
sraSignInstr->addOperand(std::make_unique<ImmOperand>(is_32bit ? 31 : 63));
newInstrs.push_back(std::move(sraSignInstr));
// 最终结果dst = temp - sign对于正除数或 dst = temp + sign对于负除数
if (divisor > 0) {
auto finalSubInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB);
finalSubInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
finalSubInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
finalSubInstr->addOperand(std::make_unique<RegOperand>(sign_reg));
newInstrs.push_back(std::move(finalSubInstr));
} else {
auto finalAddInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
finalAddInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
finalAddInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
finalAddInstr->addOperand(std::make_unique<RegOperand>(sign_reg));
newInstrs.push_back(std::move(finalAddInstr));
}
}
// 对于大的除数或复杂情况,保持原始除法指令不变
}
if (!newInstrs.empty()) {
replacements.push_back({i, std::move(newInstrs)});
}
}
// 批量应用替换(从后往前处理避免索引问题)
for (auto it = replacements.rbegin(); it != replacements.rend(); ++it) {
instrs.erase(instrs.begin() + it->index);
instrs.insert(instrs.begin() + it->index,
std::make_move_iterator(it->newInstrs.begin()),
std::make_move_iterator(it->newInstrs.end()));
}
}
}
} // namespace sysy

View File

@ -119,7 +119,11 @@ std::string RISCv64CodeGen::function_gen(Function* func) {
RISCv64AsmPrinter printer1(mfunc.get());
printer1.run(ss1, true);
// 阶段 2: 指令调度 (Instruction Scheduling)
// 阶段 2: 除法强度削弱优化 (Division Strength Reduction)
DivStrengthReduction div_strength_reduction;
div_strength_reduction.runOnMachineFunction(mfunc.get());
// 阶段 2.1: 指令调度 (Instruction Scheduling)
PreRA_Scheduler scheduler;
scheduler.runOnMachineFunction(mfunc.get());

View File

@ -539,6 +539,15 @@ void RISCv64ISel::selectNode(DAGNode* node) {
CurMBB->addInstruction(std::move(instr));
break;
}
case Instruction::kSRA: {
auto rhs_const = dynamic_cast<ConstantInteger*>(rhs);
auto instr = std::make_unique<MachineInstr>(RVOpcodes::SRAIW);
instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
instr->addOperand(std::make_unique<RegOperand>(lhs_vreg));
instr->addOperand(std::make_unique<ImmOperand>(rhs_const->getInt()));
CurMBB->addInstruction(std::move(instr));
break;
}
case BinaryInst::kICmpEQ: { // 等于 (a == b) -> (subw; seqz)
auto sub = std::make_unique<MachineInstr>(RVOpcodes::SUBW);
sub->addOperand(std::make_unique<RegOperand>(dest_vreg));
@ -1473,7 +1482,7 @@ std::vector<std::unique_ptr<RISCv64ISel::DAGNode>> RISCv64ISel::build_dag(BasicB
}
}
}
if (bin->getKind() >= Instruction::kFAdd) { // 假设浮点指令枚举值更大
if (bin->isFPBinary()) { // 假设浮点指令枚举值更大
auto fbin_node = create_node(DAGNode::FBINARY, bin, value_to_node, nodes_storage);
fbin_node->operands.push_back(get_operand_node(bin->getLhs(), value_to_node, nodes_storage));
fbin_node->operands.push_back(get_operand_node(bin->getRhs(), value_to_node, nodes_storage));

View File

@ -0,0 +1,30 @@
#ifndef RISCV64_DIV_STRENGTH_REDUCTION_H
#define RISCV64_DIV_STRENGTH_REDUCTION_H
#include "RISCv64LLIR.h"
#include "Pass.h"
namespace sysy {
/**
* @class DivStrengthReduction
* @brief 除法强度削弱优化器
* * 将除法运算转换为乘法运算使用magic number算法
* 适用于除数为常数的情况,可以显著提高性能
*/
class DivStrengthReduction : public Pass {
public:
static char ID;
DivStrengthReduction() : Pass("div-strength-reduction", Granularity::Function, PassKind::Optimization) {}
void *getPassID() const override { return &ID; }
bool runOnFunction(Function *F, AnalysisManager& AM) override;
void runOnMachineFunction(MachineFunction* mfunc);
};
} // namespace sysy
#endif // RISCV64_DIV_STRENGTH_REDUCTION_H

View File

@ -9,6 +9,8 @@
#include "LegalizeImmediates.h"
#include "PrologueEpilogueInsertion.h"
#include "Pass.h"
#include "DivStrengthReduction.h"
namespace sysy {

View File

@ -708,6 +708,8 @@ class Instruction : public User {
kPhi = 0x1UL << 39,
kBitItoF = 0x1UL << 40,
kBitFtoI = 0x1UL << 41,
kSRA = 0x1UL << 42,
kMulh = 0x1UL << 43,
};
protected:
@ -804,6 +806,12 @@ public:
return "Memset";
case kPhi:
return "Phi";
case kBitItoF:
return "BitItoF";
case kBitFtoI:
return "BitFtoI";
case kSRA:
return "SRA";
default:
return "Unknown";
}
@ -815,11 +823,15 @@ public:
bool isBinary() const {
static constexpr uint64_t BinaryOpMask =
(kAdd | kSub | kMul | kDiv | kRem | kAnd | kOr) |
(kICmpEQ | kICmpNE | kICmpLT | kICmpGT | kICmpLE | kICmpGE) |
(kAdd | kSub | kMul | kDiv | kRem | kAnd | kOr | kSRA) |
(kICmpEQ | kICmpNE | kICmpLT | kICmpGT | kICmpLE | kICmpGE);
return kind & BinaryOpMask;
}
bool isFPBinary() const {
static constexpr uint64_t FPBinaryOpMask =
(kFAdd | kFSub | kFMul | kFDiv) |
(kFCmpEQ | kFCmpNE | kFCmpLT | kFCmpGT | kFCmpLE | kFCmpGE);
return kind & BinaryOpMask;
return kind & FPBinaryOpMask;
}
bool isUnary() const {
static constexpr uint64_t UnaryOpMask =

View File

@ -217,6 +217,9 @@ class IRBuilder {
BinaryInst * createOrInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kOr, Type::getIntType(), lhs, rhs, name);
} ///< 创建按位或指令
BinaryInst * createSRAInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kSRA, Type::getIntType(), lhs, rhs, name);
} ///< 创建算术右移指令
CallInst * createCallInst(Function *callee, const std::vector<Value *> &args, const std::string &name = "") {
std::string newName;
if (name.empty() && callee->getReturnType() != Type::getVoidType()) {

View File

@ -249,7 +249,26 @@ void SysYIRGenerator::compute() {
case BinaryOp::ADD: resultValue = builder.createAddInst(lhs, rhs); break;
case BinaryOp::SUB: resultValue = builder.createSubInst(lhs, rhs); break;
case BinaryOp::MUL: resultValue = builder.createMulInst(lhs, rhs); break;
case BinaryOp::DIV: resultValue = builder.createDivInst(lhs, rhs); break;
case BinaryOp::DIV: {
ConstantInteger *rhsConst = dynamic_cast<ConstantInteger *>(rhs);
if (rhsConst) {
int divisor = rhsConst->getInt();
if (divisor > 0 && (divisor & (divisor - 1)) == 0) {
int shift = 0;
int temp = divisor;
while (temp > 1) {
temp >>= 1;
shift++;
}
resultValue = builder.createSRAInst(lhs, ConstantInteger::get(shift));
} else {
resultValue = builder.createDivInst(lhs, rhs);
}
} else {
resultValue = builder.createDivInst(lhs, rhs);
}
break;
}
case BinaryOp::MOD: resultValue = builder.createRemInst(lhs, rhs); break;
}
} else if (commonType == Type::getFloatType()) {

View File

@ -240,6 +240,8 @@ void SysYPrinter::printInst(Instruction *pInst) {
case Kind::kMul:
case Kind::kDiv:
case Kind::kRem:
case Kind::kSRA:
case Kind::kMulh:
case Kind::kFAdd:
case Kind::kFSub:
case Kind::kFMul:
@ -272,6 +274,8 @@ void SysYPrinter::printInst(Instruction *pInst) {
case Kind::kMul: std::cout << "mul"; break;
case Kind::kDiv: std::cout << "sdiv"; break;
case Kind::kRem: std::cout << "srem"; break;
case Kind::kSRA: std::cout << "ashr"; break;
case Kind::kMulh: std::cout << "mulh"; break;
case Kind::kFAdd: std::cout << "fadd"; break;
case Kind::kFSub: std::cout << "fsub"; break;
case Kind::kFMul: std::cout << "fmul"; break;