Merge branch 'array_add'

This commit is contained in:
lixuanwang
2025-06-22 14:24:00 +08:00
23 changed files with 5657 additions and 3162 deletions

87
TODO.md Normal file
View File

@ -0,0 +1,87 @@
要打通从SysY到RISC-V的完整编译流程以下是必须实现的核心模块和关键步骤按编译流程顺序。在你们当前IR生成阶段可以优先实现这些基础模块来快速获得可工作的RISC-V汇编输出
### 1. **前端必须模块**
- **词法/语法分析**(已完成):
- `SysYLexer`/`SysYParser`ANTLR生成的解析器
- **IR生成核心**
- `SysYIRGenerator`将AST转换为中间表示IR
- `IRBuilder`:构建指令和基本块的工具类(你们正在实现的部分)
### 2. **中端必要优化(最小集合)**
常量传播
| 优化阶段 | 关键作用 | 是否必须 |
|-------------------|----------------------------------|----------|
| `Mem2Reg` | 消除冗余内存访问转换为SSA形式 | ✅ 核心 |
| `DCE` (死代码消除) | 移除无用指令 | ✅ 必要 |
| `DFE` (死函数消除) | 移除未使用的函数 | ✅ 必要 |
| `FuncAnalysis` | 函数调用关系分析 | ✅ 基础 |
| `Global2Local` | 全局变量降级为局部变量 | ✅ 重要 |
### 3. **后端核心流程(必须实现)**
```mermaid
graph LR
A[IR指令选择] --> B[寄存器分配]
B --> C[指令调度]
C --> D[汇编生成]
```
1. **指令选择**(关键步骤):
- `DAGBuilder`将IR转换为有向无环图DAG
- `DAGCoverage`DAG到目标指令的映射
- `Mid2End`IR到机器指令的转换接口
2. **寄存器分配**
- `RegisterAlloc`:基础寄存器分配器(可先实现简单算法如线性扫描)
3. **汇编生成**
- `RiscvPrinter`将机器指令输出为RISC-V汇编
- 实现基础指令集:`add`/`sub`/`lw`/`sw`/`beq`/`jal`
### 4. **最小可工作流程**
```cpp
// 精简版编译流程(跳过复杂优化)
int main() {
// 1. 前端解析
auto module = sysy::SysYIRGenerator().genIR(input);
// 2. 关键中端优化
sysy::Mem2Reg(module).run(); // 必须
sysy::Global2Local(module).run(); // 必须
sysy::DCE(module).run(); // 推荐
// 3. 后端代码生成
auto backendModule = mid2end::CodeGenerater().run(module);
riscv::RiscvPrinter().print("output.s", backendModule);
}
```
### 5. **当前开发优先级建议**
1. **完成IR生成**
- 确保能构建基本块、函数、算术/内存/控制流指令
- 实现`createCall`/`createLoad`/`createStore`等核心方法
2. **实现Mem2Reg**
- 插入Phi节点
- 变量重命名(关键算法)
3. **构建基础后端**
- 指令选择实现IR到RISC-V的简单映射例如`IRAdd``add`
- 寄存器分配:使用无限寄存器方案(后期替换为真实分配)
- 汇编打印:支持基础指令输出
> **注意**:循环优化、函数内联、高级寄存器分配等可在基础流程打通后逐步添加。初期可跳过复杂优化。
### 6. 调试建议
- 添加IR打印模块`SysYPrinter`)验证前端输出
- 使用简化测试用例:
```c
int main() {
int a = 1;
int b = a + 2;
return b;
}
```
- 逐步扩展支持:
1. 算术运算 → 2. 条件分支 → 3. 函数调用 → 4. 数组访问
通过聚焦这些核心模块你们可以快速打通从SysY到RISC-V的基础编译流程后续再逐步添加优化传递提升代码质量。

62
olddef.h Normal file
View File

@ -0,0 +1,62 @@
class SymbolTable{
private:
enum Kind
{
kModule,
kFunction,
kBlock,
};
std::forward_list<std::pair<Kind, std::unordered_map<std::string, Value*>>> Scopes;
public:
struct ModuleScope {
SymbolTable& tables_ref;
ModuleScope(SymbolTable& tables) : tables_ref(tables) {
tables.enter(kModule);
}
~ModuleScope() { tables_ref.exit(); }
};
struct FunctionScope {
SymbolTable& tables_ref;
FunctionScope(SymbolTable& tables) : tables_ref(tables) {
tables.enter(kFunction);
}
~FunctionScope() { tables_ref.exit(); }
};
struct BlockScope {
SymbolTable& tables_ref;
BlockScope(SymbolTable& tables) : tables_ref(tables) {
tables.enter(kBlock);
}
~BlockScope() { tables_ref.exit(); }
};
SymbolTable() = default;
bool isModuleScope() const { return Scopes.front().first == kModule; }
bool isFunctionScope() const { return Scopes.front().first == kFunction; }
bool isBlockScope() const { return Scopes.front().first == kBlock; }
Value *lookup(const std::string &name) const {
for (auto &scope : Scopes) {
auto iter = scope.second.find(name);
if (iter != scope.second.end())
return iter->second;
}
return nullptr;
}
auto insert(const std::string &name, Value *value) {
assert(not Scopes.empty());
return Scopes.front().second.emplace(name, value);
}
private:
void enter(Kind kind) {
Scopes.emplace_front();
Scopes.front().first = kind;
}
void exit() {
Scopes.pop_front();
}
};

View File

@ -1,355 +0,0 @@
#include <algorithm>
#include <iostream>
using namespace std;
#include "ASTPrinter.h"
#include "SysYParser.h"
any ASTPrinter::visitCompUnit(SysYParser::CompUnitContext *ctx) {
if(ctx->decl().empty() && ctx->funcDef().empty())
return nullptr;
for (auto dcl : ctx->decl()) {dcl->accept(this);cout << '\n';}cout << '\n';
for (auto func : ctx->funcDef()) {func->accept(this);cout << "\n";}
return nullptr;
}
// std::any ASTPrinter::visitBType(SysYParser::BTypeContext *ctx);
// std::any ASTPrinter::visitDecl(SysYParser::DeclContext *ctx);
std::any ASTPrinter::visitConstDecl(SysYParser::ConstDeclContext *ctx) {
cout << getIndent() << ctx->CONST()->getText() << ' ' << ctx->bType()->getText() << ' ';
auto numConstDefs = ctx->constDef().size();
ctx->constDef(0)->accept(this);
for (int i = 1; i < numConstDefs; ++i) {
cout << ctx->COMMA(i - 1)->getText() << ' ';
ctx->constDef(i)->accept(this);
}
cout << ctx->SEMICOLON()->getText() << '\n';
return nullptr;
}
std::any ASTPrinter::visitConstDef(SysYParser::ConstDefContext *ctx) {
cout << ctx->Ident()->getText();
auto numConstExps = ctx->constExp().size();
for (int i = 0; i < numConstExps; ++i) {
cout << ctx->LBRACK(i)->getText();
ctx->constExp(i)->accept(this);
cout << ctx->RBRACK(i)->getText();
}
cout << ' ' << ctx->ASSIGN()->getText() << ' ';
ctx->constInitVal()->accept(this);
return nullptr;
}
// std::any ASTPrinter::visitConstInitVal(SysYParser::ConstInitValContext *ctx);
std::any ASTPrinter::visitVarDecl(SysYParser::VarDeclContext *ctx){
cout << getIndent() << ctx->bType()->getText() << ' ';
auto numVarDefs = ctx->varDef().size();
ctx->varDef(0)->accept(this);
for (int i = 1; i < numVarDefs; ++i) {
cout << ", ";
ctx->varDef(i)->accept(this);
}
cout << ctx->SEMICOLON()->getText() << '\n';
return nullptr;
}
std::any ASTPrinter::visitVarDef(SysYParser::VarDefContext *ctx){
cout << ctx->Ident()->getText();
auto numConstExps = ctx->constExp().size();
for (int i = 0; i < numConstExps; ++i) {
cout << ctx->LBRACK(i)->getText();
ctx->constExp(i)->accept(this);
cout << ctx->RBRACK(i)->getText();
}
if (ctx->initVal()) {
cout << ' ' << ctx->ASSIGN()->getText() << ' ';
ctx->initVal()->accept(this);
}
return nullptr;
}
std::any ASTPrinter::visitInitVal(SysYParser::InitValContext *ctx){
if (ctx->exp()) {
ctx->exp()->accept(this);
} else {
cout << ctx->LBRACE()->getText();
auto numInitVals = ctx->initVal().size();
ctx->initVal(0)->accept(this);
for (int i = 1; i < numInitVals; ++i) {
cout << ctx->COMMA(i - 1)->getText() << ' ';
ctx->initVal(i)->accept(this);
}
cout << ctx->RBRACE()->getText();
}
return nullptr;
}
std::any ASTPrinter::visitFuncDef(SysYParser::FuncDefContext *ctx){
cout << getIndent() << ctx->funcType()->getText() << ' ' << ctx->Ident()->getText();
cout << ctx->LPAREN()->getText();
if (ctx->funcFParams()) ctx->funcFParams()->accept(this);
if(ctx->RPAREN())
cout << ctx->RPAREN()->getText();
else
cout << "<missing \')\'?>";
ctx->blockStmt()->accept(this);
return nullptr;
}
// std::any ASTPrinter::visitFuncType(SysYParser::FuncTypeContext *ctx);
std::any ASTPrinter::visitFuncFParams(SysYParser::FuncFParamsContext *ctx){
auto numFuncFParams = ctx->funcFParam().size();
ctx->funcFParam(0)->accept(this);
for (int i = 1; i < numFuncFParams; ++i) {
cout << ctx->COMMA(i - 1)->getText() << ' ';
ctx->funcFParam(i)->accept(this);
}
return nullptr;
}
std::any ASTPrinter::visitFuncFParam(SysYParser::FuncFParamContext *ctx){
cout << ctx->bType()->getText() << ' ' << ctx->Ident()->getText();
if (!ctx->exp().empty()) {
cout << "[]";
for (auto exp : ctx->exp()) {
cout << '[';
exp->accept(this);
cout << ']';
}
}
return nullptr;
}
std::any ASTPrinter::visitBlockStmt(SysYParser::BlockStmtContext *ctx){
cout << ctx->LBRACE()->getText() << endl;
indentLevel++;
for (auto item : ctx->blockItem()) item->accept(this);
indentLevel--;
cout << getIndent() << ctx->RBRACE()->getText() << endl;
return nullptr;
}
// std::any ASTPrinter::visitBlockItem(SysYParser::BlockItemContext *ctx);
std::any ASTPrinter::visitAssignStmt(SysYParser::AssignStmtContext *ctx){
cout << getIndent();
ctx->lValue()->accept(this);
cout << ' ' << ctx->ASSIGN()->getText() << ' ';
ctx->exp()->accept(this);
cout << ctx->SEMICOLON()->getText() << '\n';
return nullptr;
}
std::any ASTPrinter::visitExpStmt(SysYParser::ExpStmtContext *ctx){
cout << getIndent();
if (ctx->exp()) {
ctx->exp()->accept(this);
}
cout << ctx->SEMICOLON()->getText() << '\n';
return nullptr;
}
std::any ASTPrinter::visitIfStmt(SysYParser::IfStmtContext *ctx){
cout << getIndent() << ctx->IF()->getText() << ' ' << ctx->LPAREN()->getText();
ctx->cond()->accept(this);
cout << ctx->RPAREN()->getText() << ' ';
//格式化有问题
if(ctx->stmt(0)) {
ctx->stmt(0)->accept(this);
}
else {
cout << '{' << endl;
indentLevel++;
ctx->stmt(0)->accept(this);
indentLevel--;
cout << getIndent() << '}' << endl;
}
if (ctx->ELSE()) {
cout << getIndent() << ctx->ELSE()->getText() << ' ';
ctx->stmt(1)->accept(this);
}
return nullptr;
}
std::any ASTPrinter::visitWhileStmt(SysYParser::WhileStmtContext *ctx){
cout << getIndent() << ctx->WHILE()->getText() << ' ' << ctx->LPAREN()->getText();
ctx->cond()->accept(this);
cout << ctx->RPAREN()->getText() << ' ';
ctx->stmt()->accept(this);
return nullptr;
}
std::any ASTPrinter::visitBreakStmt(SysYParser::BreakStmtContext *ctx){
cout << getIndent() << ctx->BREAK()->getText() << ctx->SEMICOLON()->getText() << '\n';
return nullptr;
}
std::any ASTPrinter::visitContinueStmt(SysYParser::ContinueStmtContext *ctx){
cout << getIndent() << ctx->CONTINUE()->getText() << ctx->SEMICOLON()->getText() << '\n';
return nullptr;
}
std::any ASTPrinter::visitReturnStmt(SysYParser::ReturnStmtContext *ctx){
cout << getIndent() << ctx->RETURN()->getText() << ' ';
if (ctx->exp()) {
ctx->exp()->accept(this);
}
cout << ctx->SEMICOLON()->getText() << '\n';
return nullptr;
}
// std::any ASTPrinter::visitExp(SysYParser::ExpContext *ctx);
// std::any ASTPrinter::visitCond(SysYParser::CondContext *ctx);
std::any ASTPrinter::visitLValue(SysYParser::LValueContext *ctx){
cout << ctx->Ident()->getText();
for (auto exp : ctx->exp()) {
cout << "[";
exp->accept(this);
cout << "]";
}
return nullptr;
}
// std::any ASTPrinter::visitPrimaryExp(SysYParser::PrimaryExpContext *ctx);
std::any ASTPrinter::visitParenExp(SysYParser::ParenExpContext *ctx){
cout << ctx->LPAREN()->getText();
ctx->exp()->accept(this);
cout << ctx->RPAREN()->getText();
return nullptr;
}
std::any ASTPrinter::visitNumber(SysYParser::NumberContext *ctx) {
if(ctx->ILITERAL())cout << ctx->ILITERAL()->getText();
if(ctx->FLITERAL())cout << ctx->FLITERAL()->getText();
return nullptr;
}
std::any ASTPrinter::visitString(SysYParser::StringContext *ctx) {
cout << ctx->STRING()->getText();
return nullptr;
}
// std::any visitUnaryExp(SysYParser::UnaryExpContext *ctx);
// std::any ASTPrinter::visitUnaryOp(SysYParser::UnaryOpContext *ctx);
std::any ASTPrinter::visitCall(SysYParser::CallContext *ctx){
cout << ctx->Ident()->getText() << ctx->LPAREN()->getText();
if(ctx->funcRParams())
ctx->funcRParams()->accept(this);
cout << ctx->RPAREN()->getText();
return nullptr;
}
any ASTPrinter::visitFuncRParams(SysYParser::FuncRParamsContext *ctx) {
if (ctx->exp().empty())
return nullptr;
auto numParams = ctx->exp().size();
ctx->exp(0)->accept(this);
for (int i = 1; i < numParams; ++i) {
cout << ctx->COMMA(i - 1)->getText() << ' ';
ctx->exp(i)->accept(this);
}
return nullptr;
}
std::any ASTPrinter::visitMulExp(SysYParser::MulExpContext *ctx){
auto unaryExps = ctx->unaryExp();
if (unaryExps.size() == 1) {
unaryExps[0]->accept(this);
} else {
for (size_t i = 0; i < unaryExps.size() - 1; ++i) {
auto opNode = dynamic_cast<antlr4::tree::TerminalNode *>(ctx->children[2 * i + 1]);
if (opNode) {
unaryExps[i]->accept(this);
cout << " " << opNode->getText() << " ";
}
}
unaryExps.back()->accept(this);
}
return nullptr;
}
std::any ASTPrinter::visitAddExp(SysYParser::AddExpContext *ctx){
auto mulExps = ctx->mulExp();
if (mulExps.size() == 1) {
mulExps[0]->accept(this);
} else {
for (size_t i = 0; i < mulExps.size() - 1; ++i) {
auto opNode = dynamic_cast<antlr4::tree::TerminalNode *>(ctx->children[2 * i + 1]);
if (opNode) {
mulExps[i]->accept(this);
cout << " " << opNode->getText() << " ";
}
}
mulExps.back()->accept(this);
}
return nullptr;
}
// 以下表达式待补全形式同addexp mulexp
std::any ASTPrinter::visitRelExp(SysYParser::RelExpContext *ctx){
auto relExps = ctx->addExp();
if (relExps.size() == 1) {
relExps[0]->accept(this);
} else {
for (size_t i = 0; i < relExps.size() - 1; ++i) {
auto opNode = dynamic_cast<antlr4::tree::TerminalNode *>(ctx->children[2 * i + 1]);
if (opNode) {
relExps[i]->accept(this);
cout << " " << opNode->getText() << " ";
}
}
relExps.back()->accept(this);
}
return nullptr;
}
std::any ASTPrinter::visitEqExp(SysYParser::EqExpContext *ctx){
auto eqExps = ctx->relExp();
if (eqExps.size() == 1) {
eqExps[0]->accept(this);
} else {
for (size_t i = 0; i < eqExps.size() - 1; ++i) {
auto opNode = dynamic_cast<antlr4::tree::TerminalNode *>(ctx->children[2 * i + 1]);
if (opNode) {
eqExps[i]->accept(this);
cout << " " << opNode->getText() << " ";
}
}
eqExps.back()->accept(this);
}
return nullptr;
}
std::any ASTPrinter::visitLAndExp(SysYParser::LAndExpContext *ctx){
auto lAndExps = ctx->eqExp();
if (lAndExps.size() == 1) {
lAndExps[0]->accept(this);
} else {
for (size_t i = 0; i < lAndExps.size() - 1; ++i) {
auto opNode = dynamic_cast<antlr4::tree::TerminalNode *>(ctx->children[2 * i + 1]);
if (opNode) {
lAndExps[i]->accept(this);
cout << " " << opNode->getText() << " ";
}
}
lAndExps.back()->accept(this);
}
return nullptr;
}
std::any ASTPrinter::visitLOrExp(SysYParser::LOrExpContext *ctx){
auto lOrExps = ctx->lAndExp();
if (lOrExps.size() == 1) {
lOrExps[0]->accept(this);
} else {
for (size_t i = 0; i < lOrExps.size() - 1; ++i) {
auto opNode = dynamic_cast<antlr4::tree::TerminalNode *>(ctx->children[2 * i + 1]);
if (opNode) {
lOrExps[i]->accept(this);
cout << " " << opNode->getText() << " ";
}
}
lOrExps.back()->accept(this);
}
return nullptr;
}
std::any ASTPrinter::visitConstExp(SysYParser::ConstExpContext *ctx){
ctx->addExp()->accept(this);
return nullptr;
}

