Compare commits

...

14 Commits

6 changed files with 667 additions and 165 deletions

View File

@ -1,30 +1,10 @@
#include "RISCv32Backend.h"
#include <sstream>
#include <algorithm>
#include <stack>
namespace sysy {
const std::vector<RISCv32CodeGen::PhysicalReg> RISCv32CodeGen::allocable_regs = {
PhysicalReg::T0, PhysicalReg::T1, PhysicalReg::T2, PhysicalReg::T3,
PhysicalReg::T4, PhysicalReg::T5, PhysicalReg::T6,
PhysicalReg::A0, PhysicalReg::A1, PhysicalReg::A2, PhysicalReg::A3,
PhysicalReg::A4, PhysicalReg::A5, PhysicalReg::A6, PhysicalReg::A7
};
std::string RISCv32CodeGen::reg_to_string(PhysicalReg reg) {
switch (reg) {
case PhysicalReg::T0: return "t0"; case PhysicalReg::T1: return "t1";
case PhysicalReg::T2: return "t2"; case PhysicalReg::T3: return "t3";
case PhysicalReg::T4: return "t4"; case PhysicalReg::T5: return "t5";
case PhysicalReg::T6: return "t6"; case PhysicalReg::A0: return "a0";
case PhysicalReg::A1: return "a1"; case PhysicalReg::A2: return "a2";
case PhysicalReg::A3: return "a3"; case PhysicalReg::A4: return "a4";
case PhysicalReg::A5: return "a5"; case PhysicalReg::A6: return "a6";
case PhysicalReg::A7: return "a7";
default: return "";
}
}
std::string RISCv32CodeGen::code_gen() {
std::stringstream ss;
ss << ".text\n";
@ -34,16 +14,22 @@ std::string RISCv32CodeGen::code_gen() {
std::string RISCv32CodeGen::module_gen() {
std::stringstream ss;
// 生成全局变量(数据段)
for (const auto& global : module->getGlobals()) {
ss << ".data\n";
ss << ".globl " << global->getName() << "\n";
for (auto& global : module->getGlobals()) {
ss << ".global " << global->getName() << "\n";
ss << ".section .data\n";
ss << ".align 2\n";
ss << global->getName() << ":\n";
ss << " .word 0\n"; // 假设初始化为0
for (auto value : global->getInitValues().getValues()) {
auto const_val = dynamic_cast<ConstantValue*>(value);
if (const_val->isInt()) {
ss << ".word " << const_val->getInt() << "\n";
} else {
ss << ".float " << const_val->getFloat() << "\n";
}
// 生成函数(文本段)
ss << ".text\n";
for (const auto& func : module->getFunctions()) {
}
}
ss << ".section .text\n";
for (auto& func : module->getFunctions()) {
ss << function_gen(func.second.get());
}
return ss.str();
@ -51,107 +37,582 @@ std::string RISCv32CodeGen::module_gen() {
std::string RISCv32CodeGen::function_gen(Function* func) {
std::stringstream ss;
// 函数标签
ss << ".globl " << func->getName() << "\n";
ss << ".global " << func->getName() << "\n";
ss << ".type " << func->getName() << ", @function\n";
ss << func->getName() << ":\n";
// 序言:保存 ra分配堆栈
bool is_leaf = true; // 简化假设
ss << " addi sp, sp, -16\n";
ss << " sw ra, 12(sp)\n";
// 寄存器分配
auto alloc = register_allocation(func);
// 生成基本块代码
for (const auto& bb : func->getBasicBlocks()) {
ss << basicBlock_gen(bb.get(), alloc);
// Perform register allocation
auto live_sets = liveness_analysis(func);
auto interference_graph = build_interference_graph(live_sets);
auto alloc = color_graph(func, interference_graph);
// Prologue: Adjust stack and save callee-saved registers
if (alloc.stack_size > 0) {
ss << " addi sp, sp, -" << alloc.stack_size << "\n";
ss << " sw ra, " << (alloc.stack_size - 4) << "(sp)\n";
}
for (auto preg : callee_saved) {
if (std::find_if(alloc.vreg_to_preg.begin(), alloc.vreg_to_preg.end(),
[preg](const auto& pair) { return pair.second == preg; }) != alloc.vreg_to_preg.end()) {
ss << " sw " << get_preg_str(preg) << ", " << (alloc.stack_size - 8) << "(sp)\n";
}
}
int block_idx = 0;
for (auto& bb : func->getBasicBlocks()) {
ss << basicBlock_gen(bb.get(), alloc, block_idx++);
}
// Epilogue: Restore callee-saved registers and stack
for (auto preg : callee_saved) {
if (std::find_if(alloc.vreg_to_preg.begin(), alloc.vreg_to_preg.end(),
[preg](const auto& pair) { return pair.second == preg; }) != alloc.vreg_to_preg.end()) {
ss << " lw " << get_preg_str(preg) << ", " << (alloc.stack_size - 8) << "(sp)\n";
}
}
if (alloc.stack_size > 0) {
ss << " lw ra, " << (alloc.stack_size - 4) << "(sp)\n";
ss << " addi sp, sp, " << alloc.stack_size << "\n";
}
// 结尾:恢复 ra释放堆栈
ss << " lw ra, 12(sp)\n";
ss << " addi sp, sp, 16\n";
ss << " ret\n";
return ss.str();
}
std::string RISCv32CodeGen::basicBlock_gen(BasicBlock* bb, const RegAllocResult& alloc) {
std::string RISCv32CodeGen::basicBlock_gen(BasicBlock* bb, const RegAllocResult& alloc, int block_idx) {
std::stringstream ss;
ss << bb->getName() << ":\n";
for (const auto& inst : bb->getInstructions()) {
auto riscv_insts = instruction_gen(inst.get());
for (const auto& riscv_inst : riscv_insts) {
ss << " " << riscv_inst.opcode;
for (size_t i = 0; i < riscv_inst.operands.size(); ++i) {
if (i > 0) ss << ", ";
if (riscv_inst.operands[i].kind == Operand::Kind::Reg) {
auto it = alloc.reg_map.find(riscv_inst.operands[i].value);
if (it != alloc.reg_map.end()) {
ss << reg_to_string(it->second);
} else {
auto stack_it = alloc.stack_map.find(riscv_inst.operands[i].value);
if (stack_it != alloc.stack_map.end()) {
ss << stack_it->second << "(sp)";
} else {
ss << "%" << riscv_inst.operands[i].value->getName();
}
}
} else if (riscv_inst.operands[i].kind == Operand::Kind::Imm) {
ss << riscv_inst.operands[i].label;
} else {
ss << riscv_inst.operands[i].label;
}
}
ss << "\n";
ss << ".L" << block_idx << ":\n";
auto dag_nodes = build_dag(bb);
for (auto& node : dag_nodes) {
select_instructions(node.get(), alloc);
}
std::set<DAGNode*> emitted_nodes;
for (auto& node : dag_nodes) {
emit_instructions(node.get(), ss, alloc, emitted_nodes);
}
return ss.str();
}
std::vector<RISCv32CodeGen::RISCv32Inst> RISCv32CodeGen::instruction_gen(Instruction* inst) {
std::vector<RISCv32Inst> insts;
if (auto bin = dynamic_cast<BinaryInst*>(inst)) {
std::string opcode;
if (bin->getKind() == BinaryInst::kAdd) opcode = "add";
else if (bin->getKind() == BinaryInst::kSub) opcode = "sub";
else if (bin->getKind() == BinaryInst::kMul) opcode = "mul";
else return insts; // 其他操作未实现
insts.emplace_back(opcode, std::vector<Operand>{
{Operand::Kind::Reg, bin},
{Operand::Kind::Reg, bin->getLhs()},
{Operand::Kind::Reg, bin->getRhs()}
});
} else if (auto load = dynamic_cast<LoadInst*>(inst)) {
insts.emplace_back("lw", std::vector<Operand>{
{Operand::Kind::Reg, load},
{Operand::Kind::Label, load->getPointer()->getName()}
});
} else if (auto store = dynamic_cast<StoreInst*>(inst)) {
insts.emplace_back("sw", std::vector<Operand>{
{Operand::Kind::Reg, store->getValue()},
{Operand::Kind::Label, store->getPointer()->getName()}
});
std::vector<std::unique_ptr<RISCv32CodeGen::DAGNode>> RISCv32CodeGen::build_dag(BasicBlock* bb) {
std::vector<std::unique_ptr<DAGNode>> nodes;
std::map<Value*, DAGNode*> value_to_node;
int vreg_counter = 0;
for (auto& inst : bb->getInstructions()) {
if (auto alloca = dynamic_cast<AllocaInst*>(inst.get())) {
auto node = std::make_unique<DAGNode>(DAGNode::ALLOCA_ADDR);
node->value = alloca;
node->result_vreg = "%" + inst->getName(); // Use IR name (%a(0), %b(0))
value_to_node[alloca] = node.get();
nodes.push_back(std::move(node));
} else if (auto load = dynamic_cast<LoadInst*>(inst.get())) {
auto node = std::make_unique<DAGNode>(DAGNode::LOAD);
node->value = load;
node->result_vreg = "%" + inst->getName(); // Use IR name (%0, %1)
auto pointer = load->getPointer();
if (value_to_node.count(pointer)) {
node->operands.push_back(value_to_node[pointer]);
value_to_node[pointer]->users.push_back(node.get());
}
return insts;
value_to_node[load] = node.get();
nodes.push_back(std::move(node));
} else if (auto store = dynamic_cast<StoreInst*>(inst.get())) {
auto node = std::make_unique<DAGNode>(DAGNode::STORE);
node->value = store;
auto value_operand = store->getValue();
auto pointer = store->getPointer();
if (value_to_node.count(value_operand)) {
node->operands.push_back(value_to_node[value_operand]);
value_to_node[value_operand]->users.push_back(node.get());
} else if (auto const_val = dynamic_cast<ConstantValue*>(value_operand)) {
auto const_node = std::make_unique<DAGNode>(DAGNode::CONSTANT);
const_node->value = const_val;
const_node->result_vreg = "%" + std::to_string(vreg_counter++); // Use simple %N for constants
value_to_node[value_operand] = const_node.get();
node->operands.push_back(const_node.get());
const_node->users.push_back(node.get());
nodes.push_back(std::move(const_node));
}
if (value_to_node.count(pointer)) {
node->operands.push_back(value_to_node[pointer]);
value_to_node[pointer]->users.push_back(node.get());
}
nodes.push_back(std::move(node));
} else if (auto binary = dynamic_cast<BinaryInst*>(inst.get())) {
auto node = std::make_unique<DAGNode>(DAGNode::BINARY);
node->value = binary;
node->result_vreg = "%" + inst->getName(); // Use IR name (%2)
for (auto operand : binary->getOperands()) {
auto op_value = operand->getValue();
if (value_to_node.count(op_value)) {
node->operands.push_back(value_to_node[op_value]);
value_to_node[op_value]->users.push_back(node.get());
} else if (auto const_val = dynamic_cast<ConstantValue*>(op_value)) {
auto const_node = std::make_unique<DAGNode>(DAGNode::CONSTANT);
const_node->value = const_val;
const_node->result_vreg = "%" + std::to_string(vreg_counter++);
value_to_node[op_value] = const_node.get();
node->operands.push_back(const_node.get());
const_node->users.push_back(node.get());
nodes.push_back(std::move(const_node));
}
}
value_to_node[binary] = node.get();
nodes.push_back(std::move(node));
} else if (auto ret = dynamic_cast<ReturnInst*>(inst.get())) {
auto node = std::make_unique<DAGNode>(DAGNode::RETURN);
node->value = ret;
if (ret->hasReturnValue()) {
auto value_operand = ret->getReturnValue();
if (value_to_node.count(value_operand)) {
node->operands.push_back(value_to_node[value_operand]);
value_to_node[value_operand]->users.push_back(node.get());
} else if (auto const_val = dynamic_cast<ConstantValue*>(value_operand)) {
auto const_node = std::make_unique<DAGNode>(DAGNode::CONSTANT);
const_node->value = const_val;
const_node->result_vreg = "%" + std::to_string(vreg_counter++);
value_to_node[value_operand] = const_node.get();
node->operands.push_back(const_node.get());
const_node->users.push_back(node.get());
nodes.push_back(std::move(const_node));
}
}
nodes.push_back(std::move(node));
} else if (auto cond_br = dynamic_cast<CondBrInst*>(inst.get())) {
auto node = std::make_unique<DAGNode>(DAGNode::BRANCH);
node->value = cond_br;
auto condition = cond_br->getCondition();
if (value_to_node.count(condition)) {
node->operands.push_back(value_to_node[condition]);
value_to_node[condition]->users.push_back(node.get());
} else if (auto const_val = dynamic_cast<ConstantValue*>(condition)) {
auto const_node = std::make_unique<DAGNode>(DAGNode::CONSTANT);
const_node->value = const_val;
const_node->result_vreg = "%" + std::to_string(vreg_counter++);
value_to_node[condition] = const_node.get();
node->operands.push_back(const_node.get());
const_node->users.push_back(node.get());
nodes.push_back(std::move(const_node));
}
nodes.push_back(std::move(node));
} else if (auto uncond_br = dynamic_cast<UncondBrInst*>(inst.get())) {
auto node = std::make_unique<DAGNode>(DAGNode::BRANCH);
node->value = uncond_br;
nodes.push_back(std::move(node));
}
}
return nodes;
}
void RISCv32CodeGen::eliminate_phi(Function* func) {
// TODO: 实现 phi 指令消除
void RISCv32CodeGen::select_instructions(DAGNode* node, const RegAllocResult& alloc) {
if (node->inst.empty()) {
switch (node->kind) {
case DAGNode::CONSTANT: {
auto const_val = dynamic_cast<ConstantValue*>(node->value);
if (const_val->isInt()) {
node->inst = "li " + node->result_vreg + ", " + std::to_string(const_val->getInt());
} else {
node->inst = "# float constant not implemented";
}
break;
}
case DAGNode::LOAD: {
auto load = dynamic_cast<LoadInst*>(node->value);
auto pointer = load->getPointer();
if (auto alloca = dynamic_cast<AllocaInst*>(pointer)) {
if (alloc.stack_map.count(alloca)) {
node->inst = "lw " + node->result_vreg + ", " + std::to_string(alloc.stack_map.at(alloca)) + "(sp)";
}
} else if (auto global = dynamic_cast<GlobalValue*>(pointer)) {
node->inst = "lw " + node->result_vreg + ", " + global->getName() + "(gp)";
}
break;
}
case DAGNode::STORE: {
auto store = dynamic_cast<StoreInst*>(node->value);
auto pointer = store->getPointer();
auto value_vreg = node->operands[0]->result_vreg;
if (auto alloca = dynamic_cast<AllocaInst*>(pointer)) {
if (alloc.stack_map.count(alloca)) {
node->inst = "sw " + value_vreg + ", " + std::to_string(alloc.stack_map.at(alloca)) + "(sp)";
}
} else if (auto global = dynamic_cast<GlobalValue*>(pointer)) {
node->inst = "sw " + value_vreg + ", " + global->getName() + "(gp)";
}
break;
}
case DAGNode::BINARY: {
auto binary = dynamic_cast<BinaryInst*>(node->value);
auto lhs_vreg = node->operands[0]->result_vreg;
auto rhs_vreg = node->operands[1]->result_vreg;
std::string op;
switch (binary->getKind()) {
case Instruction::kAdd: op = "add"; break;
case Instruction::kSub: op = "sub"; break;
case Instruction::kMul: op = "mul"; break;
case Instruction::kDiv: op = "div"; break;
case Instruction::kICmpEQ: op = "seq"; break;
case Instruction::kICmpNE: op = "sne"; break;
case Instruction::kICmpLT: op = "slt"; break;
case Instruction::kICmpGT: op = "sgt"; break;
case Instruction::kICmpLE: op = "sle"; break;
case Instruction::kICmpGE: op = "sge"; break;
default: op = "# unknown"; break;
}
node->inst = op + " " + node->result_vreg + ", " + lhs_vreg + ", " + rhs_vreg;
break;
}
case DAGNode::RETURN: {
auto ret = dynamic_cast<ReturnInst*>(node->value);
if (ret->hasReturnValue()) {
auto value_vreg = node->operands[0]->result_vreg;
node->inst = "mv a0, " + value_vreg;
} else {
node->inst = "ret";
}
break;
}
case DAGNode::BRANCH: {
if (auto cond_br = dynamic_cast<CondBrInst*>(node->value)) {
auto condition_vreg = node->operands[0]->result_vreg;
auto then_block = cond_br->getThenBlock();
auto else_block = cond_br->getElseBlock();
int then_idx = 0, else_idx = 0;
int idx = 0;
for (auto& bb : cond_br->getFunction()->getBasicBlocks()) {
if (bb.get() == then_block) then_idx = idx;
if (bb.get() == else_block) else_idx = idx;
idx++;
}
node->inst = "bne " + condition_vreg + ", zero, .L" + std::to_string(then_idx) + "\n j .L" + std::to_string(else_idx);
} else if (auto uncond_br = dynamic_cast<UncondBrInst*>(node->value)) {
auto target_block = uncond_br->getBlock();
int target_idx = 0;
int idx = 0;
for (auto& bb : uncond_br->getFunction()->getBasicBlocks()) {
if (bb.get() == target_block) target_idx = idx;
idx++;
}
node->inst = "j .L" + std::to_string(target_idx);
}
break;
}
default:
node->inst = "# unimplemented";
break;
}
}
}
std::map<Instruction*, std::set<Value*>> RISCv32CodeGen::liveness_analysis(Function* func) {
std::map<Instruction*, std::set<Value*>> live_sets;
// TODO: 实现活跃性分析
return live_sets;
void RISCv32CodeGen::emit_instructions(DAGNode* node, std::stringstream& ss, const RegAllocResult& alloc, std::set<DAGNode*>& emitted_nodes) {
if (emitted_nodes.count(node)) return;
for (auto operand : node->operands) {
emit_instructions(operand, ss, alloc, emitted_nodes);
}
if (!node->inst.empty() && node->inst != "# unimplemented" && node->inst.find("# alloca") == std::string::npos) {
std::string inst = node->inst;
std::vector<std::pair<std::string, std::string>> replacements;
// Collect replacements for result and operand virtual registers
if (node->result_vreg != "" && node->kind != DAGNode::ALLOCA_ADDR) {
if (alloc.vreg_to_preg.count(node->result_vreg)) {
replacements.emplace_back(node->result_vreg, get_preg_str(alloc.vreg_to_preg.at(node->result_vreg)));
} else if (alloc.spill_map.count(node->result_vreg)) {
auto temp_reg = PhysicalReg::T0;
replacements.emplace_back(node->result_vreg, get_preg_str(temp_reg));
inst = inst.substr(0, inst.find('\n')); // Handle multi-line instructions
ss << " " << inst << "\n";
ss << " sw " << get_preg_str(temp_reg) << ", " << alloc.spill_map.at(node->result_vreg) << "(sp)\n";
emitted_nodes.insert(node);
return;
} else {
ss << "# Error: Virtual register " << node->result_vreg << " not allocated (kind: " << node->getNodeKindString() << ")\n";
}
}
for (auto operand : node->operands) {
if (operand->result_vreg != "" && operand->kind != DAGNode::ALLOCA_ADDR) {
if (alloc.vreg_to_preg.count(operand->result_vreg)) {
replacements.emplace_back(operand->result_vreg, get_preg_str(alloc.vreg_to_preg.at(operand->result_vreg)));
} else if (alloc.spill_map.count(operand->result_vreg)) {
auto temp_reg = PhysicalReg::T1;
ss << " lw " << get_preg_str(temp_reg) << ", " << alloc.spill_map.at(operand->result_vreg) << "(sp)\n";
replacements.emplace_back(operand->result_vreg, get_preg_str(temp_reg));
} else {
ss << "# Error: Operand virtual register " << operand->result_vreg << " not allocated (kind: " << operand->getNodeKindString() << ")\n";
}
}
}
// Perform all replacements only if vreg exists in inst
for (const auto& [vreg, preg] : replacements) {
size_t pos = inst.find(vreg);
while (pos != std::string::npos) {
inst.replace(pos, vreg.length(), preg);
pos = inst.find(vreg, pos + preg.length());
}
}
// Emit the instruction
if (node->kind == DAGNode::BRANCH || inst.find('\n') != std::string::npos) {
ss << inst << "\n";
} else if (inst != "ret") {
ss << " " << inst << "\n";
}
}
emitted_nodes.insert(node);
}
std::map<Value*, std::set<Value*>> RISCv32CodeGen::build_interference_graph(
const std::map<Instruction*, std::set<Value*>>& live_sets) {
std::map<Value*, std::set<Value*>> graph;
// TODO: 实现干扰图构建
return graph;
std::map<Instruction*, std::set<std::string>> RISCv32CodeGen::liveness_analysis(Function* func) {
std::map<Instruction*, std::set<std::string>> live_in, live_out;
bool changed;
// Build DAG for all basic blocks
std::map<BasicBlock*, std::vector<std::unique_ptr<DAGNode>>> bb_dags;
for (auto& bb : func->getBasicBlocks()) {
bb_dags[bb.get()] = build_dag(bb.get());
}
// Initialize live_in and live_out
for (auto& bb : func->getBasicBlocks()) {
for (auto& inst : bb->getInstructions()) {
live_in[inst.get()];
live_out[inst.get()];
}
}
do {
changed = false;
for (auto& bb : func->getBasicBlocks()) {
// Reverse iterate for backward analysis
for (auto it = bb->getInstructions().rbegin(); it != bb->getInstructions().rend(); ++it) {
auto inst = it->get();
std::set<std::string> new_live_in, new_live_out;
// live_out = union of live_in of successors
for (auto succ : bb->getSuccessors()) {
if (!succ->getInstructions().empty()) {
auto succ_inst = succ->getInstructions().front().get();
new_live_out.insert(live_in[succ_inst].begin(), live_in[succ_inst].end());
}
}
// Collect def and use
std::set<std::string> def, use;
// IR instruction def
if (inst->getName() != "" && !dynamic_cast<AllocaInst*>(inst)) {
def.insert("%" + inst->getName());
}
// IR instruction use
for (auto operand : inst->getOperands()) {
auto value = operand->getValue();
if (auto op_inst = dynamic_cast<Instruction*>(value)) {
if (op_inst->getName() != "" && !dynamic_cast<AllocaInst*>(op_inst)) {
use.insert("%" + op_inst->getName());
}
}
}
// DAG node def and use
for (auto& node : bb_dags[bb.get()]) {
if (node->value == inst && node->kind != DAGNode::ALLOCA_ADDR) {
if (node->result_vreg != "") {
def.insert(node->result_vreg);
}
for (auto operand : node->operands) {
if (operand->result_vreg != "" && operand->kind != DAGNode::ALLOCA_ADDR) {
use.insert(operand->result_vreg);
}
}
}
// Constants
if (node->kind == DAGNode::CONSTANT) {
for (auto user : node->users) {
if (user->value == inst) {
use.insert(node->result_vreg);
}
}
}
}
// live_in = use U (live_out - def)
std::set<std::string> live_out_minus_def;
std::set_difference(new_live_out.begin(), new_live_out.end(),
def.begin(), def.end(),
std::inserter(live_out_minus_def, live_out_minus_def.begin()));
new_live_in.insert(use.begin(), use.end());
new_live_in.insert(live_out_minus_def.begin(), live_out_minus_def.end());
// Debug
std::cerr << "Instruction: " << (inst->getName() != "" ? "%" + inst->getName() : "none") << "\n";
std::cerr << " def: "; for (const auto& d : def) std::cerr << d << " "; std::cerr << "\n";
std::cerr << " use: "; for (const auto& u : use) std::cerr << u << " "; std::cerr << "\n";
std::cerr << " live_in: "; for (const auto& v : new_live_in) std::cerr << v << " "; std::cerr << "\n";
std::cerr << " live_out: "; for (const auto& v : new_live_out) std::cerr << v << " "; std::cerr << "\n";
if (live_in[inst] != new_live_in || live_out[inst] != new_live_out) {
live_in[inst] = new_live_in;
live_out[inst] = new_live_out;
changed = true;
}
}
}
} while (changed);
// Debug live_out
for (const auto& [inst, live_vars] : live_out) {
std::cerr << "Instruction: " << (inst->getName() != "" ? "%" + inst->getName() : "none") << " live_out: ";
for (const auto& var : live_vars) {
std::cerr << var << " ";
}
std::cerr << "\n";
}
return live_out;
}
RISCv32CodeGen::RegAllocResult RISCv32CodeGen::register_allocation(Function* func) {
RegAllocResult result;
// TODO: 实现寄存器分配
return result;
std::map<std::string, std::set<std::string>> RISCv32CodeGen::build_interference_graph(
const std::map<Instruction*, std::set<std::string>>& live_sets) {
std::map<std::string, std::set<std::string>> interference_graph;
for (const auto& [inst, live_vars] : live_sets) {
std::string def_var = inst->getName() != "" && !dynamic_cast<AllocaInst*>(inst) ? "%" + inst->getName() : "";
if (def_var != "") {
interference_graph[def_var]; // Initialize
for (const auto& live_var : live_vars) {
if (live_var != def_var && live_var.find("%a(") != 0 && live_var.find("%b(") != 0) {
interference_graph[def_var].insert(live_var);
interference_graph[live_var].insert(def_var);
}
}
}
// Initialize all live variables
for (const auto& live_var : live_vars) {
if (live_var.find("%a(") != 0 && live_var.find("%b(") != 0) {
interference_graph[live_var];
}
}
// Live variables interfere with each other
for (auto it1 = live_vars.begin(); it1 != live_vars.end(); ++it1) {
if (it1->find("%a(") == 0 || it1->find("%b(") == 0) continue;
for (auto it2 = std::next(it1); it2 != live_vars.end(); ++it2) {
if (it2->find("%a(") == 0 || it2->find("%b(") == 0) continue;
interference_graph[*it1].insert(*it2);
interference_graph[*it2].insert(*it1);
}
}
}
// Debug
for (const auto& [vreg, neighbors] : interference_graph) {
std::cerr << "Vreg " << vreg << " interferes with: ";
for (const auto& neighbor : neighbors) {
std::cerr << neighbor << " ";
}
std::cerr << "\n";
}
return interference_graph;
}
RISCv32CodeGen::RegAllocResult RISCv32CodeGen::color_graph(Function* func, const std::map<std::string, std::set<std::string>>& interference_graph) {
RegAllocResult alloc;
std::map<std::string, std::set<std::string>> ig = interference_graph;
std::stack<std::string> stack;
std::set<std::string> spilled;
// Available physical registers
std::vector<PhysicalReg> available_regs = {
PhysicalReg::T0, PhysicalReg::T1, PhysicalReg::T2, PhysicalReg::T3, PhysicalReg::T4, PhysicalReg::T5, PhysicalReg::T6,
PhysicalReg::S0, PhysicalReg::S1, PhysicalReg::S2, PhysicalReg::S3, PhysicalReg::S4, PhysicalReg::S5,
PhysicalReg::S6, PhysicalReg::S7, PhysicalReg::S8, PhysicalReg::S9, PhysicalReg::S10, PhysicalReg::S11
};
// Simplify: Push nodes with degree < number of registers
while (!ig.empty()) {
bool simplified = false;
for (auto it = ig.begin(); it != ig.end();) {
if (it->second.size() < available_regs.size()) {
stack.push(it->first);
for (auto& [vreg, neighbors] : ig) {
neighbors.erase(it->first);
}
it = ig.erase(it);
simplified = true;
} else {
++it;
}
}
if (!simplified) {
// Spill the node with the highest degree
auto max_it = ig.begin();
for (auto it = ig.begin(); it != ig.end(); ++it) {
if (it->second.size() > max_it->second.size()) {
max_it = it;
}
}
spilled.insert(max_it->first);
for (auto& [vreg, neighbors] : ig) {
neighbors.erase(max_it->first);
}
ig.erase(max_it);
}
}
// Assign colors (physical registers)
while (!stack.empty()) {
auto vreg = stack.top();
stack.pop();
std::set<PhysicalReg> used_colors;
if (interference_graph.count(vreg)) {
for (const auto& neighbor : interference_graph.at(vreg)) {
if (alloc.vreg_to_preg.count(neighbor)) {
used_colors.insert(alloc.vreg_to_preg.at(neighbor));
}
}
}
bool assigned = false;
for (auto preg : available_regs) {
if (!used_colors.count(preg)) {
alloc.vreg_to_preg[vreg] = preg;
assigned = true;
break;
}
}
if (!assigned) {
spilled.insert(vreg);
}
}
// Allocate stack space for AllocaInst and spilled virtual registers
int stack_offset = 0;
for (auto& bb : func->getBasicBlocks()) {
for (auto& inst : bb->getInstructions()) {
if (auto alloca = dynamic_cast<AllocaInst*>(inst.get())) {
alloc.stack_map[alloca] = stack_offset;
stack_offset += 4; // 4 bytes per variable
}
}
}
for (const auto& vreg : spilled) {
alloc.spill_map[vreg] = stack_offset;
stack_offset += 4;
}
alloc.stack_size = stack_offset + 8; // Extra space for ra and callee-saved
// Debug output to verify register allocation
for (const auto& [vreg, preg] : alloc.vreg_to_preg) {
std::cerr << "Vreg " << vreg << " assigned to " << get_preg_str(preg) << "\n";
}
for (const auto& vreg : spilled) {
std::cerr << "Vreg " << vreg << " spilled to stack offset " << alloc.spill_map.at(vreg) << "\n";
}
return alloc;
}
RISCv32CodeGen::PhysicalReg RISCv32CodeGen::get_preg_or_temp(const std::string& vreg, const RegAllocResult& alloc) const {
if (alloc.vreg_to_preg.count(vreg)) {
return alloc.vreg_to_preg.at(vreg);
}
return PhysicalReg::T0; // Fallback for spilled registers, handled in emit_instructions
}
} // namespace sysy

View File

@ -6,60 +6,96 @@
#include <vector>
#include <map>
#include <set>
#include <memory>
#include <iostream>
#include <functional>
#include <stack>
namespace sysy {
class RISCv32CodeGen {
public:
explicit RISCv32CodeGen(Module* mod) : module(mod) {}
std::string code_gen(); // 生成模块的汇编代码
enum class PhysicalReg {
ZERO, RA, SP, GP, TP, T0, T1, T2, S0, S1, A0, A1, A2, A3, A4, A5, A6, A7, S2, S3, S4, S5, S6, S7, S8, S9, S10, S11, T3, T4, T5, T6
};
struct DAGNode {
enum NodeKind { CONSTANT, LOAD, STORE, BINARY, CALL, RETURN, BRANCH, ALLOCA_ADDR };
NodeKind kind;
Value* value = nullptr;
std::string inst;
std::string result_vreg;
std::vector<DAGNode*> operands;
std::vector<DAGNode*> users;
DAGNode(NodeKind k) : kind(k) {}
std::string getNodeKindString() const {
switch (kind) {
case CONSTANT: return "CONSTANT";
case LOAD: return "LOAD";
case STORE: return "STORE";
case BINARY: return "BINARY";
case CALL: return "CALL";
case RETURN: return "RETURN";
case BRANCH: return "BRANCH";
case ALLOCA_ADDR: return "ALLOCA_ADDR";
default: return "UNKNOWN";
}
}
};
struct RegAllocResult {
std::map<std::string, PhysicalReg> vreg_to_preg; // 虚拟寄存器到物理寄存器的映射
std::map<Value*, int> stack_map; // AllocaInst到栈偏移的映射
std::map<std::string, int> spill_map; // 溢出的虚拟寄存器到栈偏移的映射
int stack_size = 0; // 总栈帧大小
};
RISCv32CodeGen(Module* mod) : module(mod) {}
std::string code_gen();
std::string module_gen();
std::string function_gen(Function* func);
std::string basicBlock_gen(BasicBlock* bb, const RegAllocResult& alloc, int block_idx);
std::vector<std::unique_ptr<DAGNode>> build_dag(BasicBlock* bb);
void select_instructions(DAGNode* node, const RegAllocResult& alloc);
void emit_instructions(DAGNode* node, std::stringstream& ss, const RegAllocResult& alloc, std::set<DAGNode*>& emitted_nodes);
std::map<Instruction*, std::set<std::string>> liveness_analysis(Function* func);
std::map<std::string, std::set<std::string>> build_interference_graph(
const std::map<Instruction*, std::set<std::string>>& live_sets);
RegAllocResult color_graph(Function* func, const std::map<std::string, std::set<std::string>>& interference_graph);
private:
Module* module;
// 物理寄存器
enum class PhysicalReg {
T0, T1, T2, T3, T4, T5, T6, // x5-x7, x28-x31
A0, A1, A2, A3, A4, A5, A6, A7 // x10-x17
};
static const std::vector<PhysicalReg> allocable_regs;
// 操作数
struct Operand {
enum class Kind { Reg, Imm, Label };
Kind kind;
Value* value; // 用于寄存器
std::string label; // 用于标签或立即数
Operand(Kind k, Value* v) : kind(k), value(v), label("") {}
Operand(Kind k, const std::string& l) : kind(k), value(nullptr), label(l) {}
std::map<PhysicalReg, std::string> preg_to_str = {
{PhysicalReg::ZERO, "zero"}, {PhysicalReg::RA, "ra"}, {PhysicalReg::SP, "sp"},
{PhysicalReg::GP, "gp"}, {PhysicalReg::TP, "tp"}, {PhysicalReg::T0, "t0"},
{PhysicalReg::T1, "t1"}, {PhysicalReg::T2, "t2"}, {PhysicalReg::S0, "s0"},
{PhysicalReg::S1, "s1"}, {PhysicalReg::A0, "a0"}, {PhysicalReg::A1, "a1"},
{PhysicalReg::A2, "a2"}, {PhysicalReg::A3, "a3"}, {PhysicalReg::A4, "a4"},
{PhysicalReg::A5, "a5"}, {PhysicalReg::A6, "a6"}, {PhysicalReg::A7, "a7"},
{PhysicalReg::S2, "s2"}, {PhysicalReg::S3, "s3"}, {PhysicalReg::S4, "s4"},
{PhysicalReg::S5, "s5"}, {PhysicalReg::S6, "s6"}, {PhysicalReg::S7, "s7"},
{PhysicalReg::S8, "s8"}, {PhysicalReg::S9, "s9"}, {PhysicalReg::S10, "s10"},
{PhysicalReg::S11, "s11"}, {PhysicalReg::T3, "t3"}, {PhysicalReg::T4, "t4"},
{PhysicalReg::T5, "t5"}, {PhysicalReg::T6, "t6"}
};
// RISC-V 指令
struct RISCv32Inst {
std::string opcode;
std::vector<Operand> operands;
RISCv32Inst(const std::string& op, const std::vector<Operand>& ops)
: opcode(op), operands(ops) {}
std::vector<PhysicalReg> caller_saved = {
PhysicalReg::T0, PhysicalReg::T1, PhysicalReg::T2, PhysicalReg::T3, PhysicalReg::T4, PhysicalReg::T5, PhysicalReg::T6,
PhysicalReg::A0, PhysicalReg::A1, PhysicalReg::A2, PhysicalReg::A3, PhysicalReg::A4, PhysicalReg::A5, PhysicalReg::A6, PhysicalReg::A7
};
// 寄存器分配结果
struct RegAllocResult {
std::map<Value*, PhysicalReg> reg_map; // 虚拟寄存器到物理寄存器的映射
std::map<Value*, int> stack_map; // 虚拟寄存器到堆栈槽的映射
int stack_size; // 堆栈帧大小
std::vector<PhysicalReg> callee_saved = {
PhysicalReg::S0, PhysicalReg::S1, PhysicalReg::S2, PhysicalReg::S3, PhysicalReg::S4, PhysicalReg::S5,
PhysicalReg::S6, PhysicalReg::S7, PhysicalReg::S8, PhysicalReg::S9, PhysicalReg::S10, PhysicalReg::S11
};
// 后端方法
std::string module_gen();
std::string function_gen(Function* func);
std::string basicBlock_gen(BasicBlock* bb, const RegAllocResult& alloc);
std::vector<RISCv32Inst> instruction_gen(Instruction* inst);
RegAllocResult register_allocation(Function* func);
void eliminate_phi(Function* func);
std::map<Instruction*, std::set<Value*>> liveness_analysis(Function* func);
std::map<Value*, std::set<Value*>> build_interference_graph(
const std::map<Instruction*, std::set<Value*>>& live_sets);
std::string reg_to_string(PhysicalReg reg);
std::string get_preg_str(PhysicalReg preg) const {
return preg_to_str.at(preg);
}
PhysicalReg get_preg_or_temp(const std::string& vreg, const RegAllocResult& alloc) const;
};
} // namespace sysy

View File

@ -10,6 +10,7 @@ using namespace antlr4;
#include "SysYIRGenerator.h"
#include "SysYIRPrinter.h"
#include "SysYIROptPre.h"
#include "RISCv32Backend.h"
// #include "LLVMIRGenerator.h"
using namespace sysy;
@ -73,18 +74,26 @@ int main(int argc, char **argv) {
// visit AST to generate IR
if (argStopAfter == "ir") {
SysYIRGenerator generator;
generator.visitCompUnit(moduleAST);
if (argStopAfter == "ir") {
auto moduleIR = generator.get();
SysYPrinter printer(moduleIR);
printer.printIR();
auto builder = generator.getBuilder();
SysYOptPre optPre(moduleIR, builder);
optPre.SysYOptimizateAfterIR();
printer.printIR();
return EXIT_SUCCESS;
}
// generate assembly
auto module = generator.get();
sysy::RISCv32CodeGen codegen(module);
string asmCode = codegen.code_gen();
if (argStopAfter == "asm") {
cout << asmCode << endl;
return EXIT_SUCCESS;
}
return EXIT_SUCCESS;
}

View File

@ -1,12 +1,8 @@
//test add
int main(){
int a, b;
float d;
a = 10;
b = 2;
int c = a;
d = 1.1 ;
return a + b + c;
return a + b;
}

View File

@ -5,10 +5,10 @@ int main() {
const int b = 2;
int c;
if (a == b)
c = a + b;
if (a != b)
c = b - a + 20; // 21 <- this
else
c = a * b;
c = a * b + b + b + 10; // 16
return c;
}

View File

@ -7,7 +7,7 @@ int mul(int x, int y) {
int main(){
int a, b;
a = 10;
b = 0;
a = mul(a, b);
return a + b;
b = 3;
a = mul(a, b); //60
return a + b; //66
}