From c8308047df7d169fa0e67fe539e6a0b4e314eada Mon Sep 17 00:00:00 2001 From: Lixuanwang Date: Sat, 19 Jul 2025 13:52:09 +0800 Subject: [PATCH] =?UTF-8?q?[backend]=E5=BC=95=E5=85=A5=E4=BA=86Memset?= =?UTF-8?q?=E6=8C=87=E4=BB=A4=E5=9C=A8=E5=90=8E=E7=AB=AF=E7=9A=84=E5=B1=95?= =?UTF-8?q?=E5=BC=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/RISCv64Backend.cpp | 154 ++++++++++++++++++++++++++++------- src/include/RISCv64Backend.h | 7 +- 2 files changed, 130 insertions(+), 31 deletions(-) diff --git a/src/RISCv64Backend.cpp b/src/RISCv64Backend.cpp index 913f7ab..29b0bfd 100644 --- a/src/RISCv64Backend.cpp +++ b/src/RISCv64Backend.cpp @@ -146,6 +146,7 @@ std::string RISCv64CodeGen::module_gen() { // 函数级代码生成 std::string RISCv64CodeGen::function_gen(Function* func) { + this->local_label_counter = 0; // 为每个函数重置本地标签计数器 std::stringstream ss; ss << ".globl " << func->getName() << "\n"; // 声明函数为全局符号 ss << func->getName() << ":\n"; // 函数入口标签 @@ -401,6 +402,26 @@ std::vector> RISCv64CodeGen::build_dag( store_node->operands.push_back(val_node); store_node->operands.push_back(ptr_node); ptr_node->users.push_back(store_node); + } else if (auto memset = dynamic_cast(inst)) { + auto memset_node = create_node(DAGNode::MEMSET, memset, value_to_node, nodes_storage); + + // 根据 IR.h 中的定义,获取 MemsetInst 的操作数 + DAGNode* pointer_node = get_operand_node(memset->getPointer(), value_to_node, nodes_storage); + DAGNode* begin_node = get_operand_node(memset->getBegin(), value_to_node, nodes_storage); + DAGNode* size_node = get_operand_node(memset->getSize(), value_to_node, nodes_storage); + DAGNode* value_node = get_operand_node(memset->getValue(), value_to_node, nodes_storage); + + // 将操作数节点添加到 MEMSET 节点的依赖列表中 + memset_node->operands.push_back(pointer_node); + memset_node->operands.push_back(begin_node); + memset_node->operands.push_back(size_node); + memset_node->operands.push_back(value_node); + + // 建立反向链接 + pointer_node->users.push_back(memset_node); + begin_node->users.push_back(memset_node); + size_node->users.push_back(memset_node); + value_node->users.push_back(memset_node); } else if (auto load = dynamic_cast(inst)) { auto load_node = create_node(DAGNode::LOAD, load, value_to_node, nodes_storage); @@ -490,6 +511,12 @@ std::vector> RISCv64CodeGen::build_dag( } else if (auto uncond_br = dynamic_cast(inst)) { auto br_node = create_node(DAGNode::BRANCH, uncond_br, value_to_node, nodes_storage); // 传递参数 br_node->inst = "j " + uncond_br->getBlock()->getName(); + } else { + // 其他指令类型(如 PHI, 可能的自定义指令等) + // 目前假设未处理的指令类型不需要特殊 DAGNode 类型 + // 可以在这里添加更多的处理逻辑 + if (DEBUG) std::cerr << "未处理的指令类型: " << inst->getKindString() << "\n"; + continue; // 跳过未处理的指令 } } return nodes_storage; @@ -914,10 +941,10 @@ void RISCv64CodeGen::select_instructions(DAGNode* node, const RegAllocResult& al // [V1 特性] 如果基本块没有名字(例如,匿名块),给它一个伪名称 if (then_block.empty()) { - then_block = ENTRY_BLOCK_PSEUDO_NAME + "_then_" + std::to_string((uintptr_t)br); - } + then_block = ENTRY_BLOCK_PSEUDO_NAME + "_then_" + std::to_string(this->local_label_counter++); + } if (else_block.empty()) { - else_block = ENTRY_BLOCK_PSEUDO_NAME + "_else_" + std::to_string((uintptr_t)br); + else_block = ENTRY_BLOCK_PSEUDO_NAME + "_else_" + std::to_string(this->local_label_counter++); } ss_inst << "bnez " << cond_reg << ", " << then_block << "\n"; @@ -931,6 +958,61 @@ void RISCv64CodeGen::select_instructions(DAGNode* node, const RegAllocResult& al } break; } + case DAGNode::MEMSET: { + if (node->operands.size() < 4) break; + + // 1. 获取操作数被分配到的物理寄存器 + // 您的 IR 中 pointer 和 begin 可能是同一个值,这里我们假设 pointer 是基地址 + DAGNode* ptr_node = node->operands[0]; + DAGNode* size_node = node->operands[2]; + DAGNode* value_node = node->operands[3]; + + std::string R_DEST_ADDR = get_preg_or_temp(ptr_node->result_vreg); + std::string R_NUM_BYTES = get_preg_or_temp(size_node->result_vreg); + std::string R_VALUE_BYTE = get_preg_or_temp(value_node->result_vreg); + + // 2. 定义我们将要使用的临时寄存器 + // 由于我们在冲突图中添加了规则,可以安全地使用这些调用者保存寄存器 + std::string R_COUNTER = "t3"; // 循环计数器 (字节) + std::string R_END_ADDR = "t4"; // 结束地址 + std::string R_CURRENT_ADDR = "t5"; // 当前写入地址 + std::string R_TEMP_VAL = "t6"; // 64位填充值 + + // 使用 local_label_counter 为 memset 循环生成唯一的、整洁的标签 + int unique_id = this->local_label_counter++; + std::string loop_start_label = "memset_loop_start_" + std::to_string(unique_id); + std::string loop_end_label = "memset_loop_end_" + std::to_string(unique_id); + std::string remainder_label = "memset_remainder_" + std::to_string(unique_id); + std::string done_label = "memset_done_" + std::to_string(unique_id); + + // 3. 生成汇编代码 + ss_inst << "# --- Memset Start ---\n"; + ss_inst << " andi " << R_TEMP_VAL << ", " << R_VALUE_BYTE << ", 255\n"; + ss_inst << " slli " << R_VALUE_BYTE << ", " << R_TEMP_VAL << ", 8\n"; + ss_inst << " or " << R_TEMP_VAL << ", " << R_TEMP_VAL << ", " << R_VALUE_BYTE << "\n"; + ss_inst << " slli " << R_VALUE_BYTE << ", " << R_TEMP_VAL << ", 16\n"; + ss_inst << " or " << R_TEMP_VAL << ", " << R_TEMP_VAL << ", " << R_VALUE_BYTE << "\n"; + ss_inst << " slli " << R_VALUE_BYTE << ", " << R_TEMP_VAL << ", 32\n"; + ss_inst << " or " << R_TEMP_VAL << ", " << R_TEMP_VAL << ", " << R_VALUE_BYTE << "\n"; + ss_inst << " add " << R_END_ADDR << ", " << R_DEST_ADDR << ", " << R_NUM_BYTES << "\n"; + ss_inst << " mv " << R_CURRENT_ADDR << ", " << R_DEST_ADDR << "\n"; + ss_inst << " andi " << R_COUNTER << ", " << R_NUM_BYTES << ", -8\n"; + ss_inst << " add " << R_COUNTER << ", " << R_DEST_ADDR << ", " << R_COUNTER << "\n"; + ss_inst << loop_start_label << ":\n"; + ss_inst << " bgeu " << R_CURRENT_ADDR << ", " << R_COUNTER << ", " << loop_end_label << "\n"; + ss_inst << " sd " << R_TEMP_VAL << ", 0(" << R_CURRENT_ADDR << ")\n"; + ss_inst << " addi " << R_CURRENT_ADDR << ", " << R_CURRENT_ADDR << ", 8\n"; + ss_inst << " j " << loop_start_label << "\n"; + ss_inst << loop_end_label << ":\n"; + ss_inst << remainder_label << ":\n"; + ss_inst << " bgeu " << R_CURRENT_ADDR << ", " << R_END_ADDR << ", " << done_label << "\n"; + ss_inst << " sb " << R_TEMP_VAL << ", 0(" << R_CURRENT_ADDR << ")\n"; + ss_inst << " addi " << R_CURRENT_ADDR << ", " << R_CURRENT_ADDR << ", 1\n"; + ss_inst << " j " << remainder_label << "\n"; + ss_inst << done_label << ":\n"; + ss_inst << "# --- Memset End ---"; + break; + } default: throw std::runtime_error("不支持的节点类型: " + node->getNodeKindString()); } @@ -940,49 +1022,41 @@ void RISCv64CodeGen::select_instructions(DAGNode* node, const RegAllocResult& al // 指令发射 void RISCv64CodeGen::emit_instructions(DAGNode* node, std::stringstream& ss, const RegAllocResult& alloc, std::set& emitted_nodes) { if (!node || emitted_nodes.count(node)) { - return; // 已发射或为空 + return; // 如果节点为空或已经发射过,则返回 } - // 递归地发射操作数以确保满足依赖关系 + // 递归地发射操作数,以确保满足指令依赖 for (auto operand : node->operands) { if (operand) { emit_instructions(operand, ss, alloc, emitted_nodes); } } - // 标记当前节点为已发射 + // 标记当前节点为已发射,防止重复 emitted_nodes.insert(node); - // 分割多行指令并处理每一行 + // node->inst 中可能包含由 \n 分隔的多行指令和标签 std::stringstream node_inst_ss(node->inst); std::string line; while (std::getline(node_inst_ss, line, '\n')) { - // 清除前导/尾随空白并移除行开头的潜在标签 - line = std::regex_replace(line, std::regex("^\\s*[^\\s:]*:\\s*"), ""); // 移除标签(例如 `label: inst`) - line = std::regex_replace(line, std::regex("^\\s+|\\s+$"), ""); // 清除空白 + // 首先,移除行首和行尾的空白字符,方便后续判断 + line = std::regex_replace(line, std::regex("^\\s+|\\s+$"), ""); - if (line.empty()) continue; - - // 处理虚拟寄存器替换和溢出/加载逻辑 - std::string processed_line = line; - - // 替换结果虚拟寄存器 (如果此行中存在) - if (!node->result_vreg.empty() && alloc.vreg_to_preg.count(node->result_vreg)) { - std::string preg = reg_to_string(alloc.vreg_to_preg.at(node->result_vreg)); - processed_line = std::regex_replace(processed_line, std::regex("\\b" + node->result_vreg + "\\b"), preg); - } - - // 替换操作数虚拟寄存器 (如果此行中存在) - for (auto operand : node->operands) { - if (operand && !operand->result_vreg.empty() && alloc.vreg_to_preg.count(operand->result_vreg)) { - std::string operand_preg = reg_to_string(alloc.vreg_to_preg.at(operand->result_vreg)); - processed_line = std::regex_replace(processed_line, std::regex("\\b" + operand->result_vreg + "\\b"), operand_preg); - } + if (line.empty()) { + continue; // 跳过空行 } - // 添加处理后的指令 - ss << " " << processed_line << "\n"; + // ====================== 核心修正逻辑 ====================== + // 判断当前行是否是一个标签(即,非空且以':'结尾) + if (!line.empty() && line.back() == ':') { + // 如果是标签,直接打印,不加前导缩进 + ss << line << "\n"; + } else { + // 如果是常规指令,添加4个空格的前导缩进后再打印 + ss << " " << line << "\n"; + } + // ======================================================== } } @@ -1178,6 +1252,28 @@ std::map> RISCv64CodeGen::build_interference_ graph[arg_vregs[j]].insert(arg_vregs[i]); } } + } else if (auto memset = dynamic_cast(inst)) { + // 规则:MemsetInst 像一个函数调用,它会污染临时寄存器。 + // 因此,所有跨越这条指令的活跃变量(live_out), + // 都应该与这条指令的操作数(use)互相冲突。 + // 这会强制分配器将它们放入不同的寄存器中,或者安全地保存/恢复。 + + std::set use_set; + for (const auto& operand_use : memset->getOperands()) { + Value* operand = operand_use->getValue(); + if (value_vreg_map.count(operand)) { + use_set.insert(value_vreg_map.at(operand)); + } + } + + for (const auto& live_vreg : live_out_set) { + for (const auto& use_vreg : use_set) { + if (live_vreg != use_vreg) { + graph[live_vreg].insert(use_vreg); + graph[use_vreg].insert(live_vreg); + } + } + } } } return graph; diff --git a/src/include/RISCv64Backend.h b/src/include/RISCv64Backend.h index f929250..429aba2 100644 --- a/src/include/RISCv64Backend.h +++ b/src/include/RISCv64Backend.h @@ -32,7 +32,7 @@ public: // Move DAGNode and RegAllocResult to public section struct DAGNode { - enum NodeKind { CONSTANT, LOAD, STORE, BINARY, CALL, RETURN, BRANCH, ALLOCA_ADDR, UNARY }; + enum NodeKind { CONSTANT, LOAD, STORE, BINARY, CALL, RETURN, BRANCH, ALLOCA_ADDR, UNARY, MEMSET }; NodeKind kind; Value* value = nullptr; // For IR Value std::string inst; // Generated RISC-V instruction(s) for this node @@ -53,6 +53,7 @@ public: case BRANCH: return "BRANCH"; case ALLOCA_ADDR: return "ALLOCA_ADDR"; case UNARY: return "UNARY"; + case MEMSET: return "MEMSET"; default: return "UNKNOWN"; } } @@ -103,7 +104,9 @@ private: // 为空标签定义一个伪名称前缀,加上块索引以确保唯一性 const std::string ENTRY_BLOCK_PSEUDO_NAME = "entry_block_"; - + + int local_label_counter = 0; // 用于生成唯一的本地标签 (如 memset 循环, 匿名块跳转等) + // !!! 修改:get_operand_node 辅助函数现在需要传入 value_to_node 和 nodes_storage 的引用 // 因为它们是 build_dag 局部管理的 DAGNode* get_operand_node(