View File

@ -1,59 +0,0 @@
#pragma once
#include "SysYBaseVisitor.h"
#include "SysYParser.h"
class ASTPrinter : public SysYBaseVisitor {
private:
int indentLevel = 0;
std::string getIndent() {
return std::string(indentLevel * 4, ' ');
}
public:
std::any visitCompUnit(SysYParser::CompUnitContext *ctx) override;
// std::any visitBType(SysYParser::BTypeContext *ctx) override;
// std::any visitDecl(SysYParser::DeclContext *ctx) override;
std::any visitConstDecl(SysYParser::ConstDeclContext *ctx) override;
std::any visitConstDef(SysYParser::ConstDefContext *ctx) override;
// std::any visitConstInitVal(SysYParser::ConstInitValContext *ctx) override;
std::any visitVarDecl(SysYParser::VarDeclContext *ctx) override;
std::any visitVarDef(SysYParser::VarDefContext *ctx) override;
std::any visitInitVal(SysYParser::InitValContext *ctx) override;
std::any visitFuncDef(SysYParser::FuncDefContext *ctx) override;
// std::any visitFuncType(SysYParser::FuncTypeContext *ctx) override;
std::any visitFuncFParams(SysYParser::FuncFParamsContext *ctx) override;
std::any visitFuncFParam(SysYParser::FuncFParamContext *ctx) override;
std::any visitBlockStmt(SysYParser::BlockStmtContext *ctx) override;
// std::any visitBlockItem(SysYParser::BlockItemContext *ctx) override;
// std::any visitStmt(SysYParser::StmtContext *ctx) override;
std::any visitAssignStmt(SysYParser::AssignStmtContext *ctx) override;
std::any visitExpStmt(SysYParser::ExpStmtContext *ctx) override;
std::any visitIfStmt(SysYParser::IfStmtContext *ctx) override;
std::any visitWhileStmt(SysYParser::WhileStmtContext *ctx) override;
std::any visitBreakStmt(SysYParser::BreakStmtContext *ctx) override;
std::any visitContinueStmt(SysYParser::ContinueStmtContext *ctx) override;
std::any visitReturnStmt(SysYParser::ReturnStmtContext *ctx) override;
// std::any visitExp(SysYParser::ExpContext *ctx) override;
// std::any visitCond(SysYParser::CondContext *ctx) override;
std::any visitLValue(SysYParser::LValueContext *ctx) override;
// std::any visitPrimaryExp(SysYParser::PrimaryExpContext *ctx) override;
std::any visitParenExp(SysYParser::ParenExpContext *ctx) override;
std::any visitNumber(SysYParser::NumberContext *ctx) override;
std::any visitString(SysYParser::StringContext *ctx) override;
// std::any visitUnaryExp(SysYParser::UnaryExpContext *ctx) override;
std::any visitCall(SysYParser::CallContext *ctx) override;
// std::any visitUnExpOp(SysYParser::UnExpContext *ctx) override;
// std::any visitUnaryOp(SysYParser::UnaryOpContext *ctx) override;
std::any visitFuncRParams(SysYParser::FuncRParamsContext *ctx) override;
std::any visitMulExp(SysYParser::MulExpContext *ctx) override;
std::any visitAddExp(SysYParser::AddExpContext *ctx) override;
std::any visitRelExp(SysYParser::RelExpContext *ctx) override;
std::any visitEqExp(SysYParser::EqExpContext *ctx) override;
std::any visitLAndExp(SysYParser::LAndExpContext *ctx) override;
std::any visitLOrExp(SysYParser::LOrExpContext *ctx) override;
std::any visitConstExp(SysYParser::ConstExpContext *ctx) override;
};

View File

