[backend]引入了Memset指令在后端的展开

This commit is contained in:
Lixuanwang
2025-07-19 13:52:09 +08:00
parent 86d1de6696
commit c8308047df
2 changed files with 130 additions and 31 deletions

View File

@ -146,6 +146,7 @@ std::string RISCv64CodeGen::module_gen() {
// 函数级代码生成
std::string RISCv64CodeGen::function_gen(Function* func) {
this->local_label_counter = 0; // 为每个函数重置本地标签计数器
std::stringstream ss;
ss << ".globl " << func->getName() << "\n"; // 声明函数为全局符号
ss << func->getName() << ":\n"; // 函数入口标签
@ -401,6 +402,26 @@ std::vector<std::unique_ptr<RISCv64CodeGen::DAGNode>> RISCv64CodeGen::build_dag(
store_node->operands.push_back(val_node);
store_node->operands.push_back(ptr_node);
ptr_node->users.push_back(store_node);
} else if (auto memset = dynamic_cast<MemsetInst*>(inst)) {
auto memset_node = create_node(DAGNode::MEMSET, memset, value_to_node, nodes_storage);
// 根据 IR.h 中的定义,获取 MemsetInst 的操作数
DAGNode* pointer_node = get_operand_node(memset->getPointer(), value_to_node, nodes_storage);
DAGNode* begin_node = get_operand_node(memset->getBegin(), value_to_node, nodes_storage);
DAGNode* size_node = get_operand_node(memset->getSize(), value_to_node, nodes_storage);
DAGNode* value_node = get_operand_node(memset->getValue(), value_to_node, nodes_storage);
// 将操作数节点添加到 MEMSET 节点的依赖列表中
memset_node->operands.push_back(pointer_node);
memset_node->operands.push_back(begin_node);
memset_node->operands.push_back(size_node);
memset_node->operands.push_back(value_node);
// 建立反向链接
pointer_node->users.push_back(memset_node);
begin_node->users.push_back(memset_node);
size_node->users.push_back(memset_node);
value_node->users.push_back(memset_node);
} else if (auto load = dynamic_cast<LoadInst*>(inst)) {
auto load_node = create_node(DAGNode::LOAD, load, value_to_node, nodes_storage);
@ -490,6 +511,12 @@ std::vector<std::unique_ptr<RISCv64CodeGen::DAGNode>> RISCv64CodeGen::build_dag(
} else if (auto uncond_br = dynamic_cast<UncondBrInst*>(inst)) {
auto br_node = create_node(DAGNode::BRANCH, uncond_br, value_to_node, nodes_storage); // 传递参数
br_node->inst = "j " + uncond_br->getBlock()->getName();
} else {
// 其他指令类型(如 PHI, 可能的自定义指令等)
// 目前假设未处理的指令类型不需要特殊 DAGNode 类型
// 可以在这里添加更多的处理逻辑
if (DEBUG) std::cerr << "未处理的指令类型: " << inst->getKindString() << "\n";
continue; // 跳过未处理的指令
}
}
return nodes_storage;
@ -914,10 +941,10 @@ void RISCv64CodeGen::select_instructions(DAGNode* node, const RegAllocResult& al
// [V1 特性] 如果基本块没有名字(例如,匿名块),给它一个伪名称
if (then_block.empty()) {
then_block = ENTRY_BLOCK_PSEUDO_NAME + "_then_" + std::to_string((uintptr_t)br);
}
then_block = ENTRY_BLOCK_PSEUDO_NAME + "_then_" + std::to_string(this->local_label_counter++);
}
if (else_block.empty()) {
else_block = ENTRY_BLOCK_PSEUDO_NAME + "_else_" + std::to_string((uintptr_t)br);
else_block = ENTRY_BLOCK_PSEUDO_NAME + "_else_" + std::to_string(this->local_label_counter++);
}
ss_inst << "bnez " << cond_reg << ", " << then_block << "\n";
@ -931,6 +958,61 @@ void RISCv64CodeGen::select_instructions(DAGNode* node, const RegAllocResult& al
}
break;
}
case DAGNode::MEMSET: {
if (node->operands.size() < 4) break;
// 1. 获取操作数被分配到的物理寄存器
// 您的 IR 中 pointer 和 begin 可能是同一个值,这里我们假设 pointer 是基地址
DAGNode* ptr_node = node->operands[0];
DAGNode* size_node = node->operands[2];
DAGNode* value_node = node->operands[3];
std::string R_DEST_ADDR = get_preg_or_temp(ptr_node->result_vreg);
std::string R_NUM_BYTES = get_preg_or_temp(size_node->result_vreg);
std::string R_VALUE_BYTE = get_preg_or_temp(value_node->result_vreg);
// 2. 定义我们将要使用的临时寄存器
// 由于我们在冲突图中添加了规则,可以安全地使用这些调用者保存寄存器
std::string R_COUNTER = "t3"; // 循环计数器 (字节)
std::string R_END_ADDR = "t4"; // 结束地址
std::string R_CURRENT_ADDR = "t5"; // 当前写入地址
std::string R_TEMP_VAL = "t6"; // 64位填充值
// 使用 local_label_counter 为 memset 循环生成唯一的、整洁的标签
int unique_id = this->local_label_counter++;
std::string loop_start_label = "memset_loop_start_" + std::to_string(unique_id);
std::string loop_end_label = "memset_loop_end_" + std::to_string(unique_id);
std::string remainder_label = "memset_remainder_" + std::to_string(unique_id);
std::string done_label = "memset_done_" + std::to_string(unique_id);
// 3. 生成汇编代码
ss_inst << "# --- Memset Start ---\n";
ss_inst << " andi " << R_TEMP_VAL << ", " << R_VALUE_BYTE << ", 255\n";
ss_inst << " slli " << R_VALUE_BYTE << ", " << R_TEMP_VAL << ", 8\n";
ss_inst << " or " << R_TEMP_VAL << ", " << R_TEMP_VAL << ", " << R_VALUE_BYTE << "\n";
ss_inst << " slli " << R_VALUE_BYTE << ", " << R_TEMP_VAL << ", 16\n";
ss_inst << " or " << R_TEMP_VAL << ", " << R_TEMP_VAL << ", " << R_VALUE_BYTE << "\n";
ss_inst << " slli " << R_VALUE_BYTE << ", " << R_TEMP_VAL << ", 32\n";
ss_inst << " or " << R_TEMP_VAL << ", " << R_TEMP_VAL << ", " << R_VALUE_BYTE << "\n";
ss_inst << " add " << R_END_ADDR << ", " << R_DEST_ADDR << ", " << R_NUM_BYTES << "\n";
ss_inst << " mv " << R_CURRENT_ADDR << ", " << R_DEST_ADDR << "\n";
ss_inst << " andi " << R_COUNTER << ", " << R_NUM_BYTES << ", -8\n";
ss_inst << " add " << R_COUNTER << ", " << R_DEST_ADDR << ", " << R_COUNTER << "\n";
ss_inst << loop_start_label << ":\n";
ss_inst << " bgeu " << R_CURRENT_ADDR << ", " << R_COUNTER << ", " << loop_end_label << "\n";
ss_inst << " sd " << R_TEMP_VAL << ", 0(" << R_CURRENT_ADDR << ")\n";
ss_inst << " addi " << R_CURRENT_ADDR << ", " << R_CURRENT_ADDR << ", 8\n";
ss_inst << " j " << loop_start_label << "\n";
ss_inst << loop_end_label << ":\n";
ss_inst << remainder_label << ":\n";
ss_inst << " bgeu " << R_CURRENT_ADDR << ", " << R_END_ADDR << ", " << done_label << "\n";
ss_inst << " sb " << R_TEMP_VAL << ", 0(" << R_CURRENT_ADDR << ")\n";
ss_inst << " addi " << R_CURRENT_ADDR << ", " << R_CURRENT_ADDR << ", 1\n";
ss_inst << " j " << remainder_label << "\n";
ss_inst << done_label << ":\n";
ss_inst << "# --- Memset End ---";
break;
}
default:
throw std::runtime_error("不支持的节点类型: " + node->getNodeKindString());
}
@ -940,49 +1022,41 @@ void RISCv64CodeGen::select_instructions(DAGNode* node, const RegAllocResult& al
// 指令发射
void RISCv64CodeGen::emit_instructions(DAGNode* node, std::stringstream& ss, const RegAllocResult& alloc, std::set<DAGNode*>& emitted_nodes) {
if (!node || emitted_nodes.count(node)) {
return; // 已发射或为空
return; // 如果节点为空或已经发射过,则返回
}
// 递归地发射操作数以确保满足依赖关系
// 递归地发射操作数以确保满足指令依赖
for (auto operand : node->operands) {
if (operand) {
emit_instructions(operand, ss, alloc, emitted_nodes);
}
}
// 标记当前节点为已发射
// 标记当前节点为已发射,防止重复
emitted_nodes.insert(node);
// 分割多行指令并处理每一行
// node->inst 中可能包含由 \n 分隔的多行指令和标签
std::stringstream node_inst_ss(node->inst);
std::string line;
while (std::getline(node_inst_ss, line, '\n')) {
// 清除前导/尾随空白并移除行开头的潜在标签
line = std::regex_replace(line, std::regex("^\\s*[^\\s:]*:\\s*"), ""); // 移除标签(例如 `label: inst`
line = std::regex_replace(line, std::regex("^\\s+|\\s+$"), ""); // 清除空白
// 首先,移除行首和行尾的空白字符,方便后续判断
line = std::regex_replace(line, std::regex("^\\s+|\\s+$"), "");
if (line.empty()) continue;
// 处理虚拟寄存器替换和溢出/加载逻辑
std::string processed_line = line;
// 替换结果虚拟寄存器 (如果此行中存在)
if (!node->result_vreg.empty() && alloc.vreg_to_preg.count(node->result_vreg)) {
std::string preg = reg_to_string(alloc.vreg_to_preg.at(node->result_vreg));
processed_line = std::regex_replace(processed_line, std::regex("\\b" + node->result_vreg + "\\b"), preg);
}
// 替换操作数虚拟寄存器 (如果此行中存在)
for (auto operand : node->operands) {
if (operand && !operand->result_vreg.empty() && alloc.vreg_to_preg.count(operand->result_vreg)) {
std::string operand_preg = reg_to_string(alloc.vreg_to_preg.at(operand->result_vreg));
processed_line = std::regex_replace(processed_line, std::regex("\\b" + operand->result_vreg + "\\b"), operand_preg);
}
if (line.empty()) {
continue; // 跳过空行
}
// 添加处理后的指令
ss << " " << processed_line << "\n";
// ====================== 核心修正逻辑 ======================
// 判断当前行是否是一个标签(即,非空且以':'结尾)
if (!line.empty() && line.back() == ':') {
// 如果是标签,直接打印,不加前导缩进
ss << line << "\n";
} else {
// 如果是常规指令添加4个空格的前导缩进后再打印
ss << " " << line << "\n";
}
// ========================================================
}
}
@ -1178,6 +1252,28 @@ std::map<std::string, std::set<std::string>> RISCv64CodeGen::build_interference_
graph[arg_vregs[j]].insert(arg_vregs[i]);
}
}
} else if (auto memset = dynamic_cast<MemsetInst*>(inst)) {
// 规则MemsetInst 像一个函数调用,它会污染临时寄存器。
// 因此,所有跨越这条指令的活跃变量(live_out)
// 都应该与这条指令的操作数(use)互相冲突。
// 这会强制分配器将它们放入不同的寄存器中,或者安全地保存/恢复。
std::set<std::string> use_set;
for (const auto& operand_use : memset->getOperands()) {
Value* operand = operand_use->getValue();
if (value_vreg_map.count(operand)) {
use_set.insert(value_vreg_map.at(operand));
}
}
for (const auto& live_vreg : live_out_set) {
for (const auto& use_vreg : use_set) {
if (live_vreg != use_vreg) {
graph[live_vreg].insert(use_vreg);
graph[use_vreg].insert(live_vreg);
}
}
}
}
}
return graph;

