[backend] debugging segmentation fault caused by branch instr
This commit is contained in:
@ -3,6 +3,8 @@
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
#include <regex>
|
||||
#include <iomanip>
|
||||
#include <functional>
|
||||
|
||||
namespace sysy {
|
||||
|
||||
@ -104,6 +106,7 @@ std::string RISCv32CodeGen::basicBlock_gen(BasicBlock* bb, const RegAllocResult&
|
||||
std::stringstream ss;
|
||||
ss << bb->getName() << ":\n";
|
||||
auto dag = build_dag(bb);
|
||||
print_dag(dag, bb->getName()); // Print DAG for debugging
|
||||
std::vector<std::string> insts;
|
||||
for (auto& node : dag) {
|
||||
select_instructions(node.get(), alloc);
|
||||
@ -119,47 +122,50 @@ std::string RISCv32CodeGen::basicBlock_gen(BasicBlock* bb, const RegAllocResult&
|
||||
std::vector<std::unique_ptr<RISCv32CodeGen::DAGNode>> RISCv32CodeGen::build_dag(BasicBlock* bb) {
|
||||
std::vector<std::unique_ptr<DAGNode>> nodes;
|
||||
std::map<Value*, DAGNode*> value_to_node;
|
||||
static int vreg_counter = 0; // Counter for unique vreg names
|
||||
static int vreg_counter = 0;
|
||||
|
||||
auto create_node = [&](DAGNode::NodeKind kind, Value* val = nullptr) {
|
||||
if (val && value_to_node.find(val) != value_to_node.end()) {
|
||||
return value_to_node[val];
|
||||
}
|
||||
auto node = std::make_unique<DAGNode>(kind);
|
||||
node->value = val;
|
||||
node->result_reg = val ? "v" + std::to_string(vreg_counter++) : "";
|
||||
if (val) value_to_node[val] = node.get();
|
||||
if (val && kind != DAGNode::STORE && kind != DAGNode::RETURN) {
|
||||
node->result_reg = "v" + std::to_string(vreg_counter++);
|
||||
value_vreg_map[val] = node->result_reg;
|
||||
value_to_node[val] = node.get();
|
||||
}
|
||||
nodes.push_back(std::move(node));
|
||||
return nodes.back().get();
|
||||
};
|
||||
|
||||
for (const auto& inst : bb->getInstructions()) {
|
||||
if (auto alloca = dynamic_cast<AllocaInst*>(inst.get())) {
|
||||
create_node(DAGNode::CONSTANT, alloca); // Allocate stack space
|
||||
create_node(DAGNode::CONSTANT, alloca);
|
||||
} else if (auto store = dynamic_cast<StoreInst*>(inst.get())) {
|
||||
auto store_node = create_node(DAGNode::STORE);
|
||||
auto val_node = value_to_node.find(store->getValue()) != value_to_node.end()
|
||||
? value_to_node[store->getValue()]
|
||||
: create_node(DAGNode::CONSTANT, store->getValue());
|
||||
auto ptr_node = value_to_node.find(store->getPointer()) != value_to_node.end()
|
||||
? value_to_node[store->getPointer()]
|
||||
: create_node(DAGNode::CONSTANT, store->getPointer());
|
||||
Value* val = store->getValue();
|
||||
Value* ptr = store->getPointer();
|
||||
DAGNode* val_node = nullptr;
|
||||
if (auto constant = dynamic_cast<ConstantValue*>(val)) {
|
||||
val_node = create_node(DAGNode::CONSTANT, val);
|
||||
} else {
|
||||
val_node = create_node(DAGNode::LOAD, val);
|
||||
}
|
||||
auto ptr_node = create_node(DAGNode::CONSTANT, ptr);
|
||||
store_node->operands.push_back(val_node);
|
||||
store_node->operands.push_back(ptr_node);
|
||||
val_node->users.push_back(store_node);
|
||||
ptr_node->users.push_back(store_node);
|
||||
} else if (auto load = dynamic_cast<LoadInst*>(inst.get())) {
|
||||
auto load_node = create_node(DAGNode::LOAD, load);
|
||||
auto ptr_node = value_to_node.find(load->getPointer()) != value_to_node.end()
|
||||
? value_to_node[load->getPointer()]
|
||||
: create_node(DAGNode::CONSTANT, load->getPointer());
|
||||
auto ptr_node = create_node(DAGNode::CONSTANT, load->getPointer());
|
||||
load_node->operands.push_back(ptr_node);
|
||||
ptr_node->users.push_back(load_node);
|
||||
} else if (auto bin = dynamic_cast<BinaryInst*>(inst.get())) {
|
||||
auto bin_node = create_node(DAGNode::BINARY, bin);
|
||||
auto lhs_node = value_to_node.find(bin->getLhs()) != value_to_node.end()
|
||||
? value_to_node[bin->getLhs()]
|
||||
: create_node(DAGNode::CONSTANT, bin->getLhs());
|
||||
auto rhs_node = value_to_node.find(bin->getRhs()) != value_to_node.end()
|
||||
? value_to_node[bin->getRhs()]
|
||||
: create_node(DAGNode::CONSTANT, bin->getRhs());
|
||||
auto lhs_node = create_node(DAGNode::LOAD, bin->getLhs());
|
||||
auto rhs_node = create_node(DAGNode::LOAD, bin->getRhs());
|
||||
bin_node->operands.push_back(lhs_node);
|
||||
bin_node->operands.push_back(rhs_node);
|
||||
lhs_node->users.push_back(bin_node);
|
||||
@ -167,27 +173,105 @@ std::vector<std::unique_ptr<RISCv32CodeGen::DAGNode>> RISCv32CodeGen::build_dag(
|
||||
} else if (auto call = dynamic_cast<CallInst*>(inst.get())) {
|
||||
auto call_node = create_node(DAGNode::CALL, call);
|
||||
for (auto arg : call->getArguments()) {
|
||||
auto arg_node = value_to_node.find(arg->getValue()) != value_to_node.end()
|
||||
? value_to_node[arg->getValue()]
|
||||
: create_node(DAGNode::CONSTANT, arg->getValue());
|
||||
auto arg_node = create_node(DAGNode::CONSTANT, arg->getValue());
|
||||
call_node->operands.push_back(arg_node);
|
||||
arg_node->users.push_back(call_node);
|
||||
}
|
||||
} else if (auto ret = dynamic_cast<ReturnInst*>(inst.get())) {
|
||||
auto ret_node = create_node(DAGNode::RETURN);
|
||||
if (ret->hasReturnValue()) {
|
||||
auto val_node = value_to_node.find(ret->getReturnValue()) != value_to_node.end()
|
||||
? value_to_node[ret->getReturnValue()]
|
||||
: create_node(DAGNode::CONSTANT, ret->getReturnValue());
|
||||
auto val_node = create_node(DAGNode::LOAD, ret->getReturnValue());
|
||||
ret_node->operands.push_back(val_node);
|
||||
val_node->users.push_back(ret_node);
|
||||
}
|
||||
} else if (auto cond_br = dynamic_cast<CondBrInst*>(inst.get())) {
|
||||
auto br_node = create_node(DAGNode::BRANCH);
|
||||
auto cond = cond_br->getCondition();
|
||||
auto cond_node = value_to_node.count(cond) ? value_to_node[cond] : create_node(DAGNode::LOAD, cond);
|
||||
br_node->operands.push_back(cond_node);
|
||||
br_node->value = cond_br; // 存储 CondBrInst 以获取 then/else 块
|
||||
cond_node->users.push_back(br_node);
|
||||
}
|
||||
}
|
||||
|
||||
return nodes;
|
||||
}
|
||||
|
||||
// 打印 DAG
|
||||
void RISCv32CodeGen::print_dag(const std::vector<std::unique_ptr<DAGNode>>& dag, const std::string& bb_name) {
|
||||
std::cerr << "=== DAG for Basic Block: " << bb_name << " ===\n";
|
||||
std::set<DAGNode*> visited;
|
||||
|
||||
// 显式声明 print_node 的类型为 std::function
|
||||
std::function<void(DAGNode*, int, int)> print_node = [&](DAGNode* node, int indent, int node_index) {
|
||||
if (!node || visited.find(node) != visited.end()) {
|
||||
if (node) {
|
||||
std::cerr << std::string(indent, ' ') << "Node@" << node_index << ": (already printed)\n";
|
||||
}
|
||||
return;
|
||||
}
|
||||
visited.insert(node);
|
||||
|
||||
std::string kind_str;
|
||||
switch (node->kind) {
|
||||
case DAGNode::CONSTANT: kind_str = "CONSTANT"; break;
|
||||
case DAGNode::LOAD: kind_str = "LOAD"; break;
|
||||
case DAGNode::STORE: kind_str = "STORE"; break;
|
||||
case DAGNode::BINARY: kind_str = "BINARY"; break;
|
||||
case DAGNode::CALL: kind_str = "CALL"; break;
|
||||
case DAGNode::RETURN: kind_str = "RETURN"; break;
|
||||
}
|
||||
|
||||
std::cerr << std::string(indent, ' ') << "Node@" << node_index << ": " << kind_str;
|
||||
if (!node->result_reg.empty()) {
|
||||
std::cerr << " (vreg: " << node->result_reg << ")";
|
||||
}
|
||||
|
||||
if (node->value) {
|
||||
std::cerr << " [";
|
||||
if (auto inst = dynamic_cast<Instruction*>(node->value)) {
|
||||
std::cerr << inst->getKindString(); // 修复:getKindStr -> getKindString
|
||||
} else if (auto constant = dynamic_cast<ConstantValue*>(node->value)) {
|
||||
if (constant->isInt()) {
|
||||
std::cerr << "ConstInt(" << constant->getInt() << ")";
|
||||
} else {
|
||||
std::cerr << "ConstFloat(" << constant->getFloat() << ")";
|
||||
}
|
||||
} else if (auto global = dynamic_cast<GlobalValue*>(node->value)) {
|
||||
std::cerr << "Global(" << global->getName() << ")";
|
||||
} else {
|
||||
std::cerr << "Value";
|
||||
}
|
||||
std::cerr << "]";
|
||||
}
|
||||
std::cerr << "\n";
|
||||
|
||||
if (!node->operands.empty()) {
|
||||
std::cerr << std::string(indent, ' ') << " Operands:\n";
|
||||
int op_index = 0;
|
||||
for (auto operand : node->operands) {
|
||||
std::cerr << std::string(indent + 2, ' ') << "Op" << op_index++ << ": ";
|
||||
print_node(operand, indent + 4, reinterpret_cast<uintptr_t>(operand) % 1000);
|
||||
}
|
||||
}
|
||||
|
||||
if (!node->users.empty()) {
|
||||
std::cerr << std::string(indent, ' ') << " Users:\n";
|
||||
int user_index = 0;
|
||||
for (auto user : node->users) {
|
||||
std::cerr << std::string(indent + 2, ' ') << "User" << user_index++ << ": ";
|
||||
print_node(user, indent + 4, reinterpret_cast<uintptr_t>(user) % 1000);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
int node_index = 0;
|
||||
for (const auto& node : dag) {
|
||||
print_node(node.get(), 0, node_index++);
|
||||
}
|
||||
std::cerr << "=== End DAG ===\n\n";
|
||||
}
|
||||
|
||||
// 指令选择
|
||||
void RISCv32CodeGen::select_instructions(DAGNode* node, const RegAllocResult& alloc) {
|
||||
if (!node->inst.empty()) return;
|
||||
@ -209,29 +293,30 @@ void RISCv32CodeGen::select_instructions(DAGNode* node, const RegAllocResult& al
|
||||
} else if (auto global = dynamic_cast<GlobalValue*>(node->value)) {
|
||||
node->inst = "la " + node->result_reg + ", " + global->getName();
|
||||
} else if (auto alloca = dynamic_cast<AllocaInst*>(node->value)) {
|
||||
if (alloc.stack_map.find(alloca) != alloc.stack_map.end()) {
|
||||
node->inst = ""; // Stack address handled in LOAD/STORE
|
||||
}
|
||||
node->inst = "";
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DAGNode::LOAD: {
|
||||
auto ptr_reg = node->operands[0]->result_reg;
|
||||
if (alloc.stack_map.find(node->operands[0]->value) != alloc.stack_map.end()) {
|
||||
int offset = alloc.stack_map.at(node->operands[0]->value);
|
||||
auto ptr = node->operands[0]->value;
|
||||
if (alloc.stack_map.count(ptr)) {
|
||||
int offset = alloc.stack_map.at(ptr);
|
||||
node->inst = "lw " + node->result_reg + ", " + std::to_string(offset) + "(s0)";
|
||||
} else {
|
||||
auto ptr_reg = node->operands[0]->result_reg;
|
||||
node->inst = "lw " + node->result_reg + ", 0(" + ptr_reg + ")";
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DAGNode::STORE: {
|
||||
auto val = node->operands[0]->value;
|
||||
auto ptr = node->operands[1]->value;
|
||||
auto val_reg = node->operands[0]->result_reg;
|
||||
auto ptr_reg = node->operands[1]->result_reg;
|
||||
if (alloc.stack_map.find(node->operands[1]->value) != alloc.stack_map.end()) {
|
||||
int offset = alloc.stack_map.at(node->operands[1]->value);
|
||||
if (alloc.stack_map.count(ptr)) {
|
||||
int offset = alloc.stack_map.at(ptr);
|
||||
node->inst = "sw " + val_reg + ", " + std::to_string(offset) + "(s0)";
|
||||
} else {
|
||||
auto ptr_reg = node->operands[1]->result_reg;
|
||||
node->inst = "sw " + val_reg + ", 0(" + ptr_reg + ")";
|
||||
}
|
||||
break;
|
||||
@ -243,6 +328,7 @@ void RISCv32CodeGen::select_instructions(DAGNode* node, const RegAllocResult& al
|
||||
std::string opcode;
|
||||
switch (bin->getKind()) {
|
||||
case BinaryInst::kAdd: opcode = "add"; break;
|
||||
case BinaryInst::kSub: opcode = "sub"; break;
|
||||
case BinaryInst::kMul: opcode = "mul"; break;
|
||||
default: break;
|
||||
}
|
||||
@ -270,33 +356,76 @@ void RISCv32CodeGen::select_instructions(DAGNode* node, const RegAllocResult& al
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DAGNode::BRANCH: {
|
||||
auto br = dynamic_cast<CondBrInst*>(node->value);
|
||||
auto cond_reg = node->operands[0]->result_reg;
|
||||
auto then_block = br->getThenBlock()->getName();
|
||||
auto else_block = br->getElseBlock()->getName();
|
||||
// 假设条件为 true 时跳转到 then,否则到 else
|
||||
node->inst = "bnez " + cond_reg + ", " + then_block + "\nj " + else_block;
|
||||
break;
|
||||
}
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
|
||||
// 指令发射
|
||||
void RISCv32CodeGen::emit_instructions(DAGNode* node, std::vector<std::string>& insts, const RegAllocResult& alloc) {
|
||||
for (auto operand : node->operands) {
|
||||
emit_instructions(operand, insts, alloc);
|
||||
std::set<DAGNode*> emitted; // 局部变量,针对每个基本块独立
|
||||
std::set<std::string> seen_insts; // 跟踪已发射的指令以去重
|
||||
|
||||
std::function<void(DAGNode*)> emit = [&](DAGNode* n) {
|
||||
if (!n || emitted.count(n)) return;
|
||||
emitted.insert(n);
|
||||
|
||||
for (auto operand : n->operands) {
|
||||
emit(operand);
|
||||
}
|
||||
if (!node->inst.empty()) {
|
||||
std::stringstream ss(node->inst);
|
||||
|
||||
if (!n->inst.empty()) {
|
||||
std::stringstream ss(n->inst);
|
||||
std::string line;
|
||||
while (std::getline(ss, line, '\n')) {
|
||||
if (!line.empty()) {
|
||||
// Replace virtual registers with physical registers
|
||||
if (!node->result_reg.empty() && alloc.vreg_to_preg.find(node->result_reg) != alloc.vreg_to_preg.end()) {
|
||||
line = std::regex_replace(line, std::regex("\\b" + node->result_reg + "\\b"), reg_to_string(alloc.vreg_to_preg.at(node->result_reg)));
|
||||
}
|
||||
for (auto operand : node->operands) {
|
||||
if (!operand->result_reg.empty() && alloc.vreg_to_preg.find(operand->result_reg) != alloc.vreg_to_preg.end()) {
|
||||
line = std::regex_replace(line, std::regex("\\b" + operand->result_reg + "\\b"), reg_to_string(alloc.vreg_to_preg.at(operand->result_reg)));
|
||||
}
|
||||
}
|
||||
insts.push_back(line);
|
||||
std::string new_line = line;
|
||||
if (!n->result_reg.empty()) {
|
||||
if (alloc.vreg_to_preg.count(n->result_reg)) {
|
||||
new_line = std::regex_replace(new_line, std::regex("\\b" + n->result_reg + "\\b"), reg_to_string(alloc.vreg_to_preg.at(n->result_reg)));
|
||||
} else if (alloc.stack_map.count(n->value)) {
|
||||
int offset = alloc.stack_map.at(n->value);
|
||||
new_line = std::regex_replace(new_line, std::regex("\\b" + n->result_reg + "\\b"), "t0");
|
||||
std::string store = "sw t0, " + std::to_string(offset) + "(s0)";
|
||||
if (!seen_insts.count(store)) {
|
||||
insts.push_back(store);
|
||||
seen_insts.insert(store);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto operand : n->operands) {
|
||||
if (!operand->result_reg.empty()) {
|
||||
if (alloc.vreg_to_preg.count(operand->result_reg)) {
|
||||
new_line = std::regex_replace(new_line, std::regex("\\b" + operand->result_reg + "\\b"), reg_to_string(alloc.vreg_to_preg.at(operand->result_reg)));
|
||||
} else if (alloc.stack_map.count(operand->value)) {
|
||||
int offset = alloc.stack_map.at(operand->value);
|
||||
std::string load = "lw t0, " + std::to_string(offset) + "(s0)";
|
||||
if (!seen_insts.count(load)) {
|
||||
insts.push_back(load);
|
||||
seen_insts.insert(load);
|
||||
}
|
||||
new_line = std::regex_replace(new_line, std::regex("\\b" + operand->result_reg + "\\b"), "t0");
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!seen_insts.count(new_line)) {
|
||||
insts.push_back(new_line);
|
||||
seen_insts.insert(new_line);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
emit(node);
|
||||
}
|
||||
|
||||
// 活跃性分析
|
||||
@ -312,7 +441,7 @@ std::map<Instruction*, std::set<std::string>> RISCv32CodeGen::liveness_analysis(
|
||||
auto inst = inst_it->get();
|
||||
std::set<std::string> new_in, new_out;
|
||||
|
||||
// Calculate live_out
|
||||
// 计算 live_out
|
||||
if (auto br = dynamic_cast<CondBrInst*>(inst)) {
|
||||
new_out.insert(live_in[br->getThenBlock()->getInstructions().front().get()].begin(),
|
||||
live_in[br->getThenBlock()->getInstructions().front().get()].end());
|
||||
@ -328,7 +457,7 @@ std::map<Instruction*, std::set<std::string>> RISCv32CodeGen::liveness_analysis(
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate live_in = use ∪ (live_out - def)
|
||||
// 计算 use 和 def
|
||||
std::set<std::string> use, def;
|
||||
if (auto bin = dynamic_cast<BinaryInst*>(inst)) {
|
||||
if (value_vreg_map.find(bin->getLhs()) != value_vreg_map.end())
|
||||
@ -342,9 +471,8 @@ std::map<Instruction*, std::set<std::string>> RISCv32CodeGen::liveness_analysis(
|
||||
if (value_vreg_map.find(arg->getValue()) != value_vreg_map.end())
|
||||
use.insert(value_vreg_map[arg->getValue()]);
|
||||
}
|
||||
if ((call->getType()->isInt() || call->getType()->isFloat()) && value_vreg_map.find(call) != value_vreg_map.end()) {
|
||||
if ((call->getType()->isInt() || call->getType()->isFloat()) && value_vreg_map.find(call) != value_vreg_map.end())
|
||||
def.insert(value_vreg_map[call]);
|
||||
}
|
||||
} else if (auto load = dynamic_cast<LoadInst*>(inst)) {
|
||||
if (value_vreg_map.find(load->getPointer()) != value_vreg_map.end())
|
||||
use.insert(value_vreg_map[load->getPointer()]);
|
||||
@ -356,11 +484,11 @@ std::map<Instruction*, std::set<std::string>> RISCv32CodeGen::liveness_analysis(
|
||||
if (value_vreg_map.find(store->getPointer()) != value_vreg_map.end())
|
||||
use.insert(value_vreg_map[store->getPointer()]);
|
||||
} else if (auto ret = dynamic_cast<ReturnInst*>(inst)) {
|
||||
if (ret->hasReturnValue() && value_vreg_map.find(ret->getReturnValue()) != value_vreg_map.end()) {
|
||||
if (ret->hasReturnValue() && value_vreg_map.find(ret->getReturnValue()) != value_vreg_map.end())
|
||||
use.insert(value_vreg_map[ret->getReturnValue()]);
|
||||
}
|
||||
}
|
||||
|
||||
// 计算 live_in = use ∪ (live_out - def)
|
||||
new_in = use;
|
||||
for (const auto& vreg : new_out) {
|
||||
if (def.find(vreg) == def.end()) {
|
||||
@ -389,16 +517,27 @@ std::map<std::string, std::set<std::string>> RISCv32CodeGen::build_interference_
|
||||
auto inst = pair.first;
|
||||
const auto& live = pair.second;
|
||||
std::string def;
|
||||
std::set<std::string> uses;
|
||||
|
||||
if (auto bin = dynamic_cast<BinaryInst*>(inst)) {
|
||||
if (value_vreg_map.find(bin) != value_vreg_map.end())
|
||||
def = value_vreg_map[bin];
|
||||
if (value_vreg_map.find(bin->getLhs()) != value_vreg_map.end())
|
||||
uses.insert(value_vreg_map[bin->getLhs()]);
|
||||
if (value_vreg_map.find(bin->getRhs()) != value_vreg_map.end())
|
||||
uses.insert(value_vreg_map[bin->getRhs()]);
|
||||
} else if (auto call = dynamic_cast<CallInst*>(inst)) {
|
||||
if ((call->getType()->isInt() || call->getType()->isFloat()) && value_vreg_map.find(call) != value_vreg_map.end()) {
|
||||
if ((call->getType()->isInt() || call->getType()->isFloat()) && value_vreg_map.find(call) != value_vreg_map.end())
|
||||
def = value_vreg_map[call];
|
||||
for (auto arg : call->getArguments()) {
|
||||
if (value_vreg_map.find(arg->getValue()) != value_vreg_map.end())
|
||||
uses.insert(value_vreg_map[arg->getValue()]);
|
||||
}
|
||||
} else if (auto load = dynamic_cast<LoadInst*>(inst)) {
|
||||
if (value_vreg_map.find(load) != value_vreg_map.end())
|
||||
def = value_vreg_map[load];
|
||||
if (value_vreg_map.find(load->getPointer()) != value_vreg_map.end())
|
||||
uses.insert(value_vreg_map[load->getPointer()]);
|
||||
}
|
||||
|
||||
if (!def.empty()) {
|
||||
@ -409,6 +548,16 @@ std::map<std::string, std::set<std::string>> RISCv32CodeGen::build_interference_
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 为 BinaryInst 的两个源操作数添加干扰边
|
||||
if (!uses.empty()) {
|
||||
for (auto it1 = uses.begin(); it1 != uses.end(); ++it1) {
|
||||
for (auto it2 = std::next(it1); it2 != uses.end(); ++it2) {
|
||||
graph[*it1].insert(*it2);
|
||||
graph[*it2].insert(*it1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return graph;
|
||||
@ -419,25 +568,37 @@ void RISCv32CodeGen::color_graph(std::map<std::string, PhysicalReg>& vreg_to_pre
|
||||
const std::map<std::string, std::set<std::string>>& interference_graph) {
|
||||
std::vector<std::string> stack;
|
||||
std::map<std::string, std::set<std::string>> temp_graph = interference_graph;
|
||||
std::map<std::string, int> degree;
|
||||
|
||||
// 计算每个节点的度
|
||||
for (const auto& pair : temp_graph) {
|
||||
degree[pair.first] = pair.second.size();
|
||||
}
|
||||
|
||||
while (!temp_graph.empty()) {
|
||||
std::string node_to_remove;
|
||||
for (const auto& pair : temp_graph) {
|
||||
if (pair.second.size() < allocable_regs.size()) {
|
||||
if (degree[pair.first] < static_cast<int>(allocable_regs.size())) {
|
||||
node_to_remove = pair.first;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (node_to_remove.empty()) {
|
||||
node_to_remove = temp_graph.begin()->first; // Spill if necessary
|
||||
// 选择度最小的节点
|
||||
node_to_remove = std::min_element(degree.begin(), degree.end(),
|
||||
[](const auto& a, const auto& b) { return a.second < b.second; })->first;
|
||||
}
|
||||
|
||||
stack.push_back(node_to_remove);
|
||||
for (auto& pair : temp_graph) {
|
||||
if (pair.second.find(node_to_remove) != pair.second.end()) {
|
||||
degree[pair.first]--;
|
||||
}
|
||||
pair.second.erase(node_to_remove);
|
||||
}
|
||||
temp_graph.erase(node_to_remove);
|
||||
degree.erase(node_to_remove);
|
||||
}
|
||||
|
||||
while (!stack.empty()) {
|
||||
@ -450,23 +611,21 @@ void RISCv32CodeGen::color_graph(std::map<std::string, PhysicalReg>& vreg_to_pre
|
||||
}
|
||||
}
|
||||
|
||||
bool assigned = false;
|
||||
for (auto preg : allocable_regs) {
|
||||
if (used_colors.find(reg_to_string(preg)) == used_colors.end()) {
|
||||
vreg_to_preg[vreg] = preg;
|
||||
assigned = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// If no register is available, spill to stack (handled in register_allocation)
|
||||
}
|
||||
}
|
||||
|
||||
// 寄存器分配
|
||||
RISCv32CodeGen::RegAllocResult RISCv32CodeGen::register_allocation(Function* func) {
|
||||
RegAllocResult result;
|
||||
int stack_offset = 0;
|
||||
value_vreg_map.clear(); // Clear vreg map for new function
|
||||
static int vreg_counter = 0; // Counter for unique vreg names
|
||||
value_vreg_map.clear();
|
||||
static int vreg_counter = 0;
|
||||
|
||||
// 分配局部变量栈空间
|
||||
for (const auto& bb : func->getBasicBlocks()) {
|
||||
@ -485,10 +644,23 @@ RISCv32CodeGen::RegAllocResult RISCv32CodeGen::register_allocation(Function* fun
|
||||
if (value_vreg_map.find(bin) == value_vreg_map.end()) {
|
||||
value_vreg_map[bin] = "v" + std::to_string(vreg_counter++);
|
||||
}
|
||||
if (value_vreg_map.find(bin->getLhs()) == value_vreg_map.end()) {
|
||||
value_vreg_map[bin->getLhs()] = "v" + std::to_string(vreg_counter++);
|
||||
}
|
||||
if (value_vreg_map.find(bin->getRhs()) == value_vreg_map.end()) {
|
||||
value_vreg_map[bin->getRhs()] = "v" + std::to_string(vreg_counter++);
|
||||
}
|
||||
} else if (auto call = dynamic_cast<CallInst*>(inst.get())) {
|
||||
if ((call->getType()->isInt() || call->getType()->isFloat()) && value_vreg_map.find(call) == value_vreg_map.end()) {
|
||||
value_vreg_map[call] = "v" + std::to_string(vreg_counter++);
|
||||
}
|
||||
} else if (auto store = dynamic_cast<StoreInst*>(inst.get())) {
|
||||
if (value_vreg_map.find(store->getValue()) == value_vreg_map.end()) {
|
||||
value_vreg_map[store->getValue()] = "v" + std::to_string(vreg_counter++);
|
||||
}
|
||||
if (value_vreg_map.find(store->getPointer()) == value_vreg_map.end()) {
|
||||
value_vreg_map[store->getPointer()] = "v" + std::to_string(vreg_counter++);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -514,29 +686,13 @@ RISCv32CodeGen::RegAllocResult RISCv32CodeGen::register_allocation(Function* fun
|
||||
color_graph(result.vreg_to_preg, interference_graph);
|
||||
|
||||
// 分配溢出栈空间
|
||||
for (const auto& bb : func->getBasicBlocks()) {
|
||||
for (const auto& inst : bb->getInstructions()) {
|
||||
if (auto bin = dynamic_cast<BinaryInst*>(inst.get())) {
|
||||
std::string vreg = value_vreg_map[bin];
|
||||
for (const auto& pair : value_vreg_map) {
|
||||
auto vreg = pair.second;
|
||||
if (result.vreg_to_preg.find(vreg) == result.vreg_to_preg.end()) {
|
||||
result.stack_map[bin] = stack_offset;
|
||||
if (result.stack_map.find(pair.first) == result.stack_map.end()) {
|
||||
result.stack_map[pair.first] = stack_offset;
|
||||
stack_offset += 4;
|
||||
}
|
||||
} else if (auto call = dynamic_cast<CallInst*>(inst.get())) {
|
||||
if (call->getType()->isInt() || call->getType()->isFloat()) {
|
||||
std::string vreg = value_vreg_map[call];
|
||||
if (result.vreg_to_preg.find(vreg) == result.vreg_to_preg.end()) {
|
||||
result.stack_map[call] = stack_offset;
|
||||
stack_offset += 4;
|
||||
}
|
||||
}
|
||||
} else if (auto load = dynamic_cast<LoadInst*>(inst.get())) {
|
||||
std::string vreg = value_vreg_map[load];
|
||||
if (result.vreg_to_preg.find(vreg) == result.vreg_to_preg.end()) {
|
||||
result.stack_map[load] = stack_offset;
|
||||
stack_offset += 4;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -553,7 +709,7 @@ RISCv32CodeGen::RegAllocResult RISCv32CodeGen::register_allocation(Function* fun
|
||||
}
|
||||
|
||||
if (needs_caller_saved || stack_offset > 0) {
|
||||
stack_offset += 8; // 保存 ra 和 s0
|
||||
stack_offset += 8;
|
||||
}
|
||||
|
||||
result.stack_size = stack_offset;
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <memory>
|
||||
#include <iostream>
|
||||
|
||||
namespace sysy {
|
||||
|
||||
@ -19,7 +20,7 @@ public:
|
||||
|
||||
// Move DAGNode and RegAllocResult to public section
|
||||
struct DAGNode {
|
||||
enum NodeKind { CONSTANT, LOAD, STORE, BINARY, CALL, RETURN };
|
||||
enum NodeKind { CONSTANT, LOAD, STORE, BINARY, CALL, RETURN, BRANCH };
|
||||
NodeKind kind;
|
||||
Value* value = nullptr;
|
||||
std::string inst;
|
||||
@ -52,6 +53,7 @@ public:
|
||||
RegAllocResult register_allocation(Function* func);
|
||||
void eliminate_phi(Function* func);
|
||||
std::string reg_to_string(PhysicalReg reg);
|
||||
void print_dag(const std::vector<std::unique_ptr<DAGNode>>& dag, const std::string& bb_name);
|
||||
|
||||
private:
|
||||
static const std::vector<PhysicalReg> allocable_regs;
|
||||
|
||||
Reference in New Issue
Block a user