@ -13,13 +13,12 @@ target_link_libraries(SysYParser PUBLIC antlr4_shared)
add_executable(sysyc
sysyc.cpp
ASTPrinter.cpp
IR.cpp
SysYIRGenerator.cpp
Backend.cpp
RISCv32Backend.cpp
)
target_include_directories(sysyc PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
target_include_directories(sysyc PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/include)
target_compile_options(sysyc PRIVATE -frtti)
target_link_libraries(sysyc PRIVATE SysYParser)

1020
src/IR.cpp

File diff suppressed because it is too large Load Diff

994
src/IR.h
View File

@ -1,994 +0,0 @@
#pragma once
#include "range.h"
#include <cassert>
#include <cstdint>
#include <iterator>
#include <list>
#include <map>
#include <memory>
#include <ostream>
#include <string>
#include <type_traits>
#include <vector>
namespace sysy {
/*!
* \defgroup type Types
* The SysY type system is quite simple.
* 1. The base class `Type` is used to represent all primitive scalar types,
* include `int`, `float`, `void`, and the label type representing branch
* targets.
* 2. `PointerType` and `FunctionType` derive from `Type` and represent pointer
* type and function type, respectively.
*
* NOTE `Type` and its derived classes have their ctors declared as 'protected'.
* Users must use Type::getXXXType() methods to obtain `Type` pointers.
* @{
*/
/*!
* `Type` is used to represent all primitive scalar types,
* include `int`, `float`, `void`, and the label type representing branch
* targets
*/
class Type {
public:
enum Kind {
kInt,
kFloat,
kVoid,
kLabel,
kPointer,
kFunction,
};
Kind kind;
protected:
Type(Kind kind) : kind(kind) {}
virtual ~Type() = default;
public:
static Type *getIntType();
static Type *getFloatType();
static Type *getVoidType();
static Type *getLabelType();
static Type *getPointerType(Type *baseType);
static Type *getFunctionType(Type *returnType,
const std::vector<Type *> &paramTypes = {});
public:
Kind getKind() const { return kind; }
bool isInt() const { return kind == kInt; }
bool isFloat() const { return kind == kFloat; }
bool isVoid() const { return kind == kVoid; }
bool isLabel() const { return kind == kLabel; }
bool isPointer() const { return kind == kPointer; }
bool isFunction() const { return kind == kFunction; }
bool isIntOrFloat() const { return kind == kInt or kind == kFloat; }
int getSize() const;
template <typename T>
std::enable_if_t<std::is_base_of_v<Type, T>, T *> as() const {
return dynamic_cast<T *>(const_cast<Type *>(this));
}
void print(std::ostream &os) const;
}; // class Type
//! Pointer type
class PointerType : public Type {
protected:
Type *baseType;
protected:
PointerType(Type *baseType) : Type(kPointer), baseType(baseType) {}
public:
static PointerType *get(Type *baseType);
public:
Type *getBaseType() const { return baseType; }
}; // class PointerType
//! Function type
class FunctionType : public Type {
private:
Type *returnType;
std::vector<Type *> paramTypes;
protected:
FunctionType(Type *returnType, const std::vector<Type *> &paramTypes = {})
: Type(kFunction), returnType(returnType), paramTypes(paramTypes) {}
public:
static FunctionType *get(Type *returnType,
const std::vector<Type *> &paramTypes = {});
public:
Type *getReturnType() const { return returnType; }
auto getParamTypes() const { return make_range(paramTypes); }
int getNumParams() const { return paramTypes.size(); }
}; // class FunctionType
/*!
* @}
*/
/*!
* \defgroup ir IR
*
* The SysY IR is an instruction level language. The IR is orgnized
* as a four-level tree structure, as shown below
*
* \dotfile ir-4level.dot IR Structure
*
* - `Module` corresponds to the top level "CompUnit" syntax structure
* - `GlobalValue` corresponds to the "Decl" syntax structure
* - `Function` corresponds to the "FuncDef" syntax structure
* - `BasicBlock` is a sequence of instructions without branching. A `Function`
* made up by one or more `BasicBlock`s.
* - `Instruction` represents a primitive operation on values, e.g., add or sub.
*
* The fundamental data concept in SysY IR is `Value`. A `Value` is like
* a register and is used by `Instruction`s as input/output operand. Each value
* has an associated `Type` indicating the data type held by the value.
*
* Most `Instruction`s have a three-address signature, i.e., there are at most 2
* input values and at most 1 output value.
*
* The SysY IR adots a Static-Single-Assignment (SSA) design. That is, `Value`
* is defined (as the output operand ) by some instruction, and used (as the
* input operand) by other instructions. While a value can be used by multiple
* instructions, the `definition` occurs only once. As a result, there is a
* one-to-one relation between a value and the instruction defining it. In other
* words, any instruction defines a value can be viewed as the defined value
* itself. So `Instruction` is also a `Value` in SysY IR. See `Value` for the
* type hierachy.
*
* @{
*/
class User;
class Value;
//! `Use` represents the relation between a `Value` and its `User`
class Use {
private:
//! the position of value in the user's operands, i.e.,
//! user->getOperands[index] == value
int index;
User *user;
Value *value;
public:
Use() = default;
Use(int index, User *user, Value *value)
: index(index), user(user), value(value) {}
public:
int getIndex() const { return index; }
User *getUser() const { return user; }
Value *getValue() const { return value; }
void setValue(Value *value) { value = value; }
}; // class Use
template <typename T>
inline std::enable_if_t<std::is_base_of_v<Value, T>, bool>
isa(const Value *value) {
return T::classof(value);
}
template <typename T>
inline std::enable_if_t<std::is_base_of_v<Value, T>, T *>
dyncast(Value *value) {
return isa<T>(value) ? static_cast<T *>(value) : nullptr;
}
template <typename T>
inline std::enable_if_t<std::is_base_of_v<Value, T>, const T *>
dyncast(const Value *value) {
return isa<T>(value) ? static_cast<const T *>(value) : nullptr;
}
//! The base class of all value types
class Value {
public:
enum Kind : uint64_t {
kInvalid,
// Instructions
// Binary
kAdd = 0x1UL << 0,
kSub = 0x1UL << 1,
kMul = 0x1UL << 2,
kDiv = 0x1UL << 3,
kRem = 0x1UL << 4,
kICmpEQ = 0x1UL << 5,
kICmpNE = 0x1UL << 6,
kICmpLT = 0x1UL << 7,
kICmpGT = 0x1UL << 8,
kICmpLE = 0x1UL << 9,
kICmpGE = 0x1UL << 10,
kFAdd = 0x1UL << 14,
kFSub = 0x1UL << 15,
kFMul = 0x1UL << 16,
kFDiv = 0x1UL << 17,
kFRem = 0x1UL << 18,
kFCmpEQ = 0x1UL << 19,
kFCmpNE = 0x1UL << 20,
kFCmpLT = 0x1UL << 21,
kFCmpGT = 0x1UL << 22,
kFCmpLE = 0x1UL << 23,
kFCmpGE = 0x1UL << 24,
// Unary
kNeg = 0x1UL << 25,
kNot = 0x1UL << 26,
kFNeg = 0x1UL << 27,
kFtoI = 0x1UL << 28,
kIToF = 0x1UL << 29,
// call
kCall = 0x1UL << 30,
// terminator
kCondBr = 0x1UL << 31,
kBr = 0x1UL << 32,
kReturn = 0x1UL << 33,
// mem op
kAlloca = 0x1UL << 34,
kLoad = 0x1UL << 35,
kStore = 0x1UL << 36,
kFirstInst = kAdd,
kLastInst = kStore,
// others
kArgument = 0x1UL << 37,
kBasicBlock = 0x1UL << 38,
kFunction = 0x1UL << 39,
kConstant = 0x1UL << 40,
kGlobal = 0x1UL << 41,
};
protected:
Kind kind;
Type *type;
std::string name;
std::list<Use *> uses;
protected:
Value(Kind kind, Type *type, const std::string &name = "")
: kind(kind), type(type), name(name), uses() {}
public:
virtual ~Value() = default;
public:
Kind getKind() const { return kind; }
static bool classof(const Value *) { return true; }
public:
Type *getType() const { return type; }
const std::string &getName() const { return name; }
void setName(const std::string &n) { name = n; }
bool hasName() const { return not name.empty(); }
bool isInt() const { return type->isInt(); }
bool isFloat() const { return type->isFloat(); }
bool isPointer() const { return type->isPointer(); }
const std::list<Use *> &getUses() { return uses; }
void addUse(Use *use) { uses.push_back(use); }
void replaceAllUsesWith(Value *value);
void removeUse(Use *use) { uses.remove(use); }
bool isConstant() const;
public:
virtual void print(std::ostream &os) const {};
}; // class Value
/*!
* Static constants known at compile time.
*
* `ConstantValue`s are not defined by instructions, and do not use any other
* `Value`s. It's type is either `int` or `float`.
*/
class ConstantValue : public Value {
protected:
union {
int iScalar;
float fScalar;
};
protected:
ConstantValue(int value)
: Value(kConstant, Type::getIntType(), ""), iScalar(value) {}
ConstantValue(float value)
: Value(kConstant, Type::getFloatType(), ""), fScalar(value) {}
public:
static ConstantValue *get(int value);
static ConstantValue *get(float value);
public:
static bool classof(const Value *value) {
return value->getKind() == kConstant;
}
public:
int getInt() const {
assert(isInt());
return iScalar;
}
float getFloat() const {
assert(isFloat());
return fScalar;
}
public:
void print(std::ostream &os) const override;
}; // class ConstantValue
class BasicBlock;
/*!
* Arguments of `BasicBlock`s.
*
* SysY IR is an SSA language, however, it does not use PHI instructions as in
* LLVM IR. `Value`s from different predecessor blocks are passed explicitly as
* block arguments. This is also the approach used by MLIR.
* NOTE that `Function` does not own `Argument`s, function arguments are
* implemented as its entry block's arguments.
*/
class Argument : public Value {
protected:
BasicBlock *block;
int index;
public:
Argument(Type *type, BasicBlock *block, int index,
const std::string &name = "");
public:
static bool classof(const Value *value) {
return value->getKind() == kConstant;
}
public:
BasicBlock *getParent() const { return block; }
int getIndex() const { return index; }
public:
void print(std::ostream &os) const override;
};
class Instruction;
class Function;
/*!
* The container for `Instruction` sequence.
*
* `BasicBlock` maintains a list of `Instruction`s, with the last one being
* a terminator (branch or return). Besides, `BasicBlock` stores its arguments
* and records its predecessor and successor `BasicBlock`s.
*/
class BasicBlock : public Value {
friend class Function;
public:
using inst_list = std::list<std::unique_ptr<Instruction>>;
using iterator = inst_list::iterator;
using arg_list = std::vector<std::unique_ptr<Argument>>;
using block_list = std::vector<BasicBlock *>;
protected:
Function *parent;
inst_list instructions;
arg_list arguments;
block_list successors;
block_list predecessors;
protected:
explicit BasicBlock(Function *parent, const std::string &name = "");
public:
static bool classof(const Value *value) {
return value->getKind() == kBasicBlock;
}
public:
int getNumInstructions() const { return instructions.size(); }
int getNumArguments() const { return arguments.size(); }
int getNumPredecessors() const { return predecessors.size(); }
int getNumSuccessors() const { return successors.size(); }
Function *getParent() const { return parent; }
inst_list &getInstructions() { return instructions; }
auto getArguments() const { return make_range(arguments); }
block_list &getPredecessors() { return predecessors; }
block_list &getSuccessors() { return successors; }
iterator begin() { return instructions.begin(); }
iterator end() { return instructions.end(); }
iterator terminator() { return std::prev(end()); }
Argument *createArgument(Type *type, const std::string &name = "") {
auto arg = new Argument(type, this, arguments.size(), name);
assert(arg);
arguments.emplace_back(arg);
return arguments.back().get();
};
public:
void print(std::ostream &os) const override;
}; // class BasicBlock
//! User is the abstract base type of `Value` types which use other `Value` as
//! operands. Currently, there are two kinds of `User`s, `Instruction` and
//! `GlobalValue`.
class User : public Value {
protected:
std::vector<Use> operands;
protected:
User(Kind kind, Type *type, const std::string &name = "")
: Value(kind, type, name), operands() {}
public:
using use_iterator = std::vector<Use>::const_iterator;
struct operand_iterator : public std::vector<Use>::const_iterator {
using Base = std::vector<Use>::const_iterator;
operand_iterator(const Base &iter) : Base(iter) {}
using value_type = Value *;
value_type operator->() { return Base::operator*().getValue(); }
value_type operator*() { return Base::operator*().getValue(); }
};
// struct const_operand_iterator : std::vector<Use>::const_iterator {
// using Base = std::vector<Use>::const_iterator;
// const_operand_iterator(const Base &iter) : Base(iter) {}
// using value_type = Value *;
// value_type operator->() { return operator*().getValue(); }
// };
public:
int getNumOperands() const { return operands.size(); }
operand_iterator operand_begin() const { return operands.begin(); }
operand_iterator operand_end() const { return operands.end(); }
auto getOperands() const {
return make_range(operand_begin(), operand_end());
}
Value *getOperand(int index) const { return operands[index].getValue(); }
void addOperand(Value *value) {
operands.emplace_back(operands.size(), this, value);
value->addUse(&operands.back());
}
template <typename ContainerT> void addOperands(const ContainerT &operands) {
for (auto value : operands)
addOperand(value);
}
void replaceOperand(int index, Value *value);
void setOperand(int index, Value *value);
}; // class User
/*!
* Base of all concrete instruction types.
*/
class Instruction : public User {
public:
// enum Kind : uint64_t {
// kInvalid = 0x0UL,
// // Binary
// kAdd = 0x1UL << 0,
// kSub = 0x1UL << 1,
// kMul = 0x1UL << 2,
// kDiv = 0x1UL << 3,
// kRem = 0x1UL << 4,
// kICmpEQ = 0x1UL << 5,
// kICmpNE = 0x1UL << 6,
// kICmpLT = 0x1UL << 7,
// kICmpGT = 0x1UL << 8,
// kICmpLE = 0x1UL << 9,
// kICmpGE = 0x1UL << 10,
// kFAdd = 0x1UL << 14,
// kFSub = 0x1UL << 15,
// kFMul = 0x1UL << 16,
// kFDiv = 0x1UL << 17,
// kFRem = 0x1UL << 18,
// kFCmpEQ = 0x1UL << 19,
// kFCmpNE = 0x1UL << 20,
// kFCmpLT = 0x1UL << 21,
// kFCmpGT = 0x1UL << 22,
// kFCmpLE = 0x1UL << 23,
// kFCmpGE = 0x1UL << 24,
// // Unary
// kNeg = 0x1UL << 25,
// kNot = 0x1UL << 26,
// kFNeg = 0x1UL << 27,
// kFtoI = 0x1UL << 28,
// kIToF = 0x1UL << 29,
// // call
// kCall = 0x1UL << 30,
// // terminator
// kCondBr = 0x1UL << 31,
// kBr = 0x1UL << 32,
// kReturn = 0x1UL << 33,
// // mem op
// kAlloca = 0x1UL << 34,
// kLoad = 0x1UL << 35,
// kStore = 0x1UL << 36,
// // constant
// // kConstant = 0x1UL << 37,
// };
protected:
Kind kind;
BasicBlock *parent;
protected:
Instruction(Kind kind, Type *type, BasicBlock *parent = nullptr,
const std::string &name = "");
public:
static bool classof(const Value *value) {
return value->getKind() >= kFirstInst and value->getKind() <= kLastInst;
}
public:
Kind getKind() const { return kind; }
BasicBlock *getParent() const { return parent; }
Function *getFunction() const { return parent->getParent(); }
void setParent(BasicBlock *bb) { parent = bb; }
bool isBinary() const {
static constexpr uint64_t BinaryOpMask =
(kAdd | kSub | kMul | kDiv | kRem) |
(kICmpEQ | kICmpNE | kICmpLT | kICmpGT | kICmpLE | kICmpGE) |
(kFAdd | kFSub | kFMul | kFDiv | kFRem) |
(kFCmpEQ | kFCmpNE | kFCmpLT | kFCmpGT | kFCmpLE | kFCmpGE);
return kind & BinaryOpMask;
}
bool isUnary() const {
static constexpr uint64_t UnaryOpMask = kNeg | kNot | kFNeg | kFtoI | kIToF;
return kind & UnaryOpMask;
}
bool isMemory() const {
static constexpr uint64_t MemoryOpMask = kAlloca | kLoad | kStore;
return kind & MemoryOpMask;
}
bool isTerminator() const {
static constexpr uint64_t TerminatorOpMask = kCondBr | kBr | kReturn;
return kind & TerminatorOpMask;
}
bool isCmp() const {
static constexpr uint64_t CmpOpMask =
(kICmpEQ | kICmpNE | kICmpLT | kICmpGT | kICmpLE | kICmpGE) |
(kFCmpEQ | kFCmpNE | kFCmpLT | kFCmpGT | kFCmpLE | kFCmpGE);
return kind & CmpOpMask;
}
bool isBranch() const {
static constexpr uint64_t BranchOpMask = kBr | kCondBr;
return kind & BranchOpMask;
}
bool isCommutative() const {
static constexpr uint64_t CommutativeOpMask =
kAdd | kMul | kICmpEQ | kICmpNE | kFAdd | kFMul | kFCmpEQ | kFCmpNE;
return kind & CommutativeOpMask;
}
bool isUnconditional() const { return kind == kBr; }
bool isConditional() const { return kind == kCondBr; }
}; // class Instruction
class Function;
//! Function call.
class CallInst : public Instruction {
friend class IRBuilder;
protected:
CallInst(Function *callee, const std::vector<Value *> &args = {},
BasicBlock *parent = nullptr, const std::string &name = "");
public:
static bool classof(const Value *value) { return value->getKind() == kCall; }
public:
Function *getCallee() const;
auto getArguments() const {
return make_range(std::next(operand_begin()), operand_end());
}
public:
void print(std::ostream &os) const override;
}; // class CallInst
//! Unary instruction, includes '!', '-' and type conversion.
class UnaryInst : public Instruction {
friend class IRBuilder;
protected:
UnaryInst(Kind kind, Type *type, Value *operand, BasicBlock *parent = nullptr,
const std::string &name = "")
: Instruction(kind, type, parent, name) {
addOperand(operand);
}
public:
static bool classof(const Value *value) {
return Instruction::classof(value) and
static_cast<const Instruction *>(value)->isUnary();
}
public:
Value *getOperand() const { return User::getOperand(0); }
public:
void print(std::ostream &os) const override;
}; // class UnaryInst
//! Binary instruction, e.g., arithmatic, relation, logic, etc.
class BinaryInst : public Instruction {
friend class IRBuilder;
protected:
BinaryInst(Kind kind, Type *type, Value *lhs, Value *rhs, BasicBlock *parent,
const std::string &name = "")
: Instruction(kind, type, parent, name) {
addOperand(lhs);
addOperand(rhs);
}
public:
static bool classof(const Value *value) {
return Instruction::classof(value) and
static_cast<const Instruction *>(value)->isBinary();
}
public:
Value *getLhs() const { return getOperand(0); }
Value *getRhs() const { return getOperand(1); }
public:
void print(std::ostream &os) const override;
}; // class BinaryInst
//! The return statement
class ReturnInst : public Instruction {
friend class IRBuilder;
protected:
ReturnInst(Value *value = nullptr, BasicBlock *parent = nullptr)
: Instruction(kReturn, Type::getVoidType(), parent, "") {
if (value)
addOperand(value);
}
public:
static bool classof(const Value *value) {
return value->getKind() == kReturn;
}
public:
bool hasReturnValue() const { return not operands.empty(); }
Value *getReturnValue() const {
return hasReturnValue() ? getOperand(0) : nullptr;
}
public:
void print(std::ostream &os) const override;
}; // class ReturnInst
//! Unconditional branch
class UncondBrInst : public Instruction {
friend class IRBuilder;
protected:
UncondBrInst(BasicBlock *block, std::vector<Value *> args,
BasicBlock *parent = nullptr)
: Instruction(kCondBr, Type::getVoidType(), parent, "") {
assert(block->getNumArguments() == args.size());
addOperand(block);
addOperands(args);
}
public:
static bool classof(const Value *value) { return value->getKind() == kBr; }
public:
BasicBlock *getBlock() const { return dyncast<BasicBlock>(getOperand(0)); }
auto getArguments() const {
return make_range(std::next(operand_begin()), operand_end());
}
public:
void print(std::ostream &os) const override;
}; // class UncondBrInst
//! Conditional branch
class CondBrInst : public Instruction {
friend class IRBuilder;
protected:
CondBrInst(Value *condition, BasicBlock *thenBlock, BasicBlock *elseBlock,
const std::vector<Value *> &thenArgs,
const std::vector<Value *> &elseArgs, BasicBlock *parent = nullptr)
: Instruction(kCondBr, Type::getVoidType(), parent, "") {
assert(thenBlock->getNumArguments() == thenArgs.size() and
elseBlock->getNumArguments() == elseArgs.size());
addOperand(condition);
addOperand(thenBlock);
addOperand(elseBlock);
addOperands(thenArgs);
addOperands(elseArgs);
}
public:
static bool classof(const Value *value) {
return value->getKind() == kCondBr;
}
public:
Value *getCondition() const { return getOperand(0); }
BasicBlock *getThenBlock() const {
return dyncast<BasicBlock>(getOperand(1));
}
BasicBlock *getElseBlock() const {
return dyncast<BasicBlock>(getOperand(2));
}
auto getThenArguments() const {
auto begin = std::next(operand_begin(), 3);
auto end = std::next(begin, getThenBlock()->getNumArguments());
return make_range(begin, end);
}
auto getElseArguments() const {
auto begin =
std::next(operand_begin(), 3 + getThenBlock()->getNumArguments());
auto end = operand_end();
return make_range(begin, end);
}
public:
void print(std::ostream &os) const override;
}; // class CondBrInst
//! Allocate memory for stack variables, used for non-global variable declartion
class AllocaInst : public Instruction {
friend class IRBuilder;
protected:
AllocaInst(Type *type, const std::vector<Value *> &dims = {},
BasicBlock *parent = nullptr, const std::string &name = "")
: Instruction(kAlloca, type, parent, name) {
addOperands(dims);
}
public:
static bool classof(const Value *value) {
return value->getKind() == kAlloca;
}
public:
int getNumDims() const { return getNumOperands(); }
auto getDims() const { return getOperands(); }
Value *getDim(int index) { return getOperand(index); }
public:
void print(std::ostream &os) const override;
}; // class AllocaInst
//! Load a value from memory address specified by a pointer value
class LoadInst : public Instruction {
friend class IRBuilder;
protected:
LoadInst(Value *pointer, const std::vector<Value *> &indices = {},
BasicBlock *parent = nullptr, const std::string &name = "")
: Instruction(kLoad, pointer->getType()->as<PointerType>()->getBaseType(),
parent, name) {
addOperand(pointer);
addOperands(indices);
}
public:
static bool classof(const Value *value) { return value->getKind() == kLoad; }
public:
int getNumIndices() const { return getNumOperands() - 1; }
Value *getPointer() const { return getOperand(0); }
auto getIndices() const {
return make_range(std::next(operand_begin()), operand_end());
}
Value *getIndex(int index) const { return getOperand(index + 1); }
public:
void print(std::ostream &os) const override;
}; // class LoadInst
//! Store a value to memory address specified by a pointer value
class StoreInst : public Instruction {
friend class IRBuilder;
protected:
StoreInst(Value *value, Value *pointer,
const std::vector<Value *> &indices = {},
BasicBlock *parent = nullptr, const std::string &name = "")
: Instruction(kStore, Type::getVoidType(), parent, name) {
addOperand(value);
addOperand(pointer);
addOperands(indices);
}
public:
static bool classof(const Value *value) { return value->getKind() == kStore; }
public:
int getNumIndices() const { return getNumOperands() - 2; }
Value *getValue() const { return getOperand(0); }
Value *getPointer() const { return getOperand(1); }
auto getIndices() const {
return make_range(std::next(operand_begin(), 2), operand_end());
}
Value *getIndex(int index) const { return getOperand(index + 2); }
public:
void print(std::ostream &os) const override;
}; // class StoreInst
class Module;
//! Function definition
class Function : public Value {
friend class Module;
protected:
Function(Module *parent, Type *type, const std::string &name)
: Value(kFunction, type, name), parent(parent), variableID(0), blocks() {
blocks.emplace_back(new BasicBlock(this, "entry"));
}
public:
static bool classof(const Value *value) {
return value->getKind() == kFunction;
}
public:
using block_list = std::list<std::unique_ptr<BasicBlock>>;
protected:
Module *parent;
int variableID;
int blockID;
block_list blocks;
public:
Type *getReturnType() const {
return getType()->as<FunctionType>()->getReturnType();
}
auto getParamTypes() const {
return getType()->as<FunctionType>()->getParamTypes();
}
auto getBasicBlocks() const { return make_range(blocks); }
BasicBlock *getEntryBlock() const { return blocks.front().get(); }
BasicBlock *addBasicBlock(const std::string &name = "") {
blocks.emplace_back(new BasicBlock(this, name));
return blocks.back().get();
}
void removeBasicBlock(BasicBlock *block) {
blocks.remove_if([&](std::unique_ptr<BasicBlock> &b) -> bool {
return block == b.get();
});
}
int allocateVariableID() { return variableID++; }
int allocateblockID() { return blockID++; }
public:
void print(std::ostream &os) const override;
}; // class Function
// class ArrayValue : public User {
// protected:
// ArrayValue(Type *type, const std::vector<Value *> &values = {})
// : User(type, "") {
// addOperands(values);
// }
// public:
// static ArrayValue *get(Type *type, const std::vector<Value *> &values);
// static ArrayValue *get(const std::vector<int> &values);
// static ArrayValue *get(const std::vector<float> &values);
// public:
// auto getValues() const { return getOperands(); }
// public:
// void print(std::ostream &os) const override{};
// }; // class ConstantArray
//! Global value declared at file scope
class GlobalValue : public User {
friend class Module;
protected:
Module *parent;
bool hasInit;
bool isConst;
protected:
GlobalValue(Module *parent, Type *type, const std::string &name,
const std::vector<Value *> &dims = {}, Value *init = nullptr)
: User(kGlobal, type, name), parent(parent), hasInit(init) {
assert(type->isPointer());
addOperands(dims);
if (init)
addOperand(init);
}
public:
static bool classof(const Value *value) {
return value->getKind() == kGlobal;
}
public:
Value *init() const { return hasInit ? operands.back().getValue() : nullptr; }
int getNumDims() const { return getNumOperands() - (hasInit ? 1 : 0); }
Value *getDim(int index) { return getOperand(index); }
public:
void print(std::ostream &os) const override{};
}; // class GlobalValue
//! IR unit for representing a SysY compile unit
class Module {
protected:
std::vector<std::unique_ptr<Value>> children;
std::map<std::string, Function *> functions;
std::map<std::string, GlobalValue *> globals;
public:
Module() = default;
public:
Function *createFunction(const std::string &name, Type *type) {
if (functions.count(name))
return nullptr;
auto func = new Function(this, type, name);
assert(func);
children.emplace_back(func);
functions.emplace(name, func);
return func;
};
GlobalValue *createGlobalValue(const std::string &name, Type *type,
const std::vector<Value *> &dims = {},
Value *init = nullptr) {
if (globals.count(name))
return nullptr;
auto global = new GlobalValue(this, type, name, dims, init);
assert(global);
children.emplace_back(global);
globals.emplace(name, global);
return global;
}
Function *getFunction(const std::string &name) const {
auto result = functions.find(name);
if (result == functions.end())
return nullptr;
return result->second;
}
GlobalValue *getGlobalValue(const std::string &name) const {
auto result = globals.find(name);
if (result == globals.end())
return nullptr;
return result->second;
}
std::map<std::string, Function *> *getFunctions(){
return &functions;
}
std::map<std::string, GlobalValue *> *getGlobalValues(){
return &globals;
}
public:
void print(std::ostream &os) const;
}; // class Module
/*!
* @}
*/
inline std::ostream &operator<<(std::ostream &os, const Type &type) {
type.print(os);
return os;
}
inline std::ostream &operator<<(std::ostream &os, const Value &value) {
value.print(os);
return os;
}
} // namespace sysy

View File

@ -1,232 +0,0 @@
#pragma once
#include "IR.h"
#include <cassert>
#include <memory>
namespace sysy {
class IRBuilder {
private:
BasicBlock *block;
BasicBlock::iterator position;
public:
IRBuilder() = default;
IRBuilder(BasicBlock *block) : block(block), position(block->end()) {}
IRBuilder(BasicBlock *block, BasicBlock::iterator position)
: block(block), position(position) {}
public:
BasicBlock *getBasicBlock() const { return block; }
BasicBlock::iterator getPosition() const { return position; }
void setPosition(BasicBlock *block, BasicBlock::iterator position) {
this->block = block;
this->position = position;
}
void setPosition(BasicBlock::iterator position) { this->position = position; }
public:
CallInst *createCallInst(Function *callee,
const std::vector<Value *> &args = {},
const std::string &name = "") {
auto inst = new CallInst(callee, args, block, name);
assert(inst);
block->getInstructions().emplace(position, inst);
return inst;
}
UnaryInst *createUnaryInst(Instruction::Kind kind, Type *type, Value *operand,
const std::string &name = "") {
auto inst = new UnaryInst(kind, type, operand, block, name);
assert(inst);
block->getInstructions().emplace(position, inst);
return inst;
}
UnaryInst *createNegInst(Value *operand, const std::string &name = "") {
return createUnaryInst(Instruction::kNeg, Type::getIntType(), operand,
name);
}
UnaryInst *createNotInst(Value *operand, const std::string &name = "") {
return createUnaryInst(Instruction::kNot, Type::getIntType(), operand,
name);
}
UnaryInst *createFtoIInst(Value *operand, const std::string &name = "") {
return createUnaryInst(Instruction::kFtoI, Type::getIntType(), operand,
name);
}
UnaryInst *createFNegInst(Value *operand, const std::string &name = "") {
return createUnaryInst(Instruction::kFNeg, Type::getFloatType(), operand,
name);
}
UnaryInst *createIToFInst(Value *operand, const std::string &name = "") {
return createUnaryInst(Instruction::kIToF, Type::getFloatType(), operand,
name);
}
BinaryInst *createBinaryInst(Instruction::Kind kind, Type *type, Value *lhs,
Value *rhs, const std::string &name = "") {
auto inst = new BinaryInst(kind, type, lhs, rhs, block, name);
assert(inst);
block->getInstructions().emplace(position, inst);
return inst;
}
BinaryInst *createAddInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kAdd, Type::getIntType(), lhs, rhs,
name);
}
BinaryInst *createSubInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kSub, Type::getIntType(), lhs, rhs,
name);
}
BinaryInst *createMulInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kMul, Type::getIntType(), lhs, rhs,
name);
}
BinaryInst *createDivInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kDiv, Type::getIntType(), lhs, rhs,
name);
}
BinaryInst *createRemInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kRem, Type::getIntType(), lhs, rhs,
name);
}
BinaryInst *createICmpEQInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kICmpEQ, Type::getIntType(), lhs, rhs,
name);
}
BinaryInst *createICmpNEInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kICmpNE, Type::getIntType(), lhs, rhs,
name);
}
BinaryInst *createICmpLTInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kICmpLT, Type::getIntType(), lhs, rhs,
name);
}
BinaryInst *createICmpLEInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kICmpLE, Type::getIntType(), lhs, rhs,
name);
}
BinaryInst *createICmpGTInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kICmpGT, Type::getIntType(), lhs, rhs,
name);
}
BinaryInst *createICmpGEInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kICmpGE, Type::getIntType(), lhs, rhs,
name);
}
BinaryInst *createFAddInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kFAdd, Type::getFloatType(), lhs, rhs,
name);
}
BinaryInst *createFSubInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kFSub, Type::getFloatType(), lhs, rhs,
name);
}
BinaryInst *createFMulInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kFMul, Type::getFloatType(), lhs, rhs,
name);
}
BinaryInst *createFDivInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kFDiv, Type::getFloatType(), lhs, rhs,
name);
}
BinaryInst *createFRemInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kFRem, Type::getFloatType(), lhs, rhs,
name);
}
BinaryInst *createFCmpEQInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kFCmpEQ, Type::getFloatType(), lhs,
rhs, name);
}
BinaryInst *createFCmpNEInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kFCmpNE, Type::getFloatType(), lhs,
rhs, name);
}
BinaryInst *createFCmpLTInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kFCmpLT, Type::getFloatType(), lhs,
rhs, name);
}
BinaryInst *createFCmpLEInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kFCmpLE, Type::getFloatType(), lhs,
rhs, name);
}
BinaryInst *createFCmpGTInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kFCmpGT, Type::getFloatType(), lhs,
rhs, name);
}
BinaryInst *createFCmpGEInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kFCmpGE, Type::getFloatType(), lhs,
rhs, name);
}
ReturnInst *createReturnInst(Value *value = nullptr) {
auto inst = new ReturnInst(value);
assert(inst);
block->getInstructions().emplace(position, inst);
return inst;
}
UncondBrInst *createUncondBrInst(BasicBlock *block,
std::vector<Value *> args) {
auto inst = new UncondBrInst(block, args, block);
assert(inst);
block->getInstructions().emplace(position, inst);
return inst;
}
CondBrInst *createCondBrInst(Value *condition, BasicBlock *thenBlock,
BasicBlock *elseBlock,
const std::vector<Value *> &thenArgs,
const std::vector<Value *> &elseArgs) {
auto inst = new CondBrInst(condition, thenBlock, elseBlock, thenArgs,
elseArgs, block);
assert(inst);
block->getInstructions().emplace(position, inst);
return inst;
}
AllocaInst *createAllocaInst(Type *type,
const std::vector<Value *> &dims = {},
const std::string &name = "") {
auto inst = new AllocaInst(type, dims, block, name);
assert(inst);
block->getInstructions().emplace(position, inst);
return inst;
}
LoadInst *createLoadInst(Value *pointer,
const std::vector<Value *> &indices = {},
const std::string &name = "") {
auto inst = new LoadInst(pointer, indices, block, name);
assert(inst);
block->getInstructions().emplace(position, inst);
return inst;
}
StoreInst *createStoreInst(Value *value, Value *pointer,
const std::vector<Value *> &indices = {},
const std::string &name = "") {
auto inst = new StoreInst(value, pointer, indices, block, name);
assert(inst);
block->getInstructions().emplace(position, inst);
return inst;
}
};
} // namespace sysy

