Merge branch 'backend-divopt' into midend

This commit is contained in:
Lixuanwang
2025-08-03 14:53:22 +08:00
12 changed files with 400 additions and 8 deletions

View File

@ -11,6 +11,7 @@ add_library(riscv64_backend_lib STATIC
Optimize/Peephole.cpp
Optimize/PostRA_Scheduler.cpp
Optimize/PreRA_Scheduler.cpp
Optimize/DivStrengthReduction.cpp
)
# 包含后端模块所需的头文件路径

View File

@ -0,0 +1,282 @@
#include "DivStrengthReduction.h"
#include <cmath>
#include <cstdint>
namespace sysy {
char DivStrengthReduction::ID = 0;
bool DivStrengthReduction::runOnFunction(Function *F, AnalysisManager& AM) {
// This pass works on MachineFunction level, not IR level
return false;
}
void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) {
if (!mfunc)
return;
bool debug = false; // Set to true for debugging
if (debug)
std::cout << "Running DivStrengthReduction optimization..." << std::endl;
int next_temp_reg = 1000;
auto createTempReg = [&]() -> int {
return next_temp_reg++;
};
struct MagicInfo {
int64_t magic;
int shift;
};
auto computeMagic = [](int64_t d, bool is_32bit) -> MagicInfo {
int word_size = is_32bit ? 32 : 64;
uint64_t ad = std::abs(d);
if (ad == 0) return {0, 0};
int l = std::floor(std::log2(ad));
if ((ad & (ad - 1)) == 0) { // power of 2
l = 0; // special case for power of 2, shift will be calculated differently
}
__int128_t one = 1;
__int128_t num;
int total_shift;
if (is_32bit) {
total_shift = 31 + l;
num = one << total_shift;
} else {
total_shift = 63 + l;
num = one << total_shift;
}
__int128_t den = ad;
int64_t magic = (num / den) + 1;
return {magic, total_shift};
};
auto isPowerOfTwo = [](int64_t n) -> bool {
return n > 0 && (n & (n - 1)) == 0;
};
auto getPowerOfTwoExponent = [](int64_t n) -> int {
if (n <= 0 || (n & (n - 1)) != 0) return -1;
int shift = 0;
while (n > 1) {
n >>= 1;
shift++;
}
return shift;
};
struct InstructionReplacement {
size_t index;
size_t count_to_erase;
std::vector<std::unique_ptr<MachineInstr>> newInstrs;
};
for (auto &mbb_uptr : mfunc->getBlocks()) {
auto &mbb = *mbb_uptr;
auto &instrs = mbb.getInstructions();
std::vector<InstructionReplacement> replacements;
for (size_t i = 0; i < instrs.size(); ++i) {
auto *instr = instrs[i].get();
bool is_32bit = (instr->getOpcode() == RVOpcodes::DIVW);
if (instr->getOpcode() != RVOpcodes::DIV && !is_32bit) {
continue;
}
if (instr->getOperands().size() != 3) {
continue;
}
auto *dst_op = instr->getOperands()[0].get();
auto *src1_op = instr->getOperands()[1].get();
auto *src2_op = instr->getOperands()[2].get();
int64_t divisor = 0;
bool const_divisor_found = false;
size_t instructions_to_replace = 1;
if (src2_op->getKind() == MachineOperand::KIND_IMM) {
divisor = static_cast<ImmOperand *>(src2_op)->getValue();
const_divisor_found = true;
} else if (src2_op->getKind() == MachineOperand::KIND_REG) {
if (i > 0) {
auto *prev_instr = instrs[i - 1].get();
if (prev_instr->getOpcode() == RVOpcodes::LI && prev_instr->getOperands().size() == 2) {
auto *li_dst_op = prev_instr->getOperands()[0].get();
auto *li_imm_op = prev_instr->getOperands()[1].get();
if (li_dst_op->getKind() == MachineOperand::KIND_REG && li_imm_op->getKind() == MachineOperand::KIND_IMM) {
auto *div_reg_op = static_cast<RegOperand *>(src2_op);
auto *li_dst_reg_op = static_cast<RegOperand *>(li_dst_op);
if (div_reg_op->isVirtual() && li_dst_reg_op->isVirtual() &&
div_reg_op->getVRegNum() == li_dst_reg_op->getVRegNum()) {
divisor = static_cast<ImmOperand *>(li_imm_op)->getValue();
const_divisor_found = true;
instructions_to_replace = 2;
}
}
}
}
}
if (!const_divisor_found) {
continue;
}
auto *dst_reg = static_cast<RegOperand *>(dst_op);
auto *src1_reg = static_cast<RegOperand *>(src1_op);
if (divisor == 0) continue;
std::vector<std::unique_ptr<MachineInstr>> newInstrs;
if (divisor == 1) {
auto moveInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
moveInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
moveInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
moveInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
newInstrs.push_back(std::move(moveInstr));
}
else if (divisor == -1) {
auto negInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB);
negInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
negInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
negInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
newInstrs.push_back(std::move(negInstr));
}
else if (isPowerOfTwo(std::abs(divisor))) {
int shift = getPowerOfTwoExponent(std::abs(divisor));
int temp_reg = createTempReg();
auto sraSignInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI);
sraSignInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraSignInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
sraSignInstr->addOperand(std::make_unique<ImmOperand>(is_32bit ? 31 : 63));
newInstrs.push_back(std::move(sraSignInstr));
auto srlInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRLIW : RVOpcodes::SRLI);
srlInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
srlInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
srlInstr->addOperand(std::make_unique<ImmOperand>((is_32bit ? 32 : 64) - shift));
newInstrs.push_back(std::move(srlInstr));
auto addInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
addInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
addInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
addInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
newInstrs.push_back(std::move(addInstr));
auto sraInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI);
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<ImmOperand>(shift));
newInstrs.push_back(std::move(sraInstr));
if (divisor < 0) {
auto negInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB);
negInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
negInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
negInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
newInstrs.push_back(std::move(negInstr));
} else {
auto moveInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
moveInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
moveInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
moveInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
newInstrs.push_back(std::move(moveInstr));
}
}
else {
auto magic_info = computeMagic(divisor, is_32bit);
int magic_reg = createTempReg();
int temp_reg = createTempReg();
auto loadInstr = std::make_unique<MachineInstr>(RVOpcodes::LI);
loadInstr->addOperand(std::make_unique<RegOperand>(magic_reg));
loadInstr->addOperand(std::make_unique<ImmOperand>(magic_info.magic));
newInstrs.push_back(std::move(loadInstr));
if (is_32bit) {
auto mulInstr = std::make_unique<MachineInstr>(RVOpcodes::MUL);
mulInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
mulInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
mulInstr->addOperand(std::make_unique<RegOperand>(magic_reg));
newInstrs.push_back(std::move(mulInstr));
auto sraInstr = std::make_unique<MachineInstr>(RVOpcodes::SRAI);
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<ImmOperand>(magic_info.shift));
newInstrs.push_back(std::move(sraInstr));
} else {
auto mulhInstr = std::make_unique<MachineInstr>(RVOpcodes::MULH);
mulhInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
mulhInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
mulhInstr->addOperand(std::make_unique<RegOperand>(magic_reg));
newInstrs.push_back(std::move(mulhInstr));
int post_shift = magic_info.shift - 63;
if (post_shift > 0) {
auto sraInstr = std::make_unique<MachineInstr>(RVOpcodes::SRAI);
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<ImmOperand>(post_shift));
newInstrs.push_back(std::move(sraInstr));
}
}
int sign_reg = createTempReg();
auto sraSignInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI);
sraSignInstr->addOperand(std::make_unique<RegOperand>(sign_reg));
sraSignInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
sraSignInstr->addOperand(std::make_unique<ImmOperand>(is_32bit ? 31 : 63));
newInstrs.push_back(std::move(sraSignInstr));
auto subInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB);
subInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
subInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
subInstr->addOperand(std::make_unique<RegOperand>(sign_reg));
newInstrs.push_back(std::move(subInstr));
if (divisor < 0) {
auto negInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB);
negInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
negInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
negInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
newInstrs.push_back(std::move(negInstr));
} else {
auto moveInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
moveInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
moveInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
moveInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
newInstrs.push_back(std::move(moveInstr));
}
}
if (!newInstrs.empty()) {
size_t start_index = i;
if (instructions_to_replace == 2) {
start_index = i - 1;
}
replacements.push_back({start_index, instructions_to_replace, std::move(newInstrs)});
}
}
for (auto it = replacements.rbegin(); it != replacements.rend(); ++it) {
instrs.erase(instrs.begin() + it->index, instrs.begin() + it->index + it->count_to_erase);
instrs.insert(instrs.begin() + it->index,
std::make_move_iterator(it->newInstrs.begin()),
std::make_move_iterator(it->newInstrs.end()));
}
}
}
} // namespace sysy

