@ -1,4 +1,6 @@
# include "DivStrengthReduction.h"
# include <cmath>
# include <cstdint>
namespace sysy {
@ -17,69 +19,49 @@ void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) {
if ( debug )
std : : cout < < " Running DivStrengthReduction optimization... " < < std : : endl ;
// 虚拟寄存器分配器
int next_temp_reg = 1000 ;
auto createTempReg = [ & ] ( ) - > int {
return next_temp_reg + + ;
} ;
// Magic number 信息结构
struct MagicInfo {
int64_t magic ;
int shift ;
bool add_indicator ; // 是否需要额外的加法修正
} ;
// 针对缺少MULH指令的简化magic number计算
auto computeMagicNumber = [ ] ( int64_t divisor , bool is_32bit ) - > MagicInfo {
if ( divisor = = 0 ) return { 0 , 0 , false } ;
if ( divisor = = 1 ) return { 1 , 0 , false } ;
if ( divisor = = - 1 ) return { - 1 , 0 , false } ;
// 对于没有MULH的情况, 我们使用更简单但有效的算法
// 基于 2^n / divisor 的近似
bool neg = divisor < 0 ;
int64_t d = neg ? - divisor : divisor ;
auto computeMagic = [ ] ( int64_t d , bool is_32bit ) - > MagicInfo {
int word_size = is_32bit ? 32 : 64 ;
uint64_t ad = std : : abs ( d ) ;
// 计算合适的移位量
int shift = word_size ;
int64_t magic = ( ( 1LL < < shift ) + d - 1 ) / d ;
if ( ad = = 0 ) return { 0 , 0 } ;
int l = std : : floor ( std : : log2 ( ad ) ) ;
if ( ( ad & ( ad - 1 ) ) = = 0 ) { // power of 2
l = 0 ; // special case for power of 2, shift will be calculated differently
}
__int128_t one = 1 ;
__int128_t num ;
int total_shift ;
// 调整magic number以适应MUL指令
if ( is_32bit ) {
// 32位情况: 调整magic使其适合符号扩展后的乘法
shift = 32 ;
magic = ( ( 1LL < < shift ) + d - 1 ) / d ;
total_shift = 31 + l ;
num = one < < total_shift ;
} else {
// 64位情况: 使用更保守的算法
shift = 32 ; // 使用32位作为基础移位
magic = ( ( 1LL < < shift ) + d - 1 ) / d ;
total_shift = 63 + l ;
num = one < < total_shift ;
}
bool add_indicator = false ;
__int128_t den = ad ;
int64_t magic = ( num / den ) + 1 ;
// 检查是否需要加法修正
if ( magic > = ( 1LL < < ( word_size - 1 ) ) ) {
add_indicator = true ;
magic - = ( 1LL < < word_size ) ;
}
if ( neg ) {
magic = - magic ;
}
return { magic , shift , add_indicator } ;
return { magic , total_shift } ;
} ;
// 检查是否为2的幂次
auto isPowerOfTwo = [ ] ( int64_t n ) - > bool {
return n > 0 & & ( n & ( n - 1 ) ) = = 0 ;
} ;
// 获取2的幂次的指数
auto getPowerOfTwoExponent = [ ] ( int64_t n ) - > int {
if ( n < = 0 | | ( n & ( n - 1 ) ) ! = 0 ) return - 1 ;
int shift = 0 ;
@ -90,9 +72,9 @@ void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) {
return shift ;
} ;
// 收集需要替换的指令
struct InstructionReplacement {
size_t index ;
size_t count_to_erase ;
std : : vector < std : : unique_ptr < MachineInstr > > newInstrs ;
} ;
@ -106,7 +88,6 @@ void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) {
bool is_32bit = ( instr - > getOpcode ( ) = = RVOpcodes : : DIVW ) ;
// 只处理 DIV 和 DIVW 指令
if ( instr - > getOpcode ( ) ! = RVOpcodes : : DIV & & ! is_32bit ) {
continue ;
}
@ -119,99 +100,73 @@ void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) {
auto * src1_op = instr - > getOperands ( ) [ 1 ] . get ( ) ;
auto * src2_op = instr - > getOperands ( ) [ 2 ] . get ( ) ;
// 检查操作数类型
if ( dst_op - > getKind ( ) ! = MachineOperand : : KIND_REG | |
src1_op - > getKind ( ) ! = MachineOperand : : KIND_REG | |
src2_op - > getKind ( ) ! = MachineOperand : : KIND_IMM ) {
int64_t divisor = 0 ;
bool const_divisor_found = false ;
size_t instructions_to_replace = 1 ;
if ( src2_op - > getKind ( ) = = MachineOperand : : KIND_IMM ) {
divisor = static_cast < ImmOperand * > ( src2_op ) - > getValue ( ) ;
const_divisor_found = true ;
} else if ( src2_op - > getKind ( ) = = MachineOperand : : KIND_REG ) {
if ( i > 0 ) {
auto * prev_instr = instrs [ i - 1 ] . get ( ) ;
if ( prev_instr - > getOpcode ( ) = = RVOpcodes : : LI & & prev_instr - > getOperands ( ) . size ( ) = = 2 ) {
auto * li_dst_op = prev_instr - > getOperands ( ) [ 0 ] . get ( ) ;
auto * li_imm_op = prev_instr - > getOperands ( ) [ 1 ] . get ( ) ;
if ( li_dst_op - > getKind ( ) = = MachineOperand : : KIND_REG & & li_imm_op - > getKind ( ) = = MachineOperand : : KIND_IMM ) {
auto * div_reg_op = static_cast < RegOperand * > ( src2_op ) ;
auto * li_dst_reg_op = static_cast < RegOperand * > ( li_dst_op ) ;
if ( div_reg_op - > isVirtual ( ) & & li_dst_reg_op - > isVirtual ( ) & &
div_reg_op - > getVRegNum ( ) = = li_dst_reg_op - > getVRegNum ( ) ) {
divisor = static_cast < ImmOperand * > ( li_imm_op ) - > getValue ( ) ;
const_divisor_found = true ;
instructions_to_replace = 2 ;
}
}
}
}
}
if ( ! const_divisor_found ) {
continue ;
}
auto * dst_reg = static_cast < RegOperand * > ( dst_op ) ;
auto * src1_reg = static_cast < RegOperand * > ( src1_op ) ;
auto * src2_imm = static_cast < ImmOperand * > ( src2_op ) ;
int64_t divisor = src2_imm - > getValue ( ) ;
// 跳过除数为0的情况
if ( divisor = = 0 ) continue ;
std : : vector < std : : unique_ptr < MachineInstr > > newInstrs ;
// 情况1: 除数为1
if ( divisor = = 1 ) {
// dst = src1 (直接复制)
auto moveInstr = std : : make_unique < MachineInstr > ( is_32bit ? RVOpcodes : : ADDW : RVOpcodes : : ADD ) ;
moveInstr - > addOperand ( std : : make_unique < RegOperand > ( * dst_reg ) ) ;
moveInstr - > addOperand ( std : : make_unique < RegOperand > ( * src1_reg ) ) ;
moveInstr - > addOperand ( std : : make_unique < RegOperand > ( PhysicalReg : : ZERO ) ) ;
newInstrs . push_back ( std : : move ( moveInstr ) ) ;
}
// 情况2: 除数为-1
else if ( divisor = = - 1 ) {
// dst = -src1
auto negInstr = std : : make_unique < MachineInstr > ( is_32bit ? RVOpcodes : : SUBW : RVOpcodes : : SUB ) ;
negInstr - > addOperand ( std : : make_unique < RegOperand > ( * dst_reg ) ) ;
negInstr - > addOperand ( std : : make_unique < RegOperand > ( PhysicalReg : : ZERO ) ) ;
negInstr - > addOperand ( std : : make_unique < RegOperand > ( * src1_reg ) ) ;
newInstrs . push_back ( std : : move ( negInstr ) ) ;
}
// 情况3: 正的2的幂次除法
else if ( is PowerOfTwo( divisor ) ) {
int shift = getPowerOfTwoExponent ( divisor ) ;
else if ( isPowerOfTwo ( std : : abs ( divisor ) ) ) {
int sh ift = get PowerOfTwoExponent ( std : : abs ( divisor ) ) ;
int temp_reg = createTempReg ( ) ;
// 对于有符号除法,需要处理负数的舍入
// if (src1 < 0) src1 += (divisor - 1)
// 获取符号位: temp = src1 >> (word_size - 1)
auto sraSignInstr = std : : make_unique < MachineInstr > ( is_32bit ? RVOpcodes : : SRAIW : RVOpcodes : : SRAI ) ;
sraSignInstr - > addOperand ( std : : make_unique < RegOperand > ( temp_reg ) ) ;
sraSignInstr - > addOperand ( std : : make_unique < RegOperand > ( * src1_reg ) ) ;
sraSignInstr - > addOperand ( std : : make_unique < ImmOperand > ( is_32bit ? 31 : 63 ) ) ;
newInstrs . push_back ( std : : move ( sraSignInstr ) ) ;
// 计算偏移: temp = temp >> (word_size - shift)
if ( shift < ( is_32bit ? 32 : 64 ) ) {
auto srlInstr = std : : make_unique < MachineInstr > ( is_32bit ? RVOpcodes : : SRLIW : RVOpcodes : : SRLI ) ;
srlInstr - > addOperand ( std : : make_unique < RegOperand > ( temp_reg ) ) ;
srlInstr - > addOperand ( std : : make_unique < RegOperand > ( temp_reg ) ) ;
srlInstr - > addOperand ( std : : make_unique < ImmOperand > ( ( is_32bit ? 32 : 64 ) - shift ) ) ;
newInstrs . push_back ( std : : move ( srlInstr ) ) ;
}
// 加上偏移: temp = src1 + temp
auto addInstr = std : : make_unique < MachineInstr > ( is_32bit ? RVOpcodes : : ADDW : RVOpcodes : : ADD ) ;
addInstr - > addOperand ( std : : make_unique < RegOperand > ( temp_reg ) ) ;
addInstr - > addOperand ( std : : make_unique < RegOperand > ( * src1_reg ) ) ;
addInstr - > addOperand ( std : : make_unique < RegOperand > ( temp_reg ) ) ;
newInstrs . push_back ( std : : move ( addInstr ) ) ;
// 最终右移: dst = temp >> shift
auto sraInstr = std : : make_unique < MachineInstr > ( is_32bit ? RVOpcodes : : SRAIW : RVOpcodes : : SRAI ) ;
sraInstr - > addOperand ( std : : make_unique < RegOperand > ( * dst_reg ) ) ;
sraInstr - > addOperand ( std : : make_unique < RegOperand > ( temp_reg ) ) ;
sraInstr - > addOperand ( std : : make_unique < ImmOperand > ( shift ) ) ;
newInstrs . push_back ( std : : move ( sraInstr ) ) ;
}
// 情况4: 负的2的幂次除法
else if ( divisor < 0 & & isPowerOfTwo ( - divisor ) ) {
int shift = getPowerOfTwoExponent ( - divisor ) ;
int temp_reg = createTempReg ( ) ;
// 先按正数处理
auto sraSignInstr = std : : make_unique < MachineInstr > ( is_32bit ? RVOpcodes : : SRAIW : RVOpcodes : : SRAI ) ;
sraSignInstr - > addOperand ( std : : make_unique < RegOperand > ( temp_reg ) ) ;
sraSignInstr - > addOperand ( std : : make_unique < RegOperand > ( * src1_reg ) ) ;
sraSignInstr - > addOperand ( std : : make_unique < ImmOperand > ( is_32bit ? 31 : 63 ) ) ;
newInstrs . push_back ( std : : move ( sraSignInstr ) ) ;
if ( shift < ( is_32bit ? 32 : 64 ) ) {
auto srlInstr = std : : make_unique < MachineInstr > ( is_32bit ? RVOpcodes : : SRLIW : RVOpcodes : : SRLI ) ;
srlInstr - > addOperand ( std : : make_unique < RegOperand > ( temp_reg ) ) ;
srlInstr - > addOperand ( std : : make_unique < RegOperand > ( temp_reg ) ) ;
srlInstr - > addOperand ( std : : make_unique < ImmOperand > ( ( is_32bit ? 32 : 64 ) - shift ) ) ;
newInstrs . push_back ( std : : move ( srlInstr ) ) ;
}
auto addInstr = std : : make_unique < MachineInstr > ( is_32bit ? RVOpcodes : : ADDW : RVOpcodes : : ADD ) ;
addInstr - > addOperand ( std : : make_unique < RegOperand > ( temp_reg ) ) ;
@ -225,100 +180,98 @@ void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) {
sraInstr - > addOperand ( std : : make_unique < ImmOperand > ( shift ) ) ;
newInstrs . push_back ( std : : move ( sraInstr ) ) ;
// 然后取反
if ( divisor < 0 ) {
auto negInstr = std : : make_unique < MachineInstr > ( is_32bit ? RVOpcodes : : SUBW : RVOpcodes : : SUB ) ;
negInstr - > addOperand ( std : : make_unique < RegOperand > ( * dst_reg ) ) ;
negInstr - > addOperand ( std : : make_unique < RegOperand > ( PhysicalReg : : ZERO ) ) ;
negInstr - > addOperand ( std : : make_unique < RegOperand > ( temp_reg ) ) ;
newInstrs . push_back ( std : : move ( negInstr ) ) ;
} else {
auto moveInstr = std : : make_unique < MachineInstr > ( is_32bit ? RVOpcodes : : ADDW : RVOpcodes : : ADD ) ;
moveInstr - > addOperand ( std : : make_unique < RegOperand > ( * dst_reg ) ) ;
moveInstr - > addOperand ( std : : make_unique < RegOperand > ( temp_reg ) ) ;
moveInstr - > addOperand ( std : : make_unique < RegOperand > ( PhysicalReg : : ZERO ) ) ;
newInstrs . push_back ( std : : move ( moveInstr ) ) ;
}
}
// 情况5: 通用magic number算法( 针对没有MULH的情况进行了简化)
else {
// 对于一般除法, 在没有MULH的情况下, 我们采用更保守的策略
// 只处理一些简单的常数除法,复杂的情况保持原始除法指令
// 检查是否为小的常数(可以用简单乘法处理)
if ( std : : abs ( divisor ) < = 1024 ) { // 限制在较小的除数范围内
auto magic_info = computeMagicNumber ( divisor , is_32bit ) ;
if ( magic_info . magic = = 0 ) continue ;
auto magic_info = computeMagic ( divisor , is_32bit ) ;
int magic_reg = createTempReg ( ) ;
int temp_reg = createTempReg ( ) ;
// 加载magic number到寄存器
auto loadInstr = std : : make_unique < MachineInstr > ( RVOpcodes : : LI ) ;
loadInstr - > addOperand ( std : : make_unique < RegOperand > ( magic_reg ) ) ;
loadInstr - > addOperand ( std : : make_unique < ImmOperand > ( magic_info . magic ) ) ;
newInstrs . push_back ( std : : move ( loadInstr ) ) ;
// 使用普通乘法模拟高位乘法
if ( is_32bit ) {
// 32位: 使用MULW
auto mulInstr = std : : make_unique < MachineInstr > ( RVOpcodes : : MULW ) ;
mulInstr - > addOperand ( std : : make_unique < RegOperand > ( temp_reg ) ) ;
mulInstr - > addOperand ( std : : make_unique < RegOperand > ( * src1_reg ) ) ;
mulInstr - > addOperand ( std : : make_unique < RegOperand > ( magic_reg ) ) ;
newInstrs . push_back ( std : : move ( mulInstr ) ) ;
// 右移得到近似结果
auto sraInstr = std : : make_unique < MachineInstr > ( RVOpcodes : : SRAIW ) ;
sraInstr - > addOperand ( std : : make_unique < RegOperand > ( temp_reg ) ) ;
sraInstr - > addOperand ( std : : make_unique < RegOperand > ( temp_reg ) ) ;
sraInstr - > addOperand ( std : : make_unique < ImmOperand > ( magic_info . shift ) ) ;
newInstrs . push_back ( std : : move ( sraInstr ) ) ;
} else {
// 64位: 使用MUL
auto mulInstr = std : : make_unique < MachineInstr > ( RVOpcodes : : MUL ) ;
mulInstr - > addOperand ( std : : make_unique < RegOperand > ( temp_reg ) ) ;
mulInstr - > addOperand ( std : : make_unique < RegOperand > ( * src1_reg ) ) ;
mulInstr - > addOperand ( std : : make_unique < RegOperand > ( magic_reg ) ) ;
newInstrs . push_back ( std : : move ( mulInstr ) ) ;
// 右移得到近似结果
auto sraInstr = std : : make_unique < MachineInstr > ( RVOpcodes : : SRAI ) ;
sraInstr - > addOperand ( std : : make_unique < RegOperand > ( temp_reg ) ) ;
sraInstr - > addOperand ( std : : make_unique < RegOperand > ( temp_reg ) ) ;
sraInstr - > addOperand ( std : : make_unique < ImmOperand > ( magic_info . shift ) ) ;
newInstrs . push_back ( std : : move ( sraInstr ) ) ;
} else {
auto mulhInstr = std : : make_unique < MachineInstr > ( RVOpcodes : : MULH ) ;
mulhInstr - > addOperand ( std : : make_unique < RegOperand > ( temp_reg ) ) ;
mulhInstr - > addOperand ( std : : make_unique < RegOperand > ( * src1_reg ) ) ;
mulhInstr - > addOperand ( std : : make_unique < RegOperand > ( magic_reg ) ) ;
newInstrs . push_back ( std : : move ( mulhInstr ) ) ;
int post_shift = magic_info . shift - 63 ;
if ( post_shift > 0 ) {
auto sraInstr = std : : make_unique < MachineInstr > ( RVOpcodes : : SRAI ) ;
sraInstr - > addOperand ( std : : make_unique < RegOperand > ( temp_reg ) ) ;
sraInstr - > addOperand ( std : : make_unique < RegOperand > ( temp_reg ) ) ;
sraInstr - > addOperand ( std : : make_unique < ImmOperand > ( post_shift ) ) ;
newInstrs . push_back ( std : : move ( sraInstr ) ) ;
}
}
// 符号修正:处理负数被除数
int sign_reg = createTempReg ( ) ;
// 获取被除数的符号位
auto sraSignInstr = std : : make_unique < MachineInstr > ( is_32bit ? RVOpcodes : : SRAIW : RVOpcodes : : SRAI ) ;
sraSignInstr - > addOperand ( std : : make_unique < RegOperand > ( sign_reg ) ) ;
sraSignInstr - > addOperand ( std : : make_unique < RegOperand > ( * src1_reg ) ) ;
sraSignInstr - > addOperand ( std : : make_unique < ImmOperand > ( is_32bit ? 31 : 63 ) ) ;
newInstrs . push_back ( std : : move ( sraSignInstr ) ) ;
// 最终结果: dst = temp - sign( 对于正除数) 或 dst = temp + sign( 对于负除数)
if ( divisor > 0 ) {
auto finalSubInstr = std : : make_unique < MachineInstr > ( is_32bit ? RVOpcodes : : SUBW : RVOpcodes : : SUB ) ;
finalS ubInstr - > addOperand ( std : : make_unique < RegOperand > ( * dst _reg) ) ;
finalSubInstr - > addOperand ( std : : make_unique < RegOperand > ( temp_reg ) ) ;
finalSubInstr - > addOperand ( std : : make_unique < RegOperand > ( sign_reg ) ) ;
newInstrs . push_back ( std : : move ( finalSubInstr ) ) ;
auto subInstr = std : : make_unique < MachineInstr > ( is_32bit ? RVOpcodes : : SUBW : RVOpcodes : : SUB ) ;
subInstr - > addOperand ( std : : make_unique < RegOperand > ( temp_reg ) ) ;
subInstr - > addOperand ( std : : make_unique < RegOperand > ( temp_reg ) ) ;
s ubInstr- > addOperand ( std : : make_unique < RegOperand > ( sign _reg) ) ;
newInstrs . push_back ( std : : move ( subInstr ) ) ;
if ( divisor < 0 ) {
auto negInstr = std : : make_unique < MachineInstr > ( is_32bit ? RVOpcodes : : SUBW : RVOpcodes : : SUB ) ;
negInstr - > addOperand ( std : : make_unique < RegOperand > ( * dst_reg ) ) ;
negInstr - > addOperand ( std : : make_unique < RegOperand > ( PhysicalReg : : ZERO ) ) ;
negInstr - > addOperand ( std : : make_unique < RegOperand > ( temp_reg ) ) ;
newInstrs . push_back ( std : : move ( negInstr ) ) ;
} else {
auto finalAdd Instr = std : : make_unique < MachineInstr > ( is_32bit ? RVOpcodes : : ADDW : RVOpcodes : : ADD ) ;
finalAdd Instr - > addOperand ( std : : make_unique < RegOperand > ( * dst_reg ) ) ;
finalAdd Instr - > addOperand ( std : : make_unique < RegOperand > ( temp_reg ) ) ;
finalAdd Instr - > addOperand ( std : : make_unique < RegOperand > ( sign_reg ) ) ;
newInstrs . push_back ( std : : move ( finalAdd Instr) ) ;
auto move Instr = std : : make_unique < MachineInstr > ( is_32bit ? RVOpcodes : : ADDW : RVOpcodes : : ADD ) ;
move Instr- > addOperand ( std : : make_unique < RegOperand > ( * dst_reg ) ) ;
move Instr- > addOperand ( std : : make_unique < RegOperand > ( temp_reg ) ) ;
move Instr- > addOperand ( std : : make_unique < RegOperand > ( PhysicalReg : : ZERO ) ) ;
newInstrs . push_back ( std : : move ( move Instr) ) ;
}
}
// 对于大的除数或复杂情况,保持原始除法指令不变
}
if ( ! newInstrs . empty ( ) ) {
replacements . push_back ( { i , std : : move ( newInstrs ) } ) ;
size_t start_index = i ;
if ( instructions_to_replace = = 2 ) {
start_index = i - 1 ;
}
replacements . push_back ( { start_index , instructions_to_replace , std : : move ( newInstrs ) } ) ;
}
}
// 批量应用替换(从后往前处理避免索引问题)
for ( auto it = replacements . rbegin ( ) ; it ! = replacements . rend ( ) ; + + it ) {
instrs . erase ( instrs . begin ( ) + it - > index ) ;
instrs . erase ( instrs . begin ( ) + it - > index , instrs . begin ( ) + it - > index + it - > count_to_erase );
instrs . insert ( instrs . begin ( ) + it - > index ,
std : : make_move_iterator ( it - > newInstrs . begin ( ) ) ,
std : : make_move_iterator ( it - > newInstrs . end ( ) ) ) ;