View File

@ -32,7 +32,7 @@ public:
// Move DAGNode and RegAllocResult to public section
struct DAGNode {
enum NodeKind { CONSTANT, LOAD, STORE, BINARY, CALL, RETURN, BRANCH, ALLOCA_ADDR, UNARY };
enum NodeKind { CONSTANT, LOAD, STORE, BINARY, CALL, RETURN, BRANCH, ALLOCA_ADDR, UNARY, MEMSET };
NodeKind kind;
Value* value = nullptr; // For IR Value
std::string inst; // Generated RISC-V instruction(s) for this node
@ -53,6 +53,7 @@ public:
case BRANCH: return "BRANCH";
case ALLOCA_ADDR: return "ALLOCA_ADDR";
case UNARY: return "UNARY";
case MEMSET: return "MEMSET";
default: return "UNKNOWN";
}
}
@ -103,7 +104,9 @@ private:
// 为空标签定义一个伪名称前缀,加上块索引以确保唯一性
const std::string ENTRY_BLOCK_PSEUDO_NAME = "entry_block_";
int local_label_counter = 0; // 用于生成唯一的本地标签 (如 memset 循环, 匿名块跳转等)
// !!! 修改get_operand_node 辅助函数现在需要传入 value_to_node 和 nodes_storage 的引用
// 因为它们是 build_dag 局部管理的
DAGNode* get_operand_node(