View File

@ -60,7 +60,7 @@ void RISCv64AsmPrinter::printInstruction(MachineInstr* instr, bool debug) {
case RVOpcodes::ADD: *OS << "add "; break; case RVOpcodes::ADDI: *OS << "addi "; break;
case RVOpcodes::ADDW: *OS << "addw "; break; case RVOpcodes::ADDIW: *OS << "addiw "; break;
case RVOpcodes::SUB: *OS << "sub "; break; case RVOpcodes::SUBW: *OS << "subw "; break;
case RVOpcodes::MUL: *OS << "mul "; break; case RVOpcodes::MULW: *OS << "mulw "; break;
case RVOpcodes::MUL: *OS << "mul "; break; case RVOpcodes::MULW: *OS << "mulw "; break; case RVOpcodes::MULH: *OS << "mulh "; break;
case RVOpcodes::DIV: *OS << "div "; break; case RVOpcodes::DIVW: *OS << "divw "; break;
case RVOpcodes::REM: *OS << "rem "; break; case RVOpcodes::REMW: *OS << "remw "; break;
case RVOpcodes::XOR: *OS << "xor "; break; case RVOpcodes::XORI: *OS << "xori "; break;

View File

@ -182,7 +182,11 @@ std::string RISCv64CodeGen::function_gen(Function* func) {
RISCv64AsmPrinter printer1(mfunc.get());
printer1.run(ss1, true);
// 阶段 2: 指令调度 (Instruction Scheduling)
// 阶段 2: 除法强度削弱优化 (Division Strength Reduction)
DivStrengthReduction div_strength_reduction;
div_strength_reduction.runOnMachineFunction(mfunc.get());
// 阶段 2.1: 指令调度 (Instruction Scheduling)
PreRA_Scheduler scheduler;
scheduler.runOnMachineFunction(mfunc.get());

View File

@ -539,6 +539,15 @@ void RISCv64ISel::selectNode(DAGNode* node) {
CurMBB->addInstruction(std::move(instr));
break;
}
case Instruction::kSRA: {
auto rhs_const = dynamic_cast<ConstantInteger*>(rhs);
auto instr = std::make_unique<MachineInstr>(RVOpcodes::SRAIW);
instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
instr->addOperand(std::make_unique<RegOperand>(lhs_vreg));
instr->addOperand(std::make_unique<ImmOperand>(rhs_const->getInt()));
CurMBB->addInstruction(std::move(instr));
break;
}
case BinaryInst::kICmpEQ: { // 等于 (a == b) -> (subw; seqz)
auto sub = std::make_unique<MachineInstr>(RVOpcodes::SUBW);
sub->addOperand(std::make_unique<RegOperand>(dest_vreg));
@ -1473,7 +1482,7 @@ std::vector<std::unique_ptr<RISCv64ISel::DAGNode>> RISCv64ISel::build_dag(BasicB
}
}
}
if (bin->getKind() >= Instruction::kFAdd) { // 假设浮点指令枚举值更大
if (bin->isFPBinary()) { // 假设浮点指令枚举值更大
auto fbin_node = create_node(DAGNode::FBINARY, bin, value_to_node, nodes_storage);
fbin_node->operands.push_back(get_operand_node(bin->getLhs(), value_to_node, nodes_storage));
fbin_node->operands.push_back(get_operand_node(bin->getRhs(), value_to_node, nodes_storage));

View File

@ -0,0 +1,30 @@
#ifndef RISCV64_DIV_STRENGTH_REDUCTION_H
#define RISCV64_DIV_STRENGTH_REDUCTION_H
#include "RISCv64LLIR.h"
#include "Pass.h"
namespace sysy {
/**
* @class DivStrengthReduction
* @brief 除法强度削弱优化器
* * 将除法运算转换为乘法运算使用magic number算法
* 适用于除数为常数的情况,可以显著提高性能
*/
class DivStrengthReduction : public Pass {
public:
static char ID;
DivStrengthReduction() : Pass("div-strength-reduction", Granularity::Function, PassKind::Optimization) {}
void *getPassID() const override { return &ID; }
bool runOnFunction(Function *F, AnalysisManager& AM) override;
void runOnMachineFunction(MachineFunction* mfunc);
};
} // namespace sysy
#endif // RISCV64_DIV_STRENGTH_REDUCTION_H

View File

@ -45,7 +45,7 @@ enum class PhysicalReg {
// RISC-V 指令操作码枚举
enum class RVOpcodes {
// 算术指令
ADD, ADDI, ADDW, ADDIW, SUB, SUBW, MUL, MULW, DIV, DIVW, REM, REMW,
ADD, ADDI, ADDW, ADDIW, SUB, SUBW, MUL, MULW, MULH, DIV, DIVW, REM, REMW,
// 逻辑指令
XOR, XORI, OR, ORI, AND, ANDI,
// 移位指令

View File

@ -9,6 +9,8 @@
#include "LegalizeImmediates.h"
#include "PrologueEpilogueInsertion.h"
#include "Pass.h"
#include "DivStrengthReduction.h"
namespace sysy {

View File

@ -728,6 +728,8 @@ class Instruction : public User {
kPhi = 0x1UL << 39,
kBitItoF = 0x1UL << 40,
kBitFtoI = 0x1UL << 41,
kSRA = 0x1UL << 42,
kMulh = 0x1UL << 43
};
protected:
@ -824,6 +826,12 @@ public:
return "Memset";
case kPhi:
return "Phi";
case kBitItoF:
return "BitItoF";
case kBitFtoI:
return "BitFtoI";
case kSRA:
return "SRA";
default:
return "Unknown";
}
@ -835,11 +843,15 @@ public:
bool isBinary() const {
static constexpr uint64_t BinaryOpMask =
(kAdd | kSub | kMul | kDiv | kRem | kAnd | kOr) |
(kICmpEQ | kICmpNE | kICmpLT | kICmpGT | kICmpLE | kICmpGE) |
(kAdd | kSub | kMul | kDiv | kRem | kAnd | kOr | kSRA | kMulh) |
(kICmpEQ | kICmpNE | kICmpLT | kICmpGT | kICmpLE | kICmpGE);
return kind & BinaryOpMask;
}
bool isFPBinary() const {
static constexpr uint64_t FPBinaryOpMask =
(kFAdd | kFSub | kFMul | kFDiv) |
(kFCmpEQ | kFCmpNE | kFCmpLT | kFCmpGT | kFCmpLE | kFCmpGE);
return kind & BinaryOpMask;
return kind & FPBinaryOpMask;
}
bool isUnary() const {
static constexpr uint64_t UnaryOpMask =

View File

@ -217,6 +217,12 @@ class IRBuilder {
BinaryInst * createOrInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kOr, Type::getIntType(), lhs, rhs, name);
} ///< 创建按位或指令
BinaryInst * createSRAInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kSRA, Type::getIntType(), lhs, rhs, name);
} ///< 创建算术右移指令
BinaryInst * createMulhInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kMulh, Type::getIntType(), lhs, rhs, name);
} ///< 创建高位乘法指令
CallInst * createCallInst(Function *callee, const std::vector<Value *> &args, const std::string &name = "") {
std::string newName;
if (name.empty() && callee->getReturnType() != Type::getVoidType()) {

View File

@ -15,6 +15,29 @@
using namespace std;
namespace sysy {
std::pair<long long, int> calculate_signed_magic(int d) {
if (d == 0) throw std::runtime_error("Division by zero");
if (d == 1 || d == -1) return {0, 0}; // Not used by strength reduction
int k = 0;
unsigned int ad = (d > 0) ? d : -d;
unsigned int temp = ad;
while (temp > 0) {
temp >>= 1;
k++;
}
if ((ad & (ad - 1)) == 0) { // if power of 2
k--;
}
unsigned __int128 m_val = 1;
m_val <<= (32 + k - 1);
unsigned __int128 m_prime = m_val / ad;
long long m = m_prime + 1;
return {m, k};
}
// std::vector<Value*> BinaryValueStack; ///< 用于存储value的栈
// std::vector<int> BinaryOpStack; ///< 用于存储二元表达式的操作符栈
@ -249,7 +272,26 @@ void SysYIRGenerator::compute() {
case BinaryOp::ADD: resultValue = builder.createAddInst(lhs, rhs); break;
case BinaryOp::SUB: resultValue = builder.createSubInst(lhs, rhs); break;
case BinaryOp::MUL: resultValue = builder.createMulInst(lhs, rhs); break;
case BinaryOp::DIV: resultValue = builder.createDivInst(lhs, rhs); break;
case BinaryOp::DIV: {
ConstantInteger *rhsConst = dynamic_cast<ConstantInteger *>(rhs);
if (rhsConst) {
int divisor = rhsConst->getInt();
if (divisor > 0 && (divisor & (divisor - 1)) == 0) {
int shift = 0;
int temp = divisor;
while (temp > 1) {
temp >>= 1;
shift++;
}
resultValue = builder.createSRAInst(lhs, ConstantInteger::get(shift));
} else {
resultValue = builder.createDivInst(lhs, rhs);
}
} else {
resultValue = builder.createDivInst(lhs, rhs);
}
break;
}
case BinaryOp::MOD: resultValue = builder.createRemInst(lhs, rhs); break;
}
} else if (commonType == Type::getFloatType()) {

View File

@ -240,6 +240,8 @@ void SysYPrinter::printInst(Instruction *pInst) {
case Kind::kMul:
case Kind::kDiv:
case Kind::kRem:
case Kind::kSRA:
case Kind::kMulh:
case Kind::kFAdd:
case Kind::kFSub:
case Kind::kFMul:
@ -272,6 +274,8 @@ void SysYPrinter::printInst(Instruction *pInst) {
case Kind::kMul: std::cout << "mul"; break;
case Kind::kDiv: std::cout << "sdiv"; break;
case Kind::kRem: std::cout << "srem"; break;
case Kind::kSRA: std::cout << "ashr"; break;
case Kind::kMulh: std::cout << "mulh"; break;
case Kind::kFAdd: std::cout << "fadd"; break;
case Kind::kFSub: std::cout << "fsub"; break;
case Kind::kFMul: std::cout << "fmul"; break;