feat: complete Lab3 instruction selection and assembly generation
This commit is contained in:
@@ -1,7 +1,11 @@
|
||||
#include "mir/MIR.h"
|
||||
#include "ir/IR.h"
|
||||
|
||||
#include <ostream>
|
||||
#include <stdexcept>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
|
||||
#include "utils/Log.h"
|
||||
|
||||
@@ -16,10 +20,34 @@ const FrameSlot& GetFrameSlot(const MachineFunction& function,
|
||||
return function.GetFrameSlot(operand.GetFrameIndex());
|
||||
}
|
||||
|
||||
bool IsFloatReg(PhysReg reg) {
|
||||
return reg >= PhysReg::S0 && reg <= PhysReg::S15;
|
||||
}
|
||||
|
||||
void PrintStackAccess(std::ostream& os, const char* mnemonic, PhysReg reg,
|
||||
int offset) {
|
||||
os << " " << mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset
|
||||
<< "]\n";
|
||||
bool is_float = IsFloatReg(reg);
|
||||
const char* ldr_cmd = is_float ? "ldr" : "ldr";
|
||||
const char* str_cmd = is_float ? "str" : "str";
|
||||
const char* base_mnemonic = (std::strcmp(mnemonic, "ldur") == 0) ? ldr_cmd : str_cmd;
|
||||
|
||||
if (offset >= -256 && offset <= 255) {
|
||||
if (is_float) {
|
||||
os << " " << base_mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset << "]\n";
|
||||
} else {
|
||||
os << " " << mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset << "]\n";
|
||||
}
|
||||
} else {
|
||||
os << " mov x10, #" << offset << "\n";
|
||||
os << " " << base_mnemonic << " " << PhysRegName(reg) << ", [x29, x10]\n";
|
||||
}
|
||||
}
|
||||
|
||||
std::string GetBlockLabel(const std::string& func_name, const std::string& block_name) {
|
||||
if (block_name == "entry" || block_name.empty()) {
|
||||
return func_name;
|
||||
}
|
||||
return ".L_" + func_name + "_" + block_name;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -28,51 +56,269 @@ void PrintAsm(const MachineFunction& function, std::ostream& os) {
|
||||
os << ".text\n";
|
||||
os << ".global " << function.GetName() << "\n";
|
||||
os << ".type " << function.GetName() << ", %function\n";
|
||||
os << function.GetName() << ":\n";
|
||||
|
||||
for (const auto& inst : function.GetEntry().GetInstructions()) {
|
||||
const auto& ops = inst.GetOperands();
|
||||
switch (inst.GetOpcode()) {
|
||||
case Opcode::Prologue:
|
||||
os << " stp x29, x30, [sp, #-16]!\n";
|
||||
os << " mov x29, sp\n";
|
||||
if (function.GetFrameSize() > 0) {
|
||||
os << " sub sp, sp, #" << function.GetFrameSize() << "\n";
|
||||
struct FloatConstant {
|
||||
std::string label;
|
||||
float value;
|
||||
};
|
||||
std::vector<FloatConstant> float_constants;
|
||||
|
||||
for (size_t b = 0; b < function.GetBlocks().size(); ++b) {
|
||||
const auto& block = function.GetBlocks()[b];
|
||||
|
||||
// Print the block label
|
||||
if (b == 0) {
|
||||
os << function.GetName() << ":\n";
|
||||
} else {
|
||||
os << GetBlockLabel(function.GetName(), block.GetName()) << ":\n";
|
||||
}
|
||||
|
||||
for (const auto& inst : block.GetInstructions()) {
|
||||
const auto& ops = inst.GetOperands();
|
||||
switch (inst.GetOpcode()) {
|
||||
case Opcode::Prologue:
|
||||
os << " stp x29, x30, [sp, #-16]!\n";
|
||||
os << " mov x29, sp\n";
|
||||
if (function.GetFrameSize() > 0) {
|
||||
os << " sub sp, sp, #" << function.GetFrameSize() << "\n";
|
||||
}
|
||||
break;
|
||||
case Opcode::Epilogue:
|
||||
if (function.GetFrameSize() > 0) {
|
||||
os << " add sp, sp, #" << function.GetFrameSize() << "\n";
|
||||
}
|
||||
os << " ldp x29, x30, [sp], #16\n";
|
||||
break;
|
||||
case Opcode::MovImm: {
|
||||
PhysReg dst = ops.at(0).GetReg();
|
||||
if (IsFloatReg(dst)) {
|
||||
// Load float constant
|
||||
int bits = ops.at(1).GetImm();
|
||||
float val;
|
||||
std::memcpy(&val, &bits, sizeof(float));
|
||||
std::string flabel = ".LC_" + function.GetName() + "_" + std::to_string(float_constants.size());
|
||||
float_constants.push_back({flabel, val});
|
||||
|
||||
os << " adrp x8, " << flabel << "\n";
|
||||
os << " ldr " << PhysRegName(dst) << ", [x8, :lo12:" << flabel << "]\n";
|
||||
} else {
|
||||
os << " mov " << PhysRegName(dst) << ", #" << ops.at(1).GetImm() << "\n";
|
||||
}
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case Opcode::Epilogue:
|
||||
if (function.GetFrameSize() > 0) {
|
||||
os << " add sp, sp, #" << function.GetFrameSize() << "\n";
|
||||
case Opcode::LoadStack: {
|
||||
const auto& slot = GetFrameSlot(function, ops.at(1));
|
||||
PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset);
|
||||
break;
|
||||
}
|
||||
os << " ldp x29, x30, [sp], #16\n";
|
||||
break;
|
||||
case Opcode::MovImm:
|
||||
os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", #"
|
||||
<< ops.at(1).GetImm() << "\n";
|
||||
break;
|
||||
case Opcode::LoadStack: {
|
||||
const auto& slot = GetFrameSlot(function, ops.at(1));
|
||||
PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset);
|
||||
break;
|
||||
case Opcode::StoreStack: {
|
||||
const auto& slot = GetFrameSlot(function, ops.at(1));
|
||||
PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset);
|
||||
break;
|
||||
}
|
||||
case Opcode::AddRR:
|
||||
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
|
||||
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
||||
<< PhysRegName(ops.at(2).GetReg()) << "\n";
|
||||
break;
|
||||
case Opcode::SubRR:
|
||||
os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", "
|
||||
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
||||
<< PhysRegName(ops.at(2).GetReg()) << "\n";
|
||||
break;
|
||||
case Opcode::MulRR:
|
||||
os << " mul " << PhysRegName(ops.at(0).GetReg()) << ", "
|
||||
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
||||
<< PhysRegName(ops.at(2).GetReg()) << "\n";
|
||||
break;
|
||||
case Opcode::SDivRR:
|
||||
os << " sdiv " << PhysRegName(ops.at(0).GetReg()) << ", "
|
||||
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
||||
<< PhysRegName(ops.at(2).GetReg()) << "\n";
|
||||
break;
|
||||
case Opcode::MSubRRRR:
|
||||
os << " msub " << PhysRegName(ops.at(0).GetReg()) << ", "
|
||||
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
||||
<< PhysRegName(ops.at(2).GetReg()) << ", "
|
||||
<< PhysRegName(ops.at(3).GetReg()) << "\n";
|
||||
break;
|
||||
case Opcode::FAddRRR:
|
||||
os << " fadd " << PhysRegName(ops.at(0).GetReg()) << ", "
|
||||
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
||||
<< PhysRegName(ops.at(2).GetReg()) << "\n";
|
||||
break;
|
||||
case Opcode::FSubRRR:
|
||||
os << " fsub " << PhysRegName(ops.at(0).GetReg()) << ", "
|
||||
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
||||
<< PhysRegName(ops.at(2).GetReg()) << "\n";
|
||||
break;
|
||||
case Opcode::FMulRRR:
|
||||
os << " fmul " << PhysRegName(ops.at(0).GetReg()) << ", "
|
||||
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
||||
<< PhysRegName(ops.at(2).GetReg()) << "\n";
|
||||
break;
|
||||
case Opcode::FDivRRR:
|
||||
os << " fdiv " << PhysRegName(ops.at(0).GetReg()) << ", "
|
||||
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
||||
<< PhysRegName(ops.at(2).GetReg()) << "\n";
|
||||
break;
|
||||
case Opcode::CmpRR:
|
||||
os << " cmp " << PhysRegName(ops.at(0).GetReg()) << ", "
|
||||
<< PhysRegName(ops.at(1).GetReg()) << "\n";
|
||||
break;
|
||||
case Opcode::FCmpRR:
|
||||
os << " fcmp " << PhysRegName(ops.at(0).GetReg()) << ", "
|
||||
<< PhysRegName(ops.at(1).GetReg()) << "\n";
|
||||
break;
|
||||
case Opcode::Cset:
|
||||
os << " cset " << PhysRegName(ops.at(0).GetReg()) << ", "
|
||||
<< ops.at(1).GetCondCode() << "\n";
|
||||
break;
|
||||
case Opcode::B:
|
||||
os << " b " << GetBlockLabel(function.GetName(), ops.at(0).GetLabelName()) << "\n";
|
||||
break;
|
||||
case Opcode::BCond:
|
||||
os << " b." << ops.at(0).GetCondCode() << " "
|
||||
<< GetBlockLabel(function.GetName(), ops.at(1).GetLabelName()) << "\n";
|
||||
break;
|
||||
case Opcode::Call:
|
||||
os << " bl " << ops.at(0).GetGlobalName() << "\n";
|
||||
break;
|
||||
case Opcode::Ret:
|
||||
os << " ret\n";
|
||||
break;
|
||||
case Opcode::MovReg:
|
||||
if (IsFloatReg(ops.at(0).GetReg()) || IsFloatReg(ops.at(1).GetReg())) {
|
||||
os << " fmov " << PhysRegName(ops.at(0).GetReg()) << ", "
|
||||
<< PhysRegName(ops.at(1).GetReg()) << "\n";
|
||||
} else {
|
||||
os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", "
|
||||
<< PhysRegName(ops.at(1).GetReg()) << "\n";
|
||||
}
|
||||
break;
|
||||
case Opcode::Adrp:
|
||||
os << " adrp " << PhysRegName(ops.at(0).GetReg()) << ", "
|
||||
<< ops.at(1).GetGlobalName() << "\n";
|
||||
break;
|
||||
case Opcode::AddRegImm: {
|
||||
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
|
||||
<< PhysRegName(ops.at(1).GetReg()) << ", ";
|
||||
if (ops.at(2).GetKind() == Operand::Kind::FrameIndex) {
|
||||
const auto& slot = function.GetFrameSlot(ops.at(2).GetFrameIndex());
|
||||
os << "#" << slot.offset << "\n";
|
||||
} else if (ops.at(2).GetKind() == Operand::Kind::Global) {
|
||||
os << ":lo12:" << ops.at(2).GetGlobalName() << "\n";
|
||||
} else {
|
||||
os << "#" << ops.at(2).GetImm() << "\n";
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Opcode::LdrRegReg: {
|
||||
PhysReg reg = ops.at(0).GetReg();
|
||||
const char* ldr_cmd = IsFloatReg(reg) ? "ldr" : "ldr";
|
||||
os << " " << ldr_cmd << " " << PhysRegName(reg) << ", ["
|
||||
<< PhysRegName(ops.at(1).GetReg()) << "]\n";
|
||||
break;
|
||||
}
|
||||
case Opcode::StrRegReg: {
|
||||
PhysReg reg = ops.at(0).GetReg();
|
||||
const char* str_cmd = IsFloatReg(reg) ? "str" : "str";
|
||||
os << " " << str_cmd << " " << PhysRegName(reg) << ", ["
|
||||
<< PhysRegName(ops.at(1).GetReg()) << "]\n";
|
||||
break;
|
||||
}
|
||||
case Opcode::SIToFP:
|
||||
os << " scvtf " << PhysRegName(ops.at(0).GetReg()) << ", "
|
||||
<< PhysRegName(ops.at(1).GetReg()) << "\n";
|
||||
break;
|
||||
case Opcode::FPToSI:
|
||||
os << " fcvtzs " << PhysRegName(ops.at(0).GetReg()) << ", "
|
||||
<< PhysRegName(ops.at(1).GetReg()) << "\n";
|
||||
break;
|
||||
case Opcode::ZExt:
|
||||
if (ops.at(0).GetReg() >= PhysReg::X0 && ops.at(0).GetReg() <= PhysReg::X28) {
|
||||
os << " sxtw " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << "\n";
|
||||
} else {
|
||||
os << " and " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", #1\n";
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Opcode::StoreStack: {
|
||||
const auto& slot = GetFrameSlot(function, ops.at(1));
|
||||
PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset);
|
||||
break;
|
||||
}
|
||||
case Opcode::AddRR:
|
||||
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
|
||||
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
||||
<< PhysRegName(ops.at(2).GetReg()) << "\n";
|
||||
break;
|
||||
case Opcode::Ret:
|
||||
os << " ret\n";
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
os << ".size " << function.GetName() << ", .-" << function.GetName()
|
||||
<< "\n";
|
||||
os << ".size " << function.GetName() << ", .-" << function.GetName() << "\n";
|
||||
|
||||
// Print read-only data segment if there are float constants
|
||||
if (!float_constants.empty()) {
|
||||
os << ".section .rodata\n";
|
||||
os << ".align 2\n";
|
||||
for (const auto& fc : float_constants) {
|
||||
os << fc.label << ":\n";
|
||||
uint32_t bits;
|
||||
std::memcpy(&bits, &fc.value, sizeof(float));
|
||||
os << " .word " << bits << " // float " << fc.value << "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static uint32_t GetTypeSize(const ir::Type* type) {
|
||||
if (type->IsInt32() || type->IsFloat()) {
|
||||
return 4;
|
||||
}
|
||||
if (type->IsPtrInt32() || type->IsPtrFloat()) {
|
||||
return 8;
|
||||
}
|
||||
if (type->IsArray()) {
|
||||
auto* arr_ty = const_cast<ir::Type*>(type)->GetAsArrayType().get();
|
||||
return arr_ty->GetNumElements() * GetTypeSize(arr_ty->GetElementType().get());
|
||||
}
|
||||
return 4;
|
||||
}
|
||||
|
||||
void PrintGlobals(const ir::Module& module, std::ostream& os) {
|
||||
for (const auto& gv : module.GetGlobalValues()) {
|
||||
os << ".global " << gv->GetName() << "\n";
|
||||
|
||||
std::shared_ptr<ir::Type> actual_ty = gv->GetType();
|
||||
if (actual_ty->IsPtrInt32()) actual_ty = ir::Type::GetInt32Type();
|
||||
else if (actual_ty->IsPtrFloat()) actual_ty = ir::Type::GetFloatType();
|
||||
|
||||
uint32_t actual_size = GetTypeSize(actual_ty.get());
|
||||
|
||||
if (gv->GetInitializer()) {
|
||||
os << ".data\n";
|
||||
os << ".align 2\n";
|
||||
os << ".size " << gv->GetName() << ", " << actual_size << "\n";
|
||||
os << gv->GetName() << ":\n";
|
||||
|
||||
if (actual_ty->IsFloat()) {
|
||||
float val = 0.0f;
|
||||
if (auto* cf = dynamic_cast<const ir::ConstantFloat*>(gv->GetInitializer())) {
|
||||
val = cf->GetValue();
|
||||
} else if (auto* ci = dynamic_cast<const ir::ConstantInt*>(gv->GetInitializer())) {
|
||||
val = static_cast<float>(ci->GetValue());
|
||||
}
|
||||
uint32_t bits;
|
||||
std::memcpy(&bits, &val, sizeof(float));
|
||||
os << " .word " << bits << " // float " << val << "\n";
|
||||
} else {
|
||||
int val = 0;
|
||||
if (auto* ci = dynamic_cast<const ir::ConstantInt*>(gv->GetInitializer())) {
|
||||
val = ci->GetValue();
|
||||
} else if (auto* cf = dynamic_cast<const ir::ConstantFloat*>(gv->GetInitializer())) {
|
||||
val = static_cast<int>(cf->GetValue());
|
||||
}
|
||||
os << " .word " << val << "\n";
|
||||
}
|
||||
} else {
|
||||
os << ".bss\n";
|
||||
os << ".align 4\n";
|
||||
os << ".size " << gv->GetName() << ", " << actual_size << "\n";
|
||||
os << gv->GetName() << ":\n";
|
||||
os << " .zero " << actual_size << "\n";
|
||||
}
|
||||
os << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mir
|
||||
|
||||
@@ -18,10 +18,10 @@ void RunFrameLowering(MachineFunction& function) {
|
||||
int cursor = 0;
|
||||
for (const auto& slot : function.GetFrameSlots()) {
|
||||
cursor += slot.size;
|
||||
if (-cursor < -256) {
|
||||
throw std::runtime_error(FormatError("mir", "暂不支持过大的栈帧"));
|
||||
}
|
||||
}
|
||||
|
||||
// Align stack frames to 16 bytes for AArch64
|
||||
cursor = AlignTo(cursor, 16);
|
||||
|
||||
cursor = 0;
|
||||
for (const auto& slot : function.GetFrameSlots()) {
|
||||
@@ -30,16 +30,25 @@ void RunFrameLowering(MachineFunction& function) {
|
||||
}
|
||||
function.SetFrameSize(AlignTo(cursor, 16));
|
||||
|
||||
auto& insts = function.GetEntry().GetInstructions();
|
||||
std::vector<MachineInstr> lowered;
|
||||
lowered.emplace_back(Opcode::Prologue);
|
||||
for (const auto& inst : insts) {
|
||||
if (inst.GetOpcode() == Opcode::Ret) {
|
||||
lowered.emplace_back(Opcode::Epilogue);
|
||||
auto& blocks = function.GetBlocks();
|
||||
if (blocks.empty()) return;
|
||||
|
||||
// Insert Prologue at the start of the first block
|
||||
auto& entry_insts = blocks.front().GetInstructions();
|
||||
entry_insts.insert(entry_insts.begin(), MachineInstr(Opcode::Prologue));
|
||||
|
||||
// Insert Epilogue before every Ret in all blocks
|
||||
for (auto& block : blocks) {
|
||||
auto& insts = block.GetInstructions();
|
||||
std::vector<MachineInstr> lowered;
|
||||
for (const auto& inst : insts) {
|
||||
if (inst.GetOpcode() == Opcode::Ret) {
|
||||
lowered.emplace_back(Opcode::Epilogue);
|
||||
}
|
||||
lowered.push_back(inst);
|
||||
}
|
||||
lowered.push_back(inst);
|
||||
insts = std::move(lowered);
|
||||
}
|
||||
insts = std::move(lowered);
|
||||
}
|
||||
|
||||
} // namespace mir
|
||||
|
||||
@@ -2,122 +2,467 @@
|
||||
|
||||
#include <stdexcept>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
|
||||
#include "ir/IR.h"
|
||||
#include "utils/Log.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace mir {
|
||||
namespace {
|
||||
|
||||
using ValueSlotMap = std::unordered_map<const ir::Value*, int>;
|
||||
|
||||
uint32_t GetTypeSize(const ir::Type* type) {
|
||||
if (type->IsInt32() || type->IsFloat()) {
|
||||
return 4;
|
||||
}
|
||||
if (type->IsPtrInt32() || type->IsPtrFloat()) {
|
||||
return 8; // 64-bit pointers
|
||||
}
|
||||
if (type->IsArray()) {
|
||||
auto* arr_ty = const_cast<ir::Type*>(type)->GetAsArrayType().get();
|
||||
return arr_ty->GetNumElements() * GetTypeSize(arr_ty->GetElementType().get());
|
||||
}
|
||||
return 4;
|
||||
}
|
||||
|
||||
uint32_t GetAllocaSize(const ir::Instruction& inst) {
|
||||
auto type = inst.GetType();
|
||||
if (type->IsPtrInt32() || type->IsPtrFloat()) {
|
||||
return 4;
|
||||
}
|
||||
return GetTypeSize(type.get());
|
||||
}
|
||||
|
||||
std::vector<uint32_t> GetGepStrides(const ir::GetElementPtrInst& gep) {
|
||||
std::vector<uint32_t> strides;
|
||||
auto curr_type = gep.GetPtr()->GetType();
|
||||
if (curr_type->IsPtrInt32() || curr_type->IsPtrFloat()) {
|
||||
strides.push_back(4);
|
||||
} else if (curr_type->IsArray()) {
|
||||
strides.push_back(GetTypeSize(curr_type.get()));
|
||||
for (size_t i = 2; i < gep.GetNumOperands(); ++i) {
|
||||
curr_type = curr_type->GetAsArrayType()->GetElementType();
|
||||
strides.push_back(GetTypeSize(curr_type.get()));
|
||||
}
|
||||
}
|
||||
return strides;
|
||||
}
|
||||
|
||||
void EmitAddressToReg(const ir::Value* value, PhysReg target,
|
||||
const ValueSlotMap& slots, MachineBasicBlock& block) {
|
||||
if (auto* alloca = dynamic_cast<const ir::Instruction*>(value)) {
|
||||
if (alloca->GetOpcode() == ir::Opcode::Alloca) {
|
||||
auto it = slots.find(value);
|
||||
if (it == slots.end()) {
|
||||
throw std::runtime_error(FormatError("mir", "找不到局部变量的栈槽: " + value->GetName()));
|
||||
}
|
||||
block.Append(Opcode::AddRegImm, {Operand::Reg(target), Operand::Reg(PhysReg::X29), Operand::FrameIndex(it->second)});
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (value->IsGlobalValue()) {
|
||||
block.Append(Opcode::Adrp, {Operand::Reg(target), Operand::Global(value->GetName())});
|
||||
block.Append(Opcode::AddRegImm, {Operand::Reg(target), Operand::Reg(target), Operand::Global(value->GetName())});
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise, the address itself is stored in a stack slot
|
||||
auto it = slots.find(value);
|
||||
if (it == slots.end()) {
|
||||
throw std::runtime_error(FormatError("mir", "找不到指针的值槽: " + value->GetName()));
|
||||
}
|
||||
block.Append(Opcode::LoadStack, {Operand::Reg(target), Operand::FrameIndex(it->second)});
|
||||
}
|
||||
|
||||
void EmitValueToReg(const ir::Value* value, PhysReg target,
|
||||
const ValueSlotMap& slots, MachineBasicBlock& block) {
|
||||
if (auto* constant = dynamic_cast<const ir::ConstantInt*>(value)) {
|
||||
block.Append(Opcode::MovImm,
|
||||
{Operand::Reg(target), Operand::Imm(constant->GetValue())});
|
||||
block.Append(Opcode::MovImm, {Operand::Reg(target), Operand::Imm(constant->GetValue())});
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto* constant = dynamic_cast<const ir::ConstantFloat*>(value)) {
|
||||
float fval = constant->GetValue();
|
||||
int bits;
|
||||
std::memcpy(&bits, &fval, sizeof(float));
|
||||
block.Append(Opcode::MovImm, {Operand::Reg(target), Operand::Imm(bits)});
|
||||
return;
|
||||
}
|
||||
|
||||
if (value->IsGlobalValue()) {
|
||||
EmitAddressToReg(value, target, slots, block);
|
||||
return;
|
||||
}
|
||||
|
||||
auto it = slots.find(value);
|
||||
if (it == slots.end()) {
|
||||
throw std::runtime_error(
|
||||
FormatError("mir", "找不到值对应的栈槽: " + value->GetName()));
|
||||
throw std::runtime_error(FormatError("mir", "找不到值对应的栈槽: " + value->GetName()));
|
||||
}
|
||||
|
||||
block.Append(Opcode::LoadStack,
|
||||
{Operand::Reg(target), Operand::FrameIndex(it->second)});
|
||||
block.Append(Opcode::LoadStack, {Operand::Reg(target), Operand::FrameIndex(it->second)});
|
||||
}
|
||||
|
||||
void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
|
||||
ValueSlotMap& slots) {
|
||||
auto& block = function.GetEntry();
|
||||
|
||||
ValueSlotMap& slots, MachineBasicBlock& block) {
|
||||
switch (inst.GetOpcode()) {
|
||||
case ir::Opcode::Alloca: {
|
||||
slots.emplace(&inst, function.CreateFrameIndex());
|
||||
slots.emplace(&inst, function.CreateFrameIndex(GetAllocaSize(inst)));
|
||||
return;
|
||||
}
|
||||
case ir::Opcode::Store: {
|
||||
auto& store = static_cast<const ir::StoreInst&>(inst);
|
||||
auto dst = slots.find(store.GetPtr());
|
||||
if (dst == slots.end()) {
|
||||
throw std::runtime_error(
|
||||
FormatError("mir", "暂不支持对非栈变量地址进行写入"));
|
||||
|
||||
if (auto* alloca = dynamic_cast<const ir::Instruction*>(store.GetPtr())) {
|
||||
if (alloca->GetOpcode() == ir::Opcode::Alloca) {
|
||||
auto it = slots.find(alloca);
|
||||
if (it != slots.end()) {
|
||||
PhysReg val_reg = store.GetValue()->GetType()->IsFloat() ? PhysReg::S8 : PhysReg::W8;
|
||||
EmitValueToReg(store.GetValue(), val_reg, slots, block);
|
||||
block.Append(Opcode::StoreStack, {Operand::Reg(val_reg), Operand::FrameIndex(it->second)});
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
EmitValueToReg(store.GetValue(), PhysReg::W8, slots, block);
|
||||
block.Append(Opcode::StoreStack,
|
||||
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst->second)});
|
||||
|
||||
// Dynamic store
|
||||
PhysReg val_reg = store.GetValue()->GetType()->IsFloat() ? PhysReg::S8 : PhysReg::W8;
|
||||
EmitValueToReg(store.GetValue(), val_reg, slots, block);
|
||||
EmitAddressToReg(store.GetPtr(), PhysReg::X9, slots, block);
|
||||
block.Append(Opcode::StrRegReg, {Operand::Reg(val_reg), Operand::Reg(PhysReg::X9)});
|
||||
return;
|
||||
}
|
||||
case ir::Opcode::Load: {
|
||||
auto& load = static_cast<const ir::LoadInst&>(inst);
|
||||
auto src = slots.find(load.GetPtr());
|
||||
if (src == slots.end()) {
|
||||
throw std::runtime_error(
|
||||
FormatError("mir", "暂不支持对非栈变量地址进行读取"));
|
||||
}
|
||||
int dst_slot = function.CreateFrameIndex();
|
||||
block.Append(Opcode::LoadStack,
|
||||
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(src->second)});
|
||||
block.Append(Opcode::StoreStack,
|
||||
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
|
||||
int dst_slot = function.CreateFrameIndex(GetTypeSize(load.GetType().get()));
|
||||
slots.emplace(&inst, dst_slot);
|
||||
|
||||
if (auto* alloca = dynamic_cast<const ir::Instruction*>(load.GetPtr())) {
|
||||
if (alloca->GetOpcode() == ir::Opcode::Alloca) {
|
||||
auto it = slots.find(alloca);
|
||||
if (it != slots.end()) {
|
||||
PhysReg val_reg = load.GetType()->IsFloat() ? PhysReg::S8 : PhysReg::W8;
|
||||
block.Append(Opcode::LoadStack, {Operand::Reg(val_reg), Operand::FrameIndex(it->second)});
|
||||
block.Append(Opcode::StoreStack, {Operand::Reg(val_reg), Operand::FrameIndex(dst_slot)});
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Dynamic load
|
||||
PhysReg val_reg = load.GetType()->IsFloat() ? PhysReg::S8 : PhysReg::W8;
|
||||
EmitAddressToReg(load.GetPtr(), PhysReg::X9, slots, block);
|
||||
block.Append(Opcode::LdrRegReg, {Operand::Reg(val_reg), Operand::Reg(PhysReg::X9)});
|
||||
block.Append(Opcode::StoreStack, {Operand::Reg(val_reg), Operand::FrameIndex(dst_slot)});
|
||||
return;
|
||||
}
|
||||
case ir::Opcode::Add: {
|
||||
case ir::Opcode::Add:
|
||||
case ir::Opcode::Sub:
|
||||
case ir::Opcode::Mul:
|
||||
case ir::Opcode::Div:
|
||||
case ir::Opcode::Mod: {
|
||||
auto& bin = static_cast<const ir::BinaryInst&>(inst);
|
||||
int dst_slot = function.CreateFrameIndex();
|
||||
int dst_slot = function.CreateFrameIndex(4);
|
||||
slots.emplace(&inst, dst_slot);
|
||||
|
||||
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
|
||||
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block);
|
||||
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8),
|
||||
Operand::Reg(PhysReg::W8),
|
||||
Operand::Reg(PhysReg::W9)});
|
||||
block.Append(Opcode::StoreStack,
|
||||
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
|
||||
|
||||
if (inst.GetOpcode() == ir::Opcode::Add) {
|
||||
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)});
|
||||
} else if (inst.GetOpcode() == ir::Opcode::Sub) {
|
||||
block.Append(Opcode::SubRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)});
|
||||
} else if (inst.GetOpcode() == ir::Opcode::Mul) {
|
||||
block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)});
|
||||
} else if (inst.GetOpcode() == ir::Opcode::Div) {
|
||||
block.Append(Opcode::SDivRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)});
|
||||
} else if (inst.GetOpcode() == ir::Opcode::Mod) {
|
||||
block.Append(Opcode::SDivRR, {Operand::Reg(PhysReg::W10), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)});
|
||||
block.Append(Opcode::MSubRRRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W10), Operand::Reg(PhysReg::W9), Operand::Reg(PhysReg::W8)});
|
||||
}
|
||||
|
||||
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
|
||||
return;
|
||||
}
|
||||
case ir::Opcode::FAdd:
|
||||
case ir::Opcode::FSub:
|
||||
case ir::Opcode::FMul:
|
||||
case ir::Opcode::FDiv: {
|
||||
auto& bin = static_cast<const ir::BinaryInst&>(inst);
|
||||
int dst_slot = function.CreateFrameIndex(4);
|
||||
slots.emplace(&inst, dst_slot);
|
||||
|
||||
EmitValueToReg(bin.GetLhs(), PhysReg::S8, slots, block);
|
||||
EmitValueToReg(bin.GetRhs(), PhysReg::S9, slots, block);
|
||||
|
||||
if (inst.GetOpcode() == ir::Opcode::FAdd) {
|
||||
block.Append(Opcode::FAddRRR, {Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::S9)});
|
||||
} else if (inst.GetOpcode() == ir::Opcode::FSub) {
|
||||
block.Append(Opcode::FSubRRR, {Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::S9)});
|
||||
} else if (inst.GetOpcode() == ir::Opcode::FMul) {
|
||||
block.Append(Opcode::FMulRRR, {Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::S9)});
|
||||
} else if (inst.GetOpcode() == ir::Opcode::FDiv) {
|
||||
block.Append(Opcode::FDivRRR, {Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::S9)});
|
||||
}
|
||||
|
||||
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::S8), Operand::FrameIndex(dst_slot)});
|
||||
return;
|
||||
}
|
||||
case ir::Opcode::ICmpEQ:
|
||||
case ir::Opcode::ICmpNE:
|
||||
case ir::Opcode::ICmpLT:
|
||||
case ir::Opcode::ICmpGT:
|
||||
case ir::Opcode::ICmpLE:
|
||||
case ir::Opcode::ICmpGE: {
|
||||
auto& cmp = static_cast<const ir::BinaryInst&>(inst);
|
||||
int dst_slot = function.CreateFrameIndex(4);
|
||||
slots.emplace(&inst, dst_slot);
|
||||
|
||||
EmitValueToReg(cmp.GetLhs(), PhysReg::W8, slots, block);
|
||||
EmitValueToReg(cmp.GetRhs(), PhysReg::W9, slots, block);
|
||||
block.Append(Opcode::CmpRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)});
|
||||
|
||||
std::string cond;
|
||||
switch (inst.GetOpcode()) {
|
||||
case ir::Opcode::ICmpEQ: cond = "eq"; break;
|
||||
case ir::Opcode::ICmpNE: cond = "ne"; break;
|
||||
case ir::Opcode::ICmpLT: cond = "lt"; break;
|
||||
case ir::Opcode::ICmpGT: cond = "gt"; break;
|
||||
case ir::Opcode::ICmpLE: cond = "le"; break;
|
||||
case ir::Opcode::ICmpGE: cond = "ge"; break;
|
||||
default: break;
|
||||
}
|
||||
|
||||
block.Append(Opcode::Cset, {Operand::Reg(PhysReg::W8), Operand::Cond(cond)});
|
||||
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
|
||||
return;
|
||||
}
|
||||
case ir::Opcode::FCmpEQ:
|
||||
case ir::Opcode::FCmpNE:
|
||||
case ir::Opcode::FCmpLT:
|
||||
case ir::Opcode::FCmpGT:
|
||||
case ir::Opcode::FCmpLE:
|
||||
case ir::Opcode::FCmpGE: {
|
||||
auto& cmp = static_cast<const ir::BinaryInst&>(inst);
|
||||
int dst_slot = function.CreateFrameIndex(4);
|
||||
slots.emplace(&inst, dst_slot);
|
||||
|
||||
EmitValueToReg(cmp.GetLhs(), PhysReg::S8, slots, block);
|
||||
EmitValueToReg(cmp.GetRhs(), PhysReg::S9, slots, block);
|
||||
block.Append(Opcode::FCmpRR, {Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::S9)});
|
||||
|
||||
std::string cond;
|
||||
switch (inst.GetOpcode()) {
|
||||
case ir::Opcode::FCmpEQ: cond = "eq"; break;
|
||||
case ir::Opcode::FCmpNE: cond = "ne"; break;
|
||||
case ir::Opcode::FCmpLT: cond = "mi"; break;
|
||||
case ir::Opcode::FCmpGT: cond = "gt"; break;
|
||||
case ir::Opcode::FCmpLE: cond = "ls"; break;
|
||||
case ir::Opcode::FCmpGE: cond = "ge"; break;
|
||||
default: break;
|
||||
}
|
||||
|
||||
block.Append(Opcode::Cset, {Operand::Reg(PhysReg::W8), Operand::Cond(cond)});
|
||||
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
|
||||
return;
|
||||
}
|
||||
case ir::Opcode::ZExt: {
|
||||
auto& cast = static_cast<const ir::CastInst&>(inst);
|
||||
int dst_slot = function.CreateFrameIndex(4);
|
||||
slots.emplace(&inst, dst_slot);
|
||||
|
||||
EmitValueToReg(cast.GetValue(), PhysReg::W8, slots, block);
|
||||
block.Append(Opcode::ZExt, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W8)});
|
||||
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
|
||||
return;
|
||||
}
|
||||
case ir::Opcode::SIToFP: {
|
||||
auto& cast = static_cast<const ir::CastInst&>(inst);
|
||||
int dst_slot = function.CreateFrameIndex(4);
|
||||
slots.emplace(&inst, dst_slot);
|
||||
|
||||
EmitValueToReg(cast.GetValue(), PhysReg::W8, slots, block);
|
||||
block.Append(Opcode::SIToFP, {Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::W8)});
|
||||
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::S8), Operand::FrameIndex(dst_slot)});
|
||||
return;
|
||||
}
|
||||
case ir::Opcode::FPToSI: {
|
||||
auto& cast = static_cast<const ir::CastInst&>(inst);
|
||||
int dst_slot = function.CreateFrameIndex(4);
|
||||
slots.emplace(&inst, dst_slot);
|
||||
|
||||
EmitValueToReg(cast.GetValue(), PhysReg::S8, slots, block);
|
||||
block.Append(Opcode::FPToSI, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::S8)});
|
||||
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
|
||||
return;
|
||||
}
|
||||
case ir::Opcode::Br: {
|
||||
auto& br = static_cast<const ir::BranchInst&>(inst);
|
||||
std::cerr << "DEBUG: Br is_conditional=" << br.IsConditional() << std::endl;
|
||||
if (br.IsConditional()) {
|
||||
std::cerr << "DEBUG: Cond pointer=" << br.GetCondition() << std::endl;
|
||||
std::cerr << "DEBUG: True pointer=" << br.GetIfTrue() << " name=" << (br.GetIfTrue() ? br.GetIfTrue()->GetName() : "<null>") << std::endl;
|
||||
std::cerr << "DEBUG: False pointer=" << br.GetIfFalse() << " name=" << (br.GetIfFalse() ? br.GetIfFalse()->GetName() : "<null>") << std::endl;
|
||||
EmitValueToReg(br.GetCondition(), PhysReg::W8, slots, block);
|
||||
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W9), Operand::Imm(0)});
|
||||
block.Append(Opcode::CmpRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)});
|
||||
block.Append(Opcode::BCond, {Operand::Cond("ne"), Operand::Label(br.GetIfTrue()->GetName())});
|
||||
block.Append(Opcode::B, {Operand::Label(br.GetIfFalse()->GetName())});
|
||||
} else {
|
||||
std::cerr << "DEBUG: Dest pointer=" << br.GetDest() << " name=" << (br.GetDest() ? br.GetDest()->GetName() : "<null>") << std::endl;
|
||||
block.Append(Opcode::B, {Operand::Label(br.GetDest()->GetName())});
|
||||
}
|
||||
return;
|
||||
}
|
||||
case ir::Opcode::Ret: {
|
||||
auto& ret = static_cast<const ir::ReturnInst&>(inst);
|
||||
EmitValueToReg(ret.GetValue(), PhysReg::W0, slots, block);
|
||||
if (ret.GetValue()) {
|
||||
PhysReg ret_reg = ret.GetValue()->GetType()->IsFloat() ? PhysReg::S0 : PhysReg::W0;
|
||||
EmitValueToReg(ret.GetValue(), ret_reg, slots, block);
|
||||
}
|
||||
block.Append(Opcode::Ret);
|
||||
return;
|
||||
}
|
||||
case ir::Opcode::Sub:
|
||||
case ir::Opcode::Mul:
|
||||
throw std::runtime_error(FormatError("mir", "暂不支持该二元运算"));
|
||||
case ir::Opcode::Call: {
|
||||
auto& call = static_cast<const ir::CallInst&>(inst);
|
||||
int dst_slot = -1;
|
||||
if (!call.GetType()->IsVoid()) {
|
||||
dst_slot = function.CreateFrameIndex(GetTypeSize(call.GetType().get()));
|
||||
slots.emplace(&inst, dst_slot);
|
||||
}
|
||||
|
||||
int int_idx = 0;
|
||||
int float_idx = 0;
|
||||
for (size_t i = 1; i < call.GetNumOperands(); ++i) {
|
||||
auto* arg = call.GetOperand(i);
|
||||
if (arg->GetType()->IsFloat()) {
|
||||
PhysReg reg = static_cast<PhysReg>(static_cast<int>(PhysReg::S0) + float_idx);
|
||||
EmitValueToReg(arg, reg, slots, block);
|
||||
float_idx++;
|
||||
} else {
|
||||
PhysReg reg = (arg->GetType()->IsPtrInt32() || arg->GetType()->IsPtrFloat())
|
||||
? static_cast<PhysReg>(static_cast<int>(PhysReg::X0) + int_idx)
|
||||
: static_cast<PhysReg>(static_cast<int>(PhysReg::W0) + int_idx);
|
||||
EmitValueToReg(arg, reg, slots, block);
|
||||
int_idx++;
|
||||
}
|
||||
}
|
||||
|
||||
block.Append(Opcode::Call, {Operand::Global(call.GetFunction()->GetName())});
|
||||
|
||||
if (dst_slot != -1) {
|
||||
if (call.GetType()->IsFloat()) {
|
||||
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)});
|
||||
} else {
|
||||
PhysReg ret_reg = (call.GetType()->IsPtrInt32() || call.GetType()->IsPtrFloat()) ? PhysReg::X0 : PhysReg::W0;
|
||||
block.Append(Opcode::StoreStack, {Operand::Reg(ret_reg), Operand::FrameIndex(dst_slot)});
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
case ir::Opcode::GEP: {
|
||||
auto& gep = static_cast<const ir::GetElementPtrInst&>(inst);
|
||||
int dst_slot = function.CreateFrameIndex(8);
|
||||
slots.emplace(&inst, dst_slot);
|
||||
|
||||
// Load base pointer address into X8
|
||||
if (dynamic_cast<const ir::AllocaInst*>(gep.GetPtr()) || gep.GetPtr()->IsGlobalValue()) {
|
||||
EmitAddressToReg(gep.GetPtr(), PhysReg::X8, slots, block);
|
||||
} else {
|
||||
EmitValueToReg(gep.GetPtr(), PhysReg::X8, slots, block);
|
||||
}
|
||||
|
||||
auto strides = GetGepStrides(gep);
|
||||
for (size_t i = 1; i < gep.GetNumOperands(); ++i) {
|
||||
auto* idx = gep.GetOperand(i);
|
||||
uint32_t stride = strides.at(i - 1);
|
||||
|
||||
// Skip if offset index is constant 0
|
||||
if (auto* ci = dynamic_cast<const ir::ConstantInt*>(idx)) {
|
||||
if (ci->GetValue() == 0) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
EmitValueToReg(idx, PhysReg::W9, slots, block);
|
||||
if (stride > 1) {
|
||||
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W10), Operand::Imm(stride)});
|
||||
block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::W9), Operand::Reg(PhysReg::W9), Operand::Reg(PhysReg::W10)});
|
||||
}
|
||||
|
||||
// Extend W9 to X9 and add to base address X8
|
||||
block.Append(Opcode::ZExt, {Operand::Reg(PhysReg::X9), Operand::Reg(PhysReg::W9)});
|
||||
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::X9)});
|
||||
}
|
||||
|
||||
// Store address into GEP's stack slot
|
||||
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::X8), Operand::FrameIndex(dst_slot)});
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
throw std::runtime_error(FormatError("mir", "暂不支持该 IR 指令"));
|
||||
throw std::runtime_error(FormatError("mir", "暂不支持该 IR 指令: " + std::to_string(static_cast<int>(inst.GetOpcode()))));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module) {
|
||||
std::vector<std::unique_ptr<MachineFunction>> LowerToMIR(const ir::Module& module) {
|
||||
DefaultContext();
|
||||
std::vector<std::unique_ptr<MachineFunction>> mfuncs;
|
||||
|
||||
if (module.GetFunctions().size() != 1) {
|
||||
throw std::runtime_error(FormatError("mir", "暂不支持多个函数"));
|
||||
for (const auto& funcPtr : module.GetFunctions()) {
|
||||
const auto& func = *funcPtr;
|
||||
if (func.GetBlocks().empty()) continue; // skip declarations
|
||||
|
||||
auto machine_func = std::make_unique<MachineFunction>(func.GetName());
|
||||
ValueSlotMap slots;
|
||||
|
||||
// First, create all basic blocks in MachineFunction
|
||||
std::unordered_map<const ir::BasicBlock*, MachineBasicBlock*> bb_map;
|
||||
machine_func->GetBlocks().reserve(func.GetBlocks().size());
|
||||
for (const auto& bbPtr : func.GetBlocks()) {
|
||||
auto& mbb = machine_func->CreateBlock(bbPtr->GetName());
|
||||
bb_map[bbPtr.get()] = &mbb;
|
||||
}
|
||||
|
||||
auto& entry_block = *bb_map.at(func.GetEntry());
|
||||
|
||||
// Lower function arguments at the start of the entry block
|
||||
const auto& args = func.GetArguments();
|
||||
int int_idx = 0;
|
||||
int float_idx = 0;
|
||||
for (const auto& arg : args) {
|
||||
int slot = machine_func->CreateFrameIndex(GetTypeSize(arg->GetType().get()));
|
||||
slots.emplace(arg.get(), slot);
|
||||
|
||||
if (arg->GetType()->IsFloat()) {
|
||||
PhysReg reg = static_cast<PhysReg>(static_cast<int>(PhysReg::S0) + float_idx);
|
||||
entry_block.Append(Opcode::StoreStack, {Operand::Reg(reg), Operand::FrameIndex(slot)});
|
||||
float_idx++;
|
||||
} else {
|
||||
PhysReg reg = (arg->GetType()->IsPtrInt32() || arg->GetType()->IsPtrFloat())
|
||||
? static_cast<PhysReg>(static_cast<int>(PhysReg::X0) + int_idx)
|
||||
: static_cast<PhysReg>(static_cast<int>(PhysReg::W0) + int_idx);
|
||||
entry_block.Append(Opcode::StoreStack, {Operand::Reg(reg), Operand::FrameIndex(slot)});
|
||||
int_idx++;
|
||||
}
|
||||
}
|
||||
|
||||
// Now, lower all instructions block by block
|
||||
for (const auto& bbPtr : func.GetBlocks()) {
|
||||
auto& mbb = *bb_map.at(bbPtr.get());
|
||||
for (const auto& inst : bbPtr->GetInstructions()) {
|
||||
LowerInstruction(*inst, *machine_func, slots, mbb);
|
||||
}
|
||||
}
|
||||
|
||||
mfuncs.push_back(std::move(machine_func));
|
||||
}
|
||||
|
||||
const auto& func = *module.GetFunctions().front();
|
||||
if (func.GetName() != "main") {
|
||||
throw std::runtime_error(FormatError("mir", "暂不支持非 main 函数"));
|
||||
}
|
||||
|
||||
auto machine_func = std::make_unique<MachineFunction>(func.GetName());
|
||||
ValueSlotMap slots;
|
||||
const auto* entry = func.GetEntry();
|
||||
if (!entry) {
|
||||
throw std::runtime_error(FormatError("mir", "IR 函数缺少入口基本块"));
|
||||
}
|
||||
|
||||
for (const auto& inst : entry->GetInstructions()) {
|
||||
LowerInstruction(*inst, *machine_func, slots);
|
||||
}
|
||||
|
||||
return machine_func;
|
||||
return mfuncs;
|
||||
}
|
||||
|
||||
} // namespace mir
|
||||
|
||||
@@ -8,7 +8,12 @@
|
||||
namespace mir {
|
||||
|
||||
MachineFunction::MachineFunction(std::string name)
|
||||
: name_(std::move(name)), entry_("entry") {}
|
||||
: name_(std::move(name)) {}
|
||||
|
||||
MachineBasicBlock& MachineFunction::CreateBlock(std::string name) {
|
||||
blocks_.emplace_back(std::move(name));
|
||||
return blocks_.back();
|
||||
}
|
||||
|
||||
int MachineFunction::CreateFrameIndex(int size) {
|
||||
int index = static_cast<int>(frame_slots_.size());
|
||||
|
||||
@@ -4,10 +4,12 @@
|
||||
|
||||
namespace mir {
|
||||
|
||||
Operand::Operand(Kind kind, PhysReg reg, int imm)
|
||||
: kind_(kind), reg_(reg), imm_(imm) {}
|
||||
Operand::Operand(Kind kind, PhysReg reg, int imm, std::string str)
|
||||
: kind_(kind), reg_(reg), imm_(imm), str_(std::move(str)) {}
|
||||
|
||||
Operand Operand::Reg(PhysReg reg) { return Operand(Kind::Reg, reg, 0); }
|
||||
Operand Operand::Reg(PhysReg reg) {
|
||||
return Operand(Kind::Reg, reg, 0);
|
||||
}
|
||||
|
||||
Operand Operand::Imm(int value) {
|
||||
return Operand(Kind::Imm, PhysReg::W0, value);
|
||||
@@ -17,6 +19,18 @@ Operand Operand::FrameIndex(int index) {
|
||||
return Operand(Kind::FrameIndex, PhysReg::W0, index);
|
||||
}
|
||||
|
||||
Operand Operand::Global(std::string name) {
|
||||
return Operand(Kind::Global, PhysReg::W0, 0, std::move(name));
|
||||
}
|
||||
|
||||
Operand Operand::Label(std::string name) {
|
||||
return Operand(Kind::Label, PhysReg::W0, 0, std::move(name));
|
||||
}
|
||||
|
||||
Operand Operand::Cond(std::string cond) {
|
||||
return Operand(Kind::Cond, PhysReg::W0, 0, std::move(cond));
|
||||
}
|
||||
|
||||
MachineInstr::MachineInstr(Opcode opcode, std::vector<Operand> operands)
|
||||
: opcode_(opcode), operands_(std::move(operands)) {}
|
||||
|
||||
|
||||
@@ -8,26 +8,19 @@ namespace mir {
|
||||
namespace {
|
||||
|
||||
bool IsAllowedReg(PhysReg reg) {
|
||||
switch (reg) {
|
||||
case PhysReg::W0:
|
||||
case PhysReg::W8:
|
||||
case PhysReg::W9:
|
||||
case PhysReg::X29:
|
||||
case PhysReg::X30:
|
||||
case PhysReg::SP:
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
return true; // We allow all defined physical registers
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void RunRegAlloc(MachineFunction& function) {
|
||||
for (const auto& inst : function.GetEntry().GetInstructions()) {
|
||||
for (const auto& operand : inst.GetOperands()) {
|
||||
if (operand.GetKind() == Operand::Kind::Reg &&
|
||||
!IsAllowedReg(operand.GetReg())) {
|
||||
throw std::runtime_error(FormatError("mir", "寄存器分配失败"));
|
||||
for (const auto& block : function.GetBlocks()) {
|
||||
for (const auto& inst : block.GetInstructions()) {
|
||||
for (const auto& operand : inst.GetOperands()) {
|
||||
if (operand.GetKind() == Operand::Kind::Reg &&
|
||||
!IsAllowedReg(operand.GetReg())) {
|
||||
throw std::runtime_error(FormatError("mir", "寄存器分配失败"));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "mir/MIR.h"
|
||||
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
#include "utils/Log.h"
|
||||
|
||||
@@ -8,18 +9,77 @@ namespace mir {
|
||||
|
||||
const char* PhysRegName(PhysReg reg) {
|
||||
switch (reg) {
|
||||
case PhysReg::W0:
|
||||
return "w0";
|
||||
case PhysReg::W8:
|
||||
return "w8";
|
||||
case PhysReg::W9:
|
||||
return "w9";
|
||||
case PhysReg::X29:
|
||||
return "x29";
|
||||
case PhysReg::X30:
|
||||
return "x30";
|
||||
case PhysReg::SP:
|
||||
return "sp";
|
||||
case PhysReg::W0: return "w0";
|
||||
case PhysReg::W1: return "w1";
|
||||
case PhysReg::W2: return "w2";
|
||||
case PhysReg::W3: return "w3";
|
||||
case PhysReg::W4: return "w4";
|
||||
case PhysReg::W5: return "w5";
|
||||
case PhysReg::W6: return "w6";
|
||||
case PhysReg::W7: return "w7";
|
||||
case PhysReg::W8: return "w8";
|
||||
case PhysReg::W9: return "w9";
|
||||
case PhysReg::W10: return "w10";
|
||||
case PhysReg::W11: return "w11";
|
||||
case PhysReg::W12: return "w12";
|
||||
case PhysReg::W13: return "w13";
|
||||
case PhysReg::W14: return "w14";
|
||||
case PhysReg::W15: return "w15";
|
||||
case PhysReg::W19: return "w19";
|
||||
case PhysReg::W20: return "w20";
|
||||
case PhysReg::W21: return "w21";
|
||||
case PhysReg::W22: return "w22";
|
||||
case PhysReg::W23: return "w23";
|
||||
case PhysReg::W24: return "w24";
|
||||
case PhysReg::W25: return "w25";
|
||||
case PhysReg::W26: return "w26";
|
||||
case PhysReg::W27: return "w27";
|
||||
case PhysReg::W28: return "w28";
|
||||
case PhysReg::X0: return "x0";
|
||||
case PhysReg::X1: return "x1";
|
||||
case PhysReg::X2: return "x2";
|
||||
case PhysReg::X3: return "x3";
|
||||
case PhysReg::X4: return "x4";
|
||||
case PhysReg::X5: return "x5";
|
||||
case PhysReg::X6: return "x6";
|
||||
case PhysReg::X7: return "x7";
|
||||
case PhysReg::X8: return "x8";
|
||||
case PhysReg::X9: return "x9";
|
||||
case PhysReg::X10: return "x10";
|
||||
case PhysReg::X11: return "x11";
|
||||
case PhysReg::X12: return "x12";
|
||||
case PhysReg::X13: return "x13";
|
||||
case PhysReg::X14: return "x14";
|
||||
case PhysReg::X15: return "x15";
|
||||
case PhysReg::X19: return "x19";
|
||||
case PhysReg::X20: return "x20";
|
||||
case PhysReg::X21: return "x21";
|
||||
case PhysReg::X22: return "x22";
|
||||
case PhysReg::X23: return "x23";
|
||||
case PhysReg::X24: return "x24";
|
||||
case PhysReg::X25: return "x25";
|
||||
case PhysReg::X26: return "x26";
|
||||
case PhysReg::X27: return "x27";
|
||||
case PhysReg::X28: return "x28";
|
||||
case PhysReg::S0: return "s0";
|
||||
case PhysReg::S1: return "s1";
|
||||
case PhysReg::S2: return "s2";
|
||||
case PhysReg::S3: return "s3";
|
||||
case PhysReg::S4: return "s4";
|
||||
case PhysReg::S5: return "s5";
|
||||
case PhysReg::S6: return "s6";
|
||||
case PhysReg::S7: return "s7";
|
||||
case PhysReg::S8: return "s8";
|
||||
case PhysReg::S9: return "s9";
|
||||
case PhysReg::S10: return "s10";
|
||||
case PhysReg::S11: return "s11";
|
||||
case PhysReg::S12: return "s12";
|
||||
case PhysReg::S13: return "s13";
|
||||
case PhysReg::S14: return "s14";
|
||||
case PhysReg::S15: return "s15";
|
||||
case PhysReg::X29: return "x29";
|
||||
case PhysReg::X30: return "x30";
|
||||
case PhysReg::SP: return "sp";
|
||||
}
|
||||
throw std::runtime_error(FormatError("mir", "未知物理寄存器"));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user