diff --git a/src/RISCv64AsmPrinter.cpp b/src/RISCv64AsmPrinter.cpp index 3c9d0c6..9cbddb6 100644 --- a/src/RISCv64AsmPrinter.cpp +++ b/src/RISCv64AsmPrinter.cpp @@ -115,7 +115,18 @@ void RISCv64AsmPrinter::printInstruction(MachineInstr* instr, bool debug) { case RVOpcodes::MV: *OS << "mv "; break; case RVOpcodes::NEG: *OS << "neg "; break; case RVOpcodes::NEGW: *OS << "negw "; break; case RVOpcodes::SEQZ: *OS << "seqz "; break; case RVOpcodes::SNEZ: *OS << "snez "; break; - case RVOpcodes::CALL: *OS << "call "; break; + case RVOpcodes::CALL: { // [核心修改] 为CALL指令添加特殊处理逻辑 + *OS << "call "; + // 遍历所有操作数,只寻找并打印函数名标签 + for (const auto& op : instr->getOperands()) { + if (op->getKind() == MachineOperand::KIND_LABEL) { + printOperand(op.get()); + break; // 找到标签后即可退出 + } + } + *OS << "\n"; + return; // 处理完毕,直接返回,不再执行后续的通用操作数打印 + } case RVOpcodes::LABEL: break; case RVOpcodes::FRAME_LOAD_W: diff --git a/src/RISCv64ISel.cpp b/src/RISCv64ISel.cpp index c5a7144..4ffad0c 100644 --- a/src/RISCv64ISel.cpp +++ b/src/RISCv64ISel.cpp @@ -582,7 +582,22 @@ void RISCv64ISel::selectNode(DAGNode* node) { } auto call_instr = std::make_unique(RVOpcodes::CALL); + // [协议] 如果函数有返回值,将它的目标虚拟寄存器作为第一个操作数 + if (!call->getType()->isVoid()) { + unsigned dest_vreg = getVReg(call); + call_instr->addOperand(std::make_unique(dest_vreg)); + } + + // 将函数名标签作为后续操作数 call_instr->addOperand(std::make_unique(call->getCallee()->getName())); + + // 将所有参数的虚拟寄存器也作为后续操作数,供getInstrUseDef分析 + for (size_t i = 0; i < num_operands; ++i) { + if (node->operands[i]->kind != DAGNode::CONSTANT) { // 常量参数已直接加载,无需作为use + call_instr->addOperand(std::make_unique(getVReg(node->operands[i]->value))); + } + } + CurMBB->addInstruction(std::move(call_instr)); if (num_operands > 8) { @@ -596,12 +611,12 @@ void RISCv64ISel::selectNode(DAGNode* node) { CurMBB->addInstruction(std::move(dealloc_instr)); } // 处理返回值,从a0移动到目标虚拟寄存器 - if (!call->getType()->isVoid()) { - auto mv_instr = std::make_unique(RVOpcodes::MV); - mv_instr->addOperand(std::make_unique(getVReg(call))); - mv_instr->addOperand(std::make_unique(PhysicalReg::A0)); - CurMBB->addInstruction(std::move(mv_instr)); - } + // if (!call->getType()->isVoid()) { + // auto mv_instr = std::make_unique(RVOpcodes::MV); + // mv_instr->addOperand(std::make_unique(getVReg(call))); + // mv_instr->addOperand(std::make_unique(PhysicalReg::A0)); + // CurMBB->addInstruction(std::move(mv_instr)); + // } break; } diff --git a/src/RISCv64RegAlloc.cpp b/src/RISCv64RegAlloc.cpp index 279fea9..53ee673 100644 --- a/src/RISCv64RegAlloc.cpp +++ b/src/RISCv64RegAlloc.cpp @@ -42,18 +42,22 @@ void RISCv64RegAlloc::run() { } /** - * @brief 处理调用约定,预先为函数参数分配物理寄存器。 + * @brief 处理调用约定,预先为函数参数和调用返回值分配物理寄存器。 + * 这个函数现在负责处理调用约定的两个方面: + * 1. 作为被调用者(callee),如何接收传入的参数。 + * 2. 作为调用者(caller),如何接收调用的其他函数的返回值。 */ void RISCv64RegAlloc::handleCallingConvention() { Function* F = MFunc->getFunc(); RISCv64ISel* isel = MFunc->getISel(); + // --- 部分1:处理函数传入参数的预着色 --- // 获取函数的Argument对象列表 if (F) { auto& args = F->getArguments(); // RISC-V RV64G调用约定:前8个整型/指针参数通过 a0-a7 传递 int arg_idx = 0; - // 遍历 AllocaInst* 列表 + // 遍历 Argument* 列表 for (Argument* arg : args) { if (arg_idx >= 8) { break; @@ -74,6 +78,30 @@ void RISCv64RegAlloc::handleCallingConvention() { arg_idx++; } } + + // // --- 部分2:[新逻辑] 遍历所有指令,为CALL指令的返回值预着色为 a0 --- + // // 这是为了强制寄存器分配器知道,call的结果物理上出现在a0寄存器。 + for (auto& mbb : MFunc->getBlocks()) { + for (auto& instr : mbb->getInstructions()) { + if (instr->getOpcode() == RVOpcodes::CALL) { + // 根据协议,如果CALL有返回值,其目标vreg是第一个操作数 + if (!instr->getOperands().empty() && + instr->getOperands().front()->getKind() == MachineOperand::KIND_REG) + { + auto reg_op = static_cast(instr->getOperands().front().get()); + if (reg_op->isVirtual()) { + unsigned ret_vreg = reg_op->getVRegNum(); + // 强制将这个虚拟寄存器预着色为 a0 + color_map[ret_vreg] = PhysicalReg::A0; + if (DEBUG) { + std::cout << "[DEBUG] Pre-coloring vreg" << ret_vreg + << " to a0 for CALL instruction." << std::endl; + } + } + } + } + } + } } /** @@ -236,35 +264,58 @@ void RISCv64RegAlloc::getInstrUseDef(MachineInstr* instr, LiveSet& use, LiveSet& // 2. CALL 指令的特殊处理 if (opcode == RVOpcodes::CALL) { - // 1.1 处理返回值 (def) - // 约定:如果CALL指令有返回值,IR阶段会将返回值vreg作为指令的第一个操作数。 - if (!instr->getOperands().empty() && instr->getOperands().front()->getKind() == MachineOperand::KIND_REG) { - auto reg_op = static_cast(instr->getOperands().front().get()); - if (reg_op->isVirtual()) { - def.insert(reg_op->getVRegNum()); - } - } + // // [协议] 我们约定,ISel生成的CALL指令会遵循以下格式: + // // call %vreg_ret, @func_name, %vreg_arg1, %vreg_arg2, ... + // // 其中,第一个操作数(如果存在且是vreg)是返回值(def),函数名是标签,其余是参数(use)。 + // bool has_return_val = false; + // // 1.1 处理返回值 (def) + // if (!instr->getOperands().empty() && instr->getOperands().front()->getKind() == MachineOperand::KIND_REG) { + // auto reg_op = static_cast(instr->getOperands().front().get()); + // if (reg_op->isVirtual()) { + // def.insert(reg_op->getVRegNum()); + // has_return_val = true; + // } + // } - // 1.2 处理参数 (use) - // 参数通常是指令的后续操作数 - bool first_operand_processed = false; // 用于跳过已作为def处理的返回值 - for (const auto& op : instr->getOperands()) { + // // 1.2 处理参数 (use) + // // 遍历所有操作数,跳过返回值(第一个)和函数名标签 + // for (size_t i = 1; i < instr->getOperands().size(); ++i) { + // auto& op = instr->getOperands()[i]; + // if (op->getKind() == MachineOperand::KIND_REG) { + // auto reg_op = static_cast(op.get()); + // if (reg_op->isVirtual()) { + // // 如果第一个操作数是返回值,则跳过它(因为它不是use) + // if (i == 0 && has_return_val) { + // continue; + // } + // use.insert(reg_op->getVRegNum()); + // } + // } + // } + // [新增] 根据我们在ISel中定义的新协议,解析操作数列表 + bool first_reg_operand_is_def = true; + for (auto& op : instr->getOperands()) { if (op->getKind() == MachineOperand::KIND_REG) { - if (!first_operand_processed) { // 如果是第一个操作数 - first_operand_processed = true; - // 如果第一个操作数是返回值(已被加入def),则跳过 - if (def.count(static_cast(op.get())->getVRegNum())) { - continue; - } - } - // 否则,该寄存器是 use auto reg_op = static_cast(op.get()); if (reg_op->isVirtual()) { - use.insert(reg_op->getVRegNum()); + // 协议:第一个寄存器操作数是返回值 (def) + if (first_reg_operand_is_def) { + def.insert(reg_op->getVRegNum()); + first_reg_operand_is_def = false; + } else { + // 后续所有寄存器操作数都是参数 (use) + use.insert(reg_op->getVRegNum()); + } } } } + // [新增] CALL指令隐式地使用了通过物理寄存器(a0-a7)传递的参数 + // 并且隐式地定义了a0(返回值)和所有调用者保存的寄存器。 + // a0-a7的use/def关系已经被显式操作数和预着色处理。 + // 调用者保存寄存器的冲突在 buildInterferenceGraph 中处理。 + // 所以这里只需要解析我们协议中定义的显式操作数即可。 + // **重要**: CALL指令隐式定义(杀死)了所有调用者保存的寄存器。 // **这部分逻辑不在getInstrUseDef中直接处理**。 // 而是通过`buildInterferenceGraph`中添加物理寄存器节点与活跃虚拟寄存器之间的干扰边来完成。 @@ -277,22 +328,38 @@ void RISCv64RegAlloc::getInstrUseDef(MachineInstr* instr, LiveSet& use, LiveSet& for (const auto& op : instr->getOperands()) { if (op->getKind() == MachineOperand::KIND_REG) { auto reg_op = static_cast(op.get()); - if (reg_op->isVirtual()) { // 只有虚拟寄存器才需要处理 Use/Def - // 如果是第一个寄存器操作数,且指令类型表明它是定义 (def),则加入 def 集合 - // 否则,它是 use (读取) + + if (reg_op->isVirtual()) { // 如果是虚拟寄存器 if (first_reg_is_def) { def.insert(reg_op->getVRegNum()); - first_reg_is_def = false; // 确保每条指令只定义一个目标寄存器 + first_reg_is_def = false; } else { use.insert(reg_op->getVRegNum()); } + } else { // 如果是物理寄存器 + if (!first_reg_is_def) { + PhysicalReg preg = reg_op->getPReg(); + // [核心修复] 在访问map前,先检查key是否存在 + // 我们只关心那些参与图着色的物理寄存器节点的活跃性 + if (preg_to_vreg_id_map.count(preg)) { + // 将物理寄存器对应的特殊ID加入Use集合 + use.insert(preg_to_vreg_id_map.at(preg)); + } + } } } else if (op->getKind() == MachineOperand::KIND_MEM) { - // 内存操作数 `offset(base)` 中的 `base` 寄存器是 `use` auto mem_op = static_cast(op.get()); - if (mem_op->getBase()->isVirtual()) { - use.insert(mem_op->getBase()->getVRegNum()); + auto base_reg = mem_op->getBase(); + if (base_reg->isVirtual()) { + use.insert(base_reg->getVRegNum()); + } else { + // [核心修复] 同样地,检查物理基址寄存器是否存在于map中 + PhysicalReg preg = base_reg->getPReg(); + if (preg_to_vreg_id_map.count(preg)) { + use.insert(preg_to_vreg_id_map.at(preg)); + } } + // 对于存储内存指令 (SW, SD),要存储的值(第一个操作数)也是 `use` if ((opcode == RVOpcodes::SW || opcode == RVOpcodes::SD) && !instr->getOperands().empty() && // 确保有操作数 @@ -300,6 +367,9 @@ void RISCv64RegAlloc::getInstrUseDef(MachineInstr* instr, LiveSet& use, LiveSet& auto src_reg_op = static_cast(instr->getOperands().front().get()); if (src_reg_op->isVirtual()) { use.insert(src_reg_op->getVRegNum()); + } else { + // 同样可以处理基址是物理寄存器的情况 + use.insert(preg_to_vreg_id_map.at(mem_op->getBase()->getPReg())); } } } @@ -345,8 +415,26 @@ unsigned RISCv64RegAlloc::getTypeSizeInBytes(Type* type) { void RISCv64RegAlloc::analyzeLiveness() { bool changed = true; + int iteration = 0; // [新] 添加一个迭代计数器 + + // [新] 辅助函数,用于将LiveSet打印为字符串 + auto liveset_to_string = [](const LiveSet& s) { + std::string out = "{"; + for (unsigned vreg : s) { + out += "%vreg" + std::to_string(vreg) + " "; + } + if (!s.empty()) out.pop_back(); + out += "}"; + return out; + }; + while (changed) { changed = false; + iteration++; // [新] 迭代计数 + if (DEEPDEBUG) { + std::cout << "\n===== Liveness Analysis Iteration " << iteration << " =====\n"; + } + for (auto it = MFunc->getBlocks().rbegin(); it != MFunc->getBlocks().rend(); ++it) { auto& mbb = *it; LiveSet live_out; @@ -375,6 +463,14 @@ void RISCv64RegAlloc::analyzeLiveness() { live_in.insert(diff.begin(), diff.end()); live_in_map[instr] = live_in; + if (DEEPDEBUG && mbb->getName() == "if_exit.L1") { + std::cout << " Instr (" << (void*)instr << "): \n" + << " Use: " << liveset_to_string(use) << "\n" + << " Def: " << liveset_to_string(def) << "\n" + << " Live Out: " << liveset_to_string(live_out_map[instr]) << "\n" + << " Live In: " << liveset_to_string(live_in) << std::endl; + } + live_out = live_in; if (live_in_map[instr] != old_live_in) { @@ -422,6 +518,14 @@ void RISCv64RegAlloc::buildInterferenceGraph() { // *** 核心修改点:处理 CALL 指令的隐式 def *** if (instr->getOpcode() == RVOpcodes::CALL) { + if (DEBUG) { + std::string live_out_str; + for (unsigned vreg : live_out) { + live_out_str += "%vreg" + std::to_string(vreg) + " "; + } + std::cout << "[DEBUG] buildInterferenceGraph: CALL instruction found. Live out set is: {" + << live_out_str << "}" << std::endl; + } // CALL 指令会定义(杀死)所有调用者保存的寄存器。 // 因此,所有调用者保存的物理寄存器都与 CALL 指令的 live_out 中的所有变量冲突。 const std::vector& caller_saved_regs = getCallerSavedIntRegs();