[backend] fixed 1 segmentation fault

This commit is contained in:
Lixuanwang
2025-06-23 22:38:29 +08:00
parent ab3eb253f9
commit 3c3f48ee87

View File

@ -124,13 +124,13 @@ std::vector<std::unique_ptr<RISCv32CodeGen::DAGNode>> RISCv32CodeGen::build_dag(
std::map<Value*, DAGNode*> value_to_node; std::map<Value*, DAGNode*> value_to_node;
static int vreg_counter = 0; static int vreg_counter = 0;
auto create_node = [&](DAGNode::NodeKind kind, Value* val = nullptr) { auto create_node = [&](DAGNode::NodeKind kind, Value* val = nullptr) -> DAGNode* {
if (val && value_to_node.find(val) != value_to_node.end()) { if (val && value_to_node.count(val) && kind != DAGNode::STORE && kind != DAGNode::RETURN && kind != DAGNode::BRANCH) {
return value_to_node[val]; return value_to_node[val];
} }
auto node = std::make_unique<DAGNode>(kind); auto node = std::make_unique<DAGNode>(kind);
node->value = val; node->value = val;
if (val && kind != DAGNode::STORE && kind != DAGNode::RETURN) { if (val && kind != DAGNode::STORE && kind != DAGNode::RETURN && kind != DAGNode::BRANCH) {
node->result_reg = "v" + std::to_string(vreg_counter++); node->result_reg = "v" + std::to_string(vreg_counter++);
value_vreg_map[val] = node->result_reg; value_vreg_map[val] = node->result_reg;
value_to_node[val] = node.get(); value_to_node[val] = node.get();
@ -146,53 +146,105 @@ std::vector<std::unique_ptr<RISCv32CodeGen::DAGNode>> RISCv32CodeGen::build_dag(
auto store_node = create_node(DAGNode::STORE); auto store_node = create_node(DAGNode::STORE);
Value* val = store->getValue(); Value* val = store->getValue();
Value* ptr = store->getPointer(); Value* ptr = store->getPointer();
DAGNode* val_node = nullptr; DAGNode* val_node = value_to_node.count(val) ? value_to_node[val] : nullptr;
if (!val_node) {
if (auto constant = dynamic_cast<ConstantValue*>(val)) { if (auto constant = dynamic_cast<ConstantValue*>(val)) {
val_node = create_node(DAGNode::CONSTANT, val); val_node = create_node(DAGNode::CONSTANT, val);
} else { } else {
val_node = create_node(DAGNode::LOAD, val); val_node = create_node(DAGNode::LOAD, val);
} }
auto ptr_node = create_node(DAGNode::CONSTANT, ptr); }
auto ptr_node = value_to_node.count(ptr) ? value_to_node[ptr] : create_node(DAGNode::CONSTANT, ptr);
store_node->operands.push_back(val_node); store_node->operands.push_back(val_node);
store_node->operands.push_back(ptr_node); store_node->operands.push_back(ptr_node);
val_node->users.push_back(store_node); val_node->users.push_back(store_node);
ptr_node->users.push_back(store_node); ptr_node->users.push_back(store_node);
} else if (auto load = dynamic_cast<LoadInst*>(inst.get())) { } else if (auto load = dynamic_cast<LoadInst*>(inst.get())) {
if (value_to_node.count(load)) continue; // CSE
auto load_node = create_node(DAGNode::LOAD, load); auto load_node = create_node(DAGNode::LOAD, load);
auto ptr_node = create_node(DAGNode::CONSTANT, load->getPointer()); auto ptr = load->getPointer();
auto ptr_node = value_to_node.count(ptr) ? value_to_node[ptr] : create_node(DAGNode::CONSTANT, ptr);
load_node->operands.push_back(ptr_node); load_node->operands.push_back(ptr_node);
ptr_node->users.push_back(load_node); ptr_node->users.push_back(load_node);
} else if (auto bin = dynamic_cast<BinaryInst*>(inst.get())) { } else if (auto bin = dynamic_cast<BinaryInst*>(inst.get())) {
if (value_to_node.count(bin)) continue; // CSE
auto bin_node = create_node(DAGNode::BINARY, bin); auto bin_node = create_node(DAGNode::BINARY, bin);
auto lhs_node = create_node(DAGNode::LOAD, bin->getLhs()); auto lhs = bin->getLhs();
auto rhs_node = create_node(DAGNode::LOAD, bin->getRhs()); auto rhs = bin->getRhs();
auto lhs_node = value_to_node.count(lhs) ? value_to_node[lhs] : nullptr;
if (!lhs_node) {
if (auto constant = dynamic_cast<ConstantValue*>(lhs)) {
lhs_node = create_node(DAGNode::CONSTANT, lhs);
} else {
lhs_node = create_node(DAGNode::LOAD, lhs);
}
}
auto rhs_node = value_to_node.count(rhs) ? value_to_node[rhs] : nullptr;
if (!rhs_node) {
if (auto constant = dynamic_cast<ConstantValue*>(rhs)) {
rhs_node = create_node(DAGNode::CONSTANT, rhs);
} else {
rhs_node = create_node(DAGNode::LOAD, rhs);
}
}
bin_node->operands.push_back(lhs_node); bin_node->operands.push_back(lhs_node);
bin_node->operands.push_back(rhs_node); bin_node->operands.push_back(rhs_node);
lhs_node->users.push_back(bin_node); lhs_node->users.push_back(bin_node);
rhs_node->users.push_back(bin_node); rhs_node->users.push_back(bin_node);
} else if (auto call = dynamic_cast<CallInst*>(inst.get())) { } else if (auto call = dynamic_cast<CallInst*>(inst.get())) {
if (value_to_node.count(call)) continue; // CSE
auto call_node = create_node(DAGNode::CALL, call); auto call_node = create_node(DAGNode::CALL, call);
for (auto arg : call->getArguments()) { for (auto arg : call->getArguments()) {
auto arg_node = create_node(DAGNode::CONSTANT, arg->getValue()); auto arg_val = arg->getValue();
auto arg_node = value_to_node.count(arg_val) ? value_to_node[arg_val] : nullptr;
if (!arg_node) {
if (auto constant = dynamic_cast<ConstantValue*>(arg_val)) {
arg_node = create_node(DAGNode::CONSTANT, arg_val);
} else {
arg_node = create_node(DAGNode::LOAD, arg_val);
}
}
call_node->operands.push_back(arg_node); call_node->operands.push_back(arg_node);
arg_node->users.push_back(call_node); arg_node->users.push_back(call_node);
} }
} else if (auto ret = dynamic_cast<ReturnInst*>(inst.get())) { } else if (auto ret = dynamic_cast<ReturnInst*>(inst.get())) {
auto ret_node = create_node(DAGNode::RETURN); auto ret_node = create_node(DAGNode::RETURN);
if (ret->hasReturnValue()) { if (ret->hasReturnValue()) {
auto val_node = create_node(DAGNode::LOAD, ret->getReturnValue()); auto val = ret->getReturnValue();
auto val_node = value_to_node.count(val) ? value_to_node[val] : nullptr;
if (!val_node) {
if (auto constant = dynamic_cast<ConstantValue*>(val)) {
val_node = create_node(DAGNode::CONSTANT, val);
} else {
val_node = create_node(DAGNode::LOAD, val);
}
}
ret_node->operands.push_back(val_node); ret_node->operands.push_back(val_node);
val_node->users.push_back(ret_node); val_node->users.push_back(ret_node);
} }
} else if (auto cond_br = dynamic_cast<CondBrInst*>(inst.get())) { } else if (auto cond_br = dynamic_cast<CondBrInst*>(inst.get())) {
auto br_node = create_node(DAGNode::BRANCH); auto br_node = create_node(DAGNode::BRANCH, cond_br);
auto cond = cond_br->getCondition(); auto cond = cond_br->getCondition();
auto cond_node = value_to_node.count(cond) ? value_to_node[cond] : create_node(DAGNode::LOAD, cond); DAGNode* cond_node = nullptr;
if (auto constant = dynamic_cast<ConstantValue*>(cond)) {
// 常量条件,直接记录分支目标
br_node->operands.push_back(nullptr); // 占位符,表示无条件操作数
br_node->inst = constant->getInt() ? "j " + cond_br->getThenBlock()->getName()
: "j " + cond_br->getElseBlock()->getName();
} else {
cond_node = value_to_node.count(cond) ? value_to_node[cond] : nullptr;
if (!cond_node) {
if (auto bin = dynamic_cast<BinaryInst*>(cond)) {
cond_node = create_node(DAGNode::BINARY, cond);
} else {
cond_node = create_node(DAGNode::LOAD, cond);
}
}
br_node->operands.push_back(cond_node); br_node->operands.push_back(cond_node);
br_node->value = cond_br; // 存储 CondBrInst 以获取 then/else 块
cond_node->users.push_back(br_node); cond_node->users.push_back(br_node);
} }
} }
}
return nodes; return nodes;
} }
@ -277,8 +329,10 @@ void RISCv32CodeGen::select_instructions(DAGNode* node, const RegAllocResult& al
if (!node->inst.empty()) return; if (!node->inst.empty()) return;
for (auto operand : node->operands) { for (auto operand : node->operands) {
if (operand) {
select_instructions(operand, alloc); select_instructions(operand, alloc);
} }
}
switch (node->kind) { switch (node->kind) {
case DAGNode::CONSTANT: { case DAGNode::CONSTANT: {
@ -298,38 +352,65 @@ void RISCv32CodeGen::select_instructions(DAGNode* node, const RegAllocResult& al
break; break;
} }
case DAGNode::LOAD: { case DAGNode::LOAD: {
if (node->operands.empty() || !node->operands[0]) {
if (auto constant = dynamic_cast<ConstantValue*>(node->value)) {
node->inst = "li " + node->result_reg + ", " + std::to_string(constant->getInt());
}
break;
}
auto ptr = node->operands[0]->value; auto ptr = node->operands[0]->value;
if (!ptr) break;
if (alloc.stack_map.count(ptr)) { if (alloc.stack_map.count(ptr)) {
int offset = alloc.stack_map.at(ptr); int offset = alloc.stack_map.at(ptr);
node->inst = "lw " + node->result_reg + ", " + std::to_string(offset) + "(s0)"; node->inst = "lw " + node->result_reg + ", " + std::to_string(offset) + "(s0)";
} else { } else {
auto ptr_reg = node->operands[0]->result_reg; auto ptr_reg = node->operands[0]->result_reg;
if (!ptr_reg.empty()) {
node->inst = "lw " + node->result_reg + ", 0(" + ptr_reg + ")"; node->inst = "lw " + node->result_reg + ", 0(" + ptr_reg + ")";
} }
}
break; break;
} }
case DAGNode::STORE: { case DAGNode::STORE: {
if (node->operands.size() < 2 || !node->operands[0] || !node->operands[1]) break;
auto val = node->operands[0]->value; auto val = node->operands[0]->value;
auto ptr = node->operands[1]->value; auto ptr = node->operands[1]->value;
auto val_reg = node->operands[0]->result_reg; if (!val || !ptr) break;
if (alloc.stack_map.count(ptr)) { if (alloc.stack_map.count(ptr)) {
int offset = alloc.stack_map.at(ptr); int offset = alloc.stack_map.at(ptr);
if (auto constant = dynamic_cast<ConstantValue*>(val)) {
// 直接存储常量
node->inst = "li t0, " + std::to_string(constant->getInt()) + "\nsw t0, " + std::to_string(offset) + "(s0)";
} else {
auto val_reg = node->operands[0]->result_reg;
if (!val_reg.empty()) {
node->inst = "sw " + val_reg + ", " + std::to_string(offset) + "(s0)"; node->inst = "sw " + val_reg + ", " + std::to_string(offset) + "(s0)";
}
}
} else { } else {
auto ptr_reg = node->operands[1]->result_reg; auto ptr_reg = node->operands[1]->result_reg;
auto val_reg = node->operands[0]->result_reg;
if (!ptr_reg.empty() && !val_reg.empty()) {
node->inst = "sw " + val_reg + ", 0(" + ptr_reg + ")"; node->inst = "sw " + val_reg + ", 0(" + ptr_reg + ")";
} }
}
break; break;
} }
case DAGNode::BINARY: { case DAGNode::BINARY: {
if (node->operands.size() < 2 || !node->operands[0] || !node->operands[1]) break;
auto bin = dynamic_cast<BinaryInst*>(node->value); auto bin = dynamic_cast<BinaryInst*>(node->value);
if (!bin) break;
auto lhs_reg = node->operands[0]->result_reg; auto lhs_reg = node->operands[0]->result_reg;
auto rhs_reg = node->operands[1]->result_reg; auto rhs_reg = node->operands[1]->result_reg;
if (lhs_reg.empty() || rhs_reg.empty()) break;
std::string opcode; std::string opcode;
switch (bin->getKind()) { switch (bin->getKind()) {
case BinaryInst::kAdd: opcode = "add"; break; case BinaryInst::kAdd: opcode = "add"; break;
case BinaryInst::kSub: opcode = "sub"; break; case BinaryInst::kSub: opcode = "sub"; break;
case BinaryInst::kMul: opcode = "mul"; break; case BinaryInst::kMul: opcode = "mul"; break;
case BinaryInst::kICmpEQ:
node->inst = "sub " + node->result_reg + ", " + lhs_reg + ", " + rhs_reg + "\nseqz " + node->result_reg + ", " + node->result_reg;
return;
default: break; default: break;
} }
if (!opcode.empty()) { if (!opcode.empty()) {
@ -338,11 +419,15 @@ void RISCv32CodeGen::select_instructions(DAGNode* node, const RegAllocResult& al
break; break;
} }
case DAGNode::CALL: { case DAGNode::CALL: {
if (!node->value) break;
auto call = dynamic_cast<CallInst*>(node->value); auto call = dynamic_cast<CallInst*>(node->value);
if (!call) break;
std::string insts; std::string insts;
for (size_t i = 0; i < node->operands.size() && i < 8; ++i) { for (size_t i = 0; i < node->operands.size() && i < 8; ++i) {
if (node->operands[i] && !node->operands[i]->result_reg.empty()) {
insts += "mv a" + std::to_string(i) + ", " + node->operands[i]->result_reg + "\n"; insts += "mv a" + std::to_string(i) + ", " + node->operands[i]->result_reg + "\n";
} }
}
insts += "jal " + call->getCallee()->getName(); insts += "jal " + call->getCallee()->getName();
if (call->getType()->isInt() || call->getType()->isFloat()) { if (call->getType()->isInt() || call->getType()->isFloat()) {
insts += "\nmv " + node->result_reg + ", a0"; insts += "\nmv " + node->result_reg + ", a0";
@ -351,17 +436,20 @@ void RISCv32CodeGen::select_instructions(DAGNode* node, const RegAllocResult& al
break; break;
} }
case DAGNode::RETURN: { case DAGNode::RETURN: {
if (!node->operands.empty()) { if (!node->operands.empty() && node->operands[0] && !node->operands[0]->result_reg.empty()) {
node->inst = "mv a0, " + node->operands[0]->result_reg; node->inst = "mv a0, " + node->operands[0]->result_reg;
} }
break; break;
} }
case DAGNode::BRANCH: { case DAGNode::BRANCH: {
auto br = dynamic_cast<CondBrInst*>(node->value); auto br = dynamic_cast<CondBrInst*>(node->value);
if (!br) break;
if (!node->inst.empty()) break; // 常量条件已预生成跳转
if (node->operands.empty() || !node->operands[0]) break;
auto cond_reg = node->operands[0]->result_reg; auto cond_reg = node->operands[0]->result_reg;
if (cond_reg.empty()) break;
auto then_block = br->getThenBlock()->getName(); auto then_block = br->getThenBlock()->getName();
auto else_block = br->getElseBlock()->getName(); auto else_block = br->getElseBlock()->getName();
// 假设条件为 true 时跳转到 then否则到 else
node->inst = "bnez " + cond_reg + ", " + then_block + "\nj " + else_block; node->inst = "bnez " + cond_reg + ", " + then_block + "\nj " + else_block;
break; break;
} }
@ -371,41 +459,54 @@ void RISCv32CodeGen::select_instructions(DAGNode* node, const RegAllocResult& al
// 指令发射 // 指令发射
void RISCv32CodeGen::emit_instructions(DAGNode* node, std::vector<std::string>& insts, const RegAllocResult& alloc) { void RISCv32CodeGen::emit_instructions(DAGNode* node, std::vector<std::string>& insts, const RegAllocResult& alloc) {
std::set<DAGNode*> emitted; // 局部变量,针对每个基本块独立 std::set<DAGNode*> emitted;
std::set<std::string> seen_insts; // 跟踪已发射的指令以去重 std::set<std::string> seen_insts;
std::function<void(DAGNode*)> emit = [&](DAGNode* n) { std::function<void(DAGNode*)> emit = [&](DAGNode* n) {
if (!n || emitted.count(n)) return; if (!n || emitted.count(n)) return;
emitted.insert(n); emitted.insert(n);
for (auto operand : n->operands) { for (auto operand : n->operands) {
emit(operand); if (operand) emit(operand);
} }
if (!n->inst.empty()) { if (!n->inst.empty()) {
std::stringstream ss(n->inst); std::stringstream ss(n->inst);
std::string line; std::string line;
while (std::getline(ss, line, '\n')) { while (std::getline(ss, line, '\n')) {
if (!line.empty()) { // 清理空白和无效字符
line.erase(std::remove_if(line.begin(), line.end(), [](char c) { return std::isspace(c) || c == ':'; }), line.end());
if (line.empty()) continue;
std::string new_line = line; std::string new_line = line;
// 替换结果寄存器
if (!n->result_reg.empty()) { if (!n->result_reg.empty()) {
if (alloc.vreg_to_preg.count(n->result_reg)) { if (alloc.vreg_to_preg.count(n->result_reg)) {
new_line = std::regex_replace(new_line, std::regex("\\b" + n->result_reg + "\\b"), reg_to_string(alloc.vreg_to_preg.at(n->result_reg))); new_line = std::regex_replace(new_line, std::regex("\\b" + n->result_reg + "\\b"), reg_to_string(alloc.vreg_to_preg.at(n->result_reg)));
} else if (alloc.stack_map.count(n->value)) { } else {
// 如果虚拟寄存器未分配物理寄存器,使用 t0 并通过栈访问
if (n->value && alloc.stack_map.count(n->value)) {
int offset = alloc.stack_map.at(n->value); int offset = alloc.stack_map.at(n->value);
new_line = std::regex_replace(new_line, std::regex("\\b" + n->result_reg + "\\b"), "t0"); new_line = std::regex_replace(new_line, std::regex("\\b" + n->result_reg + "\\b"), "t0");
if (n->kind != DAGNode::STORE) { // 避免为 STORE 节点重复生成存储
std::string store = "sw t0, " + std::to_string(offset) + "(s0)"; std::string store = "sw t0, " + std::to_string(offset) + "(s0)";
if (!seen_insts.count(store)) { if (!seen_insts.count(store)) {
insts.push_back(store); insts.push_back(store);
seen_insts.insert(store); seen_insts.insert(store);
} }
} }
} else {
// 兜底:使用 t0
new_line = std::regex_replace(new_line, std::regex("\\b" + n->result_reg + "\\b"), "t0");
} }
}
}
// 替换操作数寄存器
for (auto operand : n->operands) { for (auto operand : n->operands) {
if (!operand->result_reg.empty()) { if (operand && !operand->result_reg.empty()) {
if (alloc.vreg_to_preg.count(operand->result_reg)) { if (alloc.vreg_to_preg.count(operand->result_reg)) {
new_line = std::regex_replace(new_line, std::regex("\\b" + operand->result_reg + "\\b"), reg_to_string(alloc.vreg_to_preg.at(operand->result_reg))); new_line = std::regex_replace(new_line, std::regex("\\b" + operand->result_reg + "\\b"), reg_to_string(alloc.vreg_to_preg.at(operand->result_reg)));
} else if (alloc.stack_map.count(operand->value)) { } else if (operand->value && alloc.stack_map.count(operand->value)) {
int offset = alloc.stack_map.at(operand->value); int offset = alloc.stack_map.at(operand->value);
std::string load = "lw t0, " + std::to_string(offset) + "(s0)"; std::string load = "lw t0, " + std::to_string(offset) + "(s0)";
if (!seen_insts.count(load)) { if (!seen_insts.count(load)) {
@ -413,6 +514,9 @@ void RISCv32CodeGen::emit_instructions(DAGNode* node, std::vector<std::string>&
seen_insts.insert(load); seen_insts.insert(load);
} }
new_line = std::regex_replace(new_line, std::regex("\\b" + operand->result_reg + "\\b"), "t0"); new_line = std::regex_replace(new_line, std::regex("\\b" + operand->result_reg + "\\b"), "t0");
} else {
// 兜底:使用 t0
new_line = std::regex_replace(new_line, std::regex("\\b" + operand->result_reg + "\\b"), "t0");
} }
} }
} }
@ -422,7 +526,6 @@ void RISCv32CodeGen::emit_instructions(DAGNode* node, std::vector<std::string>&
} }
} }
} }
}
}; };
emit(node); emit(node);
@ -688,9 +791,10 @@ RISCv32CodeGen::RegAllocResult RISCv32CodeGen::register_allocation(Function* fun
// 分配溢出栈空间 // 分配溢出栈空间
for (const auto& pair : value_vreg_map) { for (const auto& pair : value_vreg_map) {
auto vreg = pair.second; auto vreg = pair.second;
auto value = pair.first;
if (result.vreg_to_preg.find(vreg) == result.vreg_to_preg.end()) { if (result.vreg_to_preg.find(vreg) == result.vreg_to_preg.end()) {
if (result.stack_map.find(pair.first) == result.stack_map.end()) { if (result.stack_map.find(value) == result.stack_map.end()) {
result.stack_map[pair.first] = stack_offset; result.stack_map[value] = stack_offset;
stack_offset += 4; stack_offset += 4;
} }
} }