674
src/LLVMIRGenerator.cpp Normal file
View File

@ -0,0 +1,674 @@
// LLVMIRGenerator.cpp
// TODO类型转换及其检查
// TODOsysy库函数处理
// TODO数组处理
// TODO对while、continue、break的测试
#include "LLVMIRGenerator.h"
#include <iomanip>
using namespace std;
namespace sysy {
std::string LLVMIRGenerator::generateIR(SysYParser::CompUnitContext* unit) {
// 初始化自定义IR数据结构
irModule = std::make_unique<sysy::Module>();
irBuilder = sysy::IRBuilder(); // 初始化IR构建器
tempCounter = 0;
symbolTable.clear();
tmpTable.clear();
globalVars.clear();
inFunction = false;
visitCompUnit(unit);
return irStream.str();
}
std::string LLVMIRGenerator::getNextTemp() {
std::string ret = "%." + std::to_string(tempCounter++);
tmpTable[ret] = "void";
return ret;
}
std::string LLVMIRGenerator::getLLVMType(const std::string& type) {
if (type == "int") return "i32";
if (type == "float") return "float";
if (type.find("[]") != std::string::npos)
return getLLVMType(type.substr(0, type.size()-2)) + "*";
return "i32";
}
sysy::Type* LLVMIRGenerator::getSysYType(const std::string& typeStr) {
if (typeStr == "int") return sysy::Type::getIntType();
if (typeStr == "float") return sysy::Type::getFloatType();
if (typeStr == "void") return sysy::Type::getVoidType();
// 处理指针类型等
return sysy::Type::getIntType();
}
std::any LLVMIRGenerator::visitCompUnit(SysYParser::CompUnitContext* ctx) {
auto type_i32 = Type::getIntType();
auto type_f32 = Type::getFloatType();
auto type_void = Type::getVoidType();
auto type_i32p = Type::getPointerType(type_i32);
auto type_f32p = Type::getPointerType(type_f32);
// 创建运行时库函数
irModule->createFunction("getint", sysy::FunctionType::get(type_i32, {}));
irModule->createFunction("getch", sysy::FunctionType::get(type_i32, {}));
irModule->createFunction("getfloat", sysy::FunctionType::get(type_f32, {}));
//TODO: 添加更多运行时库函数
irStream << "declare i32 @getint()\n";
irStream << "declare i32 @getch()\n";
irStream << "declare float @getfloat()\n";
//TODO: 添加更多运行时库函数的文本IR
for (auto decl : ctx->decl()) {
decl->accept(this);
}
for (auto funcDef : ctx->funcDef()) {
inFunction = true; // 进入函数定义
funcDef->accept(this);
inFunction = false; // 离开函数定义
}
return nullptr;
}
std::any LLVMIRGenerator::visitVarDecl(SysYParser::VarDeclContext* ctx) {
// TODO数组初始化
std::string type = ctx->bType()->getText();
currentVarType = getLLVMType(type);
for (auto varDef : ctx->varDef()) {
if (!inFunction) {
// 全局变量声明
std::string varName = varDef->Ident()->getText();
std::string llvmType = getLLVMType(type);
std::string value = "0"; // 默认值为 0
if (varDef->ASSIGN()) {
value = std::any_cast<std::string>(varDef->initVal()->accept(this));
} else {
std::cout << "[WR-Release-01]Warning: Global variable '" << varName
<< "' is declared without initialization, defaulting to 0.\n";
}
irStream << "@" << varName << " = dso_local global " << llvmType << " " << value << ", align 4\n";
globalVars.push_back(varName); // 记录全局变量
} else {
// 局部变量声明
varDef->accept(this);
}
}
return nullptr;
}
std::any LLVMIRGenerator::visitConstDecl(SysYParser::ConstDeclContext* ctx) {
// TODO数组初始化
std::string type = ctx->bType()->getText();
for (auto constDef : ctx->constDef()) {
if (!inFunction) {
// 全局常量声明
std::string varName = constDef->Ident()->getText();
std::string llvmType = getLLVMType(type);
std::string value = "0"; // 默认值为 0
try {
value = std::any_cast<std::string>(constDef->constInitVal()->accept(this));
} catch (...) {
throw std::runtime_error("[ERR-Release-01]Const value must be initialized upon definition.");
}
// 如果是 float 类型,转换为十六进制表示
if (llvmType == "float") {
try {
double floatValue = std::stod(value);
uint64_t hexValue = reinterpret_cast<uint64_t&>(floatValue);
std::stringstream ss;
ss << "0x" << std::hex << std::uppercase << hexValue;
value = ss.str();
} catch (...) {
throw std::runtime_error("[ERR-Release-02]Invalid float literal: " + value);
}
}
irStream << "@" << varName << " = dso_local constant " << llvmType << " " << value << ", align 4\n";
globalVars.push_back(varName); // 记录全局变量
} else {
// 局部常量声明
std::string varName = constDef->Ident()->getText();
std::string llvmType = getLLVMType(type);
std::string allocaName = getNextTemp();
std::string value = "0"; // 默认值为 0
try {
value = std::any_cast<std::string>(constDef->constInitVal()->accept(this));
} catch (...) {
throw std::runtime_error("Const value must be initialized upon definition.");
}
irStream << " " << allocaName << " = alloca " << llvmType << ", align 4\n";
if (llvmType == "float") {
try {
double floatValue = std::stod(value);
uint64_t hexValue = reinterpret_cast<uint64_t&>(floatValue);
std::stringstream ss;
ss << "0x" << std::hex << std::uppercase << hexValue;
value = ss.str();
} catch (...) {
throw std::runtime_error("Invalid float literal: " + value);
}
}
irStream << " store " << llvmType << " " << value << ", " << llvmType
<< "* " << allocaName << ", align 4\n";
symbolTable[varName] = {allocaName, llvmType};
tmpTable[allocaName] = llvmType;
}
}
return nullptr;
}
std::any LLVMIRGenerator::visitVarDef(SysYParser::VarDefContext* ctx) {
// TODO数组初始化
std::string varName = ctx->Ident()->getText();
std::string type = currentVarType;
std::string llvmType = getLLVMType(type);
std::string allocaName = getNextTemp();
irStream << " " << allocaName << " = alloca " << llvmType << ", align 4\n";
if (ctx->ASSIGN()) {
std::string value = std::any_cast<std::string>(ctx->initVal()->accept(this));
if (llvmType == "float") {
try {
double floatValue = std::stod(value);
uint64_t hexValue = reinterpret_cast<uint64_t&>(floatValue);
std::stringstream ss;
ss << "0x" << std::hex << std::uppercase << (hexValue & (0xffffffffUL << 32));
value = ss.str();
} catch (...) {
throw std::runtime_error("Invalid float literal: " + value);
}
}
irStream << " store " << llvmType << " " << value << ", " << llvmType
<< "* " << allocaName << ", align 4\n";
}
symbolTable[varName] = {allocaName, llvmType};
tmpTable[allocaName] = llvmType;
return nullptr;
}
std::any LLVMIRGenerator::visitFuncDef(SysYParser::FuncDefContext* ctx) {
currentFunction = ctx->Ident()->getText();
currentReturnType = getLLVMType(ctx->funcType()->getText());
symbolTable.clear();
tmpTable.clear();
tempCounter = 0;
hasReturn = false;
irStream << "define dso_local " << currentReturnType << " @" << currentFunction << "(";
if (ctx->funcFParams()) {
auto params = ctx->funcFParams()->funcFParam();
tempCounter += params.size();
for (size_t i = 0; i < params.size(); ++i) {
if (i > 0) irStream << ", ";
std::string paramType = getLLVMType(params[i]->bType()->getText());
irStream << paramType << " noundef %" << i;
symbolTable[params[i]->Ident()->getText()] = {"%" + std::to_string(i), paramType};
tmpTable["%" + std::to_string(i)] = paramType;
}
}
tempCounter++;
irStream << ") #0 {\n";
if (ctx->funcFParams()) {
auto params = ctx->funcFParams()->funcFParam();
for (size_t i = 0; i < params.size(); ++i) {
std::string varName = params[i]->Ident()->getText();
std::string type = params[i]->bType()->getText();
std::string llvmType = getLLVMType(type);
std::string allocaName = getNextTemp();
tmpTable[allocaName] = llvmType;
irStream << " " << allocaName << " = alloca " << llvmType << ", align 4\n";
irStream << " store " << llvmType << " " << symbolTable[varName].first << ", " << llvmType
<< "* " << allocaName << ", align 4\n";
symbolTable[varName] = {allocaName, llvmType};
}
}
ctx->blockStmt()->accept(this);
if (!hasReturn) {
if (currentReturnType == "void") {
irStream << " ret void\n";
} else {
irStream << " ret " << currentReturnType << " 0\n";
}
}
irStream << "}\n";
return nullptr;
}
std::any LLVMIRGenerator::visitBlockStmt(SysYParser::BlockStmtContext* ctx) {
for (auto item : ctx->blockItem()) {
item->accept(this);
}
return nullptr;
}
std::any LLVMIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx)
{
std::string lhsAlloca = std::any_cast<std::string>(ctx->lValue()->accept(this));
std::string lhsType = symbolTable[ctx->lValue()->Ident()->getText()].second;
std::string rhs = std::any_cast<std::string>(ctx->exp()->accept(this));
if (lhsType == "float") {
try {
double floatValue = std::stod(rhs);
uint64_t hexValue = reinterpret_cast<uint64_t&>(floatValue);
std::stringstream ss;
ss << "0x" << std::hex << std::uppercase << (hexValue & (0xffffffffUL << 32));
rhs = ss.str();
} catch (...) {
throw std::runtime_error("Invalid float literal: " + rhs);
}
}
irStream << " store " << lhsType << " " << rhs << ", " << lhsType
<< "* " << lhsAlloca << ", align 4\n";
return nullptr;
}
std::any LLVMIRGenerator::visitIfStmt(SysYParser::IfStmtContext *ctx)
{
std::string cond = std::any_cast<std::string>(ctx->cond()->accept(this));
std::string trueLabel = "if.then." + std::to_string(tempCounter);
std::string falseLabel = "if.else." + std::to_string(tempCounter);
std::string mergeLabel = "if.end." + std::to_string(tempCounter++);
irStream << " br i1 " << cond << ", label %" << trueLabel << ", label %" << falseLabel << "\n";
irStream << trueLabel << ":\n";
ctx->stmt(0)->accept(this);
irStream << " br label %" << mergeLabel << "\n";
irStream << falseLabel << ":\n";
if (ctx->ELSE()) {
ctx->stmt(1)->accept(this);
}
irStream << " br label %" << mergeLabel << "\n";
irStream << mergeLabel << ":\n";
return nullptr;
}
std::any LLVMIRGenerator::visitWhileStmt(SysYParser::WhileStmtContext *ctx)
{
std::string loop_cond = "while.cond." + std::to_string(tempCounter);
std::string loop_body = "while.body." + std::to_string(tempCounter);
std::string loop_end = "while.end." + std::to_string(tempCounter++);
loopStack.push({loop_end, loop_cond});
irStream << " br label %" << loop_cond << "\n";
irStream << loop_cond << ":\n";
std::string cond = std::any_cast<std::string>(ctx->cond()->accept(this));
irStream << " br i1 " << cond << ", label %" << loop_body << ", label %" << loop_end << "\n";
irStream << loop_body << ":\n";
ctx->stmt()->accept(this);
irStream << " br label %" << loop_cond << "\n";
irStream << loop_end << ":\n";
loopStack.pop();
return nullptr;
}
std::any LLVMIRGenerator::visitBreakStmt(SysYParser::BreakStmtContext *ctx)
{
if (loopStack.empty()) {
throw std::runtime_error("Break statement outside of a loop.");
}
irStream << " br label %" << loopStack.top().breakLabel << "\n";
return nullptr;
}
std::any LLVMIRGenerator::visitContinueStmt(SysYParser::ContinueStmtContext *ctx)
{
if (loopStack.empty()) {
throw std::runtime_error("Continue statement outside of a loop.");
}
irStream << " br label %" << loopStack.top().continueLabel << "\n";
return nullptr;
}
std::any LLVMIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext *ctx)
{
hasReturn = true;
if (ctx->exp()) {
std::string value = std::any_cast<std::string>(ctx->exp()->accept(this));
irStream << " ret " << currentReturnType << " " << value << "\n";
} else {
irStream << " ret void\n";
}
return nullptr;
}
// std::any LLVMIRGenerator::visitStmt(SysYParser::StmtContext* ctx) {
// if (ctx->lValue() && ctx->ASSIGN()) {
// std::string lhsAlloca = std::any_cast<std::string>(ctx->lValue()->accept(this));
// std::string lhsType = symbolTable[ctx->lValue()->Ident()->getText()].second;
// std::string rhs = std::any_cast<std::string>(ctx->exp()->accept(this));
// if (lhsType == "float") {
// try {
// double floatValue = std::stod(rhs);
// uint64_t hexValue = reinterpret_cast<uint64_t&>(floatValue);
// std::stringstream ss;
// ss << "0x" << std::hex << std::uppercase << (hexValue & (0xffffffffUL << 32));
// rhs = ss.str();
// } catch (...) {
// throw std::runtime_error("Invalid float literal: " + rhs);
// }
// }
// irStream << " store " << lhsType << " " << rhs << ", " << lhsType
// << "* " << lhsAlloca << ", align 4\n";
// } else if (ctx->RETURN()) {
// hasReturn = true;
// if (ctx->exp()) {
// std::string value = std::any_cast<std::string>(ctx->exp()->accept(this));
// irStream << " ret " << currentReturnType << " " << value << "\n";
// } else {
// irStream << " ret void\n";
// }
// } else if (ctx->IF()) {
// std::string cond = std::any_cast<std::string>(ctx->cond()->accept(this));
// std::string trueLabel = "if.then." + std::to_string(tempCounter);
// std::string falseLabel = "if.else." + std::to_string(tempCounter);
// std::string mergeLabel = "if.end." + std::to_string(tempCounter++);
// irStream << " br i1 " << cond << ", label %" << trueLabel << ", label %" << falseLabel << "\n";
// irStream << trueLabel << ":\n";
// ctx->stmt(0)->accept(this);
// irStream << " br label %" << mergeLabel << "\n";
// irStream << falseLabel << ":\n";
// if (ctx->ELSE()) {
// ctx->stmt(1)->accept(this);
// }
// irStream << " br label %" << mergeLabel << "\n";
// irStream << mergeLabel << ":\n";
// } else if (ctx->WHILE()) {
// std::string loop_cond = "while.cond." + std::to_string(tempCounter);
// std::string loop_body = "while.body." + std::to_string(tempCounter);
// std::string loop_end = "while.end." + std::to_string(tempCounter++);
// loopStack.push({loop_end, loop_cond});
// irStream << " br label %" << loop_cond << "\n";
// irStream << loop_cond << ":\n";
// std::string cond = std::any_cast<std::string>(ctx->cond()->accept(this));
// irStream << " br i1 " << cond << ", label %" << loop_body << ", label %" << loop_end << "\n";
// irStream << loop_body << ":\n";
// ctx->stmt(0)->accept(this);
// irStream << " br label %" << loop_cond << "\n";
// irStream << loop_end << ":\n";
// loopStack.pop();
// } else if (ctx->BREAK()) {
// if (loopStack.empty()) {
// throw std::runtime_error("Break statement outside of a loop.");
// }
// irStream << " br label %" << loopStack.top().breakLabel << "\n";
// } else if (ctx->CONTINUE()) {
// if (loopStack.empty()) {
// throw std::runtime_error("Continue statement outside of a loop.");
// }
// irStream << " br label %" << loopStack.top().continueLabel << "\n";
// } else if (ctx->blockStmt()) {
// ctx->blockStmt()->accept(this);
// } else if (ctx->exp()) {
// ctx->exp()->accept(this);
// }
// return nullptr;
// }
std::any LLVMIRGenerator::visitLValue(SysYParser::LValueContext* ctx) {
std::string varName = ctx->Ident()->getText();
return symbolTable[varName].first;
}
// std::any LLVMIRGenerator::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
// if (ctx->lValue()) {
// std::string allocaPtr = std::any_cast<std::string>(ctx->lValue()->accept(this));
// std::string varName = ctx->lValue()->Ident()->getText();
// std::string type = symbolTable[varName].second;
// std::string temp = getNextTemp();
// irStream << " " << temp << " = load " << type << ", " << type << "* " << allocaPtr << ", align 4\n";
// tmpTable[temp] = type;
// return temp;
// } else if (ctx->exp()) {
// return ctx->exp()->accept(this);
// } else {
// return ctx->number()->accept(this);
// }
// }
std::any LLVMIRGenerator::visitPrimExp(SysYParser::PrimExpContext *ctx){
// irStream << "visitPrimExp\n";
// std::cout << "Type name: " << typeid(*(ctx->primaryExp())).name() << std::endl;
SysYParser::PrimaryExpContext* pExpCtx = ctx->primaryExp();
if (auto* lvalCtx = dynamic_cast<SysYParser::LValContext*>(pExpCtx)) {
std::string allocaPtr = std::any_cast<std::string>(lvalCtx->lValue()->accept(this));
std::string varName = lvalCtx->lValue()->Ident()->getText();
std::string type = symbolTable[varName].second;
std::string temp = getNextTemp();
irStream << " " << temp << " = load " << type << ", " << type << "* " << allocaPtr << ", align 4\n";
tmpTable[temp] = type;
return temp;
} else if (auto* expCtx = dynamic_cast<SysYParser::ParenExpContext*>(pExpCtx)) {
return expCtx->exp()->accept(this);
} else if (auto* strCtx = dynamic_cast<SysYParser::StrContext*>(pExpCtx)) {
return strCtx->string()->accept(this);
} else if (auto* numCtx = dynamic_cast<SysYParser::NumContext*>(pExpCtx)) {
return numCtx->number()->accept(this);
} else {
// 没有成功转换,说明 ctx->primaryExp() 不是 NumContext 或其他已知类型
// 可能是其他类型的表达式,或者是一个空的 PrimaryExpContext
std::cout << "Unknown primary expression type." << std::endl;
throw std::runtime_error("Unknown primary expression type.");
}
// return visitChildren(ctx);
}
std::any LLVMIRGenerator::visitParenExp(SysYParser::ParenExpContext* ctx) {
return ctx->exp()->accept(this);
}
std::any LLVMIRGenerator::visitNumber(SysYParser::NumberContext* ctx) {
if (ctx->ILITERAL()) {
return ctx->ILITERAL()->getText();
} else if (ctx->FLITERAL()) {
return ctx->FLITERAL()->getText();
}
return "";
}
std::any LLVMIRGenerator::visitString(SysYParser::StringContext *ctx)
{
if (ctx->STRING()) {
// 处理字符串常量
std::string str = ctx->STRING()->getText();
// 去掉引号
str = str.substr(1, str.size() - 2);
// 转义处理
std::string escapedStr;
for (char c : str) {
if (c == '\\') {
escapedStr += "\\\\";
} else if (c == '"') {
escapedStr += "\\\"";
} else {
escapedStr += c;
}
}
return "\"" + escapedStr + "\"";
}
return ctx->STRING()->getText();
}
std::any LLVMIRGenerator::visitUnExp(SysYParser::UnExpContext* ctx) {
if (ctx->unaryOp()) {
std::string operand = std::any_cast<std::string>(ctx->unaryExp()->accept(this));
std::string op = ctx->unaryOp()->getText();
std::string temp = getNextTemp();
std::string type = operand.substr(0, operand.find(' '));
tmpTable[temp] = type;
if (op == "-") {
irStream << " " << temp << " = sub " << type << " 0, " << operand << "\n";
} else if (op == "!") {
irStream << " " << temp << " = xor " << type << " " << operand << ", 1\n";
}
return temp;
}
return ctx->unaryExp()->accept(this);
}
std::any LLVMIRGenerator::visitCall(SysYParser::CallContext *ctx)
{
std::string funcName = ctx->Ident()->getText();
std::vector<std::string> args;
if (ctx->funcRParams()) {
for (auto argCtx : ctx->funcRParams()->exp()) {
args.push_back(std::any_cast<std::string>(argCtx->accept(this)));
}
}
std::string temp = getNextTemp();
std::string argList = "";
for (size_t i = 0; i < args.size(); ++i) {
if (i > 0) argList += ", ";
argList +=tmpTable[args[i]] + " noundef " + args[i];
}
irStream << " " << temp << " = call " << currentReturnType << " @" << funcName << "(" << argList << ")\n";
tmpTable[temp] = currentReturnType;
return temp;
}
std::any LLVMIRGenerator::visitMulExp(SysYParser::MulExpContext* ctx) {
auto unaryExps = ctx->unaryExp();
std::string left = std::any_cast<std::string>(unaryExps[0]->accept(this));
for (size_t i = 1; i < unaryExps.size(); ++i) {
std::string right = std::any_cast<std::string>(unaryExps[i]->accept(this));
std::string op = ctx->children[2*i-1]->getText();
std::string temp = getNextTemp();
std::string type = tmpTable[left];
if (op == "*") {
irStream << " " << temp << " = mul nsw " << type << " " << left << ", " << right << "\n";
} else if (op == "/") {
irStream << " " << temp << " = sdiv " << type << " " << left << ", " << right << "\n";
} else if (op == "%") {
irStream << " " << temp << " = srem " << type << " " << left << ", " << right << "\n";
}
left = temp;
tmpTable[temp] = type;
}
return left;
}
std::any LLVMIRGenerator::visitAddExp(SysYParser::AddExpContext* ctx) {
auto mulExps = ctx->mulExp();
std::string left = std::any_cast<std::string>(mulExps[0]->accept(this));
for (size_t i = 1; i < mulExps.size(); ++i) {
std::string right = std::any_cast<std::string>(mulExps[i]->accept(this));
std::string op = ctx->children[2*i-1]->getText();
std::string temp = getNextTemp();
std::string type = tmpTable[left];
if (op == "+") {
irStream << " " << temp << " = add nsw " << type << " " << left << ", " << right << "\n";
} else if (op == "-") {
irStream << " " << temp << " = sub nsw " << type << " " << left << ", " << right << "\n";
}
left = temp;
tmpTable[temp] = type;
}
return left;
}
std::any LLVMIRGenerator::visitRelExp(SysYParser::RelExpContext* ctx) {
auto addExps = ctx->addExp();
std::string left = std::any_cast<std::string>(addExps[0]->accept(this));
for (size_t i = 1; i < addExps.size(); ++i) {
std::string right = std::any_cast<std::string>(addExps[i]->accept(this));
std::string op = ctx->children[2*i-1]->getText();
std::string temp = getNextTemp();
std::string type = tmpTable[left];
if (op == "<") {
irStream << " " << temp << " = icmp slt " << type << " " << left << ", " << right << "\n";
} else if (op == ">") {
irStream << " " << temp << " = icmp sgt " << type << " " << left << ", " << right << "\n";
} else if (op == "<=") {
irStream << " " << temp << " = icmp sle " << type << " " << left << ", " << right << "\n";
} else if (op == ">=") {
irStream << " " << temp << " = icmp sge " << type << " " << left << ", " << right << "\n";
}
left = temp;
}
return left;
}
std::any LLVMIRGenerator::visitEqExp(SysYParser::EqExpContext* ctx) {
auto relExps = ctx->relExp();
std::string left = std::any_cast<std::string>(relExps[0]->accept(this));
for (size_t i = 1; i < relExps.size(); ++i) {
std::string right = std::any_cast<std::string>(relExps[i]->accept(this));
std::string op = ctx->children[2*i-1]->getText();
std::string temp = getNextTemp();
std::string type = tmpTable[left];
if (op == "==") {
irStream << " " << temp << " = icmp eq " << type << " " << left << ", " << right << "\n";
} else if (op == "!=") {
irStream << " " << temp << " = icmp ne " << type << " " << left << ", " << right << "\n";
}
left = temp;
}
return left;
}
std::any LLVMIRGenerator::visitLAndExp(SysYParser::LAndExpContext* ctx) {
auto eqExps = ctx->eqExp();
std::string left = std::any_cast<std::string>(eqExps[0]->accept(this));
for (size_t i = 1; i < eqExps.size(); ++i) {
std::string falseLabel = "land.false." + std::to_string(tempCounter);
std::string endLabel = "land.end." + std::to_string(tempCounter++);
std::string temp = getNextTemp();
irStream << " br label %" << falseLabel << "\n";
irStream << falseLabel << ":\n";
std::string right = std::any_cast<std::string>(eqExps[i]->accept(this));
irStream << " " << temp << " = and i1 " << left << ", " << right << "\n";
irStream << " br label %" << endLabel << "\n";
irStream << endLabel << ":\n";
left = temp;
}
return left;
}
std::any LLVMIRGenerator::visitLOrExp(SysYParser::LOrExpContext* ctx) {
auto lAndExps = ctx->lAndExp();
std::string left = std::any_cast<std::string>(lAndExps[0]->accept(this));
for (size_t i = 1; i < lAndExps.size(); ++i) {
std::string trueLabel = "lor.true." + std::to_string(tempCounter);
std::string endLabel = "lor.end." + std::to_string(tempCounter++);
std::string temp = getNextTemp();
irStream << " br label %" << trueLabel << "\n";
irStream << trueLabel << ":\n";
std::string right = std::any_cast<std::string>(lAndExps[i]->accept(this));
irStream << " " << temp << " = or i1 " << left << ", " << right << "\n";
irStream << " br label %" << endLabel << "\n";
irStream << endLabel << ":\n";
left = temp;
}
return left;
}
}

