[optimize]添加更为通用的除法强度削减Pass, 不受除数限制替换div指令,不影响当前分数
This commit is contained in:
@ -1,4 +1,6 @@
|
|||||||
#include "DivStrengthReduction.h"
|
#include "DivStrengthReduction.h"
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
namespace sysy {
|
namespace sysy {
|
||||||
|
|
||||||
@ -17,69 +19,49 @@ void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) {
|
|||||||
if (debug)
|
if (debug)
|
||||||
std::cout << "Running DivStrengthReduction optimization..." << std::endl;
|
std::cout << "Running DivStrengthReduction optimization..." << std::endl;
|
||||||
|
|
||||||
// 虚拟寄存器分配器
|
|
||||||
int next_temp_reg = 1000;
|
int next_temp_reg = 1000;
|
||||||
auto createTempReg = [&]() -> int {
|
auto createTempReg = [&]() -> int {
|
||||||
return next_temp_reg++;
|
return next_temp_reg++;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Magic number 信息结构
|
|
||||||
struct MagicInfo {
|
struct MagicInfo {
|
||||||
int64_t magic;
|
int64_t magic;
|
||||||
int shift;
|
int shift;
|
||||||
bool add_indicator; // 是否需要额外的加法修正
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// 针对缺少MULH指令的简化magic number计算
|
auto computeMagic = [](int64_t d, bool is_32bit) -> MagicInfo {
|
||||||
auto computeMagicNumber = [](int64_t divisor, bool is_32bit) -> MagicInfo {
|
|
||||||
if (divisor == 0) return {0, 0, false};
|
|
||||||
if (divisor == 1) return {1, 0, false};
|
|
||||||
if (divisor == -1) return {-1, 0, false};
|
|
||||||
|
|
||||||
// 对于没有MULH的情况,我们使用更简单但有效的算法
|
|
||||||
// 基于 2^n / divisor 的近似
|
|
||||||
|
|
||||||
bool neg = divisor < 0;
|
|
||||||
int64_t d = neg ? -divisor : divisor;
|
|
||||||
|
|
||||||
int word_size = is_32bit ? 32 : 64;
|
int word_size = is_32bit ? 32 : 64;
|
||||||
|
uint64_t ad = std::abs(d);
|
||||||
|
|
||||||
// 计算合适的移位量
|
if (ad == 0) return {0, 0};
|
||||||
int shift = word_size;
|
|
||||||
int64_t magic = ((1LL << shift) + d - 1) / d;
|
int l = std::floor(std::log2(ad));
|
||||||
|
if ((ad & (ad - 1)) == 0) { // power of 2
|
||||||
|
l = 0; // special case for power of 2, shift will be calculated differently
|
||||||
|
}
|
||||||
|
|
||||||
|
__int128_t one = 1;
|
||||||
|
__int128_t num;
|
||||||
|
int total_shift;
|
||||||
|
|
||||||
// 调整magic number以适应MUL指令
|
|
||||||
if (is_32bit) {
|
if (is_32bit) {
|
||||||
// 32位情况:调整magic使其适合符号扩展后的乘法
|
total_shift = 31 + l;
|
||||||
shift = 32;
|
num = one << total_shift;
|
||||||
magic = ((1LL << shift) + d - 1) / d;
|
|
||||||
} else {
|
} else {
|
||||||
// 64位情况:使用更保守的算法
|
total_shift = 63 + l;
|
||||||
shift = 32; // 使用32位作为基础移位
|
num = one << total_shift;
|
||||||
magic = ((1LL << shift) + d - 1) / d;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool add_indicator = false;
|
__int128_t den = ad;
|
||||||
|
int64_t magic = (num / den) + 1;
|
||||||
|
|
||||||
// 检查是否需要加法修正
|
return {magic, total_shift};
|
||||||
if (magic >= (1LL << (word_size - 1))) {
|
|
||||||
add_indicator = true;
|
|
||||||
magic -= (1LL << word_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (neg) {
|
|
||||||
magic = -magic;
|
|
||||||
}
|
|
||||||
|
|
||||||
return {magic, shift, add_indicator};
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// 检查是否为2的幂次
|
|
||||||
auto isPowerOfTwo = [](int64_t n) -> bool {
|
auto isPowerOfTwo = [](int64_t n) -> bool {
|
||||||
return n > 0 && (n & (n - 1)) == 0;
|
return n > 0 && (n & (n - 1)) == 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// 获取2的幂次的指数
|
|
||||||
auto getPowerOfTwoExponent = [](int64_t n) -> int {
|
auto getPowerOfTwoExponent = [](int64_t n) -> int {
|
||||||
if (n <= 0 || (n & (n - 1)) != 0) return -1;
|
if (n <= 0 || (n & (n - 1)) != 0) return -1;
|
||||||
int shift = 0;
|
int shift = 0;
|
||||||
@ -90,9 +72,9 @@ void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) {
|
|||||||
return shift;
|
return shift;
|
||||||
};
|
};
|
||||||
|
|
||||||
// 收集需要替换的指令
|
|
||||||
struct InstructionReplacement {
|
struct InstructionReplacement {
|
||||||
size_t index;
|
size_t index;
|
||||||
|
size_t count_to_erase;
|
||||||
std::vector<std::unique_ptr<MachineInstr>> newInstrs;
|
std::vector<std::unique_ptr<MachineInstr>> newInstrs;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -106,7 +88,6 @@ void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) {
|
|||||||
|
|
||||||
bool is_32bit = (instr->getOpcode() == RVOpcodes::DIVW);
|
bool is_32bit = (instr->getOpcode() == RVOpcodes::DIVW);
|
||||||
|
|
||||||
// 只处理 DIV 和 DIVW 指令
|
|
||||||
if (instr->getOpcode() != RVOpcodes::DIV && !is_32bit) {
|
if (instr->getOpcode() != RVOpcodes::DIV && !is_32bit) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -119,99 +100,73 @@ void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) {
|
|||||||
auto *src1_op = instr->getOperands()[1].get();
|
auto *src1_op = instr->getOperands()[1].get();
|
||||||
auto *src2_op = instr->getOperands()[2].get();
|
auto *src2_op = instr->getOperands()[2].get();
|
||||||
|
|
||||||
// 检查操作数类型
|
int64_t divisor = 0;
|
||||||
if (dst_op->getKind() != MachineOperand::KIND_REG ||
|
bool const_divisor_found = false;
|
||||||
src1_op->getKind() != MachineOperand::KIND_REG ||
|
size_t instructions_to_replace = 1;
|
||||||
src2_op->getKind() != MachineOperand::KIND_IMM) {
|
|
||||||
|
if (src2_op->getKind() == MachineOperand::KIND_IMM) {
|
||||||
|
divisor = static_cast<ImmOperand *>(src2_op)->getValue();
|
||||||
|
const_divisor_found = true;
|
||||||
|
} else if (src2_op->getKind() == MachineOperand::KIND_REG) {
|
||||||
|
if (i > 0) {
|
||||||
|
auto *prev_instr = instrs[i - 1].get();
|
||||||
|
if (prev_instr->getOpcode() == RVOpcodes::LI && prev_instr->getOperands().size() == 2) {
|
||||||
|
auto *li_dst_op = prev_instr->getOperands()[0].get();
|
||||||
|
auto *li_imm_op = prev_instr->getOperands()[1].get();
|
||||||
|
if (li_dst_op->getKind() == MachineOperand::KIND_REG && li_imm_op->getKind() == MachineOperand::KIND_IMM) {
|
||||||
|
auto *div_reg_op = static_cast<RegOperand *>(src2_op);
|
||||||
|
auto *li_dst_reg_op = static_cast<RegOperand *>(li_dst_op);
|
||||||
|
if (div_reg_op->isVirtual() && li_dst_reg_op->isVirtual() &&
|
||||||
|
div_reg_op->getVRegNum() == li_dst_reg_op->getVRegNum()) {
|
||||||
|
divisor = static_cast<ImmOperand *>(li_imm_op)->getValue();
|
||||||
|
const_divisor_found = true;
|
||||||
|
instructions_to_replace = 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!const_divisor_found) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto *dst_reg = static_cast<RegOperand *>(dst_op);
|
auto *dst_reg = static_cast<RegOperand *>(dst_op);
|
||||||
auto *src1_reg = static_cast<RegOperand *>(src1_op);
|
auto *src1_reg = static_cast<RegOperand *>(src1_op);
|
||||||
auto *src2_imm = static_cast<ImmOperand *>(src2_op);
|
|
||||||
|
|
||||||
int64_t divisor = src2_imm->getValue();
|
|
||||||
|
|
||||||
// 跳过除数为0的情况
|
|
||||||
if (divisor == 0) continue;
|
if (divisor == 0) continue;
|
||||||
|
|
||||||
std::vector<std::unique_ptr<MachineInstr>> newInstrs;
|
std::vector<std::unique_ptr<MachineInstr>> newInstrs;
|
||||||
|
|
||||||
// 情况1: 除数为1
|
|
||||||
if (divisor == 1) {
|
if (divisor == 1) {
|
||||||
// dst = src1 (直接复制)
|
|
||||||
auto moveInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
|
auto moveInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
|
||||||
moveInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
|
moveInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
|
||||||
moveInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
|
moveInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
|
||||||
moveInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
|
moveInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
|
||||||
newInstrs.push_back(std::move(moveInstr));
|
newInstrs.push_back(std::move(moveInstr));
|
||||||
}
|
}
|
||||||
// 情况2: 除数为-1
|
|
||||||
else if (divisor == -1) {
|
else if (divisor == -1) {
|
||||||
// dst = -src1
|
|
||||||
auto negInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB);
|
auto negInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB);
|
||||||
negInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
|
negInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
|
||||||
negInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
|
negInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
|
||||||
negInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
|
negInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
|
||||||
newInstrs.push_back(std::move(negInstr));
|
newInstrs.push_back(std::move(negInstr));
|
||||||
}
|
}
|
||||||
// 情况3: 正的2的幂次除法
|
else if (isPowerOfTwo(std::abs(divisor))) {
|
||||||
else if (isPowerOfTwo(divisor)) {
|
int shift = getPowerOfTwoExponent(std::abs(divisor));
|
||||||
int shift = getPowerOfTwoExponent(divisor);
|
|
||||||
int temp_reg = createTempReg();
|
int temp_reg = createTempReg();
|
||||||
|
|
||||||
// 对于有符号除法,需要处理负数的舍入
|
|
||||||
// if (src1 < 0) src1 += (divisor - 1)
|
|
||||||
|
|
||||||
// 获取符号位:temp = src1 >> (word_size - 1)
|
|
||||||
auto sraSignInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI);
|
auto sraSignInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI);
|
||||||
sraSignInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
sraSignInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||||
sraSignInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
|
sraSignInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
|
||||||
sraSignInstr->addOperand(std::make_unique<ImmOperand>(is_32bit ? 31 : 63));
|
sraSignInstr->addOperand(std::make_unique<ImmOperand>(is_32bit ? 31 : 63));
|
||||||
newInstrs.push_back(std::move(sraSignInstr));
|
newInstrs.push_back(std::move(sraSignInstr));
|
||||||
|
|
||||||
// 计算偏移:temp = temp >> (word_size - shift)
|
|
||||||
if (shift < (is_32bit ? 32 : 64)) {
|
|
||||||
auto srlInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRLIW : RVOpcodes::SRLI);
|
auto srlInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRLIW : RVOpcodes::SRLI);
|
||||||
srlInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
srlInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||||
srlInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
srlInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||||
srlInstr->addOperand(std::make_unique<ImmOperand>((is_32bit ? 32 : 64) - shift));
|
srlInstr->addOperand(std::make_unique<ImmOperand>((is_32bit ? 32 : 64) - shift));
|
||||||
newInstrs.push_back(std::move(srlInstr));
|
newInstrs.push_back(std::move(srlInstr));
|
||||||
}
|
|
||||||
|
|
||||||
// 加上偏移:temp = src1 + temp
|
|
||||||
auto addInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
|
|
||||||
addInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
|
||||||
addInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
|
|
||||||
addInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
|
||||||
newInstrs.push_back(std::move(addInstr));
|
|
||||||
|
|
||||||
// 最终右移:dst = temp >> shift
|
|
||||||
auto sraInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI);
|
|
||||||
sraInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
|
|
||||||
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
|
||||||
sraInstr->addOperand(std::make_unique<ImmOperand>(shift));
|
|
||||||
newInstrs.push_back(std::move(sraInstr));
|
|
||||||
}
|
|
||||||
// 情况4: 负的2的幂次除法
|
|
||||||
else if (divisor < 0 && isPowerOfTwo(-divisor)) {
|
|
||||||
int shift = getPowerOfTwoExponent(-divisor);
|
|
||||||
int temp_reg = createTempReg();
|
|
||||||
|
|
||||||
// 先按正数处理
|
|
||||||
auto sraSignInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI);
|
|
||||||
sraSignInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
|
||||||
sraSignInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
|
|
||||||
sraSignInstr->addOperand(std::make_unique<ImmOperand>(is_32bit ? 31 : 63));
|
|
||||||
newInstrs.push_back(std::move(sraSignInstr));
|
|
||||||
|
|
||||||
if (shift < (is_32bit ? 32 : 64)) {
|
|
||||||
auto srlInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRLIW : RVOpcodes::SRLI);
|
|
||||||
srlInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
|
||||||
srlInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
|
||||||
srlInstr->addOperand(std::make_unique<ImmOperand>((is_32bit ? 32 : 64) - shift));
|
|
||||||
newInstrs.push_back(std::move(srlInstr));
|
|
||||||
}
|
|
||||||
|
|
||||||
auto addInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
|
auto addInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
|
||||||
addInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
addInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||||
@ -225,100 +180,98 @@ void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) {
|
|||||||
sraInstr->addOperand(std::make_unique<ImmOperand>(shift));
|
sraInstr->addOperand(std::make_unique<ImmOperand>(shift));
|
||||||
newInstrs.push_back(std::move(sraInstr));
|
newInstrs.push_back(std::move(sraInstr));
|
||||||
|
|
||||||
// 然后取反
|
if (divisor < 0) {
|
||||||
auto negInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB);
|
auto negInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB);
|
||||||
negInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
|
negInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
|
||||||
negInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
|
negInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
|
||||||
negInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
negInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||||
newInstrs.push_back(std::move(negInstr));
|
newInstrs.push_back(std::move(negInstr));
|
||||||
|
} else {
|
||||||
|
auto moveInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
|
||||||
|
moveInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
|
||||||
|
moveInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||||
|
moveInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
|
||||||
|
newInstrs.push_back(std::move(moveInstr));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// 情况5: 通用magic number算法(针对没有MULH的情况进行了简化)
|
|
||||||
else {
|
else {
|
||||||
// 对于一般除法,在没有MULH的情况下,我们采用更保守的策略
|
auto magic_info = computeMagic(divisor, is_32bit);
|
||||||
// 只处理一些简单的常数除法,复杂的情况保持原始除法指令
|
|
||||||
|
|
||||||
// 检查是否为小的常数(可以用简单乘法处理)
|
|
||||||
if (std::abs(divisor) <= 1024) { // 限制在较小的除数范围内
|
|
||||||
auto magic_info = computeMagicNumber(divisor, is_32bit);
|
|
||||||
|
|
||||||
if (magic_info.magic == 0) continue;
|
|
||||||
|
|
||||||
int magic_reg = createTempReg();
|
int magic_reg = createTempReg();
|
||||||
int temp_reg = createTempReg();
|
int temp_reg = createTempReg();
|
||||||
|
|
||||||
// 加载magic number到寄存器
|
|
||||||
auto loadInstr = std::make_unique<MachineInstr>(RVOpcodes::LI);
|
auto loadInstr = std::make_unique<MachineInstr>(RVOpcodes::LI);
|
||||||
loadInstr->addOperand(std::make_unique<RegOperand>(magic_reg));
|
loadInstr->addOperand(std::make_unique<RegOperand>(magic_reg));
|
||||||
loadInstr->addOperand(std::make_unique<ImmOperand>(magic_info.magic));
|
loadInstr->addOperand(std::make_unique<ImmOperand>(magic_info.magic));
|
||||||
newInstrs.push_back(std::move(loadInstr));
|
newInstrs.push_back(std::move(loadInstr));
|
||||||
|
|
||||||
// 使用普通乘法模拟高位乘法
|
|
||||||
if (is_32bit) {
|
if (is_32bit) {
|
||||||
// 32位:使用MULW
|
|
||||||
auto mulInstr = std::make_unique<MachineInstr>(RVOpcodes::MULW);
|
|
||||||
mulInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
|
||||||
mulInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
|
|
||||||
mulInstr->addOperand(std::make_unique<RegOperand>(magic_reg));
|
|
||||||
newInstrs.push_back(std::move(mulInstr));
|
|
||||||
|
|
||||||
// 右移得到近似结果
|
|
||||||
auto sraInstr = std::make_unique<MachineInstr>(RVOpcodes::SRAIW);
|
|
||||||
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
|
||||||
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
|
||||||
sraInstr->addOperand(std::make_unique<ImmOperand>(magic_info.shift));
|
|
||||||
newInstrs.push_back(std::move(sraInstr));
|
|
||||||
} else {
|
|
||||||
// 64位:使用MUL
|
|
||||||
auto mulInstr = std::make_unique<MachineInstr>(RVOpcodes::MUL);
|
auto mulInstr = std::make_unique<MachineInstr>(RVOpcodes::MUL);
|
||||||
mulInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
mulInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||||
mulInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
|
mulInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
|
||||||
mulInstr->addOperand(std::make_unique<RegOperand>(magic_reg));
|
mulInstr->addOperand(std::make_unique<RegOperand>(magic_reg));
|
||||||
newInstrs.push_back(std::move(mulInstr));
|
newInstrs.push_back(std::move(mulInstr));
|
||||||
|
|
||||||
// 右移得到近似结果
|
|
||||||
auto sraInstr = std::make_unique<MachineInstr>(RVOpcodes::SRAI);
|
auto sraInstr = std::make_unique<MachineInstr>(RVOpcodes::SRAI);
|
||||||
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||||
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||||
sraInstr->addOperand(std::make_unique<ImmOperand>(magic_info.shift));
|
sraInstr->addOperand(std::make_unique<ImmOperand>(magic_info.shift));
|
||||||
newInstrs.push_back(std::move(sraInstr));
|
newInstrs.push_back(std::move(sraInstr));
|
||||||
|
} else {
|
||||||
|
auto mulhInstr = std::make_unique<MachineInstr>(RVOpcodes::MULH);
|
||||||
|
mulhInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||||
|
mulhInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
|
||||||
|
mulhInstr->addOperand(std::make_unique<RegOperand>(magic_reg));
|
||||||
|
newInstrs.push_back(std::move(mulhInstr));
|
||||||
|
|
||||||
|
int post_shift = magic_info.shift - 63;
|
||||||
|
if (post_shift > 0) {
|
||||||
|
auto sraInstr = std::make_unique<MachineInstr>(RVOpcodes::SRAI);
|
||||||
|
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||||
|
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||||
|
sraInstr->addOperand(std::make_unique<ImmOperand>(post_shift));
|
||||||
|
newInstrs.push_back(std::move(sraInstr));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 符号修正:处理负数被除数
|
|
||||||
int sign_reg = createTempReg();
|
int sign_reg = createTempReg();
|
||||||
|
|
||||||
// 获取被除数的符号位
|
|
||||||
auto sraSignInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI);
|
auto sraSignInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI);
|
||||||
sraSignInstr->addOperand(std::make_unique<RegOperand>(sign_reg));
|
sraSignInstr->addOperand(std::make_unique<RegOperand>(sign_reg));
|
||||||
sraSignInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
|
sraSignInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
|
||||||
sraSignInstr->addOperand(std::make_unique<ImmOperand>(is_32bit ? 31 : 63));
|
sraSignInstr->addOperand(std::make_unique<ImmOperand>(is_32bit ? 31 : 63));
|
||||||
newInstrs.push_back(std::move(sraSignInstr));
|
newInstrs.push_back(std::move(sraSignInstr));
|
||||||
|
|
||||||
// 最终结果:dst = temp - sign(对于正除数)或 dst = temp + sign(对于负除数)
|
auto subInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB);
|
||||||
if (divisor > 0) {
|
subInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||||
auto finalSubInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB);
|
subInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||||
finalSubInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
|
subInstr->addOperand(std::make_unique<RegOperand>(sign_reg));
|
||||||
finalSubInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
newInstrs.push_back(std::move(subInstr));
|
||||||
finalSubInstr->addOperand(std::make_unique<RegOperand>(sign_reg));
|
|
||||||
newInstrs.push_back(std::move(finalSubInstr));
|
if (divisor < 0) {
|
||||||
|
auto negInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB);
|
||||||
|
negInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
|
||||||
|
negInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
|
||||||
|
negInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||||
|
newInstrs.push_back(std::move(negInstr));
|
||||||
} else {
|
} else {
|
||||||
auto finalAddInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
|
auto moveInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
|
||||||
finalAddInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
|
moveInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
|
||||||
finalAddInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
moveInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||||
finalAddInstr->addOperand(std::make_unique<RegOperand>(sign_reg));
|
moveInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
|
||||||
newInstrs.push_back(std::move(finalAddInstr));
|
newInstrs.push_back(std::move(moveInstr));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// 对于大的除数或复杂情况,保持原始除法指令不变
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!newInstrs.empty()) {
|
if (!newInstrs.empty()) {
|
||||||
replacements.push_back({i, std::move(newInstrs)});
|
size_t start_index = i;
|
||||||
|
if (instructions_to_replace == 2) {
|
||||||
|
start_index = i - 1;
|
||||||
|
}
|
||||||
|
replacements.push_back({start_index, instructions_to_replace, std::move(newInstrs)});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 批量应用替换(从后往前处理避免索引问题)
|
|
||||||
for (auto it = replacements.rbegin(); it != replacements.rend(); ++it) {
|
for (auto it = replacements.rbegin(); it != replacements.rend(); ++it) {
|
||||||
instrs.erase(instrs.begin() + it->index);
|
instrs.erase(instrs.begin() + it->index, instrs.begin() + it->index + it->count_to_erase);
|
||||||
instrs.insert(instrs.begin() + it->index,
|
instrs.insert(instrs.begin() + it->index,
|
||||||
std::make_move_iterator(it->newInstrs.begin()),
|
std::make_move_iterator(it->newInstrs.begin()),
|
||||||
std::make_move_iterator(it->newInstrs.end()));
|
std::make_move_iterator(it->newInstrs.end()));
|
||||||
|
|||||||
@ -60,7 +60,7 @@ void RISCv64AsmPrinter::printInstruction(MachineInstr* instr, bool debug) {
|
|||||||
case RVOpcodes::ADD: *OS << "add "; break; case RVOpcodes::ADDI: *OS << "addi "; break;
|
case RVOpcodes::ADD: *OS << "add "; break; case RVOpcodes::ADDI: *OS << "addi "; break;
|
||||||
case RVOpcodes::ADDW: *OS << "addw "; break; case RVOpcodes::ADDIW: *OS << "addiw "; break;
|
case RVOpcodes::ADDW: *OS << "addw "; break; case RVOpcodes::ADDIW: *OS << "addiw "; break;
|
||||||
case RVOpcodes::SUB: *OS << "sub "; break; case RVOpcodes::SUBW: *OS << "subw "; break;
|
case RVOpcodes::SUB: *OS << "sub "; break; case RVOpcodes::SUBW: *OS << "subw "; break;
|
||||||
case RVOpcodes::MUL: *OS << "mul "; break; case RVOpcodes::MULW: *OS << "mulw "; break;
|
case RVOpcodes::MUL: *OS << "mul "; break; case RVOpcodes::MULW: *OS << "mulw "; break; case RVOpcodes::MULH: *OS << "mulh "; break;
|
||||||
case RVOpcodes::DIV: *OS << "div "; break; case RVOpcodes::DIVW: *OS << "divw "; break;
|
case RVOpcodes::DIV: *OS << "div "; break; case RVOpcodes::DIVW: *OS << "divw "; break;
|
||||||
case RVOpcodes::REM: *OS << "rem "; break; case RVOpcodes::REMW: *OS << "remw "; break;
|
case RVOpcodes::REM: *OS << "rem "; break; case RVOpcodes::REMW: *OS << "remw "; break;
|
||||||
case RVOpcodes::XOR: *OS << "xor "; break; case RVOpcodes::XORI: *OS << "xori "; break;
|
case RVOpcodes::XOR: *OS << "xor "; break; case RVOpcodes::XORI: *OS << "xori "; break;
|
||||||
|
|||||||
@ -45,7 +45,7 @@ enum class PhysicalReg {
|
|||||||
// RISC-V 指令操作码枚举
|
// RISC-V 指令操作码枚举
|
||||||
enum class RVOpcodes {
|
enum class RVOpcodes {
|
||||||
// 算术指令
|
// 算术指令
|
||||||
ADD, ADDI, ADDW, ADDIW, SUB, SUBW, MUL, MULW, DIV, DIVW, REM, REMW,
|
ADD, ADDI, ADDW, ADDIW, SUB, SUBW, MUL, MULW, MULH, DIV, DIVW, REM, REMW,
|
||||||
// 逻辑指令
|
// 逻辑指令
|
||||||
XOR, XORI, OR, ORI, AND, ANDI,
|
XOR, XORI, OR, ORI, AND, ANDI,
|
||||||
// 移位指令
|
// 移位指令
|
||||||
|
|||||||
@ -709,7 +709,7 @@ class Instruction : public User {
|
|||||||
kBitItoF = 0x1UL << 40,
|
kBitItoF = 0x1UL << 40,
|
||||||
kBitFtoI = 0x1UL << 41,
|
kBitFtoI = 0x1UL << 41,
|
||||||
kSRA = 0x1UL << 42,
|
kSRA = 0x1UL << 42,
|
||||||
kMulh = 0x1UL << 43,
|
kMulh = 0x1UL << 43
|
||||||
};
|
};
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
@ -823,7 +823,7 @@ public:
|
|||||||
|
|
||||||
bool isBinary() const {
|
bool isBinary() const {
|
||||||
static constexpr uint64_t BinaryOpMask =
|
static constexpr uint64_t BinaryOpMask =
|
||||||
(kAdd | kSub | kMul | kDiv | kRem | kAnd | kOr | kSRA) |
|
(kAdd | kSub | kMul | kDiv | kRem | kAnd | kOr | kSRA | kMulh) |
|
||||||
(kICmpEQ | kICmpNE | kICmpLT | kICmpGT | kICmpLE | kICmpGE);
|
(kICmpEQ | kICmpNE | kICmpLT | kICmpGT | kICmpLE | kICmpGE);
|
||||||
return kind & BinaryOpMask;
|
return kind & BinaryOpMask;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -220,6 +220,9 @@ class IRBuilder {
|
|||||||
BinaryInst * createSRAInst(Value *lhs, Value *rhs, const std::string &name = "") {
|
BinaryInst * createSRAInst(Value *lhs, Value *rhs, const std::string &name = "") {
|
||||||
return createBinaryInst(Instruction::kSRA, Type::getIntType(), lhs, rhs, name);
|
return createBinaryInst(Instruction::kSRA, Type::getIntType(), lhs, rhs, name);
|
||||||
} ///< 创建算术右移指令
|
} ///< 创建算术右移指令
|
||||||
|
BinaryInst * createMulhInst(Value *lhs, Value *rhs, const std::string &name = "") {
|
||||||
|
return createBinaryInst(Instruction::kMulh, Type::getIntType(), lhs, rhs, name);
|
||||||
|
} ///< 创建高位乘法指令
|
||||||
CallInst * createCallInst(Function *callee, const std::vector<Value *> &args, const std::string &name = "") {
|
CallInst * createCallInst(Function *callee, const std::vector<Value *> &args, const std::string &name = "") {
|
||||||
std::string newName;
|
std::string newName;
|
||||||
if (name.empty() && callee->getReturnType() != Type::getVoidType()) {
|
if (name.empty() && callee->getReturnType() != Type::getVoidType()) {
|
||||||
|
|||||||
@ -15,6 +15,29 @@
|
|||||||
using namespace std;
|
using namespace std;
|
||||||
namespace sysy {
|
namespace sysy {
|
||||||
|
|
||||||
|
std::pair<long long, int> calculate_signed_magic(int d) {
|
||||||
|
if (d == 0) throw std::runtime_error("Division by zero");
|
||||||
|
if (d == 1 || d == -1) return {0, 0}; // Not used by strength reduction
|
||||||
|
|
||||||
|
int k = 0;
|
||||||
|
unsigned int ad = (d > 0) ? d : -d;
|
||||||
|
unsigned int temp = ad;
|
||||||
|
while (temp > 0) {
|
||||||
|
temp >>= 1;
|
||||||
|
k++;
|
||||||
|
}
|
||||||
|
if ((ad & (ad - 1)) == 0) { // if power of 2
|
||||||
|
k--;
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned __int128 m_val = 1;
|
||||||
|
m_val <<= (32 + k - 1);
|
||||||
|
unsigned __int128 m_prime = m_val / ad;
|
||||||
|
long long m = m_prime + 1;
|
||||||
|
|
||||||
|
return {m, k};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// std::vector<Value*> BinaryValueStack; ///< 用于存储value的栈
|
// std::vector<Value*> BinaryValueStack; ///< 用于存储value的栈
|
||||||
// std::vector<int> BinaryOpStack; ///< 用于存储二元表达式的操作符栈
|
// std::vector<int> BinaryOpStack; ///< 用于存储二元表达式的操作符栈
|
||||||
|
|||||||
9
test_div_optimization.sy
Normal file
9
test_div_optimization.sy
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
int main() {
|
||||||
|
int a = 100;
|
||||||
|
int b = a / 4;
|
||||||
|
int c = a / 8;
|
||||||
|
int d = a / 16;
|
||||||
|
int e = a / 7;
|
||||||
|
int f = a / 3;
|
||||||
|
return b + c + d + e;
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user