diff --git a/src/RISCv32Backend.cpp b/src/RISCv32Backend.cpp index 86cb0cc..8d7b028 100644 --- a/src/RISCv32Backend.cpp +++ b/src/RISCv32Backend.cpp @@ -124,13 +124,13 @@ std::vector> RISCv32CodeGen::build_dag( std::map value_to_node; static int vreg_counter = 0; - auto create_node = [&](DAGNode::NodeKind kind, Value* val = nullptr) { - if (val && value_to_node.find(val) != value_to_node.end()) { + auto create_node = [&](DAGNode::NodeKind kind, Value* val = nullptr) -> DAGNode* { + if (val && value_to_node.count(val) && kind != DAGNode::STORE && kind != DAGNode::RETURN && kind != DAGNode::BRANCH) { return value_to_node[val]; } auto node = std::make_unique(kind); node->value = val; - if (val && kind != DAGNode::STORE && kind != DAGNode::RETURN) { + if (val && kind != DAGNode::STORE && kind != DAGNode::RETURN && kind != DAGNode::BRANCH) { node->result_reg = "v" + std::to_string(vreg_counter++); value_vreg_map[val] = node->result_reg; value_to_node[val] = node.get(); @@ -146,51 +146,103 @@ std::vector> RISCv32CodeGen::build_dag( auto store_node = create_node(DAGNode::STORE); Value* val = store->getValue(); Value* ptr = store->getPointer(); - DAGNode* val_node = nullptr; - if (auto constant = dynamic_cast(val)) { - val_node = create_node(DAGNode::CONSTANT, val); - } else { - val_node = create_node(DAGNode::LOAD, val); + DAGNode* val_node = value_to_node.count(val) ? value_to_node[val] : nullptr; + if (!val_node) { + if (auto constant = dynamic_cast(val)) { + val_node = create_node(DAGNode::CONSTANT, val); + } else { + val_node = create_node(DAGNode::LOAD, val); + } } - auto ptr_node = create_node(DAGNode::CONSTANT, ptr); + auto ptr_node = value_to_node.count(ptr) ? value_to_node[ptr] : create_node(DAGNode::CONSTANT, ptr); store_node->operands.push_back(val_node); store_node->operands.push_back(ptr_node); val_node->users.push_back(store_node); ptr_node->users.push_back(store_node); } else if (auto load = dynamic_cast(inst.get())) { + if (value_to_node.count(load)) continue; // CSE auto load_node = create_node(DAGNode::LOAD, load); - auto ptr_node = create_node(DAGNode::CONSTANT, load->getPointer()); + auto ptr = load->getPointer(); + auto ptr_node = value_to_node.count(ptr) ? value_to_node[ptr] : create_node(DAGNode::CONSTANT, ptr); load_node->operands.push_back(ptr_node); ptr_node->users.push_back(load_node); } else if (auto bin = dynamic_cast(inst.get())) { + if (value_to_node.count(bin)) continue; // CSE auto bin_node = create_node(DAGNode::BINARY, bin); - auto lhs_node = create_node(DAGNode::LOAD, bin->getLhs()); - auto rhs_node = create_node(DAGNode::LOAD, bin->getRhs()); + auto lhs = bin->getLhs(); + auto rhs = bin->getRhs(); + auto lhs_node = value_to_node.count(lhs) ? value_to_node[lhs] : nullptr; + if (!lhs_node) { + if (auto constant = dynamic_cast(lhs)) { + lhs_node = create_node(DAGNode::CONSTANT, lhs); + } else { + lhs_node = create_node(DAGNode::LOAD, lhs); + } + } + auto rhs_node = value_to_node.count(rhs) ? value_to_node[rhs] : nullptr; + if (!rhs_node) { + if (auto constant = dynamic_cast(rhs)) { + rhs_node = create_node(DAGNode::CONSTANT, rhs); + } else { + rhs_node = create_node(DAGNode::LOAD, rhs); + } + } bin_node->operands.push_back(lhs_node); bin_node->operands.push_back(rhs_node); lhs_node->users.push_back(bin_node); rhs_node->users.push_back(bin_node); } else if (auto call = dynamic_cast(inst.get())) { + if (value_to_node.count(call)) continue; // CSE auto call_node = create_node(DAGNode::CALL, call); for (auto arg : call->getArguments()) { - auto arg_node = create_node(DAGNode::CONSTANT, arg->getValue()); + auto arg_val = arg->getValue(); + auto arg_node = value_to_node.count(arg_val) ? value_to_node[arg_val] : nullptr; + if (!arg_node) { + if (auto constant = dynamic_cast(arg_val)) { + arg_node = create_node(DAGNode::CONSTANT, arg_val); + } else { + arg_node = create_node(DAGNode::LOAD, arg_val); + } + } call_node->operands.push_back(arg_node); arg_node->users.push_back(call_node); } } else if (auto ret = dynamic_cast(inst.get())) { auto ret_node = create_node(DAGNode::RETURN); if (ret->hasReturnValue()) { - auto val_node = create_node(DAGNode::LOAD, ret->getReturnValue()); + auto val = ret->getReturnValue(); + auto val_node = value_to_node.count(val) ? value_to_node[val] : nullptr; + if (!val_node) { + if (auto constant = dynamic_cast(val)) { + val_node = create_node(DAGNode::CONSTANT, val); + } else { + val_node = create_node(DAGNode::LOAD, val); + } + } ret_node->operands.push_back(val_node); val_node->users.push_back(ret_node); } } else if (auto cond_br = dynamic_cast(inst.get())) { - auto br_node = create_node(DAGNode::BRANCH); + auto br_node = create_node(DAGNode::BRANCH, cond_br); auto cond = cond_br->getCondition(); - auto cond_node = value_to_node.count(cond) ? value_to_node[cond] : create_node(DAGNode::LOAD, cond); - br_node->operands.push_back(cond_node); - br_node->value = cond_br; // 存储 CondBrInst 以获取 then/else 块 - cond_node->users.push_back(br_node); + DAGNode* cond_node = nullptr; + if (auto constant = dynamic_cast(cond)) { + // 常量条件,直接记录分支目标 + br_node->operands.push_back(nullptr); // 占位符,表示无条件操作数 + br_node->inst = constant->getInt() ? "j " + cond_br->getThenBlock()->getName() + : "j " + cond_br->getElseBlock()->getName(); + } else { + cond_node = value_to_node.count(cond) ? value_to_node[cond] : nullptr; + if (!cond_node) { + if (auto bin = dynamic_cast(cond)) { + cond_node = create_node(DAGNode::BINARY, cond); + } else { + cond_node = create_node(DAGNode::LOAD, cond); + } + } + br_node->operands.push_back(cond_node); + cond_node->users.push_back(br_node); + } } } @@ -277,7 +329,9 @@ void RISCv32CodeGen::select_instructions(DAGNode* node, const RegAllocResult& al if (!node->inst.empty()) return; for (auto operand : node->operands) { - select_instructions(operand, alloc); + if (operand) { + select_instructions(operand, alloc); + } } switch (node->kind) { @@ -298,38 +352,65 @@ void RISCv32CodeGen::select_instructions(DAGNode* node, const RegAllocResult& al break; } case DAGNode::LOAD: { + if (node->operands.empty() || !node->operands[0]) { + if (auto constant = dynamic_cast(node->value)) { + node->inst = "li " + node->result_reg + ", " + std::to_string(constant->getInt()); + } + break; + } auto ptr = node->operands[0]->value; + if (!ptr) break; if (alloc.stack_map.count(ptr)) { int offset = alloc.stack_map.at(ptr); node->inst = "lw " + node->result_reg + ", " + std::to_string(offset) + "(s0)"; } else { auto ptr_reg = node->operands[0]->result_reg; - node->inst = "lw " + node->result_reg + ", 0(" + ptr_reg + ")"; + if (!ptr_reg.empty()) { + node->inst = "lw " + node->result_reg + ", 0(" + ptr_reg + ")"; + } } break; } case DAGNode::STORE: { + if (node->operands.size() < 2 || !node->operands[0] || !node->operands[1]) break; auto val = node->operands[0]->value; auto ptr = node->operands[1]->value; - auto val_reg = node->operands[0]->result_reg; + if (!val || !ptr) break; if (alloc.stack_map.count(ptr)) { int offset = alloc.stack_map.at(ptr); - node->inst = "sw " + val_reg + ", " + std::to_string(offset) + "(s0)"; + if (auto constant = dynamic_cast(val)) { + // 直接存储常量 + node->inst = "li t0, " + std::to_string(constant->getInt()) + "\nsw t0, " + std::to_string(offset) + "(s0)"; + } else { + auto val_reg = node->operands[0]->result_reg; + if (!val_reg.empty()) { + node->inst = "sw " + val_reg + ", " + std::to_string(offset) + "(s0)"; + } + } } else { auto ptr_reg = node->operands[1]->result_reg; - node->inst = "sw " + val_reg + ", 0(" + ptr_reg + ")"; + auto val_reg = node->operands[0]->result_reg; + if (!ptr_reg.empty() && !val_reg.empty()) { + node->inst = "sw " + val_reg + ", 0(" + ptr_reg + ")"; + } } break; } case DAGNode::BINARY: { + if (node->operands.size() < 2 || !node->operands[0] || !node->operands[1]) break; auto bin = dynamic_cast(node->value); + if (!bin) break; auto lhs_reg = node->operands[0]->result_reg; auto rhs_reg = node->operands[1]->result_reg; + if (lhs_reg.empty() || rhs_reg.empty()) break; std::string opcode; switch (bin->getKind()) { case BinaryInst::kAdd: opcode = "add"; break; case BinaryInst::kSub: opcode = "sub"; break; case BinaryInst::kMul: opcode = "mul"; break; + case BinaryInst::kICmpEQ: + node->inst = "sub " + node->result_reg + ", " + lhs_reg + ", " + rhs_reg + "\nseqz " + node->result_reg + ", " + node->result_reg; + return; default: break; } if (!opcode.empty()) { @@ -338,10 +419,14 @@ void RISCv32CodeGen::select_instructions(DAGNode* node, const RegAllocResult& al break; } case DAGNode::CALL: { + if (!node->value) break; auto call = dynamic_cast(node->value); + if (!call) break; std::string insts; for (size_t i = 0; i < node->operands.size() && i < 8; ++i) { - insts += "mv a" + std::to_string(i) + ", " + node->operands[i]->result_reg + "\n"; + if (node->operands[i] && !node->operands[i]->result_reg.empty()) { + insts += "mv a" + std::to_string(i) + ", " + node->operands[i]->result_reg + "\n"; + } } insts += "jal " + call->getCallee()->getName(); if (call->getType()->isInt() || call->getType()->isFloat()) { @@ -351,17 +436,20 @@ void RISCv32CodeGen::select_instructions(DAGNode* node, const RegAllocResult& al break; } case DAGNode::RETURN: { - if (!node->operands.empty()) { + if (!node->operands.empty() && node->operands[0] && !node->operands[0]->result_reg.empty()) { node->inst = "mv a0, " + node->operands[0]->result_reg; } break; } case DAGNode::BRANCH: { auto br = dynamic_cast(node->value); + if (!br) break; + if (!node->inst.empty()) break; // 常量条件已预生成跳转 + if (node->operands.empty() || !node->operands[0]) break; auto cond_reg = node->operands[0]->result_reg; + if (cond_reg.empty()) break; auto then_block = br->getThenBlock()->getName(); auto else_block = br->getElseBlock()->getName(); - // 假设条件为 true 时跳转到 then,否则到 else node->inst = "bnez " + cond_reg + ", " + then_block + "\nj " + else_block; break; } @@ -371,56 +459,71 @@ void RISCv32CodeGen::select_instructions(DAGNode* node, const RegAllocResult& al // 指令发射 void RISCv32CodeGen::emit_instructions(DAGNode* node, std::vector& insts, const RegAllocResult& alloc) { - std::set emitted; // 局部变量,针对每个基本块独立 - std::set seen_insts; // 跟踪已发射的指令以去重 + std::set emitted; + std::set seen_insts; std::function emit = [&](DAGNode* n) { if (!n || emitted.count(n)) return; emitted.insert(n); for (auto operand : n->operands) { - emit(operand); + if (operand) emit(operand); } if (!n->inst.empty()) { std::stringstream ss(n->inst); std::string line; while (std::getline(ss, line, '\n')) { - if (!line.empty()) { - std::string new_line = line; - if (!n->result_reg.empty()) { - if (alloc.vreg_to_preg.count(n->result_reg)) { - new_line = std::regex_replace(new_line, std::regex("\\b" + n->result_reg + "\\b"), reg_to_string(alloc.vreg_to_preg.at(n->result_reg))); - } else if (alloc.stack_map.count(n->value)) { + // 清理空白和无效字符 + line.erase(std::remove_if(line.begin(), line.end(), [](char c) { return std::isspace(c) || c == ':'; }), line.end()); + if (line.empty()) continue; + + std::string new_line = line; + // 替换结果寄存器 + if (!n->result_reg.empty()) { + if (alloc.vreg_to_preg.count(n->result_reg)) { + new_line = std::regex_replace(new_line, std::regex("\\b" + n->result_reg + "\\b"), reg_to_string(alloc.vreg_to_preg.at(n->result_reg))); + } else { + // 如果虚拟寄存器未分配物理寄存器,使用 t0 并通过栈访问 + if (n->value && alloc.stack_map.count(n->value)) { int offset = alloc.stack_map.at(n->value); new_line = std::regex_replace(new_line, std::regex("\\b" + n->result_reg + "\\b"), "t0"); - std::string store = "sw t0, " + std::to_string(offset) + "(s0)"; - if (!seen_insts.count(store)) { - insts.push_back(store); - seen_insts.insert(store); - } - } - } - for (auto operand : n->operands) { - if (!operand->result_reg.empty()) { - if (alloc.vreg_to_preg.count(operand->result_reg)) { - new_line = std::regex_replace(new_line, std::regex("\\b" + operand->result_reg + "\\b"), reg_to_string(alloc.vreg_to_preg.at(operand->result_reg))); - } else if (alloc.stack_map.count(operand->value)) { - int offset = alloc.stack_map.at(operand->value); - std::string load = "lw t0, " + std::to_string(offset) + "(s0)"; - if (!seen_insts.count(load)) { - insts.push_back(load); - seen_insts.insert(load); + if (n->kind != DAGNode::STORE) { // 避免为 STORE 节点重复生成存储 + std::string store = "sw t0, " + std::to_string(offset) + "(s0)"; + if (!seen_insts.count(store)) { + insts.push_back(store); + seen_insts.insert(store); } - new_line = std::regex_replace(new_line, std::regex("\\b" + operand->result_reg + "\\b"), "t0"); } + } else { + // 兜底:使用 t0 + new_line = std::regex_replace(new_line, std::regex("\\b" + n->result_reg + "\\b"), "t0"); } } - if (!seen_insts.count(new_line)) { - insts.push_back(new_line); - seen_insts.insert(new_line); + } + // 替换操作数寄存器 + for (auto operand : n->operands) { + if (operand && !operand->result_reg.empty()) { + if (alloc.vreg_to_preg.count(operand->result_reg)) { + new_line = std::regex_replace(new_line, std::regex("\\b" + operand->result_reg + "\\b"), reg_to_string(alloc.vreg_to_preg.at(operand->result_reg))); + } else if (operand->value && alloc.stack_map.count(operand->value)) { + int offset = alloc.stack_map.at(operand->value); + std::string load = "lw t0, " + std::to_string(offset) + "(s0)"; + if (!seen_insts.count(load)) { + insts.push_back(load); + seen_insts.insert(load); + } + new_line = std::regex_replace(new_line, std::regex("\\b" + operand->result_reg + "\\b"), "t0"); + } else { + // 兜底:使用 t0 + new_line = std::regex_replace(new_line, std::regex("\\b" + operand->result_reg + "\\b"), "t0"); + } } } + if (!seen_insts.count(new_line)) { + insts.push_back(new_line); + seen_insts.insert(new_line); + } } } }; @@ -688,9 +791,10 @@ RISCv32CodeGen::RegAllocResult RISCv32CodeGen::register_allocation(Function* fun // 分配溢出栈空间 for (const auto& pair : value_vreg_map) { auto vreg = pair.second; + auto value = pair.first; if (result.vreg_to_preg.find(vreg) == result.vreg_to_preg.end()) { - if (result.stack_map.find(pair.first) == result.stack_map.end()) { - result.stack_map[pair.first] = stack_offset; + if (result.stack_map.find(value) == result.stack_map.end()) { + result.stack_map[value] = stack_offset; stack_offset += 4; } }