859
src/LLVMIRGenerator_1.cpp Normal file
View File

@ -0,0 +1,859 @@
// LLVMIRGenerator.cpp
// TODO类型转换及其检查
// TODOsysy库函数处理
// TODO数组处理
// TODO对while、continue、break的测试
#include "LLVMIRGenerator_1.h"
#include <iomanip>
#include <stdexcept>
#include <sstream>
// namespace sysy {
std::string LLVMIRGenerator::generateIR(SysYParser::CompUnitContext* unit) {
// 初始化 SysY IR 模块
module = std::make_unique<sysy::Module>();
// 清空符号表和临时变量表
symbolTable.clear();
tmpTable.clear();
irSymbolTable.clear();
irTmpTable.clear();
tempCounter = 0;
globalVars.clear();
hasReturn = false;
loopStack = std::stack<LoopLabels>();
inFunction = false;
// 访问编译单元
visitCompUnit(unit);
return irStream.str();
}
std::string LLVMIRGenerator::getNextTemp() {
std::string ret = "%." + std::to_string(tempCounter++);
tmpTable[ret] = "void";
return ret;
}
std::string LLVMIRGenerator::getIRTempName() {
return "%" + std::to_string(tempCounter++);
}
std::string LLVMIRGenerator::getLLVMType(const std::string& type) {
if (type == "int") return "i32";
if (type == "float") return "float";
if (type.find("[]") != std::string::npos)
return getLLVMType(type.substr(0, type.size() - 2)) + "*";
return "i32";
}
sysy::Type* LLVMIRGenerator::getIRType(const std::string& type) {
if (type == "int") return sysy::Type::getIntType();
if (type == "float") return sysy::Type::getFloatType();
if (type == "void") return sysy::Type::getVoidType();
if (type.find("[]") != std::string::npos) {
std::string baseType = type.substr(0, type.size() - 2);
return sysy::Type::getPointerType(getIRType(baseType));
}
return sysy::Type::getIntType(); // 默认 int
}
void LLVMIRGenerator::setIRPosition(sysy::BasicBlock* block) {
currentIRBlock = block;
}
std::any LLVMIRGenerator::visitCompUnit(SysYParser::CompUnitContext* ctx) {
for (auto decl : ctx->decl()) {
decl->accept(this);
}
for (auto funcDef : ctx->funcDef()) {
inFunction = true;
funcDef->accept(this);
inFunction = false;
}
return nullptr;
}
std::any LLVMIRGenerator::visitVarDecl(SysYParser::VarDeclContext* ctx) {
// TODO数组初始化
std::string type = ctx->bType()->getText();
currentVarType = getLLVMType(type);
sysy::Type* irType = sysy::Type::getPointerType(getIRType(type));
for (auto varDef : ctx->varDef()) {
if (!inFunction) {
// 全局变量(文本 IR
std::string varName = varDef->Ident()->getText();
std::string llvmType = getLLVMType(type);
std::string value = "0";
sysy::Value* initValue = nullptr;
if (varDef->ASSIGN()) {
value = std::any_cast<std::string>(varDef->initVal()->accept(this));
if (irTmpTable.find(value) != irTmpTable.end() && isa<sysy::ConstantValue>(irTmpTable[value])) {
initValue = irTmpTable[value];
}
}
if (llvmType == "float" && initValue) {
try {
double floatValue = std::stod(value);
uint64_t hexValue = reinterpret_cast<uint64_t&>(floatValue);
std::stringstream ss;
ss << "0x" << std::hex << std::uppercase << hexValue;
value = ss.str();
} catch (...) {
throw std::runtime_error("[ERR-Release-02]Invalid float literal: " + value);
}
}
irStream << "@" << varName << " = dso_local global " << llvmType << " " << value << ", align 4\n";
globalVars.push_back(varName);
// 全局变量SysY IR
auto globalValue = module->createGlobalValue(varName, irType, {}, initValue);
irSymbolTable[varName] = globalValue;
} else {
varDef->accept(this);
}
}
return nullptr;
}
std::any LLVMIRGenerator::visitConstDecl(SysYParser::ConstDeclContext* ctx) {
// TODO数组初始化
std::string type = ctx->bType()->getText();
currentVarType = getLLVMType(type);
sysy::Type* irType = sysy::Type::getPointerType(getIRType(type)); // 全局变量为指针类型
for (auto constDef : ctx->constDef()) {
std::string varName = constDef->Ident()->getText();
std::string llvmType = getLLVMType(type);
std::string value = "0";
sysy::Value* initValue = nullptr;
try {
value = std::any_cast<std::string>(constDef->constInitVal()->accept(this));
if (isa<sysy::ConstantValue>(irTmpTable[value])) {
initValue = irTmpTable[value];
}
} catch (...) {
throw std::runtime_error("Const value must be initialized upon definition.");
}
if (!inFunction) {
// 全局常量(文本 IR
if (llvmType == "float") {
try {
double floatValue = std::stod(value);
uint64_t hexValue = reinterpret_cast<uint64_t&>(floatValue);
std::stringstream ss;
ss << "0x" << std::hex << std::uppercase << hexValue;
value = ss.str();
} catch (...) {
throw std::runtime_error("[ERR-Release-03]Invalid float literal: " + value);
}
}
irStream << "@" << varName << " = dso_local constant " << llvmType << " " << value << ", align 4\n";
globalVars.push_back(varName);
// 全局常量SysY IR
auto globalValue = module->createGlobalValue(varName, irType, {}, initValue);
irSymbolTable[varName] = globalValue;
} else {
// 局部常量(文本 IR
std::string allocaName = getNextTemp();
if (llvmType == "float") {
try {
double floatValue = std::stod(value);
uint64_t hexValue = reinterpret_cast<uint64_t&>(floatValue);
std::stringstream ss;
ss << "0x" << std::hex << std::uppercase << hexValue;
value = ss.str();
} catch (...) {
throw std::runtime_error("Invalid float literal: " + value);
}
}
irStream << " " << allocaName << " = alloca " << llvmType << ", align 4\n";
irStream << " store " << llvmType << " " << value << ", " << llvmType
<< "* " << allocaName << ", align 4\n";
symbolTable[varName] = {allocaName, llvmType};
tmpTable[allocaName] = llvmType;
// 局部常量SysY IRTODO:这里可能有bugAI在犯蠢
sysy::IRBuilder builder(currentIRBlock);
auto allocaInst = builder.createAllocaInst(irType, {}, varName);
builder.createStoreInst(initValue, allocaInst);
irSymbolTable[varName] = allocaInst;
irTmpTable[allocaName] = allocaInst;
}
}
return nullptr;
}
std::any LLVMIRGenerator::visitVarDef(SysYParser::VarDefContext* ctx) {
// TODO数组初始化
std::string varName = ctx->Ident()->getText();
std::string llvmType = currentVarType;
sysy::Type* irType = sysy::Type::getPointerType(getIRType(currentVarType == "i32" ? "int" : "float"));
std::string allocaName = getNextTemp();
// 局部变量(文本 IR
irStream << " " << allocaName << " = alloca " << llvmType << ", align 4\n";
// 局部变量SysY IR
sysy::IRBuilder builder(currentIRBlock);
auto allocaInst = builder.createAllocaInst(irType, {}, varName);
sysy::Value* initValue = nullptr;
if (ctx->ASSIGN()) {
std::string value = std::any_cast<std::string>(ctx->initVal()->accept(this));
if (llvmType == "float") {
try {
double floatValue = std::stod(value);
uint64_t hexValue = reinterpret_cast<uint64_t&>(floatValue);
std::stringstream ss;
ss << "0x" << std::hex << std::uppercase << (hexValue & (0xffffffffUL << 32));
value = ss.str();
} catch (...) {
throw std::runtime_error("Invalid float literal: " + value);
}
}
irStream << " store " << llvmType << " " << value << ", " << llvmType
<< "* " << allocaName << ", align 4\n";
if (irTmpTable.find(value) != irTmpTable.end()) {
initValue = irTmpTable[value];
}
builder.createStoreInst(initValue, allocaInst);
}
symbolTable[varName] = {allocaName, llvmType};
tmpTable[allocaName] = llvmType;
irSymbolTable[varName] = allocaInst;//TODO:这里没看懂在干嘛
irTmpTable[allocaName] = allocaInst;//TODO:这里没看懂在干嘛
builder.createStoreInst(initValue, allocaInst);//TODO:这里没看懂在干嘛
return nullptr;
}
std::any LLVMIRGenerator::visitFuncDef(SysYParser::FuncDefContext* ctx) {
currentFunction = ctx->Ident()->getText();
currentReturnType = getLLVMType(ctx->funcType()->getText());
sysy::Type* irReturnType = getIRType(ctx->funcType()->getText());
std::vector<sysy::Type*> paramTypes;
// 清空符号表
symbolTable.clear();
tmpTable.clear();
irSymbolTable.clear();
irTmpTable.clear();
tempCounter = 0;
hasReturn = false;
// 处理函数参数(文本 IR 和 SysY IR
if (ctx->funcFParams()) {
auto params = ctx->funcFParams()->funcFParam();
for (size_t i = 0; i < params.size(); ++i) {
std::string paramType = getLLVMType(params[i]->bType()->getText());
if (i > 0) irStream << ", ";
irStream << paramType << " noundef %" << i;
symbolTable[params[i]->Ident()->getText()] = {"%" + std::to_string(i), paramType};
tmpTable["%" + std::to_string(i)] = paramType;
paramTypes.push_back(getIRType(params[i]->bType()->getText()));
}
tempCounter += params.size();
}
tempCounter++;
// 文本 IR 函数定义
irStream << "define dso_local " << currentReturnType << " @" << currentFunction << "(";
irStream << ") #0 {\n";
// SysY IR 函数定义
sysy::Type* funcType = sysy::Type::getFunctionType(irReturnType, paramTypes);
currentIRFunction = module->createFunction(currentFunction, funcType);
setIRPosition(currentIRFunction->getEntryBlock());
// 处理函数参数分配
if (ctx->funcFParams()) {
auto params = ctx->funcFParams()->funcFParam();
for (size_t i = 0; i < params.size(); ++i) {
std::string varName = params[i]->Ident()->getText();
std::string llvmType = getLLVMType(params[i]->bType()->getText());
sysy::Type* irType = getIRType(params[i]->bType()->getText());
std::string allocaName = getNextTemp();
tmpTable[allocaName] = llvmType;
// 文本 IR 分配
irStream << " " << allocaName << " = alloca " << llvmType << ", align 4\n";
irStream << " store " << llvmType << " %" << i << ", " << llvmType
<< "* " << allocaName << ", align 4\n";
// SysY IR 分配
sysy::IRBuilder builder(currentIRBlock);
auto arg = currentIRBlock->createArgument(irType, varName);
auto allocaInst = builder.createAllocaInst(sysy::Type::getPointerType(irType), {}, varName);
builder.createStoreInst(arg, allocaInst);
symbolTable[varName] = {allocaName, llvmType};
irSymbolTable[varName] = allocaInst;
irTmpTable[allocaName] = allocaInst;
}
}
ctx->blockStmt()->accept(this);
if (!hasReturn) {
if (currentReturnType == "void") {
irStream << " ret void\n";
sysy::IRBuilder builder(currentIRBlock);
builder.createReturnInst();
} else {
irStream << " ret " << currentReturnType << " 0\n";
sysy::IRBuilder builder(currentIRBlock);
builder.createReturnInst(sysy::ConstantValue::get(0));
}
}
irStream << "}\n";
currentIRFunction = nullptr;
currentIRBlock = nullptr;
return nullptr;
}
std::any LLVMIRGenerator::visitBlockStmt(SysYParser::BlockStmtContext* ctx) {
for (auto item : ctx->blockItem()) {
item->accept(this);
}
return nullptr;
}
std::any LLVMIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext* ctx) {
std::string lhsAlloca = std::any_cast<std::string>(ctx->lValue()->accept(this));
std::string lhsType = symbolTable[ctx->lValue()->Ident()->getText()].second;
std::string rhs = std::any_cast<std::string>(ctx->exp()->accept(this));
sysy::Value* rhsValue = irTmpTable[rhs];
// 文本 IR
if (lhsType == "float") {
try {
double floatValue = std::stod(rhs);
uint64_t hexValue = reinterpret_cast<uint64_t&>(floatValue);
std::stringstream ss;
ss << "0x" << std::hex << std::uppercase << (hexValue & (0xffffffffUL << 32));
rhs = ss.str();
} catch (...) {
// 如果 rhs 不是字面量,假设已正确处理
throw std::runtime_error("Invalid float literal: " + rhs);
}
}
irStream << " store " << lhsType << " " << rhs << ", " << lhsType
<< "* " << lhsAlloca << ", align 4\n";
// SysY IR
sysy::IRBuilder builder(currentIRBlock);
builder.createStoreInst(rhsValue, irSymbolTable[ctx->lValue()->Ident()->getText()]);
return nullptr;
}
std::any LLVMIRGenerator::visitIfStmt(SysYParser::IfStmtContext* ctx) {
std::string cond = std::any_cast<std::string>(ctx->cond()->accept(this));
sysy::Value* condValue = irTmpTable[cond];
std::string trueLabel = "if.then." + std::to_string(tempCounter);
std::string falseLabel = "if.else." + std::to_string(tempCounter);
std::string mergeLabel = "if.end." + std::to_string(tempCounter++);
// SysY IR 基本块
sysy::BasicBlock* thenBlock = currentIRFunction->addBasicBlock(trueLabel);
sysy::BasicBlock* elseBlock = ctx->ELSE() ? currentIRFunction->addBasicBlock(falseLabel) : nullptr;
sysy::BasicBlock* mergeBlock = currentIRFunction->addBasicBlock(mergeLabel);
// 文本 IR
irStream << " br i1 " << cond << ", label %" << trueLabel << ", label %"
<< (ctx->ELSE() ? falseLabel : mergeLabel) << "\n";
// SysY IR 条件分支
sysy::IRBuilder builder(currentIRBlock);
builder.createCondBrInst(condValue, thenBlock, ctx->ELSE() ? elseBlock : mergeBlock, {}, {});
// 处理 then 分支
setIRPosition(thenBlock);
irStream << trueLabel << ":\n";
ctx->stmt(0)->accept(this);
irStream << " br label %" << mergeLabel << "\n";
builder.setPosition(thenBlock, thenBlock->end());
builder.createUncondBrInst(mergeBlock, {});
// 处理 else 分支
if (ctx->ELSE()) {
setIRPosition(elseBlock);
irStream << falseLabel << ":\n";
ctx->stmt(1)->accept(this);
irStream << " br label %" << mergeLabel << "\n";
builder.setPosition(elseBlock, elseBlock->end());
builder.createUncondBrInst(mergeBlock, {});
}
// 合并点
setIRPosition(mergeBlock);
irStream << mergeLabel << ":\n";
return nullptr;
}
std::any LLVMIRGenerator::visitWhileStmt(SysYParser::WhileStmtContext* ctx) {
std::string loopCond = "while.cond." + std::to_string(tempCounter);
std::string loopBody = "while.body." + std::to_string(tempCounter);
std::string loopEnd = "while.end." + std::to_string(tempCounter++);
// SysY IR 基本块
sysy::BasicBlock* condBlock = currentIRFunction->addBasicBlock(loopCond);
sysy::BasicBlock* bodyBlock = currentIRFunction->addBasicBlock(loopBody);
sysy::BasicBlock* endBlock = currentIRFunction->addBasicBlock(loopEnd);
loopStack.push({loopEnd, loopCond, endBlock, condBlock});
// 跳转到条件块
sysy::IRBuilder builder(currentIRBlock);
builder.createUncondBrInst(condBlock, {});
irStream << " br label %" << loopCond << "\n";
// 条件块
setIRPosition(condBlock);
irStream << loopCond << ":\n";
std::string cond = std::any_cast<std::string>(ctx->cond()->accept(this));
sysy::Value* condValue = irTmpTable[cond];
irStream << " br i1 " << cond << ", label %" << loopBody << ", label %" << loopEnd << "\n";
builder.setPosition(condBlock, condBlock->end());
builder.createCondBrInst(condValue, bodyBlock, endBlock, {}, {});
// 循环体
setIRPosition(bodyBlock);
irStream << loopBody << ":\n";
ctx->stmt()->accept(this);
irStream << " br label %" << loopCond << "\n";
builder.setPosition(bodyBlock, bodyBlock->end());
builder.createUncondBrInst(condBlock, {});
// 结束块
setIRPosition(endBlock);
irStream << loopEnd << ":\n";
loopStack.pop();
return nullptr;
}
std::any LLVMIRGenerator::visitBreakStmt(SysYParser::BreakStmtContext* ctx) {
if (loopStack.empty()) {
throw std::runtime_error("Break statement outside of a loop.");
}
irStream << " br label %" << loopStack.top().breakLabel << "\n";
sysy::IRBuilder builder(currentIRBlock);
builder.createUncondBrInst(loopStack.top().irBreakBlock, {});
return nullptr;
}
std::any LLVMIRGenerator::visitContinueStmt(SysYParser::ContinueStmtContext* ctx) {
if (loopStack.empty()) {
throw std::runtime_error("Continue statement outside of a loop.");
}
irStream << " br label %" << loopStack.top().continueLabel << "\n";
sysy::IRBuilder builder(currentIRBlock);
builder.createUncondBrInst(loopStack.top().irContinueBlock, {});
return nullptr;
}
std::any LLVMIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext* ctx) {
hasReturn = true;
sysy::IRBuilder builder(currentIRBlock);
if (ctx->exp()) {
std::string value = std::any_cast<std::string>(ctx->exp()->accept(this));
sysy::Value* irValue = irTmpTable[value];
irStream << " ret " << currentReturnType << " " << value << "\n";
builder.createReturnInst(irValue);
} else {
irStream << " ret void\n";
builder.createReturnInst();
}
return nullptr;
}
std::any LLVMIRGenerator::visitLValue(SysYParser::LValueContext* ctx) {
std::string varName = ctx->Ident()->getText();
if (irSymbolTable.find(varName) == irSymbolTable.end()) {
throw std::runtime_error("Undefined variable: " + varName);
}
// 对于 LValue返回分配的指针文本 IR 和 SysY IR 一致)
return symbolTable[varName].first;
}
std::any LLVMIRGenerator::visitPrimExp(SysYParser::PrimExpContext* ctx) {
SysYParser::PrimaryExpContext* pExpCtx = ctx->primaryExp();
if (auto* lvalCtx = dynamic_cast<SysYParser::LValContext*>(pExpCtx)) {
std::string allocaPtr = std::any_cast<std::string>(lvalCtx->lValue()->accept(this));
std::string varName = lvalCtx->lValue()->Ident()->getText();
std::string type = symbolTable[varName].second;
std::string temp = getNextTemp();
sysy::Type* irType = getIRType(type == "i32" ? "int" : "float");
// 文本 IR
irStream << " " << temp << " = load " << type << ", " << type << "* " << allocaPtr << ", align 4\n";
tmpTable[temp] = type;
// SysY IR
sysy::IRBuilder builder(currentIRBlock);
auto loadInst = builder.createLoadInst(irSymbolTable[varName], {});
irTmpTable[temp] = loadInst;
return temp;
} else if (auto* expCtx = dynamic_cast<SysYParser::ParenExpContext*>(pExpCtx)) {
return expCtx->exp()->accept(this);
} else if (auto* strCtx = dynamic_cast<SysYParser::StrContext*>(pExpCtx)) {
return strCtx->string()->accept(this);
} else if (auto* numCtx = dynamic_cast<SysYParser::NumContext*>(pExpCtx)) {
return numCtx->number()->accept(this);
} else {
// 没有成功转换,说明 ctx->primaryExp() 不是 NumContext 或其他已知类型
// 可能是其他类型的表达式,或者是一个空的 PrimaryExpContext
std::cout << "Unknown primary expression type." << std::endl;
throw std::runtime_error("Unknown primary expression type.");
}
}
std::any LLVMIRGenerator::visitParenExp(SysYParser::ParenExpContext* ctx) {
return ctx->exp()->accept(this);
}
std::any LLVMIRGenerator::visitNumber(SysYParser::NumberContext* ctx) {
std::string value;
sysy::Value* irValue = nullptr;
if (ctx->ILITERAL()) {
value = ctx->ILITERAL()->getText();
irValue = sysy::ConstantValue::get(std::stoi(value));
} else if (ctx->FLITERAL()) {
value = ctx->FLITERAL()->getText();
irValue = sysy::ConstantValue::get(std::stof(value));
} else {
value = "";
}
std::string temp = getNextTemp();
tmpTable[temp] = ctx->ILITERAL() ? "i32" : "float";
irTmpTable[temp] = irValue;
return value;
}
std::any LLVMIRGenerator::visitString(SysYParser::StringContext* ctx) {
if (ctx->STRING()) {
std::string str = ctx->STRING()->getText();
str = str.substr(1, str.size() - 2);
std::string escapedStr;
for (char c : str) {
if (c == '\\') {
escapedStr += "\\\\";
} else if (c == '"') {
escapedStr += "\\\"";
} else {
escapedStr += c;
}
}
// TODO: SysY IR 暂不支持字符串常量,返回文本 IR 结果
return "\"" + escapedStr + "\"";
}
return ctx->STRING()->getText();
}
std::any LLVMIRGenerator::visitUnExp(SysYParser::UnExpContext* ctx) {
if (ctx->unaryOp()) {
std::string operand = std::any_cast<std::string>(ctx->unaryExp()->accept(this));
sysy::Value* irOperand = irTmpTable[operand];
std::string op = ctx->unaryOp()->getText();
std::string temp = getNextTemp();
std::string type = tmpTable[operand];
sysy::Type* irType = getIRType(type == "i32" ? "int" : "float");
tmpTable[temp] = type;
// 文本 IR
if (op == "-") {
irStream << " " << temp << " = sub " << type << " 0, " << operand << "\n";
} else if (op == "!") {
irStream << " " << temp << " = xor " << type << " " << operand << ", 1\n";
}
// SysY IR
sysy::IRBuilder builder(currentIRBlock);
sysy::Instruction::Kind kind = (op == "-") ? (type == "i32" ? sysy::Instruction::kNeg : sysy::Instruction::kFNeg)
: sysy::Instruction::kNot;
auto unaryInst = builder.createUnaryInst(kind, irType, irOperand, temp);
irTmpTable[temp] = unaryInst;
return temp;
}
return ctx->unaryExp()->accept(this);
}
std::any LLVMIRGenerator::visitCall(SysYParser::CallContext* ctx) {
std::string funcName = ctx->Ident()->getText();
std::vector<std::string> args;
std::vector<sysy::Value*> irArgs;
if (ctx->funcRParams()) {
for (auto argCtx : ctx->funcRParams()->exp()) {
std::string arg = std::any_cast<std::string>(argCtx->accept(this));
args.push_back(arg);
irArgs.push_back(irTmpTable[arg]);
}
}
std::string temp = getNextTemp();
std::string argList;
for (size_t i = 0; i < args.size(); ++i) {
if (i > 0) argList += ", ";
argList += tmpTable[args[i]] + " noundef " + args[i];
}
// 文本 IR
irStream << " " << temp << " = call " << currentReturnType << " @" << funcName << "(" << argList << ")\n";
tmpTable[temp] = currentReturnType;
// SysY IR
sysy::IRBuilder builder(currentIRBlock);
sysy::Function* callee = module->getFunction(funcName);
if (!callee) {
throw std::runtime_error("Undefined function: " + funcName);
}
auto callInst = builder.createCallInst(callee, irArgs, temp);
irTmpTable[temp] = callInst;
return temp;
}
std::any LLVMIRGenerator::visitMulExp(SysYParser::MulExpContext* ctx) {
auto unaryExps = ctx->unaryExp();
std::string left = std::any_cast<std::string>(unaryExps[0]->accept(this));
sysy::Value* irLeft = irTmpTable[left];
sysy::Type* irType = irLeft->getType();
for (size_t i = 1; i < unaryExps.size(); ++i) {
std::string right = std::any_cast<std::string>(unaryExps[i]->accept(this));
sysy::Value* irRight = irTmpTable[right];
std::string op = ctx->children[2 * i - 1]->getText();
std::string temp = getNextTemp();
std::string type = tmpTable[left];
tmpTable[temp] = type;
// 文本 IR
if (op == "*") {
irStream << " " << temp << " = mul nsw " << type << " " << left << ", " << right << "\n";
} else if (op == "/") {
irStream << " " << temp << " = sdiv " << type << " " << left << ", " << right << "\n";
} else if (op == "%") {
irStream << " " << temp << " = srem " << type << " " << left << ", " << right << "\n";
}
// SysY IR
sysy::IRBuilder builder(currentIRBlock);
sysy::Instruction::Kind kind;
if (type == "i32") {
if (op == "*") kind = sysy::Instruction::kMul;
else if (op == "/") kind = sysy::Instruction::kDiv;
else kind = sysy::Instruction::kRem;
} else {
if (op == "*") kind = sysy::Instruction::kFMul;
else if (op == "/") kind = sysy::Instruction::kFDiv;
else kind = sysy::Instruction::kFRem;
}
auto binaryInst = builder.createBinaryInst(kind, irType, irLeft, irRight, temp);
irTmpTable[temp] = binaryInst;
left = temp;
irLeft = binaryInst;
}
return left;
}
std::any LLVMIRGenerator::visitAddExp(SysYParser::AddExpContext* ctx) {
auto mulExps = ctx->mulExp();
std::string left = std::any_cast<std::string>(mulExps[0]->accept(this));
sysy::Value* irLeft = irTmpTable[left];
sysy::Type* irType = irLeft->getType();
for (size_t i = 1; i < mulExps.size(); ++i) {
std::string right = std::any_cast<std::string>(mulExps[i]->accept(this));
sysy::Value* irRight = irTmpTable[right];
std::string op = ctx->children[2 * i - 1]->getText();
std::string temp = getNextTemp();
std::string type = tmpTable[left];
tmpTable[temp] = type;
// 文本 IR
if (op == "+") {
irStream << " " << temp << " = add nsw " << type << " " << left << ", " << right << "\n";
} else if (op == "-") {
irStream << " " << temp << " = sub nsw " << type << " " << left << ", " << right << "\n";
}
// SysY IR
sysy::IRBuilder builder(currentIRBlock);
sysy::Instruction::Kind kind = (type == "i32") ? (op == "+" ? sysy::Instruction::kAdd : sysy::Instruction::kSub)
: (op == "+" ? sysy::Instruction::kFAdd : sysy::Instruction::kFSub);
auto binaryInst = builder.createBinaryInst(kind, irType, irLeft, irRight, temp);
irTmpTable[temp] = binaryInst;
left = temp;
irLeft = binaryInst;
}
return left;
}
std::any LLVMIRGenerator::visitRelExp(SysYParser::RelExpContext* ctx) {
auto addExps = ctx->addExp();
std::string left = std::any_cast<std::string>(addExps[0]->accept(this));
sysy::Value* irLeft = irTmpTable[left];
sysy::Type* irType = sysy::Type::getIntType(); // 比较结果为 i1
for (size_t i = 1; i < addExps.size(); ++i) {
std::string right = std::any_cast<std::string>(addExps[i]->accept(this));
sysy::Value* irRight = irTmpTable[right];
std::string op = ctx->children[2 * i - 1]->getText();
std::string temp = getNextTemp();
std::string type = tmpTable[left];
tmpTable[temp] = "i1";
// 文本 IR
if (op == "<") {
irStream << " " << temp << " = icmp slt " << type << " " << left << ", " << right << "\n";
} else if (op == ">") {
irStream << " " << temp << " = icmp sgt " << type << " " << left << ", " << right << "\n";
} else if (op == "<=") {
irStream << " " << temp << " = icmp sle " << type << " " << left << ", " << right << "\n";
} else if (op == ">=") {
irStream << " " << temp << " = icmp sge " << type << " " << left << ", " << right << "\n";
}
// SysY IR
sysy::IRBuilder builder(currentIRBlock);
sysy::Instruction::Kind kind;
if (type == "i32") {
if (op == "<") kind = sysy::Instruction::kICmpLT;
else if (op == ">") kind = sysy::Instruction::kICmpGT;
else if (op == "<=") kind = sysy::Instruction::kICmpLE;
else kind = sysy::Instruction::kICmpGE;
} else {
if (op == "<") kind = sysy::Instruction::kFCmpLT;
else if (op == ">") kind = sysy::Instruction::kFCmpGT;
else if (op == "<=") kind = sysy::Instruction::kFCmpLE;
else kind = sysy::Instruction::kFCmpGE;
}
auto cmpInst = builder.createBinaryInst(kind, irType, irLeft, irRight, temp);
irTmpTable[temp] = cmpInst;
left = temp;
irLeft = cmpInst;
}
return left;
}
std::any LLVMIRGenerator::visitEqExp(SysYParser::EqExpContext* ctx) {
auto relExps = ctx->relExp();
std::string left = std::any_cast<std::string>(relExps[0]->accept(this));
sysy::Value* irLeft = irTmpTable[left];
sysy::Type* irType = sysy::Type::getIntType(); // 比较结果为 i1
for (size_t i = 1; i < relExps.size(); ++i) {
std::string right = std::any_cast<std::string>(relExps[i]->accept(this));
sysy::Value* irRight = irTmpTable[right];
std::string op = ctx->children[2 * i - 1]->getText();
std::string temp = getNextTemp();
std::string type = tmpTable[left];
tmpTable[temp] = "i1";
// 文本 IR
if (op == "==") {
irStream << " " << temp << " = icmp eq " << type << " " << left << ", " << right << "\n";
} else if (op == "!=") {
irStream << " " << temp << " = icmp ne " << type << " " << left << ", " << right << "\n";
}
// SysY IR
sysy::IRBuilder builder(currentIRBlock);
sysy::Instruction::Kind kind = (type == "i32") ? (op == "==" ? sysy::Instruction::kICmpEQ : sysy::Instruction::kICmpNE)
: (op == "==" ? sysy::Instruction::kFCmpEQ : sysy::Instruction::kFCmpNE);
auto cmpInst = builder.createBinaryInst(kind, irType, irLeft, irRight, temp);
irTmpTable[temp] = cmpInst;
left = temp;
irLeft = cmpInst;
}
return left;
}
std::any LLVMIRGenerator::visitLAndExp(SysYParser::LAndExpContext* ctx) {
auto eqExps = ctx->eqExp();
std::string left = std::any_cast<std::string>(eqExps[0]->accept(this));
sysy::Value* irLeft = irTmpTable[left];
for (size_t i = 1; i < eqExps.size(); ++i) {
std::string falseLabel = "land.false." + std::to_string(tempCounter);
std::string endLabel = "land.end." + std::to_string(tempCounter++);
sysy::BasicBlock* falseBlock = currentIRFunction->addBasicBlock(falseLabel);
sysy::BasicBlock* endBlock = currentIRFunction->addBasicBlock(endLabel);
std::string temp = getNextTemp();
tmpTable[temp] = "i1";
// 文本 IR
irStream << " br i1 " << left << ", label %" << falseLabel << ", label %" << endLabel << "\n";
irStream << falseLabel << ":\n";
// SysY IR
sysy::IRBuilder builder(currentIRBlock);
builder.createCondBrInst(irLeft, falseBlock, endBlock, {}, {});
setIRPosition(falseBlock);
std::string right = std::any_cast<std::string>(eqExps[i]->accept(this));
sysy::Value* irRight = irTmpTable[right];
irStream << " " << temp << " = and i1 " << left << ", " << right << "\n";
irStream << " br label %" << endLabel << "\n";
irStream << endLabel << ":\n";
// SysY IR 逻辑与(通过基本块实现短路求值)
builder.setPosition(falseBlock, falseBlock->end());
auto andInst = builder.createBinaryInst(sysy::Instruction::kICmpEQ, sysy::Type::getIntType(), irLeft, irRight, temp);
builder.createUncondBrInst(endBlock, {});
irTmpTable[temp] = andInst;
left = temp;
irLeft = andInst;
setIRPosition(endBlock);
}
return left;
}
std::any LLVMIRGenerator::visitLOrExp(SysYParser::LOrExpContext* ctx) {
auto lAndExps = ctx->lAndExp();
std::string left = std::any_cast<std::string>(lAndExps[0]->accept(this));
sysy::Value* irLeft = irTmpTable[left];
for (size_t i = 1; i < lAndExps.size(); ++i) {
std::string trueLabel = "lor.true." + std::to_string(tempCounter);
std::string endLabel = "lor.end." + std::to_string(tempCounter++);
sysy::BasicBlock* trueBlock = currentIRFunction->addBasicBlock(trueLabel);
sysy::BasicBlock* endBlock = currentIRFunction->addBasicBlock(endLabel);
std::string temp = getNextTemp();
tmpTable[temp] = "i1";
// 文本 IR
irStream << " br i1 " << left << ", label %" << trueLabel << ", label %" << endLabel << "\n";
irStream << trueLabel << ":\n";
// SysY IR
sysy::IRBuilder builder(currentIRBlock);
builder.createCondBrInst(irLeft, trueBlock, endBlock, {}, {});
setIRPosition(trueBlock);
std::string right = std::any_cast<std::string>(lAndExps[i]->accept(this));
sysy::Value* irRight = irTmpTable[right];
irStream << " " << temp << " = or i1 " << left << ", " << right << "\n";
irStream << " br label %" << endLabel << "\n";
irStream << endLabel << ":\n";
// SysY IR 逻辑或(通过基本块实现短路求值)
builder.setPosition(trueBlock, trueBlock->end());
auto orInst = builder.createBinaryInst(sysy::Instruction::kICmpEQ, sysy::Type::getIntType(), irLeft, irRight, temp);
builder.createUncondBrInst(endBlock, {});
irTmpTable[temp] = orInst;
left = temp;
irLeft = orInst;
setIRPosition(endBlock);
}
return left;
}
// } // namespace sysy

