Files
mysysy/src/SysYIRGenerator.cpp
2025-03-10 16:50:18 +08:00

239 lines
9.0 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include "SysYIRGenerator.h"
#include <iomanip>
std::string SysYIRGenerator::generateIR(SysYParser::CompUnitContext* unit) {
visitCompUnit(unit);
return irStream.str();
}
std::string SysYIRGenerator::getNextTemp() {
return "%" + std::to_string(tempCounter++);
}
std::string SysYIRGenerator::getLLVMType(const std::string& type) {
if (type == "int") return "i32";
if (type == "float") return "float";
if (type.find("[]") != std::string::npos)
return getLLVMType(type.substr(0, type.size() - 2)) + "*";
return "i32";
}
std::any SysYIRGenerator::visitCompUnit(SysYParser::CompUnitContext* ctx) {
for (auto decl : ctx->decl()) {
decl->accept(this);
}
for (auto funcDef : ctx->funcDef()) {
funcDef->accept(this);
}
return nullptr;
}
std::any SysYIRGenerator::visitConstDecl(SysYParser::ConstDeclContext* ctx) {
// 常量声明暂不处理LLVM IR 中常量通常内联)
return nullptr;
}
std::any SysYIRGenerator::visitVarDecl(SysYParser::VarDeclContext* ctx) {
for (auto varDef : ctx->varDef()) {
varDef->accept(this);
}
return nullptr;
}
std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext* ctx) {
currentFunction = ctx->Ident()->getText();
symbolTable.clear();
// 函数头
std::string returnType = getLLVMType(ctx->funcType()->getText());
irStream << "define " << returnType << " @" << currentFunction << "(";
// 参数
auto paramsCtx = ctx->funcFParams();
if (paramsCtx) {
auto params = paramsCtx->funcFParam();
for (size_t i = 0; i < params.size(); ++i) {
if (i > 0) irStream << ", ";
auto param = params[i];
std::string paramName = "%" + std::to_string(i);
std::string paramType = getLLVMType(param->bType()->getText());
irStream << paramType << " " << paramName;
// 分配参数
std::string allocaName = getNextTemp();
symbolTable[param->Ident()->getText()] = allocaName;
irStream << "\n " << allocaName << " = alloca " << paramType;
irStream << "\n store " << paramType << " %" << i << ", " << paramType << "* " << allocaName;
}
}
irStream << ") {\nentry:\n";
// 函数体
ctx->blockStmt()->accept(this);
// 默认返回值
if (returnType == "void") {
irStream << " ret void\n";
} else {
irStream << " ret " << returnType << " 0\n";
}
irStream << "}\n\n";
return nullptr;
}
std::any SysYIRGenerator::visitBlockStmt(SysYParser::BlockStmtContext* ctx) {
for (auto item : ctx->blockItem()) {
item->accept(this);
}
return nullptr;
}
std::any SysYIRGenerator::visitStmt(SysYParser::StmtContext* ctx) {
if (ctx->lValue() && ctx->exp()) {
// 赋值语句
std::string lhs = std::any_cast<std::string>(ctx->lValue()->accept(this));
std::string rhs = std::any_cast<std::string>(ctx->exp()->accept(this));
irStream << " store " << getLLVMType("") << " " << rhs << ", " << getLLVMType("") << "* " << lhs << "\n";
} else if (ctx->RETURN()) {
// 返回语句
if (ctx->exp()) {
std::string value = std::any_cast<std::string>(ctx->exp()->accept(this));
irStream << " ret " << getLLVMType("") << " " << value << "\n";
} else {
irStream << " ret void\n";
}
}
return nullptr;
}
std::any SysYIRGenerator::visitLValue(SysYParser::LValueContext* ctx) {
std::string varName = ctx->Ident()->getText();
if (symbolTable.find(varName) == symbolTable.end()) {
std::string allocaName = getNextTemp();
symbolTable[varName] = allocaName;
irStream << " " << allocaName << " = alloca " << getLLVMType("") << "\n";
}
return symbolTable[varName];
}
std::any SysYIRGenerator::visitNumber(SysYParser::NumberContext* ctx) {
if (ctx->ILITERAL()) {
return "i32 " + ctx->ILITERAL()->getText();
} else if (ctx->FLITERAL()) {
return "float " + ctx->FLITERAL()->getText();
}
return "";
}
std::any SysYIRGenerator::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
if (ctx->unaryOp()) {
std::string operand = std::any_cast<std::string>(ctx->unaryExp()->accept(this));
std::string op = ctx->unaryOp()->getText();
std::string temp = getNextTemp();
if (op == "-") {
irStream << " " << temp << " = sub " << getLLVMType("") << " 0, " << operand << "\n";
} else if (op == "!") {
irStream << " " << temp << " = xor " << getLLVMType("") << " " << operand << ", 1\n";
}
return temp;
}
return ctx->primaryExp()->accept(this);
}
std::any SysYIRGenerator::visitMulExp(SysYParser::MulExpContext* ctx) {
auto unaryExps = ctx->unaryExp();
std::string left = std::any_cast<std::string>(unaryExps[0]->accept(this));
for (size_t i = 1; i < unaryExps.size(); ++i) {
std::string right = std::any_cast<std::string>(unaryExps[i]->accept(this));
std::string op = ctx->children[2 * i - 1]->getText();
std::string temp = getNextTemp();
if (op == "*") {
irStream << " " << temp << " = mul " << getLLVMType("") << " " << left << ", " << right << "\n";
} else if (op == "/") {
irStream << " " << temp << " = sdiv " << getLLVMType("") << " " << left << ", " << right << "\n";
} else if (op == "%") {
irStream << " " << temp << " = srem " << getLLVMType("") << " " << left << ", " << right << "\n";
}
left = temp;
}
return left;
}
std::any SysYIRGenerator::visitAddExp(SysYParser::AddExpContext* ctx) {
auto mulExps = ctx->mulExp();
std::string left = std::any_cast<std::string>(mulExps[0]->accept(this));
for (size_t i = 1; i < mulExps.size(); ++i) {
std::string right = std::any_cast<std::string>(mulExps[i]->accept(this));
std::string op = ctx->children[2 * i - 1]->getText();
std::string temp = getNextTemp();
if (op == "+") {
irStream << " " << temp << " = add " << getLLVMType("") << " " << left << ", " << right << "\n";
} else if (op == "-") {
irStream << " " << temp << " = sub " << getLLVMType("") << " " << left << ", " << right << "\n";
}
left = temp;
}
return left;
}
std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext* ctx) {
auto addExps = ctx->addExp();
std::string left = std::any_cast<std::string>(addExps[0]->accept(this));
for (size_t i = 1; i < addExps.size(); ++i) {
std::string right = std::any_cast<std::string>(addExps[i]->accept(this));
std::string op = ctx->children[2 * i - 1]->getText();
std::string temp = getNextTemp();
if (op == "<") {
irStream << " " << temp << " = icmp slt " << getLLVMType("") << " " << left << ", " << right << "\n";
} else if (op == ">") {
irStream << " " << temp << " = icmp sgt " << getLLVMType("") << " " << left << ", " << right << "\n";
} else if (op == "<=") {
irStream << " " << temp << " = icmp sle " << getLLVMType("") << " " << left << ", " << right << "\n";
} else if (op == ">=") {
irStream << " " << temp << " = icmp sge " << getLLVMType("") << " " << left << ", " << right << "\n";
}
left = temp;
}
return left;
}
std::any SysYIRGenerator::visitEqExp(SysYParser::EqExpContext* ctx) {
auto relExps = ctx->relExp();
std::string left = std::any_cast<std::string>(relExps[0]->accept(this));
for (size_t i = 1; i < relExps.size(); ++i) {
std::string right = std::any_cast<std::string>(relExps[i]->accept(this));
std::string op = ctx->children[2 * i - 1]->getText();
std::string temp = getNextTemp();
if (op == "==") {
irStream << " " << temp << " = icmp eq " << getLLVMType("") << " " << left << ", " << right << "\n";
} else if (op == "!=") {
irStream << " " << temp << " = icmp ne " << getLLVMType("") << " " << left << ", " << right << "\n";
}
left = temp;
}
return left;
}
std::any SysYIRGenerator::visitLAndExp(SysYParser::LAndExpContext* ctx) {
auto eqExps = ctx->eqExp();
std::string left = std::any_cast<std::string>(eqExps[0]->accept(this));
for (size_t i = 1; i < eqExps.size(); ++i) {
std::string right = std::any_cast<std::string>(eqExps[i]->accept(this));
std::string temp = getNextTemp();
irStream << " " << temp << " = and " << getLLVMType("") << " " << left << ", " << right << "\n";
left = temp;
}
return left;
}
std::any SysYIRGenerator::visitLOrExp(SysYParser::LOrExpContext* ctx) {
auto lAndExps = ctx->lAndExp();
std::string left = std::any_cast<std::string>(lAndExps[0]->accept(this));
for (size_t i = 1; i < lAndExps.size(); ++i) {
std::string right = std::any_cast<std::string>(lAndExps[i]->accept(this));
std::string temp = getNextTemp();
irStream << " " << temp << " = or " << getLLVMType("") << " " << left << ", " << right << "\n";
left = temp;
}
return left;
}