[backend] fixed 1 segmentation fault

This commit is contained in:
Lixuanwang
2025-06-23 22:38:29 +08:00
parent ab3eb253f9
commit 3c3f48ee87

View File

@ -124,13 +124,13 @@ std::vector<std::unique_ptr<RISCv32CodeGen::DAGNode>> RISCv32CodeGen::build_dag(
std::map<Value*, DAGNode*> value_to_node;
static int vreg_counter = 0;
auto create_node = [&](DAGNode::NodeKind kind, Value* val = nullptr) {
if (val && value_to_node.find(val) != value_to_node.end()) {
auto create_node = [&](DAGNode::NodeKind kind, Value* val = nullptr) -> DAGNode* {
if (val && value_to_node.count(val) && kind != DAGNode::STORE && kind != DAGNode::RETURN && kind != DAGNode::BRANCH) {
return value_to_node[val];
}
auto node = std::make_unique<DAGNode>(kind);
node->value = val;
if (val && kind != DAGNode::STORE && kind != DAGNode::RETURN) {
if (val && kind != DAGNode::STORE && kind != DAGNode::RETURN && kind != DAGNode::BRANCH) {
node->result_reg = "v" + std::to_string(vreg_counter++);
value_vreg_map[val] = node->result_reg;
value_to_node[val] = node.get();
@ -146,51 +146,103 @@ std::vector<std::unique_ptr<RISCv32CodeGen::DAGNode>> RISCv32CodeGen::build_dag(
auto store_node = create_node(DAGNode::STORE);
Value* val = store->getValue();
Value* ptr = store->getPointer();
DAGNode* val_node = nullptr;
if (auto constant = dynamic_cast<ConstantValue*>(val)) {
val_node = create_node(DAGNode::CONSTANT, val);
} else {
val_node = create_node(DAGNode::LOAD, val);
DAGNode* val_node = value_to_node.count(val) ? value_to_node[val] : nullptr;
if (!val_node) {
if (auto constant = dynamic_cast<ConstantValue*>(val)) {
val_node = create_node(DAGNode::CONSTANT, val);
} else {
val_node = create_node(DAGNode::LOAD, val);
}
}
auto ptr_node = create_node(DAGNode::CONSTANT, ptr);
auto ptr_node = value_to_node.count(ptr) ? value_to_node[ptr] : create_node(DAGNode::CONSTANT, ptr);
store_node->operands.push_back(val_node);
store_node->operands.push_back(ptr_node);
val_node->users.push_back(store_node);
ptr_node->users.push_back(store_node);
} else if (auto load = dynamic_cast<LoadInst*>(inst.get())) {
if (value_to_node.count(load)) continue; // CSE
auto load_node = create_node(DAGNode::LOAD, load);
auto ptr_node = create_node(DAGNode::CONSTANT, load->getPointer());
auto ptr = load->getPointer();
auto ptr_node = value_to_node.count(ptr) ? value_to_node[ptr] : create_node(DAGNode::CONSTANT, ptr);
load_node->operands.push_back(ptr_node);
ptr_node->users.push_back(load_node);
} else if (auto bin = dynamic_cast<BinaryInst*>(inst.get())) {
if (value_to_node.count(bin)) continue; // CSE
auto bin_node = create_node(DAGNode::BINARY, bin);
auto lhs_node = create_node(DAGNode::LOAD, bin->getLhs());
auto rhs_node = create_node(DAGNode::LOAD, bin->getRhs());
auto lhs = bin->getLhs();
auto rhs = bin->getRhs();
auto lhs_node = value_to_node.count(lhs) ? value_to_node[lhs] : nullptr;
if (!lhs_node) {
if (auto constant = dynamic_cast<ConstantValue*>(lhs)) {
lhs_node = create_node(DAGNode::CONSTANT, lhs);
} else {
lhs_node = create_node(DAGNode::LOAD, lhs);
}
}
auto rhs_node = value_to_node.count(rhs) ? value_to_node[rhs] : nullptr;
if (!rhs_node) {
if (auto constant = dynamic_cast<ConstantValue*>(rhs)) {
rhs_node = create_node(DAGNode::CONSTANT, rhs);
} else {
rhs_node = create_node(DAGNode::LOAD, rhs);
}
}
bin_node->operands.push_back(lhs_node);
bin_node->operands.push_back(rhs_node);
lhs_node->users.push_back(bin_node);
rhs_node->users.push_back(bin_node);
} else if (auto call = dynamic_cast<CallInst*>(inst.get())) {
if (value_to_node.count(call)) continue; // CSE
auto call_node = create_node(DAGNode::CALL, call);
for (auto arg : call->getArguments()) {
auto arg_node = create_node(DAGNode::CONSTANT, arg->getValue());
auto arg_val = arg->getValue();
auto arg_node = value_to_node.count(arg_val) ? value_to_node[arg_val] : nullptr;
if (!arg_node) {
if (auto constant = dynamic_cast<ConstantValue*>(arg_val)) {
arg_node = create_node(DAGNode::CONSTANT, arg_val);
} else {
arg_node = create_node(DAGNode::LOAD, arg_val);
}
}
call_node->operands.push_back(arg_node);
arg_node->users.push_back(call_node);
}
} else if (auto ret = dynamic_cast<ReturnInst*>(inst.get())) {
auto ret_node = create_node(DAGNode::RETURN);
if (ret->hasReturnValue()) {
auto val_node = create_node(DAGNode::LOAD, ret->getReturnValue());
auto val = ret->getReturnValue();
auto val_node = value_to_node.count(val) ? value_to_node[val] : nullptr;
if (!val_node) {
if (auto constant = dynamic_cast<ConstantValue*>(val)) {
val_node = create_node(DAGNode::CONSTANT, val);
} else {
val_node = create_node(DAGNode::LOAD, val);
}
}
ret_node->operands.push_back(val_node);
val_node->users.push_back(ret_node);
}
} else if (auto cond_br = dynamic_cast<CondBrInst*>(inst.get())) {
auto br_node = create_node(DAGNode::BRANCH);
auto br_node = create_node(DAGNode::BRANCH, cond_br);
auto cond = cond_br->getCondition();
auto cond_node = value_to_node.count(cond) ? value_to_node[cond] : create_node(DAGNode::LOAD, cond);
br_node->operands.push_back(cond_node);
br_node->value = cond_br; // 存储 CondBrInst 以获取 then/else 块
cond_node->users.push_back(br_node);
DAGNode* cond_node = nullptr;
if (auto constant = dynamic_cast<ConstantValue*>(cond)) {
// 常量条件,直接记录分支目标
br_node->operands.push_back(nullptr); // 占位符,表示无条件操作数
br_node->inst = constant->getInt() ? "j " + cond_br->getThenBlock()->getName()
: "j " + cond_br->getElseBlock()->getName();
} else {
cond_node = value_to_node.count(cond) ? value_to_node[cond] : nullptr;
if (!cond_node) {
if (auto bin = dynamic_cast<BinaryInst*>(cond)) {
cond_node = create_node(DAGNode::BINARY, cond);
} else {
cond_node = create_node(DAGNode::LOAD, cond);
}
}
br_node->operands.push_back(cond_node);
cond_node->users.push_back(br_node);
}
}
}
@ -277,7 +329,9 @@ void RISCv32CodeGen::select_instructions(DAGNode* node, const RegAllocResult& al
if (!node->inst.empty()) return;
for (auto operand : node->operands) {
select_instructions(operand, alloc);
if (operand) {
select_instructions(operand, alloc);
}
}
switch (node->kind) {
@ -298,38 +352,65 @@ void RISCv32CodeGen::select_instructions(DAGNode* node, const RegAllocResult& al
break;
}
case DAGNode::LOAD: {
if (node->operands.empty() || !node->operands[0]) {
if (auto constant = dynamic_cast<ConstantValue*>(node->value)) {
node->inst = "li " + node->result_reg + ", " + std::to_string(constant->getInt());
}
break;
}
auto ptr = node->operands[0]->value;
if (!ptr) break;
if (alloc.stack_map.count(ptr)) {
int offset = alloc.stack_map.at(ptr);
node->inst = "lw " + node->result_reg + ", " + std::to_string(offset) + "(s0)";
} else {
auto ptr_reg = node->operands[0]->result_reg;
node->inst = "lw " + node->result_reg + ", 0(" + ptr_reg + ")";
if (!ptr_reg.empty()) {
node->inst = "lw " + node->result_reg + ", 0(" + ptr_reg + ")";
}
}
break;
}
case DAGNode::STORE: {
if (node->operands.size() < 2 || !node->operands[0] || !node->operands[1]) break;
auto val = node->operands[0]->value;
auto ptr = node->operands[1]->value;
auto val_reg = node->operands[0]->result_reg;
if (!val || !ptr) break;
if (alloc.stack_map.count(ptr)) {
int offset = alloc.stack_map.at(ptr);
node->inst = "sw " + val_reg + ", " + std::to_string(offset) + "(s0)";
if (auto constant = dynamic_cast<ConstantValue*>(val)) {
// 直接存储常量
node->inst = "li t0, " + std::to_string(constant->getInt()) + "\nsw t0, " + std::to_string(offset) + "(s0)";
} else {
auto val_reg = node->operands[0]->result_reg;
if (!val_reg.empty()) {
node->inst = "sw " + val_reg + ", " + std::to_string(offset) + "(s0)";
}
}
} else {
auto ptr_reg = node->operands[1]->result_reg;
node->inst = "sw " + val_reg + ", 0(" + ptr_reg + ")";
auto val_reg = node->operands[0]->result_reg;
if (!ptr_reg.empty() && !val_reg.empty()) {
node->inst = "sw " + val_reg + ", 0(" + ptr_reg + ")";
}
}
break;
}
case DAGNode::BINARY: {
if (node->operands.size() < 2 || !node->operands[0] || !node->operands[1]) break;
auto bin = dynamic_cast<BinaryInst*>(node->value);
if (!bin) break;
auto lhs_reg = node->operands[0]->result_reg;
auto rhs_reg = node->operands[1]->result_reg;
if (lhs_reg.empty() || rhs_reg.empty()) break;
std::string opcode;
switch (bin->getKind()) {
case BinaryInst::kAdd: opcode = "add"; break;
case BinaryInst::kSub: opcode = "sub"; break;
case BinaryInst::kMul: opcode = "mul"; break;
case BinaryInst::kICmpEQ:
node->inst = "sub " + node->result_reg + ", " + lhs_reg + ", " + rhs_reg + "\nseqz " + node->result_reg + ", " + node->result_reg;
return;
default: break;
}
if (!opcode.empty()) {
@ -338,10 +419,14 @@ void RISCv32CodeGen::select_instructions(DAGNode* node, const RegAllocResult& al
break;
}
case DAGNode::CALL: {
if (!node->value) break;
auto call = dynamic_cast<CallInst*>(node->value);
if (!call) break;
std::string insts;
for (size_t i = 0; i < node->operands.size() && i < 8; ++i) {
insts += "mv a" + std::to_string(i) + ", " + node->operands[i]->result_reg + "\n";
if (node->operands[i] && !node->operands[i]->result_reg.empty()) {
insts += "mv a" + std::to_string(i) + ", " + node->operands[i]->result_reg + "\n";
}
}
insts += "jal " + call->getCallee()->getName();
if (call->getType()->isInt() || call->getType()->isFloat()) {
@ -351,17 +436,20 @@ void RISCv32CodeGen::select_instructions(DAGNode* node, const RegAllocResult& al
break;
}
case DAGNode::RETURN: {
if (!node->operands.empty()) {
if (!node->operands.empty() && node->operands[0] && !node->operands[0]->result_reg.empty()) {
node->inst = "mv a0, " + node->operands[0]->result_reg;
}
break;
}
case DAGNode::BRANCH: {
auto br = dynamic_cast<CondBrInst*>(node->value);
if (!br) break;
if (!node->inst.empty()) break; // 常量条件已预生成跳转
if (node->operands.empty() || !node->operands[0]) break;
auto cond_reg = node->operands[0]->result_reg;
if (cond_reg.empty()) break;
auto then_block = br->getThenBlock()->getName();
auto else_block = br->getElseBlock()->getName();
// 假设条件为 true 时跳转到 then否则到 else
node->inst = "bnez " + cond_reg + ", " + then_block + "\nj " + else_block;
break;
}
@ -371,56 +459,71 @@ void RISCv32CodeGen::select_instructions(DAGNode* node, const RegAllocResult& al
// 指令发射
void RISCv32CodeGen::emit_instructions(DAGNode* node, std::vector<std::string>& insts, const RegAllocResult& alloc) {
std::set<DAGNode*> emitted; // 局部变量,针对每个基本块独立
std::set<std::string> seen_insts; // 跟踪已发射的指令以去重
std::set<DAGNode*> emitted;
std::set<std::string> seen_insts;
std::function<void(DAGNode*)> emit = [&](DAGNode* n) {
if (!n || emitted.count(n)) return;
emitted.insert(n);
for (auto operand : n->operands) {
emit(operand);
if (operand) emit(operand);
}
if (!n->inst.empty()) {
std::stringstream ss(n->inst);
std::string line;
while (std::getline(ss, line, '\n')) {
if (!line.empty()) {
std::string new_line = line;
if (!n->result_reg.empty()) {
if (alloc.vreg_to_preg.count(n->result_reg)) {
new_line = std::regex_replace(new_line, std::regex("\\b" + n->result_reg + "\\b"), reg_to_string(alloc.vreg_to_preg.at(n->result_reg)));
} else if (alloc.stack_map.count(n->value)) {
// 清理空白和无效字符
line.erase(std::remove_if(line.begin(), line.end(), [](char c) { return std::isspace(c) || c == ':'; }), line.end());
if (line.empty()) continue;
std::string new_line = line;
// 替换结果寄存器
if (!n->result_reg.empty()) {
if (alloc.vreg_to_preg.count(n->result_reg)) {
new_line = std::regex_replace(new_line, std::regex("\\b" + n->result_reg + "\\b"), reg_to_string(alloc.vreg_to_preg.at(n->result_reg)));
} else {
// 如果虚拟寄存器未分配物理寄存器,使用 t0 并通过栈访问
if (n->value && alloc.stack_map.count(n->value)) {
int offset = alloc.stack_map.at(n->value);
new_line = std::regex_replace(new_line, std::regex("\\b" + n->result_reg + "\\b"), "t0");
std::string store = "sw t0, " + std::to_string(offset) + "(s0)";
if (!seen_insts.count(store)) {
insts.push_back(store);
seen_insts.insert(store);
}
}
}
for (auto operand : n->operands) {
if (!operand->result_reg.empty()) {
if (alloc.vreg_to_preg.count(operand->result_reg)) {
new_line = std::regex_replace(new_line, std::regex("\\b" + operand->result_reg + "\\b"), reg_to_string(alloc.vreg_to_preg.at(operand->result_reg)));
} else if (alloc.stack_map.count(operand->value)) {
int offset = alloc.stack_map.at(operand->value);
std::string load = "lw t0, " + std::to_string(offset) + "(s0)";
if (!seen_insts.count(load)) {
insts.push_back(load);
seen_insts.insert(load);
if (n->kind != DAGNode::STORE) { // 避免为 STORE 节点重复生成存储
std::string store = "sw t0, " + std::to_string(offset) + "(s0)";
if (!seen_insts.count(store)) {
insts.push_back(store);
seen_insts.insert(store);
}
new_line = std::regex_replace(new_line, std::regex("\\b" + operand->result_reg + "\\b"), "t0");
}
} else {
// 兜底:使用 t0
new_line = std::regex_replace(new_line, std::regex("\\b" + n->result_reg + "\\b"), "t0");
}
}
if (!seen_insts.count(new_line)) {
insts.push_back(new_line);
seen_insts.insert(new_line);
}
// 替换操作数寄存器
for (auto operand : n->operands) {
if (operand && !operand->result_reg.empty()) {
if (alloc.vreg_to_preg.count(operand->result_reg)) {
new_line = std::regex_replace(new_line, std::regex("\\b" + operand->result_reg + "\\b"), reg_to_string(alloc.vreg_to_preg.at(operand->result_reg)));
} else if (operand->value && alloc.stack_map.count(operand->value)) {
int offset = alloc.stack_map.at(operand->value);
std::string load = "lw t0, " + std::to_string(offset) + "(s0)";
if (!seen_insts.count(load)) {
insts.push_back(load);
seen_insts.insert(load);
}
new_line = std::regex_replace(new_line, std::regex("\\b" + operand->result_reg + "\\b"), "t0");
} else {
// 兜底:使用 t0
new_line = std::regex_replace(new_line, std::regex("\\b" + operand->result_reg + "\\b"), "t0");
}
}
}
if (!seen_insts.count(new_line)) {
insts.push_back(new_line);
seen_insts.insert(new_line);
}
}
}
};
@ -688,9 +791,10 @@ RISCv32CodeGen::RegAllocResult RISCv32CodeGen::register_allocation(Function* fun
// 分配溢出栈空间
for (const auto& pair : value_vreg_map) {
auto vreg = pair.second;
auto value = pair.first;
if (result.vreg_to_preg.find(vreg) == result.vreg_to_preg.end()) {
if (result.stack_map.find(pair.first) == result.stack_map.end()) {
result.stack_map[pair.first] = stack_offset;
if (result.stack_map.find(value) == result.stack_map.end()) {
result.stack_map[value] = stack_offset;
stack_offset += 4;
}
}