View File

@ -101,7 +101,10 @@ BLOCKCOMMENT: '/*' .*? '*/' -> skip;
// CompUnit: (CompUnit)? (decl |funcDef);
compUnit: (decl |funcDef)+;
compUnit: (globalDecl |funcDef)+;
globalDecl: constDecl # globalConstDecl
| varDecl # globalVarDecl;
decl: constDecl | varDecl;
@ -111,16 +114,16 @@ bType: INT | FLOAT;
constDef: Ident (LBRACK constExp RBRACK)* ASSIGN constInitVal;
constInitVal: constExp
| LBRACE (constInitVal (COMMA constInitVal)*)? RBRACE;
constInitVal: constExp # constScalarInitValue
| LBRACE (constInitVal (COMMA constInitVal)*)? RBRACE # constArrayInitValue;
varDecl: bType varDef (COMMA varDef)* SEMICOLON;
varDef: Ident (LBRACK constExp RBRACK)*
| Ident (LBRACK constExp RBRACK)* ASSIGN initVal;
initVal: exp
| LBRACE (initVal (COMMA initVal)*)? RBRACE;
initVal: exp # scalarInitValue
| LBRACE (initVal (COMMA initVal)*)? RBRACE # arrayInitValue;
funcType: VOID | INT | FLOAT;
@ -150,15 +153,16 @@ cond: lOrExp;
lValue: Ident (LBRACK exp RBRACK)*;
// 为了方便测试 primaryExp 可以是一个string
primaryExp: LPAREN exp RPAREN #parenExp
| lValue #lVal
| number #num
| string #str;
primaryExp: LPAREN exp RPAREN
| lValue
| number
| string;
number: ILITERAL | FLITERAL;
unaryExp: primaryExp #primExp
| Ident LPAREN (funcRParams)? RPAREN #call
| unaryOp unaryExp #unExp;
call: Ident LPAREN (funcRParams)? RPAREN;
unaryExp: primaryExp
| call
| unaryOp unaryExp;
unaryOp: ADD|SUB|NOT;
funcRParams: exp (COMMA exp)*;

