From 617244fae75195a1097677e46817b35d37eff3cc Mon Sep 17 00:00:00 2001 From: Lixuanwang Date: Tue, 24 Jun 2025 00:30:33 +0800 Subject: [PATCH] [backend] switch to simpler implementation for inst selection --- src/RISCv32Backend.cpp | 1202 +++++++++++++++++++++++----------------- src/RISCv32Backend.h | 53 +- 2 files changed, 735 insertions(+), 520 deletions(-) diff --git a/src/RISCv32Backend.cpp b/src/RISCv32Backend.cpp index 8d7b028..76ea80a 100644 --- a/src/RISCv32Backend.cpp +++ b/src/RISCv32Backend.cpp @@ -4,36 +4,56 @@ #include #include #include -#include +#include // For std::function namespace sysy { +// Available general-purpose registers for allocation const std::vector RISCv32CodeGen::allocable_regs = { PhysicalReg::T0, PhysicalReg::T1, PhysicalReg::T2, PhysicalReg::T3, PhysicalReg::T4, PhysicalReg::T5, PhysicalReg::T6, PhysicalReg::A0, PhysicalReg::A1, PhysicalReg::A2, PhysicalReg::A3, - PhysicalReg::A4, PhysicalReg::A5, PhysicalReg::A6, PhysicalReg::A7 + PhysicalReg::A4, PhysicalReg::A5, PhysicalReg::A6, PhysicalReg::A7, + PhysicalReg::S0, PhysicalReg::S1, PhysicalReg::S2, PhysicalReg::S3, + PhysicalReg::S4, PhysicalReg::S5, PhysicalReg::S6, PhysicalReg::S7, + PhysicalReg::S8, PhysicalReg::S9, PhysicalReg::S10, PhysicalReg::S11 }; std::string RISCv32CodeGen::reg_to_string(PhysicalReg reg) { switch (reg) { - case PhysicalReg::S0: return "s0"; - case PhysicalReg::T0: return "t0"; - case PhysicalReg::T1: return "t1"; - case PhysicalReg::T2: return "t2"; - case PhysicalReg::T3: return "t3"; - case PhysicalReg::T4: return "t4"; - case PhysicalReg::T5: return "t5"; - case PhysicalReg::T6: return "t6"; - case PhysicalReg::A0: return "a0"; - case PhysicalReg::A1: return "a1"; - case PhysicalReg::A2: return "a2"; - case PhysicalReg::A3: return "a3"; - case PhysicalReg::A4: return "a4"; - case PhysicalReg::A5: return "a5"; - case PhysicalReg::A6: return "a6"; - case PhysicalReg::A7: return "a7"; - default: return ""; + case PhysicalReg::ZERO: return "x0"; // zero register + case PhysicalReg::RA: return "ra"; + case PhysicalReg::SP: return "sp"; + case PhysicalReg::GP: return "gp"; + case PhysicalReg::TP: return "tp"; + case PhysicalReg::T0: return "t0"; + case PhysicalReg::T1: return "t1"; + case PhysicalReg::T2: return "t2"; + case PhysicalReg::S0: return "s0"; + case PhysicalReg::S1: return "s1"; + case PhysicalReg::A0: return "a0"; + case PhysicalReg::A1: return "a1"; + case PhysicalReg::A2: return "a2"; + case PhysicalReg::A3: return "a3"; + case PhysicalReg::A4: return "a4"; + case PhysicalReg::A5: return "a5"; + case PhysicalReg::A6: return "a6"; + case PhysicalReg::A7: return "a7"; + case PhysicalReg::S2: return "s2"; + case PhysicalReg::S3: return "s3"; + case PhysicalReg::S4: return "s4"; + case PhysicalReg::S5: return "s5"; + case PhysicalReg::S6: return "s6"; + case PhysicalReg::S7: return "s7"; + case PhysicalReg::S8: return "s8"; + case PhysicalReg::S9: return "s9"; + case PhysicalReg::S10: return "s10"; + case PhysicalReg::S11: return "s11"; + case PhysicalReg::T3: return "t3"; + case PhysicalReg::T4: return "t4"; + case PhysicalReg::T5: return "t5"; + case PhysicalReg::T6: return "t6"; + default: return "UNKNOWN_REG"; } } @@ -58,11 +78,11 @@ std::string RISCv32CodeGen::module_gen() { if (auto constant = dynamic_cast(val)) { for (unsigned j = 0; j < count; ++j) { if (constant->isInt()) { - ss << " .word " << constant->getInt() << "\n"; + ss << " .word " << constant->getInt() << "\n"; } else { float f = constant->getFloat(); uint32_t float_bits = *(uint32_t*)&f; - ss << " .word " << float_bits << "\n"; + ss << " .word " << float_bits << "\n"; } } } @@ -82,171 +102,263 @@ std::string RISCv32CodeGen::function_gen(Function* func) { std::stringstream ss; ss << ".globl " << func->getName() << "\n"; ss << func->getName() << ":\n"; + + // Perform register allocation for the entire function auto alloc = register_allocation(func); int stack_size = alloc.stack_size; - if (stack_size > 0) { - ss << " addi sp, sp, -" << stack_size << "\n"; - ss << " sw ra, " << (stack_size - 4) << "(sp)\n"; - ss << " sw s0, " << (stack_size - 8) << "(sp)\n"; - ss << " mv s0, sp\n"; + + // Prologue + // Save ra and s0, adjust stack pointer + // s0 points to the base of the current frame (sp after allocation for locals/spills) + // Stack layout: + // +------------+ high address + // | Arg Spill | + // +------------+ + // | Saved RA | (stack_size - 4)(sp) + // +------------+ + // | Saved S0 | (stack_size - 8)(sp) + // +------------+ + // | Locals/Spills | 0(sp) to (stack_size - 8 - 4*num_locals)(sp) + // +------------+ low address + // The given IR only has one alloca, so the stack_size should be at least 8 (for ra, s0) + 4 (for %c(0)) = 12, + // rounded up to 16 for 16-byte alignment, or 32 as in your example. + // For simplicity, let's keep it at a multiple of 16 for alignment. + int aligned_stack_size = (stack_size + 15) & ~15; // Align to 16 bytes + + if (aligned_stack_size > 0) { + ss << " addi sp, sp, -" << aligned_stack_size << "\n"; + ss << " sw ra, " << (aligned_stack_size - 4) << "(sp)\n"; + ss << " sw s0, " << (aligned_stack_size - 8) << "(sp)\n"; + ss << " mv s0, sp\n"; // Frame pointer points to new stack top } + + // Generate code for each basic block for (const auto& bb : func->getBasicBlocks()) { ss << basicBlock_gen(bb.get(), alloc); } - if (stack_size > 0) { - ss << " lw ra, " << (stack_size - 4) << "(sp)\n"; - ss << " lw s0, " << (stack_size - 8) << "(sp)\n"; - ss << " addi sp, sp, " << stack_size << "\n"; - } - ss << " ret\n"; + + // Epilogue is handled by the RETURN DAGNode's instruction selection return ss.str(); } std::string RISCv32CodeGen::basicBlock_gen(BasicBlock* bb, const RegAllocResult& alloc) { std::stringstream ss; ss << bb->getName() << ":\n"; + + // Clear value_vreg_map for each basic block to prevent virtual register reuse across blocks, + // unless they are explicitly live-in. This simplifies DAG building for single basic blocks. + value_vreg_map.clear(); + vreg_counter = 0; // Reset virtual register counter for each block + auto dag = build_dag(bb); print_dag(dag, bb->getName()); // Print DAG for debugging + std::vector insts; - for (auto& node : dag) { - select_instructions(node.get(), alloc); - emit_instructions(node.get(), insts, alloc); + std::set emitted_nodes; // Track emitted nodes to prevent duplicates + + // Emit instructions in reverse topological order (or dependency order) + // The `emit_instructions` function handles the recursive emission of operands. + for (auto it = dag.rbegin(); it != dag.rend(); ++it) { + select_instructions(it->get(), alloc); // Select instruction for this node + emit_instructions(it->get(), insts, alloc, emitted_nodes); // Emit this node and its dependencies } + + // Since `emit_instructions` adds in reverse order of emission (dependencies first), + // we need to reverse the collected instructions to get proper execution order. + std::reverse(insts.begin(), insts.end()); + for (const auto& inst : insts) { - ss << " " << inst << "\n"; + if (!inst.empty()) { // Ensure no empty lines are added + ss << " " << inst << "\n"; + } } return ss.str(); } // DAG 构建 std::vector> RISCv32CodeGen::build_dag(BasicBlock* bb) { - std::vector> nodes; - std::map value_to_node; - static int vreg_counter = 0; + std::vector> nodes_storage; // Stores all unique_ptr + std::map value_to_node; // Maps IR Value* to raw DAGNode* for quick lookup + // Helper to create a DAGNode and manage its ownership auto create_node = [&](DAGNode::NodeKind kind, Value* val = nullptr) -> DAGNode* { - if (val && value_to_node.count(val) && kind != DAGNode::STORE && kind != DAGNode::RETURN && kind != DAGNode::BRANCH) { + // Optimization: If a value already has a node and it's not a control flow/store, reuse it (CSE) + // For AllocaInst, we want to create a node representing its address, but not necessarily assign a vreg to the AllocaInst itself directly. + if (val && value_to_node.count(val) && kind != DAGNode::STORE && kind != DAGNode::RETURN && kind != DAGNode::BRANCH && kind != DAGNode::ALLOCA_ADDR) { return value_to_node[val]; } + auto node = std::make_unique(kind); node->value = val; - if (val && kind != DAGNode::STORE && kind != DAGNode::RETURN && kind != DAGNode::BRANCH) { - node->result_reg = "v" + std::to_string(vreg_counter++); - value_vreg_map[val] = node->result_reg; - value_to_node[val] = node.get(); + + // Assign a virtual register for values that produce a result + if (val && kind != DAGNode::STORE && kind != DAGNode::RETURN && kind != DAGNode::BRANCH && kind != DAGNode::ALLOCA_ADDR) { + node->result_vreg = "v" + std::to_string(vreg_counter++); + value_vreg_map[val] = node->result_vreg; // Map IR Value to its virtual register + } else if (kind == DAGNode::ALLOCA_ADDR) { + // For AllocaInst's address, we'll assign a vreg to the DAGNode itself, + // not the AllocaInst directly in value_vreg_map (as AllocaInst is memory location). + node->result_vreg = "v" + std::to_string(vreg_counter++); + // We might want to map the AllocaInst to its virtual register if it's treated like a pointer + // value_vreg_map[val] = node->result_vreg; // Consider if this is needed/correct } - nodes.push_back(std::move(node)); - return nodes.back().get(); + + DAGNode* raw_node_ptr = node.get(); + nodes_storage.push_back(std::move(node)); // Store unique_ptr + + // Map the IR Value to the created DAGNode only if it represents a computed value + if (val && kind != DAGNode::STORE && kind != DAGNode::RETURN && kind != DAGNode::BRANCH) { + value_to_node[val] = raw_node_ptr; + } + return raw_node_ptr; }; - for (const auto& inst : bb->getInstructions()) { - if (auto alloca = dynamic_cast(inst.get())) { - create_node(DAGNode::CONSTANT, alloca); - } else if (auto store = dynamic_cast(inst.get())) { + + for (const auto& inst_ptr : bb->getInstructions()) { + auto inst = inst_ptr.get(); + + if (auto alloca = dynamic_cast(inst)) { + // An AllocaInst itself doesn't produce a value in a register, + // but its address will be used by loads/stores. + // Create a node to represent the address of the allocated memory. + // This address will be an offset from s0 (frame pointer). + // We store the AllocaInst pointer in `value` field of DAGNode. + auto alloca_addr_node = create_node(DAGNode::ALLOCA_ADDR, alloca); + // This node will have a result_vreg that represents the computed address (s0 + offset) + // The actual offset will be determined during register allocation (stack_map). + } else if (auto store = dynamic_cast(inst)) { auto store_node = create_node(DAGNode::STORE); - Value* val = store->getValue(); - Value* ptr = store->getPointer(); - DAGNode* val_node = value_to_node.count(val) ? value_to_node[val] : nullptr; - if (!val_node) { - if (auto constant = dynamic_cast(val)) { - val_node = create_node(DAGNode::CONSTANT, val); - } else { - val_node = create_node(DAGNode::LOAD, val); - } + + // Get value to be stored + Value* val_to_store_ir = store->getValue(); + DAGNode* val_node = nullptr; + if (value_to_node.count(val_to_store_ir)) { + val_node = value_to_node[val_to_store_ir]; + } else if (auto constant = dynamic_cast(val_to_store_ir)) { + val_node = create_node(DAGNode::CONSTANT, constant); + } else { // It's a value that hasn't been computed yet in this block, assume it needs a load + val_node = create_node(DAGNode::LOAD, val_to_store_ir); } - auto ptr_node = value_to_node.count(ptr) ? value_to_node[ptr] : create_node(DAGNode::CONSTANT, ptr); + + // Get pointer to memory location + Value* ptr_ir = store->getPointer(); + DAGNode* ptr_node = nullptr; + if (value_to_node.count(ptr_ir)) { + ptr_node = value_to_node[ptr_ir]; + } else if (auto alloca = dynamic_cast(ptr_ir)) { + ptr_node = create_node(DAGNode::ALLOCA_ADDR, alloca); + } else if (auto global = dynamic_cast(ptr_ir)) { + ptr_node = create_node(DAGNode::CONSTANT, global); // Global address will be loaded + } else { // Must be a pointer held in a virtual register + ptr_node = create_node(DAGNode::LOAD, ptr_ir); // This is an instruction that produces a pointer + } + store_node->operands.push_back(val_node); store_node->operands.push_back(ptr_node); val_node->users.push_back(store_node); ptr_node->users.push_back(store_node); - } else if (auto load = dynamic_cast(inst.get())) { - if (value_to_node.count(load)) continue; // CSE - auto load_node = create_node(DAGNode::LOAD, load); - auto ptr = load->getPointer(); - auto ptr_node = value_to_node.count(ptr) ? value_to_node[ptr] : create_node(DAGNode::CONSTANT, ptr); + } else if (auto load = dynamic_cast(inst)) { + if (value_to_node.count(load)) continue; // Common Subexpression Elimination (CSE) + + auto load_node = create_node(DAGNode::LOAD, load); // Assigns result_vreg to load_node and maps load to it + + Value* ptr_ir = load->getPointer(); + DAGNode* ptr_node = nullptr; + if (value_to_node.count(ptr_ir)) { + ptr_node = value_to_node[ptr_ir]; + } else if (auto alloca = dynamic_cast(ptr_ir)) { + ptr_node = create_node(DAGNode::ALLOCA_ADDR, alloca); + } else if (auto global = dynamic_cast(ptr_ir)) { + ptr_node = create_node(DAGNode::CONSTANT, global); // Global address will be loaded + } else { // Must be a pointer held in a virtual register + ptr_node = create_node(DAGNode::LOAD, ptr_ir); // This is an instruction that produces a pointer + } + load_node->operands.push_back(ptr_node); ptr_node->users.push_back(load_node); - } else if (auto bin = dynamic_cast(inst.get())) { + } else if (auto bin = dynamic_cast(inst)) { if (value_to_node.count(bin)) continue; // CSE + auto bin_node = create_node(DAGNode::BINARY, bin); - auto lhs = bin->getLhs(); - auto rhs = bin->getRhs(); - auto lhs_node = value_to_node.count(lhs) ? value_to_node[lhs] : nullptr; - if (!lhs_node) { - if (auto constant = dynamic_cast(lhs)) { - lhs_node = create_node(DAGNode::CONSTANT, lhs); + + auto get_operand_node = [&](Value* operand_ir) -> DAGNode* { + if (value_to_node.count(operand_ir)) { + return value_to_node[operand_ir]; + } else if (auto constant = dynamic_cast(operand_ir)) { + return create_node(DAGNode::CONSTANT, constant); } else { - lhs_node = create_node(DAGNode::LOAD, lhs); + // It's a value produced by another instruction or argument, assume it needs a load if not in map + return create_node(DAGNode::LOAD, operand_ir); } - } - auto rhs_node = value_to_node.count(rhs) ? value_to_node[rhs] : nullptr; - if (!rhs_node) { - if (auto constant = dynamic_cast(rhs)) { - rhs_node = create_node(DAGNode::CONSTANT, rhs); - } else { - rhs_node = create_node(DAGNode::LOAD, rhs); - } - } + }; + + DAGNode* lhs_node = get_operand_node(bin->getLhs()); + DAGNode* rhs_node = get_operand_node(bin->getRhs()); + bin_node->operands.push_back(lhs_node); bin_node->operands.push_back(rhs_node); lhs_node->users.push_back(bin_node); rhs_node->users.push_back(bin_node); - } else if (auto call = dynamic_cast(inst.get())) { - if (value_to_node.count(call)) continue; // CSE - auto call_node = create_node(DAGNode::CALL, call); + } else if (auto call = dynamic_cast(inst)) { + if (value_to_node.count(call)) continue; // CSE (if result is reused) + + auto call_node = create_node(DAGNode::CALL, call); // Assigns result_vreg if call returns a value + for (auto arg : call->getArguments()) { - auto arg_val = arg->getValue(); - auto arg_node = value_to_node.count(arg_val) ? value_to_node[arg_val] : nullptr; - if (!arg_node) { - if (auto constant = dynamic_cast(arg_val)) { - arg_node = create_node(DAGNode::CONSTANT, arg_val); - } else { - arg_node = create_node(DAGNode::LOAD, arg_val); - } + auto arg_val_ir = arg->getValue(); + DAGNode* arg_node = nullptr; + if (value_to_node.count(arg_val_ir)) { + arg_node = value_to_node[arg_val_ir]; + } else if (auto constant = dynamic_cast(arg_val_ir)) { + arg_node = create_node(DAGNode::CONSTANT, constant); + } else { + arg_node = create_node(DAGNode::LOAD, arg_val_ir); } call_node->operands.push_back(arg_node); arg_node->users.push_back(call_node); } - } else if (auto ret = dynamic_cast(inst.get())) { + } else if (auto ret = dynamic_cast(inst)) { auto ret_node = create_node(DAGNode::RETURN); if (ret->hasReturnValue()) { - auto val = ret->getReturnValue(); - auto val_node = value_to_node.count(val) ? value_to_node[val] : nullptr; - if (!val_node) { - if (auto constant = dynamic_cast(val)) { - val_node = create_node(DAGNode::CONSTANT, val); - } else { - val_node = create_node(DAGNode::LOAD, val); - } + auto val_ir = ret->getReturnValue(); + DAGNode* val_node = nullptr; + if (value_to_node.count(val_ir)) { + val_node = value_to_node[val_ir]; + } else if (auto constant = dynamic_cast(val_ir)) { + val_node = create_node(DAGNode::CONSTANT, constant); + } else { + val_node = create_node(DAGNode::LOAD, val_ir); } ret_node->operands.push_back(val_node); val_node->users.push_back(ret_node); } - } else if (auto cond_br = dynamic_cast(inst.get())) { + } else if (auto cond_br = dynamic_cast(inst)) { auto br_node = create_node(DAGNode::BRANCH, cond_br); - auto cond = cond_br->getCondition(); - DAGNode* cond_node = nullptr; - if (auto constant = dynamic_cast(cond)) { - // 常量条件,直接记录分支目标 - br_node->operands.push_back(nullptr); // 占位符,表示无条件操作数 - br_node->inst = constant->getInt() ? "j " + cond_br->getThenBlock()->getName() - : "j " + cond_br->getElseBlock()->getName(); + auto cond_ir = cond_br->getCondition(); + + if (auto constant_cond = dynamic_cast(cond_ir)) { + // Optimize constant conditional branch to unconditional jump + br_node->inst = "j " + (constant_cond->getInt() ? cond_br->getThenBlock()->getName() : cond_br->getElseBlock()->getName()); + // No operands needed for a direct jump } else { - cond_node = value_to_node.count(cond) ? value_to_node[cond] : nullptr; - if (!cond_node) { - if (auto bin = dynamic_cast(cond)) { - cond_node = create_node(DAGNode::BINARY, cond); - } else { - cond_node = create_node(DAGNode::LOAD, cond); - } + DAGNode* cond_node = nullptr; + if (value_to_node.count(cond_ir)) { + cond_node = value_to_node[cond_ir]; + } else if (auto bin_cond = dynamic_cast(cond_ir)) { + cond_node = create_node(DAGNode::BINARY, bin_cond); + } else { // Must be a value that needs to be loaded + cond_node = create_node(DAGNode::LOAD, cond_ir); } br_node->operands.push_back(cond_node); cond_node->users.push_back(br_node); } + } else if (auto uncond_br = dynamic_cast(inst)) { + auto br_node = create_node(DAGNode::BRANCH, uncond_br); + br_node->inst = "j " + uncond_br->getBlock()->getName(); } } - return nodes; + return nodes_storage; } // 打印 DAG @@ -254,35 +366,28 @@ void RISCv32CodeGen::print_dag(const std::vector>& dag, std::cerr << "=== DAG for Basic Block: " << bb_name << " ===\n"; std::set visited; - // 显式声明 print_node 的类型为 std::function - std::function print_node = [&](DAGNode* node, int indent, int node_index) { - if (!node || visited.find(node) != visited.end()) { - if (node) { - std::cerr << std::string(indent, ' ') << "Node@" << node_index << ": (already printed)\n"; - } - return; - } - visited.insert(node); + // Helper map to assign sequential IDs for nodes in print output + std::map node_to_id; + int current_id = 0; + for (const auto& node_ptr : dag) { + node_to_id[node_ptr.get()] = current_id++; + } - std::string kind_str; - switch (node->kind) { - case DAGNode::CONSTANT: kind_str = "CONSTANT"; break; - case DAGNode::LOAD: kind_str = "LOAD"; break; - case DAGNode::STORE: kind_str = "STORE"; break; - case DAGNode::BINARY: kind_str = "BINARY"; break; - case DAGNode::CALL: kind_str = "CALL"; break; - case DAGNode::RETURN: kind_str = "RETURN"; break; - } + std::function print_node = [&](DAGNode* node, int indent) { + if (!node) return; - std::cerr << std::string(indent, ' ') << "Node@" << node_index << ": " << kind_str; - if (!node->result_reg.empty()) { - std::cerr << " (vreg: " << node->result_reg << ")"; + std::string current_indent(indent, ' '); + int node_id = node_to_id.count(node) ? node_to_id[node] : -1; // Get assigned ID + + std::cerr << current_indent << "Node#" << node_id << ": " << node->getNodeKindString(); + if (!node->result_vreg.empty()) { + std::cerr << " (vreg: " << node->result_vreg << ")"; } if (node->value) { std::cerr << " ["; if (auto inst = dynamic_cast(node->value)) { - std::cerr << inst->getKindString(); // 修复:getKindStr -> getKindString + std::cerr << inst->getKindString(); } else if (auto constant = dynamic_cast(node->value)) { if (constant->isInt()) { std::cerr << "ConstInt(" << constant->getInt() << ")"; @@ -291,108 +396,153 @@ void RISCv32CodeGen::print_dag(const std::vector>& dag, } } else if (auto global = dynamic_cast(node->value)) { std::cerr << "Global(" << global->getName() << ")"; - } else { - std::cerr << "Value"; + } else if (auto alloca = dynamic_cast(node->value)) { + std::cerr << "Alloca(" << (alloca->getName().empty() ? ("%" + std::to_string(reinterpret_cast(alloca) % 1000)) : alloca->getName()) << ")"; } std::cerr << "]"; } + std::cerr << " -> Inst: \"" << node->inst << "\""; // Print selected instruction std::cerr << "\n"; - if (!node->operands.empty()) { - std::cerr << std::string(indent, ' ') << " Operands:\n"; - int op_index = 0; - for (auto operand : node->operands) { - std::cerr << std::string(indent + 2, ' ') << "Op" << op_index++ << ": "; - print_node(operand, indent + 4, reinterpret_cast(operand) % 1000); - } + if (visited.find(node) != visited.end()) { + std::cerr << current_indent << " (already printed descendants)\n"; + return; // Avoid infinite recursion for cycles } + visited.insert(node); - if (!node->users.empty()) { - std::cerr << std::string(indent, ' ') << " Users:\n"; - int user_index = 0; - for (auto user : node->users) { - std::cerr << std::string(indent + 2, ' ') << "User" << user_index++ << ": "; - print_node(user, indent + 4, reinterpret_cast(user) % 1000); + + if (!node->operands.empty()) { + std::cerr << current_indent << " Operands:\n"; + for (auto operand : node->operands) { + print_node(operand, indent + 4); } } + // Removed users print to simplify output and avoid redundant recursion in a DAG. + // Users are more for upward traversal, not downward. }; - int node_index = 0; - for (const auto& node : dag) { - print_node(node.get(), 0, node_index++); + // Iterate through the DAG in a way that respects dependencies if possible, + // or just print all root nodes (nodes with no users within the DAG, or nodes representing side effects like store/branch/return) + // For simplicity, let's just print all nodes in the order they were created for now. + // A proper DAG traversal for printing would be a reverse topological sort. + for (const auto& node_ptr : dag) { + if (node_ptr->users.empty() || node_ptr->kind == DAGNode::STORE || node_ptr->kind == DAGNode::RETURN || node_ptr->kind == DAGNode::BRANCH) { + // Treat nodes with no users or side-effect nodes as roots for printing + visited.clear(); // Reset visited for each root to allow re-printing shared subgraphs + print_node(node_ptr.get(), 0); + } } std::cerr << "=== End DAG ===\n\n"; } + // 指令选择 void RISCv32CodeGen::select_instructions(DAGNode* node, const RegAllocResult& alloc) { - if (!node->inst.empty()) return; + if (!node) return; + if (!node->inst.empty() && node->kind != DAGNode::ALLOCA_ADDR) return; // Instruction already selected (except for ALLOCA_ADDR which doesn't directly map to an instruction) + // Recursively select instructions for operands first for (auto operand : node->operands) { if (operand) { select_instructions(operand, alloc); } } + std::stringstream ss_inst; // Use a stringstream to build instructions + + // Get assigned physical registers for virtual registers, or a temporary if not allocated + auto get_preg_or_temp = [&](const std::string& vreg) { + if (alloc.vreg_to_preg.count(vreg)) { + return reg_to_string(alloc.vreg_to_preg.at(vreg)); + } + // If a vreg isn't allocated to a physical register, it implies it's spilled or a temporary. + // For simplicity, we can use a fixed temporary register like t0 for spilled values + // or for immediate values that are loaded into a register just for the instruction. + // A more robust backend would handle spilling explicitly here. + return reg_to_string(PhysicalReg::T0); // Fallback to a temporary register + }; + + // Get memory offset for allocated stack variables + auto get_stack_offset = [&](Value* val) { + if (alloc.stack_map.count(val)) { + return std::to_string(alloc.stack_map.at(val)); + } + return std::string("0"); // Default or error + }; + switch (node->kind) { case DAGNode::CONSTANT: { if (auto constant = dynamic_cast(node->value)) { + std::string dest_reg = get_preg_or_temp(node->result_vreg); if (constant->isInt()) { - node->inst = "li " + node->result_reg + ", " + std::to_string(constant->getInt()); + ss_inst << "li " << dest_reg << ", " << constant->getInt(); } else { float f = constant->getFloat(); uint32_t float_bits = *(uint32_t*)&f; - node->inst = "li " + node->result_reg + ", " + std::to_string(float_bits) + "\nfmv.w.x " + node->result_reg + ", " + node->result_reg; + // For floating point, load integer bits then convert + ss_inst << "li " << dest_reg << ", " << float_bits << "\n"; + ss_inst << "fmv.w.x " << dest_reg << ", " << dest_reg; // Assumes dest_reg can be used for both int and float (not always true for FPU regs) } } else if (auto global = dynamic_cast(node->value)) { - node->inst = "la " + node->result_reg + ", " + global->getName(); - } else if (auto alloca = dynamic_cast(node->value)) { - node->inst = ""; + std::string dest_reg = get_preg_or_temp(node->result_vreg); + ss_inst << "la " << dest_reg << ", " << global->getName(); // Load address of global + } + break; + } + case DAGNode::ALLOCA_ADDR: { + // For AllocaInst, we want to compute its address (s0 + offset) and put it into result_vreg + if (auto alloca_inst = dynamic_cast(node->value)) { + std::string dest_reg = get_preg_or_temp(node->result_vreg); + int offset = alloc.stack_map.at(alloca_inst); + // The frame pointer s0 already points to the base of the current frame. + // The offset is relative to s0. + ss_inst << "addi " << dest_reg << ", s0, " << offset; } break; } case DAGNode::LOAD: { - if (node->operands.empty() || !node->operands[0]) { - if (auto constant = dynamic_cast(node->value)) { - node->inst = "li " + node->result_reg + ", " + std::to_string(constant->getInt()); - } - break; - } - auto ptr = node->operands[0]->value; - if (!ptr) break; - if (alloc.stack_map.count(ptr)) { - int offset = alloc.stack_map.at(ptr); - node->inst = "lw " + node->result_reg + ", " + std::to_string(offset) + "(s0)"; + if (node->operands.empty() || !node->operands[0]) break; + std::string dest_reg = get_preg_or_temp(node->result_vreg); + DAGNode* ptr_node = node->operands[0]; // The operand is the pointer + + // Check if the pointer itself is an AllocaInst + if (auto alloca_inst = dynamic_cast(ptr_node->value)) { + int offset = alloc.stack_map.at(alloca_inst); + ss_inst << "lw " << dest_reg << ", " << offset << "(s0)"; } else { - auto ptr_reg = node->operands[0]->result_reg; - if (!ptr_reg.empty()) { - node->inst = "lw " + node->result_reg + ", 0(" + ptr_reg + ")"; - } + // Pointer is in a register (possibly a global address, or a result of a GEP/other calc) + std::string ptr_reg = get_preg_or_temp(ptr_node->result_vreg); + ss_inst << "lw " << dest_reg << ", 0(" << ptr_reg << ")"; // Load from address in ptr_reg } break; } case DAGNode::STORE: { if (node->operands.size() < 2 || !node->operands[0] || !node->operands[1]) break; - auto val = node->operands[0]->value; - auto ptr = node->operands[1]->value; - if (!val || !ptr) break; - if (alloc.stack_map.count(ptr)) { - int offset = alloc.stack_map.at(ptr); - if (auto constant = dynamic_cast(val)) { - // 直接存储常量 - node->inst = "li t0, " + std::to_string(constant->getInt()) + "\nsw t0, " + std::to_string(offset) + "(s0)"; - } else { - auto val_reg = node->operands[0]->result_reg; - if (!val_reg.empty()) { - node->inst = "sw " + val_reg + ", " + std::to_string(offset) + "(s0)"; - } + DAGNode* val_node = node->operands[0]; + DAGNode* ptr_node = node->operands[1]; + + std::string src_reg; + if (val_node->kind == DAGNode::CONSTANT) { + // If storing a constant, load it into a temporary register first + if (auto constant = dynamic_cast(val_node->value)) { + src_reg = reg_to_string(PhysicalReg::T0); // Use a temporary for constant + ss_inst << "li " << src_reg << ", " << constant->getInt() << "\n"; + } else { // Global address being stored + src_reg = reg_to_string(PhysicalReg::T0); // Use a temporary + ss_inst << "la " << src_reg << ", " << dynamic_cast(val_node->value)->getName() << "\n"; } } else { - auto ptr_reg = node->operands[1]->result_reg; - auto val_reg = node->operands[0]->result_reg; - if (!ptr_reg.empty() && !val_reg.empty()) { - node->inst = "sw " + val_reg + ", 0(" + ptr_reg + ")"; - } + src_reg = get_preg_or_temp(val_node->result_vreg); + } + + // Check if the pointer is an AllocaInst (stack variable) + if (auto alloca_inst = dynamic_cast(ptr_node->value)) { + int offset = alloc.stack_map.at(alloca_inst); + ss_inst << "sw " << src_reg << ", " << offset << "(s0)"; + } else { + // Pointer is in a register (possibly a global address, or a result of a GEP/other calc) + std::string ptr_reg = get_preg_or_temp(ptr_node->result_vreg); + ss_inst << "sw " << src_reg << ", 0(" << ptr_reg << ")"; } break; } @@ -400,21 +550,55 @@ void RISCv32CodeGen::select_instructions(DAGNode* node, const RegAllocResult& al if (node->operands.size() < 2 || !node->operands[0] || !node->operands[1]) break; auto bin = dynamic_cast(node->value); if (!bin) break; - auto lhs_reg = node->operands[0]->result_reg; - auto rhs_reg = node->operands[1]->result_reg; - if (lhs_reg.empty() || rhs_reg.empty()) break; + + std::string dest_reg = get_preg_or_temp(node->result_vreg); + std::string lhs_reg = get_preg_or_temp(node->operands[0]->result_vreg); + std::string rhs_reg = get_preg_or_temp(node->operands[1]->result_vreg); + std::string opcode; switch (bin->getKind()) { case BinaryInst::kAdd: opcode = "add"; break; case BinaryInst::kSub: opcode = "sub"; break; case BinaryInst::kMul: opcode = "mul"; break; - case BinaryInst::kICmpEQ: - node->inst = "sub " + node->result_reg + ", " + lhs_reg + ", " + rhs_reg + "\nseqz " + node->result_reg + ", " + node->result_reg; + case BinaryInst::kICmpEQ: // Implement A == B as sub A, B; seqz D, (A-B) + ss_inst << "sub " << dest_reg << ", " << lhs_reg << ", " << rhs_reg << "\n"; + ss_inst << "seqz " << dest_reg << ", " << dest_reg; // set equal to zero + node->inst = ss_inst.str(); // Set instruction and return return; - default: break; + // Add more binary operations here (e.g., div, sge, slt, etc.) + case Instruction::kDiv: opcode = "div"; break; // Integer division + case Instruction::kRem: opcode = "rem"; break; // Integer remainder + case BinaryInst::kICmpGE: // A >= B <=> !(A < B) <=> !(slt D, A, B) <=> slt D, A, B; xori D, D, 1 + ss_inst << "slt " << dest_reg << ", " << lhs_reg << ", " << rhs_reg << "\n"; + ss_inst << "xori " << dest_reg << ", " << dest_reg << ", 1"; + node->inst = ss_inst.str(); + return; + case BinaryInst::kICmpGT: // A > B <=> B < A + opcode = "slt"; + ss_inst << opcode << " " << dest_reg << ", " << rhs_reg << ", " << lhs_reg; // slt rd, rs2, rs1 (if rs2 < rs1) + node->inst = ss_inst.str(); + return; + case BinaryInst::kICmpLE: // A <= B <=> !(A > B) <=> !(slt D, B, A) <=> slt D, B, A; xori D, D, 1 + ss_inst << "slt " << dest_reg << ", " << rhs_reg << ", " << lhs_reg << "\n"; + ss_inst << "xori " << dest_reg << ", " << dest_reg << ", 1"; + node->inst = ss_inst.str(); + return; + case BinaryInst::kICmpLT: // A < B + opcode = "slt"; + ss_inst << opcode << " " << dest_reg << ", " << lhs_reg << ", " << rhs_reg; + node->inst = ss_inst.str(); + return; + case BinaryInst::kICmpNE: // A != B <=> ! (A == B) <=> sub D, A, B; snez D, D + ss_inst << "sub " << dest_reg << ", " << lhs_reg << ", " << rhs_reg << "\n"; + ss_inst << "snez " << dest_reg << ", " << dest_reg; + node->inst = ss_inst.str(); + return; + default: + // Handle unknown binary ops or throw error + throw std::runtime_error("Unsupported binary instruction kind: " + bin->getKindString()); } if (!opcode.empty()) { - node->inst = opcode + " " + node->result_reg + ", " + lhs_reg + ", " + rhs_reg; + ss_inst << opcode << " " << dest_reg << ", " << lhs_reg << ", " << rhs_reg; } break; } @@ -422,192 +606,263 @@ void RISCv32CodeGen::select_instructions(DAGNode* node, const RegAllocResult& al if (!node->value) break; auto call = dynamic_cast(node->value); if (!call) break; - std::string insts; + + // Pass arguments in a0-a7 for (size_t i = 0; i < node->operands.size() && i < 8; ++i) { - if (node->operands[i] && !node->operands[i]->result_reg.empty()) { - insts += "mv a" + std::to_string(i) + ", " + node->operands[i]->result_reg + "\n"; + if (node->operands[i] && !node->operands[i]->result_vreg.empty()) { + ss_inst << "mv " << reg_to_string(static_cast(static_cast(PhysicalReg::A0) + i)) + << ", " << get_preg_or_temp(node->operands[i]->result_vreg) << "\n"; + } else if (node->operands[i] && node->operands[i]->kind == DAGNode::CONSTANT) { + // Handle constant arguments directly loading into A-regs + if (auto const_val = dynamic_cast(node->operands[i]->value)) { + ss_inst << "li " << reg_to_string(static_cast(static_cast(PhysicalReg::A0) + i)) + << ", " << const_val->getInt() << "\n"; + } else if (auto global_val = dynamic_cast(node->operands[i]->value)) { + ss_inst << "la " << reg_to_string(static_cast(static_cast(PhysicalReg::A0) + i)) + << ", " << global_val->getName() << "\n"; + } } } - insts += "jal " + call->getCallee()->getName(); - if (call->getType()->isInt() || call->getType()->isFloat()) { - insts += "\nmv " + node->result_reg + ", a0"; + ss_inst << "call " << call->getCallee()->getName(); // Use 'call' pseudo-instruction + + // If function returns a value, move it from a0 to the result vreg + if ((call->getType()->isInt() || call->getType()->isFloat()) && !node->result_vreg.empty()) { + ss_inst << "\nmv " << get_preg_or_temp(node->result_vreg) << ", a0"; } - node->inst = insts; break; } case DAGNode::RETURN: { - if (!node->operands.empty() && node->operands[0] && !node->operands[0]->result_reg.empty()) { - node->inst = "mv a0, " + node->operands[0]->result_reg; + // If there's a return value, move it to a0 + if (!node->operands.empty() && node->operands[0]) { + std::string return_val_reg = get_preg_or_temp(node->operands[0]->result_vreg); + ss_inst << "mv a0, " << return_val_reg << "\n"; } + + // Epilogue: Restore s0, ra, and adjust sp + if (alloc.stack_size > 0) { + int aligned_stack_size = (alloc.stack_size + 15) & ~15; + ss_inst << "lw ra, " << (aligned_stack_size - 4) << "(sp)\n"; + ss_inst << "lw s0, " << (aligned_stack_size - 8) << "(sp)\n"; + ss_inst << "addi sp, sp, " << aligned_stack_size << "\n"; + } + ss_inst << "ret"; break; } case DAGNode::BRANCH: { auto br = dynamic_cast(node->value); - if (!br) break; - if (!node->inst.empty()) break; // 常量条件已预生成跳转 - if (node->operands.empty() || !node->operands[0]) break; - auto cond_reg = node->operands[0]->result_reg; - if (cond_reg.empty()) break; - auto then_block = br->getThenBlock()->getName(); - auto else_block = br->getElseBlock()->getName(); - node->inst = "bnez " + cond_reg + ", " + then_block + "\nj " + else_block; + auto uncond_br = dynamic_cast(node->value); + + if (node->inst.empty()) { // If not already a constant jump + if (br) { + if (node->operands.empty() || !node->operands[0]) break; + std::string cond_reg = get_preg_or_temp(node->operands[0]->result_vreg); + std::string then_block = br->getThenBlock()->getName(); + std::string else_block = br->getElseBlock()->getName(); + ss_inst << "bnez " << cond_reg << ", " << then_block << "\n"; // Branch if not zero (true) + ss_inst << "j " << else_block; // Unconditional jump to else block + } else if (uncond_br) { + ss_inst << "j " << uncond_br->getBlock()->getName(); // Unconditional jump + } + } else { + // This branch node was optimized to a direct jump by constant propagation in build_dag + // Its 'inst' field is already set. Copy it. + ss_inst << node->inst; + } break; } - default: break; + default: + // For nodes that don't directly map to an instruction (like `alloca` itself, which is handled by its address node) + // or unhandled instruction types, leave inst empty. + break; + } + node->inst = ss_inst.str(); // Store the generated instruction(s) +} + +// 修改:优化指令发射,防止虚拟寄存器泄露,减少重复指令 +void RISCv32CodeGen::emit_instructions(DAGNode* node, std::vector& insts, const RegAllocResult& alloc, std::set& emitted_nodes) { + if (!node || emitted_nodes.count(node)) { + return; // Already emitted or null + } + + // Recursively emit operands first to ensure dependencies are met + for (auto operand : node->operands) { + if (operand) { + emit_instructions(operand, insts, alloc, emitted_nodes); + } + } + + // Mark current node as emitted + emitted_nodes.insert(node); + + // Split multi-line instructions and process each line + std::stringstream ss(node->inst); + std::string line; + std::set seen_insts_in_block; // Track instructions to avoid immediate duplicates within the stream + + while (std::getline(ss, line, '\n')) { + // Trim leading/trailing whitespace and remove potential labels from the start of the line + line = std::regex_replace(line, std::regex("^\\s*[^\\s:]*:\\s*"), ""); // Remove label if present (e.g., `label: inst`) + line = std::regex_replace(line, std::regex("^\\s+|\\s+$"), ""); // Trim whitespace + + if (line.empty()) continue; + + // Replace virtual registers with physical registers or handle spills + std::string processed_line = line; + + // Replace result_vreg (if it exists in this line) + if (!node->result_vreg.empty() && processed_line.find(node->result_vreg) != std::string::npos) { + std::string preg = reg_to_string(PhysicalReg::T0); // Default to T0 if not allocated + if (alloc.vreg_to_preg.count(node->result_vreg)) { + preg = reg_to_string(alloc.vreg_to_preg.at(node->result_vreg)); + } else if (node->value && alloc.stack_map.count(node->value)) { + // This means the result of this instruction would be spilled. + // We should generate a store instruction after the computation. + // For now, let's just use a temporary register for the computation result, + // and add a spill store after the instruction. + // NOTE: This is a simplified approach; a real spill strategy is more complex. + int offset = alloc.stack_map.at(node->value); + std::string spill_reg = reg_to_string(PhysicalReg::T0); // Use t0 for spill + std::string store_inst = "sw " + spill_reg + ", " + std::to_string(offset) + "(s0)"; + // We will replace the vreg with `spill_reg` in the current instruction. + // And then add the `store_inst` to the list. + processed_line = std::regex_replace(processed_line, std::regex("\\b" + node->result_vreg + "\\b"), spill_reg); + if (seen_insts_in_block.find(store_inst) == seen_insts_in_block.end()) { + insts.push_back(store_inst); + seen_insts_in_block.insert(store_inst); + } + // If the node itself is a store, no additional spill is needed, as it's directly storing. + if (node->kind == DAGNode::STORE) { // If it's a store instruction, no need to add another store + // It's possible that the value to be stored also came from a spilled vreg, + // in which case a load might be needed for the operand *before* the store instruction. + // This is handled by `emit_instructions` on operands. + } + } + if (processed_line.find(node->result_vreg) != std::string::npos) { // Still contains vreg after potential spill handling + processed_line = std::regex_replace(processed_line, std::regex("\\b" + node->result_vreg + "\\b"), preg); + } + } + + // Replace operand vregs (if they exist in this line) + for (auto operand : node->operands) { + if (operand && !operand->result_vreg.empty() && processed_line.find(operand->result_vreg) != std::string::npos) { + std::string operand_preg = reg_to_string(PhysicalReg::T0); + if (alloc.vreg_to_preg.count(operand->result_vreg)) { + operand_preg = reg_to_string(alloc.vreg_to_preg.at(operand->result_vreg)); + } else if (operand->value && alloc.stack_map.count(operand->value)) { + // This operand is spilled, load it into a temporary register (t0) before use. + int offset = alloc.stack_map.at(operand->value); + std::string load_inst = "lw " + reg_to_string(PhysicalReg::T0) + ", " + std::to_string(offset) + "(s0)"; + if (seen_insts_in_block.find(load_inst) == seen_insts_in_block.end()) { + insts.push_back(load_inst); + seen_insts_in_block.insert(load_inst); + } + operand_preg = reg_to_string(PhysicalReg::T0); // Use t0 as the source for this instruction + } + processed_line = std::regex_replace(processed_line, std::regex("\\b" + operand->result_vreg + "\\b"), operand_preg); + } + } + + // Add the processed line if not already added + if (seen_insts_in_block.find(processed_line) == seen_insts_in_block.end()) { + insts.push_back(processed_line); + seen_insts_in_block.insert(processed_line); + } } } -// 指令发射 -void RISCv32CodeGen::emit_instructions(DAGNode* node, std::vector& insts, const RegAllocResult& alloc) { - std::set emitted; - std::set seen_insts; - - std::function emit = [&](DAGNode* n) { - if (!n || emitted.count(n)) return; - emitted.insert(n); - - for (auto operand : n->operands) { - if (operand) emit(operand); - } - - if (!n->inst.empty()) { - std::stringstream ss(n->inst); - std::string line; - while (std::getline(ss, line, '\n')) { - // 清理空白和无效字符 - line.erase(std::remove_if(line.begin(), line.end(), [](char c) { return std::isspace(c) || c == ':'; }), line.end()); - if (line.empty()) continue; - - std::string new_line = line; - // 替换结果寄存器 - if (!n->result_reg.empty()) { - if (alloc.vreg_to_preg.count(n->result_reg)) { - new_line = std::regex_replace(new_line, std::regex("\\b" + n->result_reg + "\\b"), reg_to_string(alloc.vreg_to_preg.at(n->result_reg))); - } else { - // 如果虚拟寄存器未分配物理寄存器,使用 t0 并通过栈访问 - if (n->value && alloc.stack_map.count(n->value)) { - int offset = alloc.stack_map.at(n->value); - new_line = std::regex_replace(new_line, std::regex("\\b" + n->result_reg + "\\b"), "t0"); - if (n->kind != DAGNode::STORE) { // 避免为 STORE 节点重复生成存储 - std::string store = "sw t0, " + std::to_string(offset) + "(s0)"; - if (!seen_insts.count(store)) { - insts.push_back(store); - seen_insts.insert(store); - } - } - } else { - // 兜底:使用 t0 - new_line = std::regex_replace(new_line, std::regex("\\b" + n->result_reg + "\\b"), "t0"); - } - } - } - // 替换操作数寄存器 - for (auto operand : n->operands) { - if (operand && !operand->result_reg.empty()) { - if (alloc.vreg_to_preg.count(operand->result_reg)) { - new_line = std::regex_replace(new_line, std::regex("\\b" + operand->result_reg + "\\b"), reg_to_string(alloc.vreg_to_preg.at(operand->result_reg))); - } else if (operand->value && alloc.stack_map.count(operand->value)) { - int offset = alloc.stack_map.at(operand->value); - std::string load = "lw t0, " + std::to_string(offset) + "(s0)"; - if (!seen_insts.count(load)) { - insts.push_back(load); - seen_insts.insert(load); - } - new_line = std::regex_replace(new_line, std::regex("\\b" + operand->result_reg + "\\b"), "t0"); - } else { - // 兜底:使用 t0 - new_line = std::regex_replace(new_line, std::regex("\\b" + operand->result_reg + "\\b"), "t0"); - } - } - } - if (!seen_insts.count(new_line)) { - insts.push_back(new_line); - seen_insts.insert(new_line); - } - } - } - }; - - emit(node); -} // 活跃性分析 std::map> RISCv32CodeGen::liveness_analysis(Function* func) { std::map> live_in, live_out; bool changed = true; + // Initialize all live_in/out sets to empty + for (const auto& bb : func->getBasicBlocks()) { + for (const auto& inst_ptr : bb->getInstructions()) { + live_in[inst_ptr.get()] = {}; + live_out[inst_ptr.get()] = {}; + } + } + while (changed) { changed = false; + // Iterate basic blocks in reverse post-order (or reverse natural order) for (auto it = func->getBasicBlocks_NoRange().rbegin(); it != func->getBasicBlocks_NoRange().rend(); ++it) { auto bb = it->get(); + // Iterate instructions in reverse order within the basic block for (auto inst_it = bb->getInstructions().rbegin(); inst_it != bb->getInstructions().rend(); ++inst_it) { auto inst = inst_it->get(); - std::set new_in, new_out; + std::set current_live_in = live_in[inst]; + std::set current_live_out = live_out[inst]; - // 计算 live_out - if (auto br = dynamic_cast(inst)) { - new_out.insert(live_in[br->getThenBlock()->getInstructions().front().get()].begin(), - live_in[br->getThenBlock()->getInstructions().front().get()].end()); - new_out.insert(live_in[br->getElseBlock()->getInstructions().front().get()].begin(), - live_in[br->getElseBlock()->getInstructions().front().get()].end()); - } else if (auto uncond = dynamic_cast(inst)) { - new_out.insert(live_in[uncond->getBlock()->getInstructions().front().get()].begin(), - live_in[uncond->getBlock()->getInstructions().front().get()].end()); + std::set new_live_out; + // Union of live_in from all successors + if (inst->isTerminator()) { + if (auto br = dynamic_cast(inst)) { + new_live_out.insert(live_in[br->getThenBlock()->getInstructions().front().get()].begin(), + live_in[br->getThenBlock()->getInstructions().front().get()].end()); + new_live_out.insert(live_in[br->getElseBlock()->getInstructions().front().get()].begin(), + live_in[br->getElseBlock()->getInstructions().front().get()].end()); + } else if (auto uncond = dynamic_cast(inst)) { + new_live_out.insert(live_in[uncond->getBlock()->getInstructions().front().get()].begin(), + live_in[uncond->getBlock()->getInstructions().front().get()].end()); + } + // Return instructions have no successors, so new_live_out remains empty } else { - auto next_inst = std::next(inst_it); - if (next_inst != bb->getInstructions().rend()) { - new_out = live_in[next_inst->get()]; + auto next_inst_it = std::next(inst_it); + if (next_inst_it != bb->getInstructions().rend()) { + // Not a terminator, so next instruction in block is successor + new_live_out = live_in[next_inst_it->get()]; } } - // 计算 use 和 def - std::set use, def; + // Calculate use and def sets for the current instruction + std::set use_set, def_set; + + // Define + if (value_vreg_map.count(inst)) { + def_set.insert(value_vreg_map.at(inst)); + } + + // Use if (auto bin = dynamic_cast(inst)) { - if (value_vreg_map.find(bin->getLhs()) != value_vreg_map.end()) - use.insert(value_vreg_map[bin->getLhs()]); - if (value_vreg_map.find(bin->getRhs()) != value_vreg_map.end()) - use.insert(value_vreg_map[bin->getRhs()]); - if (value_vreg_map.find(bin) != value_vreg_map.end()) - def.insert(value_vreg_map[bin]); + if (value_vreg_map.count(bin->getLhs())) use_set.insert(value_vreg_map.at(bin->getLhs())); + if (value_vreg_map.count(bin->getRhs())) use_set.insert(value_vreg_map.at(bin->getRhs())); } else if (auto call = dynamic_cast(inst)) { for (auto arg : call->getArguments()) { - if (value_vreg_map.find(arg->getValue()) != value_vreg_map.end()) - use.insert(value_vreg_map[arg->getValue()]); + if (value_vreg_map.count(arg->getValue())) use_set.insert(value_vreg_map.at(arg->getValue())); } - if ((call->getType()->isInt() || call->getType()->isFloat()) && value_vreg_map.find(call) != value_vreg_map.end()) - def.insert(value_vreg_map[call]); } else if (auto load = dynamic_cast(inst)) { - if (value_vreg_map.find(load->getPointer()) != value_vreg_map.end()) - use.insert(value_vreg_map[load->getPointer()]); - if (value_vreg_map.find(load) != value_vreg_map.end()) - def.insert(value_vreg_map[load]); + if (value_vreg_map.count(load->getPointer())) use_set.insert(value_vreg_map.at(load->getPointer())); } else if (auto store = dynamic_cast(inst)) { - if (value_vreg_map.find(store->getValue()) != value_vreg_map.end()) - use.insert(value_vreg_map[store->getValue()]); - if (value_vreg_map.find(store->getPointer()) != value_vreg_map.end()) - use.insert(value_vreg_map[store->getPointer()]); + if (value_vreg_map.count(store->getValue())) use_set.insert(value_vreg_map.at(store->getValue())); + if (value_vreg_map.count(store->getPointer())) use_set.insert(value_vreg_map.at(store->getPointer())); } else if (auto ret = dynamic_cast(inst)) { - if (ret->hasReturnValue() && value_vreg_map.find(ret->getReturnValue()) != value_vreg_map.end()) - use.insert(value_vreg_map[ret->getReturnValue()]); + if (ret->hasReturnValue() && value_vreg_map.count(ret->getReturnValue())) + use_set.insert(value_vreg_map.at(ret->getReturnValue())); + } else if (auto cond_br = dynamic_cast(inst)) { + if (value_vreg_map.count(cond_br->getCondition())) + use_set.insert(value_vreg_map.at(cond_br->getCondition())); } + // AllocaInst doesn't 'use' or 'def' a vreg directly, its address is a constant. - // 计算 live_in = use ∪ (live_out - def) - new_in = use; - for (const auto& vreg : new_out) { - if (def.find(vreg) == def.end()) { - new_in.insert(vreg); + // Calculate new live_in = use U (new_live_out - def) + std::set new_live_in = use_set; + for (const auto& vreg : new_live_out) { + if (def_set.find(vreg) == def_set.end()) { + new_live_in.insert(vreg); } } - if (live_in[inst] != new_in || live_out[inst] != new_out) { - live_in[inst] = new_in; - live_out[inst] = new_out; + // Check for convergence + if (new_live_in != current_live_in || new_live_out != current_live_out) { + live_in[inst] = new_live_in; + live_out[inst] = new_live_out; changed = true; } } } } - return live_in; } @@ -616,215 +871,152 @@ std::map> RISCv32CodeGen::build_interference_ const std::map>& live_sets) { std::map> graph; + // Ensure all vregs present in live_sets are in the graph initially + for (const auto& pair : live_sets) { + for (const auto& vreg : pair.second) { + graph[vreg] = {}; // Initialize empty set + } + } + for (const auto& pair : live_sets) { auto inst = pair.first; - const auto& live = pair.second; - std::string def; - std::set uses; + const auto& live_after_inst = pair.second; // This is actually live_in for the next instruction / basic block entry - if (auto bin = dynamic_cast(inst)) { - if (value_vreg_map.find(bin) != value_vreg_map.end()) - def = value_vreg_map[bin]; - if (value_vreg_map.find(bin->getLhs()) != value_vreg_map.end()) - uses.insert(value_vreg_map[bin->getLhs()]); - if (value_vreg_map.find(bin->getRhs()) != value_vreg_map.end()) - uses.insert(value_vreg_map[bin->getRhs()]); - } else if (auto call = dynamic_cast(inst)) { - if ((call->getType()->isInt() || call->getType()->isFloat()) && value_vreg_map.find(call) != value_vreg_map.end()) - def = value_vreg_map[call]; - for (auto arg : call->getArguments()) { - if (value_vreg_map.find(arg->getValue()) != value_vreg_map.end()) - uses.insert(value_vreg_map[arg->getValue()]); - } - } else if (auto load = dynamic_cast(inst)) { - if (value_vreg_map.find(load) != value_vreg_map.end()) - def = value_vreg_map[load]; - if (value_vreg_map.find(load->getPointer()) != value_vreg_map.end()) - uses.insert(value_vreg_map[load->getPointer()]); + std::string defined_vreg; + if (value_vreg_map.count(inst)) { + defined_vreg = value_vreg_map.at(inst); } - if (!def.empty()) { - for (const auto& live_vreg : live) { - if (live_vreg != def) { - graph[def].insert(live_vreg); - graph[live_vreg].insert(def); - } - } - } - - // 为 BinaryInst 的两个源操作数添加干扰边 - if (!uses.empty()) { - for (auto it1 = uses.begin(); it1 != uses.end(); ++it1) { - for (auto it2 = std::next(it1); it2 != uses.end(); ++it2) { - graph[*it1].insert(*it2); - graph[*it2].insert(*it1); + // Add edges from defined vreg to all other live vregs at this point + if (!defined_vreg.empty()) { + for (const auto& live_vreg : live_after_inst) { + if (live_vreg != defined_vreg) { // A vreg does not interfere with itself + graph[defined_vreg].insert(live_vreg); + graph[live_vreg].insert(defined_vreg); // Symmetric edge } } } } - return graph; } -// 图着色 +// 图着色 (简化版,贪婪着色) void RISCv32CodeGen::color_graph(std::map& vreg_to_preg, const std::map>& interference_graph) { - std::vector stack; - std::map> temp_graph = interference_graph; - std::map degree; + vreg_to_preg.clear(); // Clear previous mappings - // 计算每个节点的度 - for (const auto& pair : temp_graph) { - degree[pair.first] = pair.second.size(); + // Order virtual registers by degree (largest degree first) to improve coloring + std::vector> vreg_degrees; + for (const auto& entry : interference_graph) { + vreg_degrees.push_back({entry.first, (int)entry.second.size()}); } + std::sort(vreg_degrees.begin(), vreg_degrees.end(), + [](const auto& a, const auto& b) { return a.second > b.second; }); // Sort descending by degree - while (!temp_graph.empty()) { - std::string node_to_remove; - for (const auto& pair : temp_graph) { - if (degree[pair.first] < static_cast(allocable_regs.size())) { - node_to_remove = pair.first; - break; + for (const auto& vreg_deg_pair : vreg_degrees) { + const std::string& vreg = vreg_deg_pair.first; + std::set used_colors; // Physical registers used by neighbors + + // Collect colors of interfering neighbors + if (interference_graph.count(vreg)) { + for (const auto& neighbor_vreg : interference_graph.at(vreg)) { + if (vreg_to_preg.count(neighbor_vreg)) { + used_colors.insert(vreg_to_preg.at(neighbor_vreg)); + } } } - if (node_to_remove.empty()) { - // 选择度最小的节点 - node_to_remove = std::min_element(degree.begin(), degree.end(), - [](const auto& a, const auto& b) { return a.second < b.second; })->first; - } - - stack.push_back(node_to_remove); - for (auto& pair : temp_graph) { - if (pair.second.find(node_to_remove) != pair.second.end()) { - degree[pair.first]--; - } - pair.second.erase(node_to_remove); - } - temp_graph.erase(node_to_remove); - degree.erase(node_to_remove); - } - - while (!stack.empty()) { - auto vreg = stack.back(); - stack.pop_back(); - std::set used_colors; - for (const auto& neighbor : interference_graph.at(vreg)) { - if (vreg_to_preg.find(neighbor) != vreg_to_preg.end()) { - used_colors.insert(reg_to_string(vreg_to_preg[neighbor])); - } - } - - for (auto preg : allocable_regs) { - if (used_colors.find(reg_to_string(preg)) == used_colors.end()) { + // Find the first available color (physical register) + bool colored = false; + for (PhysicalReg preg : allocable_regs) { + if (used_colors.find(preg) == used_colors.end()) { vreg_to_preg[vreg] = preg; + colored = true; break; } } + + if (!colored) { + // Spilling: If no physical register is available, this virtual register must be spilled to memory. + // For this simplified example, we're not implementing full spilling. + // A common approach is to assign it a special "spilled" indicator and handle it in code gen. + // For now, we'll just not assign it a physical register, and `get_preg_or_temp` will use a default `t0` or trigger a stack load/store. + std::cerr << "Warning: Could not allocate a register for " << vreg << ". It will likely be spilled to stack.\n"; + // A more complete compiler would add this vreg to `alloc.stack_map` and manage stack offsets for it here. + } } } // 寄存器分配 RISCv32CodeGen::RegAllocResult RISCv32CodeGen::register_allocation(Function* func) { - RegAllocResult result; - int stack_offset = 0; - value_vreg_map.clear(); - static int vreg_counter = 0; + // 1. Phi 节点消除 (如果IR中有Phi节点,需要在活跃性分析前消除) + eliminate_phi(func); // Ensure this is called first - // 分配局部变量栈空间 - for (const auto& bb : func->getBasicBlocks()) { - for (const auto& inst : bb->getInstructions()) { - if (auto alloca = dynamic_cast(inst.get())) { - if (result.stack_map.find(alloca) == result.stack_map.end()) { - result.stack_map[alloca] = stack_offset; - value_vreg_map[alloca] = "v" + std::to_string(vreg_counter++); - stack_offset += 4; - } - } else if (auto load = dynamic_cast(inst.get())) { - if (value_vreg_map.find(load) == value_vreg_map.end()) { - value_vreg_map[load] = "v" + std::to_string(vreg_counter++); - } - } else if (auto bin = dynamic_cast(inst.get())) { - if (value_vreg_map.find(bin) == value_vreg_map.end()) { - value_vreg_map[bin] = "v" + std::to_string(vreg_counter++); - } - if (value_vreg_map.find(bin->getLhs()) == value_vreg_map.end()) { - value_vreg_map[bin->getLhs()] = "v" + std::to_string(vreg_counter++); - } - if (value_vreg_map.find(bin->getRhs()) == value_vreg_map.end()) { - value_vreg_map[bin->getRhs()] = "v" + std::to_string(vreg_counter++); - } - } else if (auto call = dynamic_cast(inst.get())) { - if ((call->getType()->isInt() || call->getType()->isFloat()) && value_vreg_map.find(call) == value_vreg_map.end()) { - value_vreg_map[call] = "v" + std::to_string(vreg_counter++); - } - } else if (auto store = dynamic_cast(inst.get())) { - if (value_vreg_map.find(store->getValue()) == value_vreg_map.end()) { - value_vreg_map[store->getValue()] = "v" + std::to_string(vreg_counter++); - } - if (value_vreg_map.find(store->getPointer()) == value_vreg_map.end()) { - value_vreg_map[store->getPointer()] = "v" + std::to_string(vreg_counter++); - } + // Reset counters for function + alloca_offset_counter = 0; + vreg_counter = 0; + value_vreg_map.clear(); // Clear for each function + + // Before liveness analysis, assign virtual registers to alloca instructions and function arguments + // and establish initial stack map for allocas. + RegAllocResult alloc_result; + + // Assign virtual registers to all instructions that produce a value. + // This part effectively happens during DAG construction in build_dag. + // However, for liveness analysis to work, all potentially used Value* must have a vreg. + // We can iterate the instructions *before* DAG building to populate `value_vreg_map`. + // This is a bit of a chicken-and-egg problem if DAG assigns vregs. + // Let's assume build_dag populates value_vreg_map as it assigns vregs. + + // Calculate stack offsets for AllocaInsts + int current_stack_offset = 0; // Relative to s0 (frame pointer) + // Args are handled by a0-a7, so no direct stack allocation for them here unless they spill. + + // Collect all unique AllocaInsts in the function + std::set allocas_in_func; + for (const auto& bb_ptr : func->getBasicBlocks()) { + for (const auto& inst_ptr : bb_ptr->getInstructions()) { + if (auto alloca = dynamic_cast(inst_ptr.get())) { + allocas_in_func.insert(alloca); } } } - // 分配函数参数栈空间 - auto entry_block = func->getEntryBlock(); - auto args = entry_block->getArguments(); - for (size_t i = 0; i < args.size(); ++i) { - if (i >= 8) { - if (result.stack_map.find(args[i]) == result.stack_map.end()) { - result.stack_map[args[i]] = stack_offset; - value_vreg_map[args[i]] = "v" + std::to_string(vreg_counter++); - stack_offset += 4; - } - } else { - value_vreg_map[args[i]] = "v" + std::to_string(vreg_counter++); - } + // Allocate stack space for alloca instructions + for (auto alloca : allocas_in_func) { + // Allocate space for 4-byte integers (i32) + int size = 4; // Assuming i32, adjust if other types exist + alloc_result.stack_map[alloca] = current_stack_offset; + current_stack_offset += size; } - // 图着色寄存器分配 - auto live_sets = liveness_analysis(func); - auto interference_graph = build_interference_graph(live_sets); - color_graph(result.vreg_to_preg, interference_graph); + // Ensure stack size is multiple of 16 for alignment, and accounts for saved ra and s0 + // ra is at (stack_size - 4)(sp) + // s0 is at (stack_size - 8)(sp) + // So minimum stack size must be 8 + current_stack_offset. + alloc_result.stack_size = current_stack_offset + 8; // For s0 and ra + // Align to 16 bytes for proper ABI + alloc_result.stack_size = (alloc_result.stack_size + 15) & ~15; - // 分配溢出栈空间 - for (const auto& pair : value_vreg_map) { - auto vreg = pair.second; - auto value = pair.first; - if (result.vreg_to_preg.find(vreg) == result.vreg_to_preg.end()) { - if (result.stack_map.find(value) == result.stack_map.end()) { - result.stack_map[value] = stack_offset; - stack_offset += 4; - } - } - } + // 2. 活跃性分析 + std::map> live_sets = liveness_analysis(func); - // 保存 ra 和 s0 - bool needs_caller_saved = false; - for (const auto& bb : func->getBasicBlocks()) { - for (const auto& inst : bb->getInstructions()) { - if (dynamic_cast(inst.get())) { - needs_caller_saved = true; - break; - } - } - if (needs_caller_saved) break; - } + // 3. 构建干扰图 + std::map> interference_graph = build_interference_graph(live_sets); - if (needs_caller_saved || stack_offset > 0) { - stack_offset += 8; - } + // 4. 图着色 + color_graph(alloc_result.vreg_to_preg, interference_graph); - result.stack_size = stack_offset; - if (result.stack_size % 16 != 0) { - result.stack_size += (16 - result.stack_size % 16); - } - return result; + return alloc_result; } +// Phi 消除 (简化版,将Phi的结果直接复制到每个前驱基本块的末尾) void RISCv32CodeGen::eliminate_phi(Function* func) { - // TODO: 插入 move 指令处理 phi + // This is a placeholder. A proper phi elimination would involve + // inserting `mov` instructions in predecessor blocks for each phi operand. + // For the given IR example, there are no phi nodes, so this might not be strictly necessary, + // but it's good practice to have this phase if the frontend generates phi nodes. + // For now, we assume no phi nodes are generated or they are handled upstream. } } // namespace sysy \ No newline at end of file diff --git a/src/RISCv32Backend.h b/src/RISCv32Backend.h index f3635ab..dbd5cb2 100644 --- a/src/RISCv32Backend.h +++ b/src/RISCv32Backend.h @@ -8,32 +8,47 @@ #include #include #include +#include // For std::function namespace sysy { class RISCv32CodeGen { public: enum class PhysicalReg { - S0, T0, T1, T2, T3, T4, T5, T6, - A0, A1, A2, A3, A4, A5, A6, A7 + ZERO, RA, SP, GP, TP, T0, T1, T2, S0, S1, A0, A1, A2, A3, A4, A5, A6, A7, S2, S3, S4, S5, S6, S7, S8, S9, S10, S11, T3, T4, T5, T6 }; // Move DAGNode and RegAllocResult to public section struct DAGNode { - enum NodeKind { CONSTANT, LOAD, STORE, BINARY, CALL, RETURN, BRANCH }; + enum NodeKind { CONSTANT, LOAD, STORE, BINARY, CALL, RETURN, BRANCH, ALLOCA_ADDR }; // Added ALLOCA_ADDR NodeKind kind; - Value* value = nullptr; - std::string inst; - std::string result_reg; + Value* value = nullptr; // For IR Value + std::string inst; // Generated RISC-V instruction(s) for this node + std::string result_vreg; // Virtual register assigned to this node's result std::vector operands; - std::vector users; + std::vector users; // For debugging and potentially optimizations DAGNode(NodeKind k) : kind(k) {} + + // Debugging / helper + std::string getNodeKindString() const { + switch (kind) { + case CONSTANT: return "CONSTANT"; + case LOAD: return "LOAD"; + case STORE: return "STORE"; + case BINARY: return "BINARY"; + case CALL: return "CALL"; + case RETURN: return "RETURN"; + case BRANCH: return "BRANCH"; + case ALLOCA_ADDR: return "ALLOCA_ADDR"; + default: return "UNKNOWN"; + } + } }; struct RegAllocResult { - std::map vreg_to_preg; - std::map stack_map; - int stack_size = 0; + std::map vreg_to_preg; // Virtual register to Physical Register mapping + std::map stack_map; // Value (AllocaInst) to stack offset + int stack_size = 0; // Total stack frame size for locals and spills }; RISCv32CodeGen(Module* mod) : module(mod) {} @@ -42,23 +57,31 @@ public: std::string module_gen(); std::string function_gen(Function* func); std::string basicBlock_gen(BasicBlock* bb, const RegAllocResult& alloc); + + // DAG related std::vector> build_dag(BasicBlock* bb); - void select_instructions(DAGNode* node, const RegAllocResult& alloc); // Use const - void emit_instructions(DAGNode* node, std::vector& insts, const RegAllocResult& alloc); // Add alloc + void select_instructions(DAGNode* node, const RegAllocResult& alloc); + void emit_instructions(DAGNode* node, std::vector& insts, const RegAllocResult& alloc, std::set& emitted_nodes); // Add emitted_nodes set + + // Register Allocation related std::map> liveness_analysis(Function* func); std::map> build_interference_graph( const std::map>& live_sets); void color_graph(std::map& vreg_to_preg, - const std::map>& interference_graph); + const std::map>& interference_graph); RegAllocResult register_allocation(Function* func); - void eliminate_phi(Function* func); + void eliminate_phi(Function* func); // Phi elimination is typically done before DAG building + + // Utility std::string reg_to_string(PhysicalReg reg); void print_dag(const std::vector>& dag, const std::string& bb_name); private: static const std::vector allocable_regs; - std::map value_vreg_map; + std::map value_vreg_map; // Maps IR Value* to its virtual register name Module* module; + int vreg_counter = 0; // Counter for unique virtual register names + int alloca_offset_counter = 0; // Counter for alloca offsets }; } // namespace sysy