Files
mysysy/src/RISCv64RegAlloc.cpp
2025-07-19 17:50:14 +08:00

322 lines
13 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include "RISCv64RegAlloc.h"
#include "RISCv64ISel.h"
#include <algorithm>
#include <vector>
namespace sysy {
RISCv64RegAlloc::RISCv64RegAlloc(MachineFunction* mfunc) : MFunc(mfunc) {
allocable_int_regs = {
PhysicalReg::T0, PhysicalReg::T1, PhysicalReg::T2, PhysicalReg::T3,
PhysicalReg::T4, PhysicalReg::T5, PhysicalReg::T6,
PhysicalReg::A0, PhysicalReg::A1, PhysicalReg::A2, PhysicalReg::A3,
PhysicalReg::A4, PhysicalReg::A5, PhysicalReg::A6, PhysicalReg::A7,
PhysicalReg::S0, PhysicalReg::S1, PhysicalReg::S2, PhysicalReg::S3,
PhysicalReg::S4, PhysicalReg::S5, PhysicalReg::S6, PhysicalReg::S7,
PhysicalReg::S8, PhysicalReg::S9, PhysicalReg::S10, PhysicalReg::S11,
};
}
void RISCv64RegAlloc::run() {
eliminateFrameIndices();
analyzeLiveness();
buildInterferenceGraph();
colorGraph();
rewriteFunction();
}
void RISCv64RegAlloc::eliminateFrameIndices() {
StackFrameInfo& frame_info = MFunc->getFrameInfo();
int current_offset = 0;
Function* F = MFunc->getFunc();
RISCv64ISel* isel = MFunc->getISel();
for (auto& bb : F->getBasicBlocks()) {
for (auto& inst : bb->getInstructions()) {
if (auto alloca = dynamic_cast<AllocaInst*>(inst.get())) {
int size = 4;
if (!alloca->getDims().empty()) {
int num_elements = 1;
for (const auto& dim_use : alloca->getDims()) {
if (auto const_dim = dynamic_cast<ConstantValue*>(dim_use->getValue())) {
num_elements *= const_dim->getInt();
}
}
size *= num_elements;
}
current_offset += size;
unsigned alloca_vreg = isel->getVReg(alloca);
frame_info.alloca_offsets[alloca_vreg] = -current_offset;
}
}
}
frame_info.locals_size = current_offset;
for (auto& mbb : MFunc->getBlocks()) {
std::vector<std::unique_ptr<MachineInstr>> new_instructions;
for (auto& instr_ptr : mbb->getInstructions()) {
if (instr_ptr->getOpcode() == RVOpcodes::FRAME_LOAD) {
auto& operands = instr_ptr->getOperands();
unsigned dest_vreg = static_cast<RegOperand*>(operands[0].get())->getVRegNum();
unsigned alloca_vreg = static_cast<RegOperand*>(operands[1].get())->getVRegNum();
int offset = frame_info.alloca_offsets.at(alloca_vreg);
auto addr_vreg = isel->getNewVReg();
auto addi = std::make_unique<MachineInstr>(RVOpcodes::ADDI);
addi->addOperand(std::make_unique<RegOperand>(addr_vreg));
addi->addOperand(std::make_unique<RegOperand>(PhysicalReg::S0));
addi->addOperand(std::make_unique<ImmOperand>(offset));
new_instructions.push_back(std::move(addi));
auto lw = std::make_unique<MachineInstr>(RVOpcodes::LW);
lw->addOperand(std::make_unique<RegOperand>(dest_vreg));
lw->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(addr_vreg),
std::make_unique<ImmOperand>(0)));
new_instructions.push_back(std::move(lw));
} else if (instr_ptr->getOpcode() == RVOpcodes::FRAME_STORE) {
auto& operands = instr_ptr->getOperands();
unsigned src_vreg = static_cast<RegOperand*>(operands[0].get())->getVRegNum();
unsigned alloca_vreg = static_cast<RegOperand*>(operands[1].get())->getVRegNum();
int offset = frame_info.alloca_offsets.at(alloca_vreg);
auto addr_vreg = isel->getNewVReg();
auto addi = std::make_unique<MachineInstr>(RVOpcodes::ADDI);
addi->addOperand(std::make_unique<RegOperand>(addr_vreg));
addi->addOperand(std::make_unique<RegOperand>(PhysicalReg::S0));
addi->addOperand(std::make_unique<ImmOperand>(offset));
new_instructions.push_back(std::move(addi));
auto sw = std::make_unique<MachineInstr>(RVOpcodes::SW);
sw->addOperand(std::make_unique<RegOperand>(src_vreg));
sw->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(addr_vreg),
std::make_unique<ImmOperand>(0)));
new_instructions.push_back(std::move(sw));
} else {
new_instructions.push_back(std::move(instr_ptr));
}
}
mbb->getInstructions() = std::move(new_instructions);
}
}
void RISCv64RegAlloc::getInstrUseDef(MachineInstr* instr, LiveSet& use, LiveSet& def) {
bool is_def = true;
auto opcode = instr->getOpcode();
// 预定义def和use规则
if (opcode == RVOpcodes::SW || opcode == RVOpcodes::SD ||
opcode == RVOpcodes::BEQ || opcode == RVOpcodes::BNE ||
opcode == RVOpcodes::BLT || opcode == RVOpcodes::BGE ||
opcode == RVOpcodes::RET || opcode == RVOpcodes::J) {
is_def = false;
}
if (opcode == RVOpcodes::CALL) {
// CALL会杀死所有调用者保存寄存器这是一个简化处理
// 同时也使用了传入a0-a7的参数
}
for (const auto& op : instr->getOperands()) {
if (op->getKind() == MachineOperand::KIND_REG) {
auto reg_op = static_cast<RegOperand*>(op.get());
if (reg_op->isVirtual()) {
if (is_def) {
def.insert(reg_op->getVRegNum());
is_def = false;
} else {
use.insert(reg_op->getVRegNum());
}
}
} else if (op->getKind() == MachineOperand::KIND_MEM) {
auto mem_op = static_cast<MemOperand*>(op.get());
if (mem_op->getBase()->isVirtual()) {
use.insert(mem_op->getBase()->getVRegNum());
}
}
}
}
void RISCv64RegAlloc::analyzeLiveness() {
bool changed = true;
while (changed) {
changed = false;
for (auto it = MFunc->getBlocks().rbegin(); it != MFunc->getBlocks().rend(); ++it) {
auto& mbb = *it;
LiveSet live_out;
for (auto succ : mbb->successors) {
if (!succ->getInstructions().empty()) {
auto first_instr = succ->getInstructions().front().get();
if (live_in_map.count(first_instr)) {
live_out.insert(live_in_map.at(first_instr).begin(), live_in_map.at(first_instr).end());
}
}
}
for (auto instr_it = mbb->getInstructions().rbegin(); instr_it != mbb->getInstructions().rend(); ++instr_it) {
MachineInstr* instr = instr_it->get();
LiveSet old_live_in = live_in_map[instr];
live_out_map[instr] = live_out;
LiveSet use, def;
getInstrUseDef(instr, use, def);
LiveSet live_in = use;
LiveSet diff = live_out;
for (auto vreg : def) {
diff.erase(vreg);
}
live_in.insert(diff.begin(), diff.end());
live_in_map[instr] = live_in;
live_out = live_in;
if (live_in_map[instr] != old_live_in) {
changed = true;
}
}
}
}
}
void RISCv64RegAlloc::buildInterferenceGraph() {
std::set<unsigned> all_vregs;
for (auto& mbb : MFunc->getBlocks()) {
for(auto& instr : mbb->getInstructions()) {
LiveSet use, def;
getInstrUseDef(instr.get(), use, def);
for(auto u : use) all_vregs.insert(u);
for(auto d : def) all_vregs.insert(d);
}
}
for (auto vreg : all_vregs) { interference_graph[vreg] = {}; }
for (auto& mbb : MFunc->getBlocks()) {
for (auto& instr : mbb->getInstructions()) {
LiveSet def, use;
getInstrUseDef(instr.get(), use, def);
const LiveSet& live_out = live_out_map.at(instr.get());
for (unsigned d : def) {
for (unsigned l : live_out) {
if (d != l) {
interference_graph[d].insert(l);
interference_graph[l].insert(d);
}
}
}
}
}
}
void RISCv64RegAlloc::colorGraph() {
std::vector<unsigned> sorted_vregs;
for (auto const& [vreg, neighbors] : interference_graph) {
sorted_vregs.push_back(vreg);
}
std::sort(sorted_vregs.begin(), sorted_vregs.end(), [&](unsigned a, unsigned b) {
return interference_graph[a].size() > interference_graph[b].size();
});
for (unsigned vreg : sorted_vregs) {
std::set<PhysicalReg> used_colors;
for (unsigned neighbor : interference_graph.at(vreg)) {
if (color_map.count(neighbor)) {
used_colors.insert(color_map.at(neighbor));
}
}
bool colored = false;
for (PhysicalReg preg : allocable_int_regs) {
if (used_colors.find(preg) == used_colors.end()) {
color_map[vreg] = preg;
colored = true;
break;
}
}
if (!colored) {
spilled_vregs.insert(vreg);
}
}
}
void RISCv64RegAlloc::rewriteFunction() {
StackFrameInfo& frame_info = MFunc->getFrameInfo();
int current_offset = frame_info.locals_size;
for (unsigned vreg : spilled_vregs) {
current_offset += 4;
frame_info.spill_offsets[vreg] = -current_offset;
}
frame_info.spill_size = current_offset - frame_info.locals_size;
for (auto& mbb : MFunc->getBlocks()) {
std::vector<std::unique_ptr<MachineInstr>> new_instructions;
for (auto& instr_ptr : mbb->getInstructions()) {
LiveSet use, def;
getInstrUseDef(instr_ptr.get(), use, def);
for (unsigned vreg : use) {
if (spilled_vregs.count(vreg)) {
int offset = frame_info.spill_offsets.at(vreg);
auto load = std::make_unique<MachineInstr>(RVOpcodes::LW);
load->addOperand(std::make_unique<RegOperand>(vreg));
load->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(offset)
));
new_instructions.push_back(std::move(load));
}
}
new_instructions.push_back(std::move(instr_ptr));
for (unsigned vreg : def) {
if (spilled_vregs.count(vreg)) {
int offset = frame_info.spill_offsets.at(vreg);
auto store = std::make_unique<MachineInstr>(RVOpcodes::SW);
store->addOperand(std::make_unique<RegOperand>(vreg));
store->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(offset)
));
new_instructions.push_back(std::move(store));
}
}
}
mbb->getInstructions() = std::move(new_instructions);
}
for (auto& mbb : MFunc->getBlocks()) {
for (auto& instr_ptr : mbb->getInstructions()) {
for (auto& op_ptr : instr_ptr->getOperands()) {
if(op_ptr->getKind() == MachineOperand::KIND_REG) {
auto reg_op = static_cast<RegOperand*>(op_ptr.get());
if (reg_op->isVirtual()) {
unsigned vreg = reg_op->getVRegNum();
if (color_map.count(vreg)) {
reg_op->setPReg(color_map.at(vreg));
} else if (spilled_vregs.count(vreg)) {
reg_op->setPReg(PhysicalReg::T6); // 溢出统一用t6
}
}
} else if (op_ptr->getKind() == MachineOperand::KIND_MEM) {
auto mem_op = static_cast<MemOperand*>(op_ptr.get());
auto base_reg_op = mem_op->getBase();
if(base_reg_op->isVirtual()){
unsigned vreg = base_reg_op->getVRegNum();
if(color_map.count(vreg)) {
base_reg_op->setPReg(color_map.at(vreg));
} else if (spilled_vregs.count(vreg)) {
base_reg_op->setPReg(PhysicalReg::T6);
}
}
}
}
}
}
}
} // namespace sysy