0
src/SysYIRAnalyser.cpp Normal file
View File

File diff suppressed because it is too large Load Diff

View File

@ -1,149 +0,0 @@
#pragma once
#include "IR.h"
#include "IRBuilder.h"
#include "SysYBaseVisitor.h"
#include "SysYParser.h"
#include <memory>
#include <unordered_map>
#include <forward_list>
namespace sysy {
class SymbolTable{
private:
enum Kind
{
kModule,
kFunction,
kBlock,
};
std::forward_list<std::pair<Kind, std::unordered_map<std::string, Value*>>> Scopes;
public:
struct ModuleScope {
SymbolTable& tables_ref;
ModuleScope(SymbolTable& tables) : tables_ref(tables) {
tables.enter(kModule);
}
~ModuleScope() { tables_ref.exit(); }
};
struct FunctionScope {
SymbolTable& tables_ref;
FunctionScope(SymbolTable& tables) : tables_ref(tables) {
tables.enter(kFunction);
}
~FunctionScope() { tables_ref.exit(); }
};
struct BlockScope {
SymbolTable& tables_ref;
BlockScope(SymbolTable& tables) : tables_ref(tables) {
tables.enter(kBlock);
}
~BlockScope() { tables_ref.exit(); }
};
SymbolTable() = default;
bool isModuleScope() const { return Scopes.front().first == kModule; }
bool isFunctionScope() const { return Scopes.front().first == kFunction; }
bool isBlockScope() const { return Scopes.front().first == kBlock; }
Value *lookup(const std::string &name) const {
for (auto &scope : Scopes) {
auto iter = scope.second.find(name);
if (iter != scope.second.end())
return iter->second;
}
return nullptr;
}
auto insert(const std::string &name, Value *value) {
assert(not Scopes.empty());
return Scopes.front().second.emplace(name, value);
}
private:
void enter(Kind kind) {
Scopes.emplace_front();
Scopes.front().first = kind;
}
void exit() {
Scopes.pop_front();
}
};
class SysYIRGenerator : public SysYBaseVisitor {
private:
std::unique_ptr<Module> module;
IRBuilder builder;
SymbolTable symbols_table;
int trueBlockNum = 0, falseBlockNum = 0;
int d = 0, n = 0;
vector<int> path;
bool isalloca;
AllocaInst *current_alloca;
GlobalValue *current_global;
Type* current_type;
int numdims = 0;
public:
SysYIRGenerator() = default;
public:
Module *get() const { return module.get(); }
public:
std::any visitCompUnit(SysYParser::CompUnitContext *ctx) override;
std::any visitDecl(SysYParser::DeclContext *ctx) override;
std::any visitConstDecl(SysYParser::ConstDeclContext *ctx) override;
std::any visitBType(SysYParser::BTypeContext *ctx) override;
std::any visitConstDef(SysYParser::ConstDefContext *ctx) override;
std::any visitConstInitVal(SysYParser::ConstInitValContext *ctx) override;
std::any visitFuncType(SysYParser::FuncTypeContext* ctx) override;
std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override;
std::any visitVarDecl(SysYParser::VarDeclContext *ctx) override;
std::any visitVarDef(SysYParser::VarDefContext *ctx) override;
std::any visitInitVal(SysYParser::InitValContext *ctx) override;
// std::any visitFuncFParam(SysYParser::FuncFParamContext *ctx) override;
// std::any visitFuncFParams(SysYParser::FuncFParamsContext *ctx) override;
std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override;
// std::any visitStmt(SysYParser::StmtContext *ctx) override;
std::any visitAssignStmt(SysYParser::AssignStmtContext *ctx) override;
std::any visitIfStmt(SysYParser::IfStmtContext *ctx) override;
std::any visitWhileStmt(SysYParser::WhileStmtContext *ctx) override;
std::any visitBreakStmt(SysYParser::BreakStmtContext *ctx) override;
std::any visitContinueStmt(SysYParser::ContinueStmtContext *ctx) override;
std::any visitReturnStmt(SysYParser::ReturnStmtContext *ctx) override;
// std::any visitExp(SysYParser::ExpContext *ctx) override;
std::any visitLValue(SysYParser::LValueContext *ctx) override;
std::any visitPrimExp(SysYParser::PrimExpContext *ctx) override;
// std::any visitParenExp(SysYParser::ParenExpContext *ctx) override;
std::any visitNumber(SysYParser::NumberContext *ctx) override;
// std::any visitString(SysYParser::StringContext *ctx) override;
std::any visitCall(SysYParser::CallContext *ctx) override;
// std::any visitUnaryExp(SysYParser::UnaryExpContext *ctx) override;
// std::any visitUnaryOp(SysYParser::UnaryOpContext *ctx) override;
std::any visitUnExp(SysYParser::UnExpContext *ctx) override;
// std::any visitFuncRParams(SysYParser::FuncRParamsContext *ctx) override;
std::any visitMulExp(SysYParser::MulExpContext *ctx) override;
std::any visitAddExp(SysYParser::AddExpContext *ctx) override;
std::any visitRelExp(SysYParser::RelExpContext *ctx) override;
std::any visitEqExp(SysYParser::EqExpContext *ctx) override;
std::any visitLAndExp(SysYParser::LAndExpContext *ctx) override;
std::any visitLOrExp(SysYParser::LOrExpContext *ctx) override;
std::any visitConstExp(SysYParser::ConstExpContext *ctx) override;
private:
std::any visitConstGlobalDecl(SysYParser::ConstDeclContext *ctx, Type* type);
std::any visitVarGlobalDecl(SysYParser::VarDeclContext *ctx, Type* type);
std::any visitConstLocalDecl(SysYParser::ConstDeclContext *ctx, Type* type);
std::any visitVarLocalDecl(SysYParser::VarDeclContext *ctx, Type* type);
Type *getArithmeticResultType(Type *lhs, Type *rhs) {
assert(lhs->isIntOrFloat() and rhs->isIntOrFloat());
return lhs == rhs ? lhs : Type::getFloatType();
}
}; // class SysYIRGenerator
} // namespace sysy

1711
src/include/IR.h Normal file

File diff suppressed because it is too large Load Diff

349
src/include/IRBuilder.h Normal file
View File

@ -0,0 +1,349 @@
#pragma once
#include <cassert>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
#include "IR.h"
/**
* @file IRBuilder.h
*
* @brief 定义IR构建器的头文件
*/
namespace sysy {
/**
* @brief 中间IR的构建器
*
*/
class IRBuilder {
private:
unsigned labelIndex; ///< 基本块标签编号
unsigned tmpIndex; ///< 临时变量编号
BasicBlock *block; ///< 当前基本块
BasicBlock::iterator position; ///< 当前基本块指令列表位置的迭代器
std::vector<BasicBlock *> trueBlocks; ///< true分支基本块列表
std::vector<BasicBlock *> falseBlocks; ///< false分支基本块列表
std::vector<BasicBlock *> breakBlocks; ///< break目标块列表
std::vector<BasicBlock *> continueBlocks; ///< continue目标块列表
public:
IRBuilder() : labelIndex(0), tmpIndex(0), block(nullptr) {}
explicit IRBuilder(BasicBlock *block) : labelIndex(0), tmpIndex(0), block(block), position(block->end()) {}
IRBuilder(BasicBlock *block, BasicBlock::iterator position)
: labelIndex(0), tmpIndex(0), block(block), position(position) {}
public:
unsigned getLabelIndex() {
labelIndex += 1;
return labelIndex - 1;
} ///< 获取基本块标签编号
unsigned getTmpIndex() {
tmpIndex += 1;
return tmpIndex - 1;
} ///< 获取临时变量编号
BasicBlock * getBasicBlock() const { return block; } ///< 获取当前基本块
BasicBlock * getBreakBlock() const { return breakBlocks.back(); } ///< 获取break目标块
BasicBlock * popBreakBlock() {
auto result = breakBlocks.back();
breakBlocks.pop_back();
return result;
} ///< 弹出break目标块
BasicBlock * getContinueBlock() const { return continueBlocks.back(); } ///< 获取continue目标块
BasicBlock * popContinueBlock() {
auto result = continueBlocks.back();
continueBlocks.pop_back();
return result;
} ///< 弹出continue目标块
BasicBlock * getTrueBlock() const { return trueBlocks.back(); } ///< 获取true分支基本块
BasicBlock * getFalseBlock() const { return falseBlocks.back(); } ///< 获取false分支基本块
BasicBlock * popTrueBlock() {
auto result = trueBlocks.back();
trueBlocks.pop_back();
return result;
} ///< 弹出true分支基本块
BasicBlock * popFalseBlock() {
auto result = falseBlocks.back();
falseBlocks.pop_back();
return result;
} ///< 弹出false分支基本块
BasicBlock::iterator getPosition() const { return position; } ///< 获取当前基本块指令列表位置的迭代器
void setPosition(BasicBlock *block, BasicBlock::iterator position) {
this->block = block;
this->position = position;
} ///< 设置基本块和基本块指令列表位置的迭代器
void setPosition(BasicBlock::iterator position) {
this->position = position;
} ///< 设置当前基本块指令列表位置的迭代器
void pushBreakBlock(BasicBlock *block) { breakBlocks.push_back(block); } ///< 压入break目标基本块
void pushContinueBlock(BasicBlock *block) { continueBlocks.push_back(block); } ///< 压入continue目标基本块
void pushTrueBlock(BasicBlock *block) { trueBlocks.push_back(block); } ///< 压入true分支基本块
void pushFalseBlock(BasicBlock *block) { falseBlocks.push_back(block); } ///< 压入false分支基本块
public:
Instruction * insertInst(Instruction *inst) {
assert(inst);
block->getInstructions().emplace(position, inst);
return inst;
} ///< 插入指令
UnaryInst * createUnaryInst(Instruction::Kind kind, Type *type, Value *operand, const std::string &name = "") {
std::string newName;
if (name.empty()) {
std::stringstream ss;
ss << "%" << tmpIndex;
newName = ss.str();
tmpIndex++;
} else {
newName = name;
}
auto inst = new UnaryInst(kind, type, operand, block, newName);
assert(inst);
block->getInstructions().emplace(position, inst);
return inst;
} ///< 创建一元指令
UnaryInst * createNegInst(Value *operand, const std::string &name = "") {
return createUnaryInst(Instruction::kNeg, Type::getIntType(), operand, name);
} ///< 创建取反指令
UnaryInst * createNotInst(Value *operand, const std::string &name = "") {
return createUnaryInst(Instruction::kNot, Type::getIntType(), operand, name);
} ///< 创建取非指令
UnaryInst * createFtoIInst(Value *operand, const std::string &name = "") {
return createUnaryInst(Instruction::kFtoI, Type::getIntType(), operand, name);
} ///< 创建浮点转整型指令
UnaryInst * createBitFtoIInst(Value *operand, const std::string &name = "") {
return createUnaryInst(Instruction::kBitFtoI, Type::getIntType(), operand, name);
} ///< 创建按位浮点转整型指令
UnaryInst * createFNegInst(Value *operand, const std::string &name = "") {
return createUnaryInst(Instruction::kFNeg, Type::getFloatType(), operand, name);
} ///< 创建浮点取反指令
UnaryInst * createFNotInst(Value *operand, const std::string &name = "") {
return createUnaryInst(Instruction::kFNot, Type::getIntType(), operand, name);
} ///< 创建浮点取非指令
UnaryInst * createIToFInst(Value *operand, const std::string &name = "") {
return createUnaryInst(Instruction::kItoF, Type::getFloatType(), operand, name);
} ///< 创建整型转浮点指令
UnaryInst * createBitItoFInst(Value *operand, const std::string &name = "") {
return createUnaryInst(Instruction::kBitItoF, Type::getFloatType(), operand, name);
} ///< 创建按位整型转浮点指令
BinaryInst * createBinaryInst(Instruction::Kind kind, Type *type, Value *lhs, Value *rhs, const std::string &name = "") {
std::string newName;
if (name.empty()) {
std::stringstream ss;
ss << "%" << tmpIndex;
newName = ss.str();
tmpIndex++;
} else {
newName = name;
}
auto inst = new BinaryInst(kind, type, lhs, rhs, block, newName);
assert(inst);
block->getInstructions().emplace(position, inst);
return inst;
} ///< 创建二元指令
BinaryInst * createAddInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kAdd, Type::getIntType(), lhs, rhs, name);
} ///< 创建加法指令
BinaryInst * createSubInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kSub, Type::getIntType(), lhs, rhs, name);
} ///< 创建减法指令
BinaryInst * createMulInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kMul, Type::getIntType(), lhs, rhs, name);
} ///< 创建乘法指令
BinaryInst * createDivInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kDiv, Type::getIntType(), lhs, rhs, name);
} ///< 创建除法指令
BinaryInst * createRemInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kRem, Type::getIntType(), lhs, rhs, name);
} ///< 创建取余指令
BinaryInst * createICmpEQInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kICmpEQ, Type::getIntType(), lhs, rhs, name);
} ///< 创建相等设置指令
BinaryInst * createICmpNEInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kICmpNE, Type::getIntType(), lhs, rhs, name);
} ///< 创建不相等设置指令
BinaryInst * createICmpLTInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kICmpLT, Type::getIntType(), lhs, rhs, name);
} ///< 创建小于设置指令
BinaryInst * createICmpLEInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kICmpLE, Type::getIntType(), lhs, rhs, name);
} ///< 创建小于等于设置指令
BinaryInst * createICmpGTInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kICmpGT, Type::getIntType(), lhs, rhs, name);
} ///< 创建大于设置指令
BinaryInst * createICmpGEInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kICmpGE, Type::getIntType(), lhs, rhs, name);
} ///< 创建大于等于设置指令
BinaryInst * createFAddInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kFAdd, Type::getFloatType(), lhs, rhs, name);
} ///< 创建浮点加法指令
BinaryInst * createFSubInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kFSub, Type::getFloatType(), lhs, rhs, name);
} ///< 创建浮点减法指令
BinaryInst * createFMulInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kFMul, Type::getFloatType(), lhs, rhs, name);
} ///< 创建浮点乘法指令
BinaryInst * createFDivInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kFDiv, Type::getFloatType(), lhs, rhs, name);
} ///< 创建浮点除法指令
BinaryInst * createFCmpEQInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kFCmpEQ, Type::getIntType(), lhs, rhs, name);
} ///< 创建浮点相等设置指令
BinaryInst * createFCmpNEInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kFCmpNE, Type::getIntType(), lhs, rhs, name);
} ///< 创建浮点不相等设置指令
BinaryInst * createFCmpLTInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kFCmpLT, Type::getIntType(), lhs, rhs, name);
} ///< 创建浮点小于设置指令
BinaryInst * createFCmpLEInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kFCmpLE, Type::getIntType(), lhs, rhs, name);
} ///< 创建浮点小于等于设置指令
BinaryInst * createFCmpGTInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kFCmpGT, Type::getIntType(), lhs, rhs, name);
} ///< 创建浮点大于设置指令
BinaryInst * createFCmpGEInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kFCmpGE, Type::getIntType(), lhs, rhs, name);
} ///< 创建浮点相大于等于设置指令
BinaryInst * createAndInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kAnd, Type::getIntType(), lhs, rhs, name);
} ///< 创建按位且指令
BinaryInst * createOrInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kOr, Type::getIntType(), lhs, rhs, name);
} ///< 创建按位或指令
CallInst * createCallInst(Function *callee, const std::vector<Value *> &args, const std::string &name = "") {
std::string newName;
if (name.empty() && callee->getReturnType() != Type::getVoidType()) {
std::stringstream ss;
ss << "%" << tmpIndex;
newName = ss.str();
tmpIndex++;
} else {
newName = name;
}
auto inst = new CallInst(callee, args, block, newName);
assert(inst);
block->getInstructions().emplace(position, inst);
return inst;
} ///< 创建Call指令
ReturnInst * createReturnInst(Value *value = nullptr, const std::string &name = "") {
auto inst = new ReturnInst(value, block, name);
assert(inst);
block->getInstructions().emplace(position, inst);
return inst;
} ///< 创建return指令
UncondBrInst * createUncondBrInst(BasicBlock *thenBlock, const std::vector<Value *> &args) {
auto inst = new UncondBrInst(thenBlock, args, block);
assert(inst);
block->getInstructions().emplace(position, inst);
return inst;
} ///< 创建无条件指令
CondBrInst * createCondBrInst(Value *condition, BasicBlock *thenBlock, BasicBlock *elseBlock,
const std::vector<Value *> &thenArgs, const std::vector<Value *> &elseArgs) {
auto inst = new CondBrInst(condition, thenBlock, elseBlock, thenArgs, elseArgs, block);
assert(inst);
block->getInstructions().emplace(position, inst);
return inst;
} ///< 创建条件跳转指令
AllocaInst * createAllocaInst(Type *type, const std::vector<Value *> &dims = {}, const std::string &name = "") {
auto inst = new AllocaInst(type, dims, block, name);
assert(inst);
block->getInstructions().emplace(position, inst);
return inst;
} ///< 创建分配指令
AllocaInst * createAllocaInstWithoutInsert(Type *type, const std::vector<Value *> &dims = {}, BasicBlock *parent = nullptr,
const std::string &name = "") {
auto inst = new AllocaInst(type, dims, parent, name);
assert(inst);
return inst;
} ///< 创建不插入指令列表的分配指令
LoadInst * createLoadInst(Value *pointer, const std::vector<Value *> &indices = {}, const std::string &name = "") {
std::string newName;
if (name.empty()) {
std::stringstream ss;
ss << "%" << tmpIndex;
newName = ss.str();
tmpIndex++;
} else {
newName = name;
}
auto inst = new LoadInst(pointer, indices, block, newName);
assert(inst);
block->getInstructions().emplace(position, inst);
return inst;
} ///< 创建load指令
LaInst * createLaInst(Value *pointer, const std::vector<Value *> &indices = {}, const std::string &name = "") {
std::string newName;
if (name.empty()) {
std::stringstream ss;
ss << "%" << tmpIndex;
newName = ss.str();
tmpIndex++;
} else {
newName = name;
}
auto inst = new LaInst(pointer, indices, block, newName);
assert(inst);
block->getInstructions().emplace(position, inst);
return inst;
} ///< 创建la指令
GetSubArrayInst * createGetSubArray(LVal *fatherArray, const std::vector<Value *> &indices, const std::string &name = "") {
assert(fatherArray->getLValNumDims() > indices.size());
std::vector<Value *> subDims;
auto dims = fatherArray->getLValDims();
auto iter = std::next(dims.begin(), indices.size());
while (iter != dims.end()) {
subDims.emplace_back(*iter);
iter++;
}
std::string childArrayName;
std::stringstream ss;
ss << "A"
<< "%" << tmpIndex;
childArrayName = ss.str();
tmpIndex++;
auto fatherArrayValue = dynamic_cast<Value *>(fatherArray);
auto childArray = new AllocaInst(fatherArrayValue->getType(), subDims, block, childArrayName);
auto inst = new GetSubArrayInst(fatherArray, childArray, indices, block, name);
assert(inst);
block->getInstructions().emplace(position, inst);
return inst;
} ///< 创建获取部分数组指令
MemsetInst * createMemsetInst(Value *pointer, Value *begin, Value *size, Value *value, const std::string &name = "") {
auto inst = new MemsetInst(pointer, begin, size, value, block, name);
assert(inst);
block->getInstructions().emplace(position, inst);
return inst;
} ///< 创建memset指令
StoreInst * createStoreInst(Value *value, Value *pointer, const std::vector<Value *> &indices = {},
const std::string &name = "") {
auto inst = new StoreInst(value, pointer, indices, block, name);
assert(inst);
block->getInstructions().emplace(position, inst);
return inst;
} ///< 创建store指令
PhiInst * createPhiInst(Type *type, Value *lhs, BasicBlock *parent, const std::string &name = "") {
auto predNum = parent->getNumPredecessors();
std::vector<Value *> rhs;
for (size_t i = 0; i < predNum; i++) {
rhs.push_back(lhs);
}
auto inst = new PhiInst(type, lhs, rhs, lhs, parent, name);
assert(inst);
parent->getInstructions().emplace(parent->begin(), inst);
return inst;
} ///< 创建Phi指令
};
} // namespace sysy

View File

@ -0,0 +1,78 @@
#pragma once
#include "SysYBaseVisitor.h"
#include "SysYParser.h"
#include "IR.h"
#include "IRBuilder.h"
#include <sstream>
#include <map>
#include <vector>
#include <stack>
class LLVMIRGenerator : public SysYBaseVisitor {
public:
sysy::Module* getIRModule() const { return irModule.get(); }
std::string generateIR(SysYParser::CompUnitContext* unit);
std::string getIR() const { return irStream.str(); }
private:
std::unique_ptr<sysy::Module> irModule; // IR数据结构
std::stringstream irStream; // 文本输出流
sysy::IRBuilder irBuilder; // IR构建器
int tempCounter = 0;
std::string currentVarType;
// std::map<std::string, sysy::Value*> symbolTable;
std::map<std::string, std::pair<std::string, std::string>> symbolTable;
std::map<std::string, std::string> tmpTable;
std::vector<std::string> globalVars;
std::string currentFunction;
std::string currentReturnType;
std::vector<std::string> breakStack;
std::vector<std::string> continueStack;
bool hasReturn = false;
struct LoopLabels {
std::string breakLabel; // break跳转的目标标签
std::string continueLabel; // continue跳转的目标标签
};
std::stack<LoopLabels> loopStack; // 用于管理循环的break和continue标签
std::string getNextTemp();
std::string getLLVMType(const std::string&);
sysy::Type* getSysYType(const std::string&);
bool inFunction = false; // 标识当前是否处于函数内部
// 访问方法
std::any visitCompUnit(SysYParser::CompUnitContext* ctx);
std::any visitConstDecl(SysYParser::ConstDeclContext* ctx);
std::any visitVarDecl(SysYParser::VarDeclContext* ctx);
std::any visitVarDef(SysYParser::VarDefContext* ctx);
std::any visitFuncDef(SysYParser::FuncDefContext* ctx);
std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx);
// std::any visitStmt(SysYParser::StmtContext* ctx);
std::any visitLValue(SysYParser::LValueContext* ctx);
std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx);
std::any visitPrimExp(SysYParser::PrimExpContext* ctx);
std::any visitParenExp(SysYParser::ParenExpContext* ctx);
std::any visitNumber(SysYParser::NumberContext* ctx);
std::any visitString(SysYParser::StringContext* ctx);
std::any visitCall(SysYParser::CallContext *ctx);
std::any visitUnExp(SysYParser::UnExpContext* ctx);
std::any visitMulExp(SysYParser::MulExpContext* ctx);
std::any visitAddExp(SysYParser::AddExpContext* ctx);
std::any visitRelExp(SysYParser::RelExpContext* ctx);
std::any visitEqExp(SysYParser::EqExpContext* ctx);
std::any visitLAndExp(SysYParser::LAndExpContext* ctx);
std::any visitLOrExp(SysYParser::LOrExpContext* ctx);
std::any visitAssignStmt(SysYParser::AssignStmtContext *ctx) override;
std::any visitIfStmt(SysYParser::IfStmtContext *ctx) override;
std::any visitWhileStmt(SysYParser::WhileStmtContext *ctx) override;
std::any visitBreakStmt(SysYParser::BreakStmtContext *ctx) override;
std::any visitContinueStmt(SysYParser::ContinueStmtContext *ctx) override;
std::any visitReturnStmt(SysYParser::ReturnStmtContext *ctx) override;
// 统一创建二元操作(同时生成数据结构和文本)
sysy::Value* createBinaryOp(SysYParser::ExpContext* lhs,
SysYParser::ExpContext* rhs,
sysy::Instruction::Kind opKind);
};

View File

View File

@ -0,0 +1,138 @@
#pragma once
#include "IR.h"
#include "IRBuilder.h"
#include "SysYBaseVisitor.h"
#include "SysYParser.h"
#include <memory>
#include <unordered_map>
#include <forward_list>
namespace sysy {
// @brief 用于存储数组值的树结构
// 多位数组本质上是一维数组的嵌套可以用树来表示。
class ArrayValueTree {
private:
Value *value = nullptr; /// 该节点存储的value
std::vector<std::unique_ptr<ArrayValueTree>> children; /// 子节点列表
public:
ArrayValueTree() = default;
public:
auto getValue() const -> Value * { return value; }
auto getChildren() const
-> const std::vector<std::unique_ptr<ArrayValueTree>> & {
return children;
}
void setValue(Value *newValue) { value = newValue; }
void addChild(ArrayValueTree *newChild) { children.emplace_back(newChild); }
void addChildren(const std::vector<ArrayValueTree *> &newChildren) {
for (const auto &child : newChildren) {
children.emplace_back(child);
}
}
};
class Utils {
public:
// transform a tree of ArrayValueTree to a ValueCounter
static void tree2Array(Type *type, ArrayValueTree *root,
const std::vector<Value *> &dims, unsigned numDims,
ValueCounter &result, IRBuilder *builder);
static void
createExternalFunction(const std::vector<Type *> &paramTypes,
const std::vector<std::string> &paramNames,
const std::vector<std::vector<Value *>> &paramDims,
Type *returnType, const std::string &funcName,
Module *pModule, IRBuilder *pBuilder);
static void initExternalFunction(Module *pModule, IRBuilder *pBuilder);
};
class SysYIRGenerator : public SysYBaseVisitor {
private:
std::unique_ptr<Module> module;
IRBuilder builder;
public:
SysYIRGenerator() = default;
public:
Module *get() const { return module.get(); }
IRBuilder *getBuilder(){ return &builder; }
public:
std::any visitCompUnit(SysYParser::CompUnitContext *ctx) override;
std::any visitGlobalConstDecl(SysYParser::GlobalConstDeclContext *ctx) override;
std::any visitGlobalVarDecl(SysYParser::GlobalVarDeclContext *ctx) override;
// std::any visitDecl(SysYParser::DeclContext *ctx) override;
std::any visitConstDecl(SysYParser::ConstDeclContext *ctx) override;
std::any visitVarDecl(SysYParser::VarDeclContext *ctx) override;
std::any visitBType(SysYParser::BTypeContext *ctx) override;
// std::any visitConstDef(SysYParser::ConstDefContext *ctx) override;
// std::any visitVarDef(SysYParser::VarDefContext *ctx) override;
std::any visitScalarInitValue(SysYParser::ScalarInitValueContext *ctx) override;
std::any visitArrayInitValue(SysYParser::ArrayInitValueContext *ctx) override;
std::any visitConstScalarInitValue(SysYParser::ConstScalarInitValueContext *ctx) override;
std::any visitConstArrayInitValue(SysYParser::ConstArrayInitValueContext *ctx) override;
// std::any visitConstInitVal(SysYParser::ConstInitValContext *ctx) override;
std::any visitFuncType(SysYParser::FuncTypeContext* ctx) override;
std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override;
// std::any visitInitVal(SysYParser::InitValContext *ctx) override;
// std::any visitFuncFParam(SysYParser::FuncFParamContext *ctx) override;
// std::any visitFuncFParams(SysYParser::FuncFParamsContext *ctx) override;
std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override;
// std::any visitStmt(SysYParser::StmtContext *ctx) override;
std::any visitAssignStmt(SysYParser::AssignStmtContext *ctx) override;
// std::any visitExpStmt(SysYParser::ExpStmtContext *ctx) override;
// std::any visitBlkStmt(SysYParser::BlkStmtContext *ctx) override;
std::any visitIfStmt(SysYParser::IfStmtContext *ctx) override;
std::any visitWhileStmt(SysYParser::WhileStmtContext *ctx) override;
std::any visitBreakStmt(SysYParser::BreakStmtContext *ctx) override;
std::any visitContinueStmt(SysYParser::ContinueStmtContext *ctx) override;
std::any visitReturnStmt(SysYParser::ReturnStmtContext *ctx) override;
// std::any visitExp(SysYParser::ExpContext *ctx) override;
// std::any visitCond(SysYParser::CondContext *ctx) override;
std::any visitLValue(SysYParser::LValueContext *ctx) override;
std::any visitPrimaryExp(SysYParser::PrimaryExpContext *ctx) override;
// std::any visitParenExp(SysYParser::ParenExpContext *ctx) override;
std::any visitNumber(SysYParser::NumberContext *ctx) override;
// std::any visitString(SysYParser::StringContext *ctx) override;
std::any visitCall(SysYParser::CallContext *ctx) override;
std::any visitUnaryExp(SysYParser::UnaryExpContext *ctx) override;
// std::any visitUnaryOp(SysYParser::UnaryOpContext *ctx) override;
// std::any visitUnExp(SysYParser::UnExpContext *ctx) override;
std::any visitFuncRParams(SysYParser::FuncRParamsContext *ctx) override;
std::any visitMulExp(SysYParser::MulExpContext *ctx) override;
std::any visitAddExp(SysYParser::AddExpContext *ctx) override;
std::any visitRelExp(SysYParser::RelExpContext *ctx) override;
std::any visitEqExp(SysYParser::EqExpContext *ctx) override;
std::any visitLAndExp(SysYParser::LAndExpContext *ctx) override;
std::any visitLOrExp(SysYParser::LOrExpContext *ctx) override;
// std::any visitConstExp(SysYParser::ConstExpContext *ctx) override;
}; // class SysYIRGenerator
} // namespace sysy

View File

@ -6,8 +6,7 @@ using namespace std;
#include "SysYLexer.h"
#include "SysYParser.h"
using namespace antlr4;
#include "ASTPrinter.h"
#include "Backend.h"
// #include "Backend.h"
#include "SysYIRGenerator.h"
#include "RISCv32Backend.h"
using namespace sysy;
@ -70,12 +69,6 @@ int main(int argc, char **argv) {
return EXIT_SUCCESS;
}
// pretty format the input file
if (argFormat) {
ASTPrinter printer;
printer.visitCompUnit(moduleAST);
return EXIT_SUCCESS;
}
// visit AST to generate IR
SysYIRGenerator generator;