diff --git a/TODO.md b/TODO.md new file mode 100644 index 0000000..aae2f4e --- /dev/null +++ b/TODO.md @@ -0,0 +1,87 @@ +要打通从SysY到RISC-V的完整编译流程,以下是必须实现的核心模块和关键步骤(按编译流程顺序)。在你们当前IR生成阶段,可以优先实现这些基础模块来快速获得可工作的RISC-V汇编输出: + +### 1. **前端必须模块** +- **词法/语法分析**(已完成): + - `SysYLexer`/`SysYParser`:ANTLR生成的解析器 +- **IR生成核心**: + - `SysYIRGenerator`:将AST转换为中间表示(IR) + - `IRBuilder`:构建指令和基本块的工具类(你们正在实现的部分) + +### 2. **中端必要优化(最小集合)** +常量传播 +| 优化阶段 | 关键作用 | 是否必须 | +|-------------------|----------------------------------|----------| +| `Mem2Reg` | 消除冗余内存访问,转换为SSA形式 | ✅ 核心 | +| `DCE` (死代码消除) | 移除无用指令 | ✅ 必要 | +| `DFE` (死函数消除) | 移除未使用的函数 | ✅ 必要 | +| `FuncAnalysis` | 函数调用关系分析 | ✅ 基础 | +| `Global2Local` | 全局变量降级为局部变量 | ✅ 重要 | + +### 3. **后端核心流程(必须实现)** +```mermaid +graph LR + A[IR指令选择] --> B[寄存器分配] + B --> C[指令调度] + C --> D[汇编生成] +``` + +1. **指令选择**(关键步骤): + - `DAGBuilder`:将IR转换为有向无环图(DAG) + - `DAGCoverage`:DAG到目标指令的映射 + - `Mid2End`:IR到机器指令的转换接口 + +2. **寄存器分配**: + - `RegisterAlloc`:基础寄存器分配器(可先实现简单算法如线性扫描) + +3. **汇编生成**: + - `RiscvPrinter`:将机器指令输出为RISC-V汇编 + - 实现基础指令集:`add`/`sub`/`lw`/`sw`/`beq`/`jal`等 + +### 4. **最小可工作流程** +```cpp +// 精简版编译流程(跳过复杂优化) +int main() { + // 1. 前端解析 + auto module = sysy::SysYIRGenerator().genIR(input); + + // 2. 关键中端优化 + sysy::Mem2Reg(module).run(); // 必须 + sysy::Global2Local(module).run(); // 必须 + sysy::DCE(module).run(); // 推荐 + + // 3. 后端代码生成 + auto backendModule = mid2end::CodeGenerater().run(module); + riscv::RiscvPrinter().print("output.s", backendModule); +} +``` + +### 5. **当前开发优先级建议** +1. **完成IR生成**: + - 确保能构建基本块、函数、算术/内存/控制流指令 + - 实现`createCall`/`createLoad`/`createStore`等核心方法 + +2. **实现Mem2Reg**: + - 插入Phi节点 + - 变量重命名(关键算法) + +3. **构建基础后端**: + - 指令选择:实现IR到RISC-V的简单映射(例如:`IRAdd` → `add`) + - 寄存器分配:使用无限寄存器方案(后期替换为真实分配) + - 汇编打印:支持基础指令输出 + +> **注意**:循环优化、函数内联、高级寄存器分配等可在基础流程打通后逐步添加。初期可跳过复杂优化。 + +### 6. 调试建议 +- 添加IR打印模块(`SysYPrinter`)验证前端输出 +- 使用简化测试用例: + ```c + int main() { + int a = 1; + int b = a + 2; + return b; + } + ``` +- 逐步扩展支持: + 1. 算术运算 → 2. 条件分支 → 3. 函数调用 → 4. 数组访问 + +通过聚焦这些核心模块,你们可以快速打通从SysY到RISC-V的基础编译流程,后续再逐步添加优化传递提升代码质量。 \ No newline at end of file diff --git a/olddef.h b/olddef.h new file mode 100644 index 0000000..c7b9522 --- /dev/null +++ b/olddef.h @@ -0,0 +1,62 @@ + +class SymbolTable{ +private: + enum Kind + { + kModule, + kFunction, + kBlock, + }; + + std::forward_list>> Scopes; + +public: + struct ModuleScope { + SymbolTable& tables_ref; + ModuleScope(SymbolTable& tables) : tables_ref(tables) { + tables.enter(kModule); + } + ~ModuleScope() { tables_ref.exit(); } + }; + struct FunctionScope { + SymbolTable& tables_ref; + FunctionScope(SymbolTable& tables) : tables_ref(tables) { + tables.enter(kFunction); + } + ~FunctionScope() { tables_ref.exit(); } + }; + struct BlockScope { + SymbolTable& tables_ref; + BlockScope(SymbolTable& tables) : tables_ref(tables) { + tables.enter(kBlock); + } + ~BlockScope() { tables_ref.exit(); } + }; + + SymbolTable() = default; + + bool isModuleScope() const { return Scopes.front().first == kModule; } + bool isFunctionScope() const { return Scopes.front().first == kFunction; } + bool isBlockScope() const { return Scopes.front().first == kBlock; } + Value *lookup(const std::string &name) const { + for (auto &scope : Scopes) { + auto iter = scope.second.find(name); + if (iter != scope.second.end()) + return iter->second; + } + return nullptr; + } + auto insert(const std::string &name, Value *value) { + assert(not Scopes.empty()); + return Scopes.front().second.emplace(name, value); + } +private: + void enter(Kind kind) { + Scopes.emplace_front(); + Scopes.front().first = kind; + } + void exit() { + Scopes.pop_front(); + } + +}; \ No newline at end of file diff --git a/src/ASTPrinter.cpp b/src/ASTPrinter.cpp deleted file mode 100644 index a29383e..0000000 --- a/src/ASTPrinter.cpp +++ /dev/null @@ -1,355 +0,0 @@ -#include -#include -using namespace std; -#include "ASTPrinter.h" -#include "SysYParser.h" - - -any ASTPrinter::visitCompUnit(SysYParser::CompUnitContext *ctx) { - if(ctx->decl().empty() && ctx->funcDef().empty()) - return nullptr; - for (auto dcl : ctx->decl()) {dcl->accept(this);cout << '\n';}cout << '\n'; - for (auto func : ctx->funcDef()) {func->accept(this);cout << "\n";} - return nullptr; -} - -// std::any ASTPrinter::visitBType(SysYParser::BTypeContext *ctx); -// std::any ASTPrinter::visitDecl(SysYParser::DeclContext *ctx); - -std::any ASTPrinter::visitConstDecl(SysYParser::ConstDeclContext *ctx) { - cout << getIndent() << ctx->CONST()->getText() << ' ' << ctx->bType()->getText() << ' '; - auto numConstDefs = ctx->constDef().size(); - ctx->constDef(0)->accept(this); - for (int i = 1; i < numConstDefs; ++i) { - cout << ctx->COMMA(i - 1)->getText() << ' '; - ctx->constDef(i)->accept(this); - } - cout << ctx->SEMICOLON()->getText() << '\n'; - return nullptr; -} - -std::any ASTPrinter::visitConstDef(SysYParser::ConstDefContext *ctx) { - cout << ctx->Ident()->getText(); - auto numConstExps = ctx->constExp().size(); - for (int i = 0; i < numConstExps; ++i) { - cout << ctx->LBRACK(i)->getText(); - ctx->constExp(i)->accept(this); - cout << ctx->RBRACK(i)->getText(); - } - cout << ' ' << ctx->ASSIGN()->getText() << ' '; - ctx->constInitVal()->accept(this); - return nullptr; -} - -// std::any ASTPrinter::visitConstInitVal(SysYParser::ConstInitValContext *ctx); - -std::any ASTPrinter::visitVarDecl(SysYParser::VarDeclContext *ctx){ - cout << getIndent() << ctx->bType()->getText() << ' '; - auto numVarDefs = ctx->varDef().size(); - ctx->varDef(0)->accept(this); - for (int i = 1; i < numVarDefs; ++i) { - cout << ", "; - ctx->varDef(i)->accept(this); - } - cout << ctx->SEMICOLON()->getText() << '\n'; - return nullptr; -} - -std::any ASTPrinter::visitVarDef(SysYParser::VarDefContext *ctx){ - cout << ctx->Ident()->getText(); - auto numConstExps = ctx->constExp().size(); - for (int i = 0; i < numConstExps; ++i) { - cout << ctx->LBRACK(i)->getText(); - ctx->constExp(i)->accept(this); - cout << ctx->RBRACK(i)->getText(); - } - if (ctx->initVal()) { - cout << ' ' << ctx->ASSIGN()->getText() << ' '; - ctx->initVal()->accept(this); - } - return nullptr; -} - -std::any ASTPrinter::visitInitVal(SysYParser::InitValContext *ctx){ - if (ctx->exp()) { - ctx->exp()->accept(this); - } else { - cout << ctx->LBRACE()->getText(); - auto numInitVals = ctx->initVal().size(); - ctx->initVal(0)->accept(this); - for (int i = 1; i < numInitVals; ++i) { - cout << ctx->COMMA(i - 1)->getText() << ' '; - ctx->initVal(i)->accept(this); - } - cout << ctx->RBRACE()->getText(); - } - return nullptr; -} - -std::any ASTPrinter::visitFuncDef(SysYParser::FuncDefContext *ctx){ - cout << getIndent() << ctx->funcType()->getText() << ' ' << ctx->Ident()->getText(); - cout << ctx->LPAREN()->getText(); - if (ctx->funcFParams()) ctx->funcFParams()->accept(this); - if(ctx->RPAREN()) - cout << ctx->RPAREN()->getText(); - else - cout << ""; - ctx->blockStmt()->accept(this); - return nullptr; -} - -// std::any ASTPrinter::visitFuncType(SysYParser::FuncTypeContext *ctx); - -std::any ASTPrinter::visitFuncFParams(SysYParser::FuncFParamsContext *ctx){ - auto numFuncFParams = ctx->funcFParam().size(); - ctx->funcFParam(0)->accept(this); - for (int i = 1; i < numFuncFParams; ++i) { - cout << ctx->COMMA(i - 1)->getText() << ' '; - ctx->funcFParam(i)->accept(this); - } - return nullptr; -} - -std::any ASTPrinter::visitFuncFParam(SysYParser::FuncFParamContext *ctx){ - cout << ctx->bType()->getText() << ' ' << ctx->Ident()->getText(); - if (!ctx->exp().empty()) { - cout << "[]"; - for (auto exp : ctx->exp()) { - cout << '['; - exp->accept(this); - cout << ']'; - } - } - return nullptr; -} - -std::any ASTPrinter::visitBlockStmt(SysYParser::BlockStmtContext *ctx){ - cout << ctx->LBRACE()->getText() << endl; - indentLevel++; - for (auto item : ctx->blockItem()) item->accept(this); - indentLevel--; - cout << getIndent() << ctx->RBRACE()->getText() << endl; - return nullptr; -} -// std::any ASTPrinter::visitBlockItem(SysYParser::BlockItemContext *ctx); - -std::any ASTPrinter::visitAssignStmt(SysYParser::AssignStmtContext *ctx){ - cout << getIndent(); - ctx->lValue()->accept(this); - cout << ' ' << ctx->ASSIGN()->getText() << ' '; - ctx->exp()->accept(this); - cout << ctx->SEMICOLON()->getText() << '\n'; - return nullptr; -} - -std::any ASTPrinter::visitExpStmt(SysYParser::ExpStmtContext *ctx){ - cout << getIndent(); - if (ctx->exp()) { - ctx->exp()->accept(this); - } - cout << ctx->SEMICOLON()->getText() << '\n'; - return nullptr; -} - -std::any ASTPrinter::visitIfStmt(SysYParser::IfStmtContext *ctx){ - cout << getIndent() << ctx->IF()->getText() << ' ' << ctx->LPAREN()->getText(); - ctx->cond()->accept(this); - cout << ctx->RPAREN()->getText() << ' '; - //格式化有问题 - if(ctx->stmt(0)) { - ctx->stmt(0)->accept(this); - } - else { - cout << '{' << endl; - indentLevel++; - ctx->stmt(0)->accept(this); - indentLevel--; - cout << getIndent() << '}' << endl; - } - if (ctx->ELSE()) { - cout << getIndent() << ctx->ELSE()->getText() << ' '; - ctx->stmt(1)->accept(this); - } - return nullptr; -} - -std::any ASTPrinter::visitWhileStmt(SysYParser::WhileStmtContext *ctx){ - cout << getIndent() << ctx->WHILE()->getText() << ' ' << ctx->LPAREN()->getText(); - ctx->cond()->accept(this); - cout << ctx->RPAREN()->getText() << ' '; - ctx->stmt()->accept(this); - return nullptr; -} - -std::any ASTPrinter::visitBreakStmt(SysYParser::BreakStmtContext *ctx){ - cout << getIndent() << ctx->BREAK()->getText() << ctx->SEMICOLON()->getText() << '\n'; - return nullptr; -} - -std::any ASTPrinter::visitContinueStmt(SysYParser::ContinueStmtContext *ctx){ - cout << getIndent() << ctx->CONTINUE()->getText() << ctx->SEMICOLON()->getText() << '\n'; - return nullptr; -} - -std::any ASTPrinter::visitReturnStmt(SysYParser::ReturnStmtContext *ctx){ - cout << getIndent() << ctx->RETURN()->getText() << ' '; - if (ctx->exp()) { - ctx->exp()->accept(this); - } - cout << ctx->SEMICOLON()->getText() << '\n'; - return nullptr; -} - -// std::any ASTPrinter::visitExp(SysYParser::ExpContext *ctx); -// std::any ASTPrinter::visitCond(SysYParser::CondContext *ctx); -std::any ASTPrinter::visitLValue(SysYParser::LValueContext *ctx){ - cout << ctx->Ident()->getText(); - for (auto exp : ctx->exp()) { - cout << "["; - exp->accept(this); - cout << "]"; - } - return nullptr; -} -// std::any ASTPrinter::visitPrimaryExp(SysYParser::PrimaryExpContext *ctx); -std::any ASTPrinter::visitParenExp(SysYParser::ParenExpContext *ctx){ - cout << ctx->LPAREN()->getText(); - ctx->exp()->accept(this); - cout << ctx->RPAREN()->getText(); - return nullptr; -} - -std::any ASTPrinter::visitNumber(SysYParser::NumberContext *ctx) { - if(ctx->ILITERAL())cout << ctx->ILITERAL()->getText(); - if(ctx->FLITERAL())cout << ctx->FLITERAL()->getText(); - return nullptr; -} - -std::any ASTPrinter::visitString(SysYParser::StringContext *ctx) { - cout << ctx->STRING()->getText(); - return nullptr; -} -// std::any visitUnaryExp(SysYParser::UnaryExpContext *ctx); -// std::any ASTPrinter::visitUnaryOp(SysYParser::UnaryOpContext *ctx); -std::any ASTPrinter::visitCall(SysYParser::CallContext *ctx){ - cout << ctx->Ident()->getText() << ctx->LPAREN()->getText(); - if(ctx->funcRParams()) - ctx->funcRParams()->accept(this); - cout << ctx->RPAREN()->getText(); - return nullptr; -} - -any ASTPrinter::visitFuncRParams(SysYParser::FuncRParamsContext *ctx) { - if (ctx->exp().empty()) - return nullptr; - auto numParams = ctx->exp().size(); - ctx->exp(0)->accept(this); - for (int i = 1; i < numParams; ++i) { - cout << ctx->COMMA(i - 1)->getText() << ' '; - ctx->exp(i)->accept(this); - } - return nullptr; -} - -std::any ASTPrinter::visitMulExp(SysYParser::MulExpContext *ctx){ - auto unaryExps = ctx->unaryExp(); - if (unaryExps.size() == 1) { - unaryExps[0]->accept(this); - } else { - for (size_t i = 0; i < unaryExps.size() - 1; ++i) { - auto opNode = dynamic_cast(ctx->children[2 * i + 1]); - if (opNode) { - unaryExps[i]->accept(this); - cout << " " << opNode->getText() << " "; - } - } - unaryExps.back()->accept(this); - } - return nullptr; -} - -std::any ASTPrinter::visitAddExp(SysYParser::AddExpContext *ctx){ - auto mulExps = ctx->mulExp(); - if (mulExps.size() == 1) { - mulExps[0]->accept(this); - } else { - for (size_t i = 0; i < mulExps.size() - 1; ++i) { - auto opNode = dynamic_cast(ctx->children[2 * i + 1]); - if (opNode) { - mulExps[i]->accept(this); - cout << " " << opNode->getText() << " "; - } - } - mulExps.back()->accept(this); - } - return nullptr; -} -// 以下表达式待补全形式同addexp mulexp -std::any ASTPrinter::visitRelExp(SysYParser::RelExpContext *ctx){ - auto relExps = ctx->addExp(); - if (relExps.size() == 1) { - relExps[0]->accept(this); - } else { - for (size_t i = 0; i < relExps.size() - 1; ++i) { - auto opNode = dynamic_cast(ctx->children[2 * i + 1]); - if (opNode) { - relExps[i]->accept(this); - cout << " " << opNode->getText() << " "; - } - } - relExps.back()->accept(this); - } - return nullptr; -} -std::any ASTPrinter::visitEqExp(SysYParser::EqExpContext *ctx){ - auto eqExps = ctx->relExp(); - if (eqExps.size() == 1) { - eqExps[0]->accept(this); - } else { - for (size_t i = 0; i < eqExps.size() - 1; ++i) { - auto opNode = dynamic_cast(ctx->children[2 * i + 1]); - if (opNode) { - eqExps[i]->accept(this); - cout << " " << opNode->getText() << " "; - } - } - eqExps.back()->accept(this); - } - return nullptr; -} -std::any ASTPrinter::visitLAndExp(SysYParser::LAndExpContext *ctx){ - auto lAndExps = ctx->eqExp(); - if (lAndExps.size() == 1) { - lAndExps[0]->accept(this); - } else { - for (size_t i = 0; i < lAndExps.size() - 1; ++i) { - auto opNode = dynamic_cast(ctx->children[2 * i + 1]); - if (opNode) { - lAndExps[i]->accept(this); - cout << " " << opNode->getText() << " "; - } - } - lAndExps.back()->accept(this); - } - return nullptr; -} -std::any ASTPrinter::visitLOrExp(SysYParser::LOrExpContext *ctx){ - auto lOrExps = ctx->lAndExp(); - if (lOrExps.size() == 1) { - lOrExps[0]->accept(this); - } else { - for (size_t i = 0; i < lOrExps.size() - 1; ++i) { - auto opNode = dynamic_cast(ctx->children[2 * i + 1]); - if (opNode) { - lOrExps[i]->accept(this); - cout << " " << opNode->getText() << " "; - } - } - lOrExps.back()->accept(this); - } - return nullptr; -} -std::any ASTPrinter::visitConstExp(SysYParser::ConstExpContext *ctx){ - ctx->addExp()->accept(this); - return nullptr; -} \ No newline at end of file diff --git a/src/ASTPrinter.h b/src/ASTPrinter.h deleted file mode 100644 index 31b4863..0000000 --- a/src/ASTPrinter.h +++ /dev/null @@ -1,59 +0,0 @@ -#pragma once - -#include "SysYBaseVisitor.h" -#include "SysYParser.h" - -class ASTPrinter : public SysYBaseVisitor { -private: - int indentLevel = 0; - - std::string getIndent() { - return std::string(indentLevel * 4, ' '); - } -public: - std::any visitCompUnit(SysYParser::CompUnitContext *ctx) override; - // std::any visitBType(SysYParser::BTypeContext *ctx) override; - // std::any visitDecl(SysYParser::DeclContext *ctx) override; - std::any visitConstDecl(SysYParser::ConstDeclContext *ctx) override; - std::any visitConstDef(SysYParser::ConstDefContext *ctx) override; - // std::any visitConstInitVal(SysYParser::ConstInitValContext *ctx) override; - std::any visitVarDecl(SysYParser::VarDeclContext *ctx) override; - std::any visitVarDef(SysYParser::VarDefContext *ctx) override; - std::any visitInitVal(SysYParser::InitValContext *ctx) override; - std::any visitFuncDef(SysYParser::FuncDefContext *ctx) override; - // std::any visitFuncType(SysYParser::FuncTypeContext *ctx) override; - std::any visitFuncFParams(SysYParser::FuncFParamsContext *ctx) override; - std::any visitFuncFParam(SysYParser::FuncFParamContext *ctx) override; - std::any visitBlockStmt(SysYParser::BlockStmtContext *ctx) override; - - // std::any visitBlockItem(SysYParser::BlockItemContext *ctx) override; - // std::any visitStmt(SysYParser::StmtContext *ctx) override; - - std::any visitAssignStmt(SysYParser::AssignStmtContext *ctx) override; - std::any visitExpStmt(SysYParser::ExpStmtContext *ctx) override; - std::any visitIfStmt(SysYParser::IfStmtContext *ctx) override; - std::any visitWhileStmt(SysYParser::WhileStmtContext *ctx) override; - std::any visitBreakStmt(SysYParser::BreakStmtContext *ctx) override; - std::any visitContinueStmt(SysYParser::ContinueStmtContext *ctx) override; - std::any visitReturnStmt(SysYParser::ReturnStmtContext *ctx) override; - - // std::any visitExp(SysYParser::ExpContext *ctx) override; - // std::any visitCond(SysYParser::CondContext *ctx) override; - std::any visitLValue(SysYParser::LValueContext *ctx) override; - // std::any visitPrimaryExp(SysYParser::PrimaryExpContext *ctx) override; - std::any visitParenExp(SysYParser::ParenExpContext *ctx) override; - std::any visitNumber(SysYParser::NumberContext *ctx) override; - std::any visitString(SysYParser::StringContext *ctx) override; - // std::any visitUnaryExp(SysYParser::UnaryExpContext *ctx) override; - std::any visitCall(SysYParser::CallContext *ctx) override; - // std::any visitUnExpOp(SysYParser::UnExpContext *ctx) override; - // std::any visitUnaryOp(SysYParser::UnaryOpContext *ctx) override; - std::any visitFuncRParams(SysYParser::FuncRParamsContext *ctx) override; - std::any visitMulExp(SysYParser::MulExpContext *ctx) override; - std::any visitAddExp(SysYParser::AddExpContext *ctx) override; - std::any visitRelExp(SysYParser::RelExpContext *ctx) override; - std::any visitEqExp(SysYParser::EqExpContext *ctx) override; - std::any visitLAndExp(SysYParser::LAndExpContext *ctx) override; - std::any visitLOrExp(SysYParser::LOrExpContext *ctx) override; - std::any visitConstExp(SysYParser::ConstExpContext *ctx) override; -}; diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 6594390..8249b81 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -13,13 +13,12 @@ target_link_libraries(SysYParser PUBLIC antlr4_shared) add_executable(sysyc sysyc.cpp - ASTPrinter.cpp IR.cpp SysYIRGenerator.cpp Backend.cpp RISCv32Backend.cpp ) -target_include_directories(sysyc PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) +target_include_directories(sysyc PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/include) target_compile_options(sysyc PRIVATE -frtti) target_link_libraries(sysyc PRIVATE SysYParser) diff --git a/src/IR.cpp b/src/IR.cpp index 318b8c8..ebd61c1 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -1,154 +1,94 @@ #include "IR.h" -#include "range.h" #include #include -#include -#include -#include -#include -#include #include -#include +#include #include -#include -#include +#include #include -using namespace std; +#include "IRBuilder.h" +/** + * @file IR.cpp + * + * @brief 定义IR相关类型与操作的源文件 + */ namespace sysy { -template -ostream &interleave(ostream &os, const T &container, const string sep = ", ") { - auto b = container.begin(), e = container.end(); - if (b == e) - return os; - os << *b; - for (b = next(b); b != e; b = next(b)) - os << sep << *b; - return os; -} -static inline ostream &printVarName(ostream &os, const Value *var) { - return os << (dyncast(var) ? '@' : '%') - << var->getName(); -} -static inline ostream &printBlockName(ostream &os, const BasicBlock *block) { - return os << '^' << block->getName(); -} -static inline ostream &printFunctionName(ostream &os, const Function *fn) { - return os << '@' << fn->getName(); -} -static inline ostream &printOperand(ostream &os, const Value *value) { - auto constant = dyncast(value); - if (constant) { - constant->print(os); - return os; - } - return printVarName(os, value); -} //===----------------------------------------------------------------------===// // Types //===----------------------------------------------------------------------===// -Type *Type::getIntType() { +auto Type::getIntType() -> Type * { static Type intType(kInt); return &intType; } -Type *Type::getFloatType() { +auto Type::getFloatType() -> Type * { static Type floatType(kFloat); return &floatType; } -Type *Type::getVoidType() { +auto Type::getVoidType() -> Type * { static Type voidType(kVoid); return &voidType; } -Type *Type::getLabelType() { +auto Type::getLabelType() -> Type * { static Type labelType(kLabel); return &labelType; } -Type *Type::getPointerType(Type *baseType) { +auto Type::getPointerType(Type *baseType) -> Type * { // forward to PointerType return PointerType::get(baseType); } -Type *Type::getFunctionType(Type *returnType, - const vector ¶mTypes) { +auto Type::getFunctionType(Type *returnType, const std::vector ¶mTypes) -> Type * { // forward to FunctionType return FunctionType::get(returnType, paramTypes); } -int Type::getSize() const { +auto Type::getSize() const -> unsigned { switch (kind) { - case kInt: - case kFloat: - return 4; - case kLabel: - case kPointer: - case kFunction: - return 8; - case kVoid: - return 0; + case kInt: + case kFloat: + return 4; + case kLabel: + case kPointer: + case kFunction: + return 8; + case kVoid: + return 0; } return 0; } -void Type::print(ostream &os) const { - auto kind = getKind(); - switch (kind) { - case kInt: - os << "int"; - break; - case kFloat: - os << "float"; - break; - case kVoid: - os << "void"; - break; - case kPointer: - static_cast(this)->getBaseType()->print(os); - os << "*"; - break; - case kFunction: - static_cast(this)->getReturnType()->print(os); - os << "("; - interleave(os, static_cast(this)->getParamTypes()); - os << ")"; - break; - case kLabel: - default: - cerr << "Unexpected type!\n"; - break; - } -} - -PointerType *PointerType::get(Type *baseType) { +PointerType* PointerType::get(Type *baseType) { static std::map> pointerTypes; auto iter = pointerTypes.find(baseType); - if (iter != pointerTypes.end()) + if (iter != pointerTypes.end()) { return iter->second.get(); + } auto type = new PointerType(baseType); assert(type); auto result = pointerTypes.emplace(baseType, type); return result.first->second.get(); } -FunctionType *FunctionType::get(Type *returnType, - const std::vector ¶mTypes) { +FunctionType*FunctionType::get(Type *returnType, const std::vector ¶mTypes) { static std::set> functionTypes; auto iter = - std::find_if(functionTypes.begin(), functionTypes.end(), - [&](const std::unique_ptr &type) -> bool { - if (returnType != type->getReturnType() or - paramTypes.size() != type->getParamTypes().size()) - return false; - return std::equal(paramTypes.begin(), paramTypes.end(), - type->getParamTypes().begin()); - }); - if (iter != functionTypes.end()) + std::find_if(functionTypes.begin(), functionTypes.end(), [&](const std::unique_ptr &type) -> bool { + if (returnType != type->getReturnType() || + paramTypes.size() != static_cast(type->getParamTypes().size())) { + return false; + } + return std::equal(paramTypes.begin(), paramTypes.end(), type->getParamTypes().begin()); + }); + if (iter != functionTypes.end()) { return iter->get(); + } auto type = new FunctionType(returnType, paramTypes); assert(type); auto result = functionTypes.emplace(type); @@ -156,376 +96,624 @@ FunctionType *FunctionType::get(Type *returnType, } void Value::replaceAllUsesWith(Value *value) { - for (auto &use : uses) + for (auto &use : uses) { use->getUser()->setOperand(use->getIndex(), value); + } uses.clear(); } -bool Value::isConstant() const { - if (dyncast(this)) - return true; - if (dyncast(this) or - dyncast(this)) - return true; - // if (auto array = dyncast(this)) { - // auto elements = array->getValues(); - // return all_of(elements.begin(), elements.end(), - // [](Value *v) -> bool { return v->isConstant(); }); - // } - return false; -} - -ConstantValue *ConstantValue::get(int value) { +ConstantValue* ConstantValue::get(int value) { static std::map> intConstants; auto iter = intConstants.find(value); - if (iter != intConstants.end()) + if (iter != intConstants.end()) { return iter->second.get(); - auto constant = new ConstantValue(value); - assert(constant); - auto result = intConstants.emplace(value, constant); + } + auto inst = new ConstantValue(value); + assert(inst); + auto result = intConstants.emplace(value, inst); return result.first->second.get(); } -ConstantValue *ConstantValue::get(float value) { +ConstantValue* ConstantValue::get(float value) { static std::map> floatConstants; auto iter = floatConstants.find(value); - if (iter != floatConstants.end()) + if (iter != floatConstants.end()) { return iter->second.get(); - auto constant = new ConstantValue(value); - assert(constant); - auto result = floatConstants.emplace(value, constant); + } + auto inst = new ConstantValue(value); + assert(inst); + auto result = floatConstants.emplace(value, inst); return result.first->second.get(); } -void ConstantValue::print(ostream &os) const { - if (isInt()) - os << getInt(); - else - os << getFloat(); -} - -Argument::Argument(Type *type, BasicBlock *block, int index, - const std::string &name) - : Value(kArgument, type, name), block(block), index(index) { - if (not hasName()) - setName(to_string(block->getParent()->allocateVariableID())); -} - -void Argument::print(std::ostream &os) const { - assert(hasName()); - printVarName(os, this) << ": " << *getType(); -} - -BasicBlock::BasicBlock(Function *parent, const std::string &name) - : Value(kBasicBlock, Type::getLabelType(), name), parent(parent), - instructions(), arguments(), successors(), predecessors() { - if (not hasName()) - setName("bb" + to_string(getParent()->allocateblockID())); -} - -void BasicBlock::print(std::ostream &os) const { - assert(hasName()); - os << " "; - printBlockName(os, this); - auto args = getArguments(); - auto b = args.begin(), e = args.end(); - if (b != e) { - os << '('; - printVarName(os, b->get()) << ": " << *b->get()->getType(); - for (auto &arg : make_range(std::next(b), e)) { - os << ", "; - printVarName(os, arg.get()) << ": " << *arg->getType(); - } - os << ')'; - } - os << ":\n"; - for (auto &inst : instructions) { - os << " " << *inst << '\n'; - } -} - -Instruction::Instruction(Kind kind, Type *type, BasicBlock *parent, - const std::string &name) - : User(kind, type, name), kind(kind), parent(parent) { - if (not type->isVoid() and not hasName()) - setName(to_string(getFunction()->allocateVariableID())); -} - -void CallInst::print(std::ostream &os) const { - if (not getType()->isVoid()) - printVarName(os, this) << " = call "; - printFunctionName(os, getCallee()) << '('; - auto args = getArguments(); - auto b = args.begin(), e = args.end(); - if (b != e) { - printOperand(os, *b); - for (auto arg : make_range(std::next(b), e)) { - os << ", "; - printOperand(os, arg); +auto Function::getCalleesWithNoExternalAndSelf() -> std::set { + std::set result; + for (auto callee : callees) { + if (parent->getExternalFunctions().count(callee->getName()) == 0U && callee != this) { + result.insert(callee); } } - os << ") : " << *getType(); + return result; } -void UnaryInst::print(std::ostream &os) const { - printVarName(os, this) << " = "; - switch (getKind()) { - case kNeg: - os << "neg"; - break; - case kNot: - os << "not"; - break; - case kFNeg: - os << "fneg"; - break; - case kFtoI: - os << "ftoi"; - break; - case kIToF: - os << "itof"; - break; - default: - assert(false); +Function * Function::clone(const std::string &suffix) const { + std::stringstream ss; + std::map oldNewBlockMap; + IRBuilder builder; + auto newFunction = new Function(parent, type, name); + newFunction->getEntryBlock()->setName(blocks.front()->getName()); + oldNewBlockMap.emplace(blocks.front().get(), newFunction->getEntryBlock()); + auto oldBlockListIter = std::next(blocks.begin()); + while (oldBlockListIter != blocks.end()) { + auto newBlock = newFunction->addBasicBlock(oldBlockListIter->get()->getName()); + oldNewBlockMap.emplace(oldBlockListIter->get(), newBlock); + oldBlockListIter++; } - printOperand(os, getOperand()) << " : " << *getType(); -} -void BinaryInst::print(std::ostream &os) const { - printVarName(os, this) << " = "; - switch (getKind()) { - case kAdd: - os << "add"; - break; - case kSub: - os << "sub"; - break; - case kMul: - os << "mul"; - break; - case kDiv: - os << "div"; - break; - case kRem: - os << "rem"; - break; - case kICmpEQ: - os << "icmpeq"; - break; - case kICmpNE: - os << "icmpne"; - break; - case kICmpLT: - os << "icmplt"; - break; - case kICmpGT: - os << "icmpgt"; - break; - case kICmpLE: - os << "icmple"; - break; - case kICmpGE: - os << "icmpge"; - break; - case kFAdd: - os << "fadd"; - break; - case kFSub: - os << "fsub"; - break; - case kFMul: - os << "fmul"; - break; - case kFDiv: - os << "fdiv"; - break; - case kFRem: - os << "frem"; - break; - case kFCmpEQ: - os << "fcmpeq"; - break; - case kFCmpNE: - os << "fcmpne"; - break; - case kFCmpLT: - os << "fcmplt"; - break; - case kFCmpGT: - os << "fcmpgt"; - break; - case kFCmpLE: - os << "fcmple"; - break; - case kFCmpGE: - os << "fcmpge"; - break; - default: - assert(false); - } - os << ' '; - printOperand(os, getLhs()) << ", "; - printOperand(os, getRhs()) << " : " << *getType(); -} - -void ReturnInst::print(std::ostream &os) const { - os << "return"; - if (auto value = getReturnValue()) { - os << ' '; - printOperand(os, value) << " : " << *value->getType(); - } -} - -void UncondBrInst::print(std::ostream &os) const { - os << "br "; - printBlockName(os, getBlock()); - auto args = getArguments(); - auto b = args.begin(), e = args.end(); - if (b != e) { - os << '('; - printOperand(os, *b); - for (auto arg : make_range(std::next(b), e)) { - os << ", "; - printOperand(os, arg); + for (const auto &oldNewBlockItem : oldNewBlockMap) { + auto oldBlock = oldNewBlockItem.first; + auto newBlock = oldNewBlockItem.second; + for (const auto &oldPred : oldBlock->getPredecessors()) { + newBlock->addPredecessor(oldNewBlockMap.at(oldPred)); + } + for (const auto &oldSucc : oldBlock->getSuccessors()) { + newBlock->addSuccessor(oldNewBlockMap.at(oldSucc)); } - os << ')'; } -} -void CondBrInst::print(std::ostream &os) const { - os << "condbr "; - printOperand(os, getCondition()) << ", "; - printBlockName(os, getThenBlock()); - { - auto args = getThenArguments(); - auto b = args.begin(), e = args.end(); - if (b != e) { - os << '('; - printOperand(os, *b); - for (auto arg : make_range(std::next(b), e)) { - os << ", "; - printOperand(os, arg); + std::map oldNewValueMap; + std::map isAddedToCreate; + std::map isCreated; + std::queue toCreate; + + for (const auto &oldBlock : blocks) { + for (const auto &inst : oldBlock->getInstructions()) { + isAddedToCreate.emplace(inst.get(), false); + isCreated.emplace(inst.get(), false); + } + } + for (const auto &oldBlock : blocks) { + for (const auto &inst : oldBlock->getInstructions()) { + for (const auto &valueUse : inst->getOperands()) { + auto value = valueUse->getValue(); + if (oldNewValueMap.find(value) == oldNewValueMap.end()) { + auto oldAllocInst = dynamic_cast(value); + if (oldAllocInst != nullptr) { + std::vector dims; + for (const auto &dim : oldAllocInst->getDims()) { + dims.emplace_back(dim->getValue()); + } + ss << oldAllocInst->getName() << suffix; + auto newAllocInst = + new AllocaInst(oldAllocInst->getType(), dims, oldNewBlockMap.at(oldAllocInst->getParent()), ss.str()); + ss.str(""); + oldNewValueMap.emplace(oldAllocInst, newAllocInst); + if (isAddedToCreate.find(oldAllocInst) == isAddedToCreate.end()) { + isAddedToCreate.emplace(oldAllocInst, true); + } else { + isAddedToCreate.at(oldAllocInst) = true; + } + if (isCreated.find(oldAllocInst) == isCreated.end()) { + isCreated.emplace(oldAllocInst, true); + } else { + isCreated.at(oldAllocInst) = true; + } + } + } } - os << ')'; - } - } - os << ", "; - printBlockName(os, getElseBlock()); - { - auto args = getElseArguments(); - auto b = args.begin(), e = args.end(); - if (b != e) { - os << '('; - printOperand(os, *b); - for (auto arg : make_range(std::next(b), e)) { - os << ", "; - printOperand(os, arg); + if (inst->getKind() == Instruction::kAlloca) { + if (oldNewValueMap.find(inst.get()) == oldNewValueMap.end()) { + auto oldAllocInst = dynamic_cast(inst.get()); + std::vector dims; + for (const auto &dim : oldAllocInst->getDims()) { + dims.emplace_back(dim->getValue()); + } + ss << oldAllocInst->getName() << suffix; + auto newAllocInst = + new AllocaInst(oldAllocInst->getType(), dims, oldNewBlockMap.at(oldAllocInst->getParent()), ss.str()); + ss.str(""); + oldNewValueMap.emplace(oldAllocInst, newAllocInst); + if (isAddedToCreate.find(oldAllocInst) == isAddedToCreate.end()) { + isAddedToCreate.emplace(oldAllocInst, true); + } else { + isAddedToCreate.at(oldAllocInst) = true; + } + if (isCreated.find(oldAllocInst) == isCreated.end()) { + isCreated.emplace(oldAllocInst, true); + } else { + isCreated.at(oldAllocInst) = true; + } + } } - os << ')'; } } -} - -void AllocaInst::print(std::ostream &os) const { - if (getNumDims()) - cerr << "not implemented yet\n"; - printVarName(os, this) << " = "; - os << "alloca " - << *static_cast(getType())->getBaseType(); - os << " : " << *getType(); -} - -void LoadInst::print(std::ostream &os) const { - if (getNumIndices()) - cerr << "not implemented yet\n"; - printVarName(os, this) << " = "; - os << "load "; - printOperand(os, getPointer()) << " : " << *getType(); -} - -void StoreInst::print(std::ostream &os) const { - if (getNumIndices()) - cerr << "not implemented yet\n"; - os << "store "; - printOperand(os, getValue()) << ", "; - printOperand(os, getPointer()) << " : " << *getValue()->getType(); -} - -void Function::print(std::ostream &os) const { - auto returnType = getReturnType(); - auto paramTypes = getParamTypes(); - os << *returnType << ' '; - printFunctionName(os, this) << '('; - auto b = paramTypes.begin(), e = paramTypes.end(); - if (b != e) { - os << *(*b); - for (auto type : make_range(std::next(b), e)) - os << ", " << *type; + for (const auto &oldBlock : blocks) { + for (const auto &inst : oldBlock->getInstructions()) { + for (const auto &valueUse : inst->getOperands()) { + auto value = valueUse->getValue(); + if (oldNewValueMap.find(value) == oldNewValueMap.end()) { + auto globalValue = dynamic_cast(value); + auto constVariable = dynamic_cast(value); + auto constantValue = dynamic_cast(value); + auto functionValue = dynamic_cast(value); + if (globalValue != nullptr || constantValue != nullptr || constVariable != nullptr || + functionValue != nullptr) { + if (functionValue == this) { + oldNewValueMap.emplace(value, newFunction); + } else { + oldNewValueMap.emplace(value, value); + } + isCreated.emplace(value, true); + isAddedToCreate.emplace(value, true); + } + } + } + } } - os << ") {\n"; - for (auto &bb : getBasicBlocks()) { - os << *bb << '\n'; + for (const auto &oldBlock : blocks) { + for (const auto &inst : oldBlock->getInstructions()) { + if (inst->getKind() != Instruction::kAlloca) { + bool isReady = true; + for (const auto &use : inst->getOperands()) { + auto value = use->getValue(); + if (dynamic_cast(value) == nullptr && !isCreated.at(value)) { + isReady = false; + break; + } + } + if (isReady) { + toCreate.push(inst.get()); + isAddedToCreate.at(inst.get()) = true; + } + } + } } - os << "}"; + + while (!toCreate.empty()) { + auto inst = dynamic_cast(toCreate.front()); + toCreate.pop(); + + bool isReady = true; + for (const auto &valueUse : inst->getOperands()) { + auto value = dynamic_cast(valueUse->getValue()); + if (value != nullptr && !isCreated.at(value)) { + isReady = false; + break; + } + } + + if (!isReady) { + toCreate.push(inst); + continue; + } + isCreated.at(inst) = true; + switch (inst->getKind()) { + case Instruction::kAdd: + case Instruction::kSub: + case Instruction::kMul: + case Instruction::kDiv: + case Instruction::kRem: + case Instruction::kICmpEQ: + case Instruction::kICmpNE: + case Instruction::kICmpLT: + case Instruction::kICmpGT: + case Instruction::kICmpLE: + case Instruction::kICmpGE: + case Instruction::kAnd: + case Instruction::kOr: + case Instruction::kFAdd: + case Instruction::kFSub: + case Instruction::kFMul: + case Instruction::kFDiv: + case Instruction::kFCmpEQ: + case Instruction::kFCmpNE: + case Instruction::kFCmpLT: + case Instruction::kFCmpGT: + case Instruction::kFCmpLE: + case Instruction::kFCmpGE: { + auto oldBinaryInst = dynamic_cast(inst); + auto lhs = oldBinaryInst->getLhs(); + auto rhs = oldBinaryInst->getRhs(); + Value *newLhs; + Value *newRhs; + newLhs = oldNewValueMap[lhs]; + newRhs = oldNewValueMap[rhs]; + ss << oldBinaryInst->getName() << suffix; + auto newBinaryInst = new BinaryInst(oldBinaryInst->getKind(), oldBinaryInst->getType(), newLhs, newRhs, + oldNewBlockMap.at(oldBinaryInst->getParent()), ss.str()); + ss.str(""); + oldNewValueMap.emplace(oldBinaryInst, newBinaryInst); + break; + } + + case Instruction::kNeg: + case Instruction::kNot: + case Instruction::kFNeg: + case Instruction::kFNot: + case Instruction::kItoF: + case Instruction::kFtoI: { + auto oldUnaryInst = dynamic_cast(inst); + auto hs = oldUnaryInst->getOperand(); + Value *newHs; + newHs = oldNewValueMap.at(hs); + ss << oldUnaryInst->getName() << suffix; + auto newUnaryInst = new UnaryInst(oldUnaryInst->getKind(), oldUnaryInst->getType(), newHs, + oldNewBlockMap.at(oldUnaryInst->getParent()), ss.str()); + ss.str(""); + oldNewValueMap.emplace(oldUnaryInst, newUnaryInst); + break; + } + + case Instruction::kCall: { + auto oldCallInst = dynamic_cast(inst); + std::vector newArgumnts; + for (const auto &arg : oldCallInst->getArguments()) { + newArgumnts.emplace_back(oldNewValueMap.at(arg->getValue())); + } + + ss << oldCallInst->getName() << suffix; + CallInst *newCallInst; + newCallInst = + new CallInst(oldCallInst->getCallee(), newArgumnts, oldNewBlockMap.at(oldCallInst->getParent()), ss.str()); + ss.str(""); + // if (oldCallInst->getCallee() != this) { + // newCallInst = new CallInst(oldCallInst->getCallee(), newArgumnts, + // oldNewBlockMap.at(oldCallInst->getParent()), + // oldCallInst->getName()); + // } else { + // newCallInst = new CallInst(newFunction, newArgumnts, oldNewBlockMap.at(oldCallInst->getParent()), + // oldCallInst->getName()); + // } + + oldNewValueMap.emplace(oldCallInst, newCallInst); + break; + } + + case Instruction::kCondBr: { + auto oldCondBrInst = dynamic_cast(inst); + auto oldCond = oldCondBrInst->getCondition(); + Value *newCond; + newCond = oldNewValueMap.at(oldCond); + auto newCondBrInst = new CondBrInst(newCond, oldNewBlockMap.at(oldCondBrInst->getThenBlock()), + oldNewBlockMap.at(oldCondBrInst->getElseBlock()), {}, {}, + oldNewBlockMap.at(oldCondBrInst->getParent())); + oldNewValueMap.emplace(oldCondBrInst, newCondBrInst); + break; + } + + case Instruction::kBr: { + auto oldBrInst = dynamic_cast(inst); + auto newBrInst = + new UncondBrInst(oldNewBlockMap.at(oldBrInst->getBlock()), {}, oldNewBlockMap.at(oldBrInst->getParent())); + oldNewValueMap.emplace(oldBrInst, newBrInst); + break; + } + + case Instruction::kReturn: { + auto oldReturnInst = dynamic_cast(inst); + auto oldRval = oldReturnInst->getReturnValue(); + Value *newRval = nullptr; + if (oldRval != nullptr) { + newRval = oldNewValueMap.at(oldRval); + } + auto newReturnInst = + new ReturnInst(newRval, oldNewBlockMap.at(oldReturnInst->getParent()), oldReturnInst->getName()); + oldNewValueMap.emplace(oldReturnInst, newReturnInst); + break; + } + + case Instruction::kAlloca: { + assert(false); + } + + case Instruction::kLoad: { + auto oldLoadInst = dynamic_cast(inst); + auto oldPointer = oldLoadInst->getPointer(); + Value *newPointer; + newPointer = oldNewValueMap.at(oldPointer); + + std::vector newIndices; + for (const auto &index : oldLoadInst->getIndices()) { + newIndices.emplace_back(oldNewValueMap.at(index->getValue())); + } + ss << oldLoadInst->getName() << suffix; + auto newLoadInst = new LoadInst(newPointer, newIndices, oldNewBlockMap.at(oldLoadInst->getParent()), ss.str()); + ss.str(""); + oldNewValueMap.emplace(oldLoadInst, newLoadInst); + break; + } + + case Instruction::kStore: { + auto oldStoreInst = dynamic_cast(inst); + auto oldPointer = oldStoreInst->getPointer(); + auto oldValue = oldStoreInst->getValue(); + Value *newPointer; + Value *newValue; + std::vector newIndices; + newPointer = oldNewValueMap.at(oldPointer); + newValue = oldNewValueMap.at(oldValue); + for (const auto &index : oldStoreInst->getIndices()) { + newIndices.emplace_back(oldNewValueMap.at(index->getValue())); + } + auto newStoreInst = new StoreInst(newValue, newPointer, newIndices, + oldNewBlockMap.at(oldStoreInst->getParent()), oldStoreInst->getName()); + oldNewValueMap.emplace(oldStoreInst, newStoreInst); + break; + } + + case Instruction::kLa: { + auto oldLaInst = dynamic_cast(inst); + auto oldPointer = oldLaInst->getPointer(); + Value *newPointer; + std::vector newIndices; + newPointer = oldNewValueMap.at(oldPointer); + + for (const auto &index : oldLaInst->getIndices()) { + newIndices.emplace_back(oldNewValueMap.at(index->getValue())); + } + ss << oldLaInst->getName() << suffix; + auto newLaInst = new LaInst(newPointer, newIndices, oldNewBlockMap.at(oldLaInst->getParent()), ss.str()); + ss.str(""); + oldNewValueMap.emplace(oldLaInst, newLaInst); + break; + } + + case Instruction::kGetSubArray: { + auto oldGetSubArrayInst = dynamic_cast(inst); + auto oldFather = oldGetSubArrayInst->getFatherArray(); + auto oldChild = oldGetSubArrayInst->getChildArray(); + Value *newFather; + Value *newChild; + std::vector newIndices; + newFather = oldNewValueMap.at(oldFather); + newChild = oldNewValueMap.at(oldChild); + + for (const auto &index : oldGetSubArrayInst->getIndices()) { + newIndices.emplace_back(oldNewValueMap.at(index->getValue())); + } + ss << oldGetSubArrayInst->getName() << suffix; + auto newGetSubArrayInst = + new GetSubArrayInst(dynamic_cast(newFather), dynamic_cast(newChild), newIndices, + oldNewBlockMap.at(oldGetSubArrayInst->getParent()), ss.str()); + ss.str(""); + oldNewValueMap.emplace(oldGetSubArrayInst, newGetSubArrayInst); + break; + } + + case Instruction::kMemset: { + auto oldMemsetInst = dynamic_cast(inst); + auto oldPointer = oldMemsetInst->getPointer(); + auto oldValue = oldMemsetInst->getValue(); + Value *newPointer; + Value *newValue; + newPointer = oldNewValueMap.at(oldPointer); + newValue = oldNewValueMap.at(oldValue); + + auto newMemsetInst = new MemsetInst(newPointer, oldMemsetInst->getBegin(), oldMemsetInst->getSize(), newValue, + oldNewBlockMap.at(oldMemsetInst->getParent()), oldMemsetInst->getName()); + oldNewValueMap.emplace(oldMemsetInst, newMemsetInst); + break; + } + + case Instruction::kInvalid: + case Instruction::kPhi: { + break; + } + + default: + assert(false); + } + for (const auto &userUse : inst->getUses()) { + auto user = userUse->getUser(); + if (!isAddedToCreate.at(user)) { + toCreate.push(user); + isAddedToCreate.at(user) = true; + } + } + } + + for (const auto &oldBlock : blocks) { + auto newBlock = oldNewBlockMap.at(oldBlock.get()); + builder.setPosition(newBlock, newBlock->end()); + for (const auto &inst : oldBlock->getInstructions()) { + builder.insertInst(dynamic_cast(oldNewValueMap.at(inst.get()))); + } + } + + for (const auto ¶m : blocks.front()->getArguments()) { + newFunction->getEntryBlock()->insertArgument(dynamic_cast(oldNewValueMap.at(param))); + } + + return newFunction; } - -void Module::print(std::ostream &os) const { - for (auto &value : children) - os << *value << '\n'; -} - -// ArrayValue *ArrayValue::get(Type *type, const vector &values) { -// static map, unique_ptr> arrayConstants; -// hash hasher; -// auto key = make_pair( -// type, hasher(string(reinterpret_cast(values.data()), -// values.size() * sizeof(Value *)))); - -// auto iter = arrayConstants.find(key); -// if (iter != arrayConstants.end()) -// return iter->second.get(); -// auto constant = new ArrayValue(type, values); -// assert(constant); -// auto result = arrayConstants.emplace(key, constant); -// return result.first->second.get(); -// } - -// ArrayValue *ArrayValue::get(const std::vector &values) { -// vector vals(values.size(), nullptr); -// std::transform(values.begin(), values.end(), vals.begin(), -// [](int v) { return ConstantValue::get(v); }); -// return get(Type::getIntType(), vals); -// } - -// ArrayValue *ArrayValue::get(const std::vector &values) { -// vector vals(values.size(), nullptr); -// std::transform(values.begin(), values.end(), vals.begin(), -// [](float v) { return ConstantValue::get(v); }); -// return get(Type::getFloatType(), vals); -// } - -void User::setOperand(int index, Value *value) { +/** + * @brief 设置操作数 + * + * @param [in] index 所要设置的操作数的位置 + * @param [in] value 所要设置成的value + * @return 无返回值 + */ +void User::setOperand(unsigned index, Value *value) { assert(index < getNumOperands()); - operands[index].setValue(value); + operands[index]->setValue(value); + value->addUse(operands[index]); } - -void User::replaceOperand(int index, Value *value) { +/** + * @brief 替换操作数 + * + * @param [in] index 所要替换的操作数的位置 + * @param [in] value 所要替换成的value + * @return 无返回值 + */ +void User::replaceOperand(unsigned index, Value *value) { assert(index < getNumOperands()); auto &use = operands[index]; - use.getValue()->removeUse(&use); - use.setValue(value); + use->getValue()->removeUse(use); + use->setValue(value); + value->addUse(use); } -CallInst::CallInst(Function *callee, const std::vector &args, - BasicBlock *parent, const std::string &name) +CallInst::CallInst(Function *callee, const std::vector &args, BasicBlock *parent, const std::string &name) : Instruction(kCall, callee->getReturnType(), parent, name) { addOperand(callee); - for (auto arg : args) + for (auto arg : args) { addOperand(arg); + } +} +/** + * @brief 获取被调用函数的指针 + * + * @return 被调用函数的指针 + */ +Function * CallInst::getCallee() const { return dynamic_cast(getOperand(0)); } + +/** + * @brief 获取变量指针 + * + * @param [in] name 变量名字 + * @return 变量指针 + */ +auto SymbolTable::getVariable(const std::string &name) const -> User * { + auto node = curNode; + while (node != nullptr) { + auto iter = node->varList.find(name); + if (iter != node->varList.end()) { + return iter->second; + } + node = node->pNode; + } + + return nullptr; +} +/** + * @brief 添加变量 + * + * @param [in] name 变量名字 + * @param [in] variable 变量指针 + * @return 变量指针 + */ +auto SymbolTable::addVariable(const std::string &name, User *variable) -> User * { + User *result = nullptr; + if (curNode != nullptr) { + std::stringstream ss; + auto iter = variableIndex.find(name); + if (iter != variableIndex.end()) { + ss << name << "(" << iter->second << ")"; + iter->second += 1; + } else { + variableIndex.emplace(name, 1); + ss << name << "(" << 0 << ")"; + } + + variable->setName(ss.str()); + curNode->varList.emplace(name, variable); + auto global = dynamic_cast(variable); + auto constvar = dynamic_cast(variable); + if (global != nullptr) { + globals.emplace_back(global); + } else if (constvar != nullptr) { + consts.emplace_back(constvar); + } + + result = variable; + } + + return result; +} +/** + * @brief 获取全局变量 + * + * @return 全局变量列表 + */ +auto SymbolTable::getGlobals() -> std::vector> & { return globals; } +/** + * @brief 获取常量 + * + * @return 常量列表 + */ +auto SymbolTable::getConsts() const -> const std::vector> & { return consts; } +/** + * @brief 进入新的作用域 + * + * @return 无返回值 + */ +void SymbolTable::enterNewScope() { + auto newNode = new SymbolTableNode; + nodeList.emplace_back(newNode); + if (curNode != nullptr) { + curNode->children.emplace_back(newNode); + } + newNode->pNode = curNode; + curNode = newNode; +} +/** + * @brief 进入全局作用域 + * + * @return 无返回值 + */ +void SymbolTable::enterGlobalScope() { curNode = nodeList.front().get(); } +/** + * @brief 离开作用域 + * + * @return 无返回值 + */ +void SymbolTable::leaveScope() { curNode = curNode->pNode; } +/** + * @brief 是否位于全局作用域 + * + * @return 布尔值 + */ +auto SymbolTable::isInGlobalScope() const -> bool { return curNode->pNode == nullptr; } + +/** + * @brief 判断是否为循环不变量 + * @param value: 要判断的value + * @return true: 是不变量 + * @return false: 不是 + */ +auto Loop::isSimpleLoopInvariant(Value *value) -> bool { + // auto constValue = dynamic_cast(value); + // if (constValue != nullptr) { + // return false; + // } + if (auto instr = dynamic_cast(value)) { + if (instr->isLoad()) { + auto loadinst = dynamic_cast(instr); + + auto loadvalue = dynamic_cast(loadinst->getOperand(0)); + if (loadvalue != nullptr) { + if (loadvalue->getParent() != nullptr) { + auto basicblock = loadvalue->getParent(); + return !this->isLoopContainsBasicBlock(basicblock); + } + } + auto globalvalue = dynamic_cast(loadinst->getOperand(0)); + if (globalvalue != nullptr) { + return true; + } + auto basicblock = instr->getParent(); + + return !this->isLoopContainsBasicBlock(basicblock); + } + auto basicblock = instr->getParent(); + return !this->isLoopContainsBasicBlock(basicblock); + } + return true; } -Function *CallInst::getCallee() const { - return dyncast(getOperand(0)); +/** + * @brief 移动指令 + * + * @param [in] sourcePos 源指令列表位置 + * @param [in] targetPos 目的指令列表位置 + * @param [in] block 目标基本块 + * @return 无返回值 + */ +auto BasicBlock::moveInst(iterator sourcePos, iterator targetPos, BasicBlock *block) -> iterator { + auto inst = sourcePos->release(); + inst->setParent(block); + block->instructions.emplace(targetPos, inst); + return instructions.erase(sourcePos); } -} // namespace sysy \ No newline at end of file +} // namespace sysy diff --git a/src/IR.h b/src/IR.h deleted file mode 100644 index 184c7f4..0000000 --- a/src/IR.h +++ /dev/null @@ -1,994 +0,0 @@ -#pragma once - -#include "range.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace sysy { - -/*! - * \defgroup type Types - * The SysY type system is quite simple. - * 1. The base class `Type` is used to represent all primitive scalar types, - * include `int`, `float`, `void`, and the label type representing branch - * targets. - * 2. `PointerType` and `FunctionType` derive from `Type` and represent pointer - * type and function type, respectively. - * - * NOTE `Type` and its derived classes have their ctors declared as 'protected'. - * Users must use Type::getXXXType() methods to obtain `Type` pointers. - * @{ - */ - -/*! - * `Type` is used to represent all primitive scalar types, - * include `int`, `float`, `void`, and the label type representing branch - * targets - */ -class Type { -public: - enum Kind { - kInt, - kFloat, - kVoid, - kLabel, - kPointer, - kFunction, - }; - Kind kind; - -protected: - Type(Kind kind) : kind(kind) {} - virtual ~Type() = default; - -public: - static Type *getIntType(); - static Type *getFloatType(); - static Type *getVoidType(); - static Type *getLabelType(); - static Type *getPointerType(Type *baseType); - static Type *getFunctionType(Type *returnType, - const std::vector ¶mTypes = {}); - -public: - Kind getKind() const { return kind; } - bool isInt() const { return kind == kInt; } - bool isFloat() const { return kind == kFloat; } - bool isVoid() const { return kind == kVoid; } - bool isLabel() const { return kind == kLabel; } - bool isPointer() const { return kind == kPointer; } - bool isFunction() const { return kind == kFunction; } - bool isIntOrFloat() const { return kind == kInt or kind == kFloat; } - int getSize() const; - template - std::enable_if_t, T *> as() const { - return dynamic_cast(const_cast(this)); - } - void print(std::ostream &os) const; -}; // class Type - -//! Pointer type -class PointerType : public Type { -protected: - Type *baseType; - -protected: - PointerType(Type *baseType) : Type(kPointer), baseType(baseType) {} - -public: - static PointerType *get(Type *baseType); - -public: - Type *getBaseType() const { return baseType; } -}; // class PointerType - -//! Function type -class FunctionType : public Type { -private: - Type *returnType; - std::vector paramTypes; - -protected: - FunctionType(Type *returnType, const std::vector ¶mTypes = {}) - : Type(kFunction), returnType(returnType), paramTypes(paramTypes) {} - -public: - static FunctionType *get(Type *returnType, - const std::vector ¶mTypes = {}); - -public: - Type *getReturnType() const { return returnType; } - auto getParamTypes() const { return make_range(paramTypes); } - int getNumParams() const { return paramTypes.size(); } -}; // class FunctionType - -/*! - * @} - */ - -/*! - * \defgroup ir IR - * - * The SysY IR is an instruction level language. The IR is orgnized - * as a four-level tree structure, as shown below - * - * \dotfile ir-4level.dot IR Structure - * - * - `Module` corresponds to the top level "CompUnit" syntax structure - * - `GlobalValue` corresponds to the "Decl" syntax structure - * - `Function` corresponds to the "FuncDef" syntax structure - * - `BasicBlock` is a sequence of instructions without branching. A `Function` - * made up by one or more `BasicBlock`s. - * - `Instruction` represents a primitive operation on values, e.g., add or sub. - * - * The fundamental data concept in SysY IR is `Value`. A `Value` is like - * a register and is used by `Instruction`s as input/output operand. Each value - * has an associated `Type` indicating the data type held by the value. - * - * Most `Instruction`s have a three-address signature, i.e., there are at most 2 - * input values and at most 1 output value. - * - * The SysY IR adots a Static-Single-Assignment (SSA) design. That is, `Value` - * is defined (as the output operand ) by some instruction, and used (as the - * input operand) by other instructions. While a value can be used by multiple - * instructions, the `definition` occurs only once. As a result, there is a - * one-to-one relation between a value and the instruction defining it. In other - * words, any instruction defines a value can be viewed as the defined value - * itself. So `Instruction` is also a `Value` in SysY IR. See `Value` for the - * type hierachy. - * - * @{ - */ - -class User; -class Value; - -//! `Use` represents the relation between a `Value` and its `User` -class Use { -private: - //! the position of value in the user's operands, i.e., - //! user->getOperands[index] == value - int index; - User *user; - Value *value; - -public: - Use() = default; - Use(int index, User *user, Value *value) - : index(index), user(user), value(value) {} - -public: - int getIndex() const { return index; } - User *getUser() const { return user; } - Value *getValue() const { return value; } - void setValue(Value *value) { value = value; } -}; // class Use - -template -inline std::enable_if_t, bool> -isa(const Value *value) { - return T::classof(value); -} - -template -inline std::enable_if_t, T *> -dyncast(Value *value) { - return isa(value) ? static_cast(value) : nullptr; -} - -template -inline std::enable_if_t, const T *> -dyncast(const Value *value) { - return isa(value) ? static_cast(value) : nullptr; -} - -//! The base class of all value types -class Value { -public: - enum Kind : uint64_t { - kInvalid, - // Instructions - // Binary - kAdd = 0x1UL << 0, - kSub = 0x1UL << 1, - kMul = 0x1UL << 2, - kDiv = 0x1UL << 3, - kRem = 0x1UL << 4, - kICmpEQ = 0x1UL << 5, - kICmpNE = 0x1UL << 6, - kICmpLT = 0x1UL << 7, - kICmpGT = 0x1UL << 8, - kICmpLE = 0x1UL << 9, - kICmpGE = 0x1UL << 10, - kFAdd = 0x1UL << 14, - kFSub = 0x1UL << 15, - kFMul = 0x1UL << 16, - kFDiv = 0x1UL << 17, - kFRem = 0x1UL << 18, - kFCmpEQ = 0x1UL << 19, - kFCmpNE = 0x1UL << 20, - kFCmpLT = 0x1UL << 21, - kFCmpGT = 0x1UL << 22, - kFCmpLE = 0x1UL << 23, - kFCmpGE = 0x1UL << 24, - // Unary - kNeg = 0x1UL << 25, - kNot = 0x1UL << 26, - kFNeg = 0x1UL << 27, - kFtoI = 0x1UL << 28, - kIToF = 0x1UL << 29, - // call - kCall = 0x1UL << 30, - // terminator - kCondBr = 0x1UL << 31, - kBr = 0x1UL << 32, - kReturn = 0x1UL << 33, - // mem op - kAlloca = 0x1UL << 34, - kLoad = 0x1UL << 35, - kStore = 0x1UL << 36, - kFirstInst = kAdd, - kLastInst = kStore, - // others - kArgument = 0x1UL << 37, - kBasicBlock = 0x1UL << 38, - kFunction = 0x1UL << 39, - kConstant = 0x1UL << 40, - kGlobal = 0x1UL << 41, - }; - -protected: - Kind kind; - Type *type; - std::string name; - std::list uses; - -protected: - Value(Kind kind, Type *type, const std::string &name = "") - : kind(kind), type(type), name(name), uses() {} - -public: - virtual ~Value() = default; - -public: - Kind getKind() const { return kind; } - static bool classof(const Value *) { return true; } - -public: - Type *getType() const { return type; } - const std::string &getName() const { return name; } - void setName(const std::string &n) { name = n; } - bool hasName() const { return not name.empty(); } - bool isInt() const { return type->isInt(); } - bool isFloat() const { return type->isFloat(); } - bool isPointer() const { return type->isPointer(); } - const std::list &getUses() { return uses; } - void addUse(Use *use) { uses.push_back(use); } - void replaceAllUsesWith(Value *value); - void removeUse(Use *use) { uses.remove(use); } - bool isConstant() const; - -public: - virtual void print(std::ostream &os) const {}; -}; // class Value - -/*! - * Static constants known at compile time. - * - * `ConstantValue`s are not defined by instructions, and do not use any other - * `Value`s. It's type is either `int` or `float`. - */ -class ConstantValue : public Value { -protected: - union { - int iScalar; - float fScalar; - }; - -protected: - ConstantValue(int value) - : Value(kConstant, Type::getIntType(), ""), iScalar(value) {} - ConstantValue(float value) - : Value(kConstant, Type::getFloatType(), ""), fScalar(value) {} - -public: - static ConstantValue *get(int value); - static ConstantValue *get(float value); - -public: - static bool classof(const Value *value) { - return value->getKind() == kConstant; - } - -public: - int getInt() const { - assert(isInt()); - return iScalar; - } - float getFloat() const { - assert(isFloat()); - return fScalar; - } - -public: - void print(std::ostream &os) const override; -}; // class ConstantValue - -class BasicBlock; -/*! - * Arguments of `BasicBlock`s. - * - * SysY IR is an SSA language, however, it does not use PHI instructions as in - * LLVM IR. `Value`s from different predecessor blocks are passed explicitly as - * block arguments. This is also the approach used by MLIR. - * NOTE that `Function` does not own `Argument`s, function arguments are - * implemented as its entry block's arguments. - */ - -class Argument : public Value { -protected: - BasicBlock *block; - int index; - -public: - Argument(Type *type, BasicBlock *block, int index, - const std::string &name = ""); - -public: - static bool classof(const Value *value) { - return value->getKind() == kConstant; - } - -public: - BasicBlock *getParent() const { return block; } - int getIndex() const { return index; } - -public: - void print(std::ostream &os) const override; -}; - -class Instruction; -class Function; -/*! - * The container for `Instruction` sequence. - * - * `BasicBlock` maintains a list of `Instruction`s, with the last one being - * a terminator (branch or return). Besides, `BasicBlock` stores its arguments - * and records its predecessor and successor `BasicBlock`s. - */ -class BasicBlock : public Value { - friend class Function; - -public: - using inst_list = std::list>; - using iterator = inst_list::iterator; - using arg_list = std::vector>; - using block_list = std::vector; - -protected: - Function *parent; - inst_list instructions; - arg_list arguments; - block_list successors; - block_list predecessors; - -protected: - explicit BasicBlock(Function *parent, const std::string &name = ""); - -public: - static bool classof(const Value *value) { - return value->getKind() == kBasicBlock; - } - -public: - int getNumInstructions() const { return instructions.size(); } - int getNumArguments() const { return arguments.size(); } - int getNumPredecessors() const { return predecessors.size(); } - int getNumSuccessors() const { return successors.size(); } - Function *getParent() const { return parent; } - inst_list &getInstructions() { return instructions; } - auto getArguments() const { return make_range(arguments); } - block_list &getPredecessors() { return predecessors; } - block_list &getSuccessors() { return successors; } - iterator begin() { return instructions.begin(); } - iterator end() { return instructions.end(); } - iterator terminator() { return std::prev(end()); } - Argument *createArgument(Type *type, const std::string &name = "") { - auto arg = new Argument(type, this, arguments.size(), name); - assert(arg); - arguments.emplace_back(arg); - return arguments.back().get(); - }; - -public: - void print(std::ostream &os) const override; -}; // class BasicBlock - -//! User is the abstract base type of `Value` types which use other `Value` as -//! operands. Currently, there are two kinds of `User`s, `Instruction` and -//! `GlobalValue`. -class User : public Value { -protected: - std::vector operands; - -protected: - User(Kind kind, Type *type, const std::string &name = "") - : Value(kind, type, name), operands() {} - -public: - using use_iterator = std::vector::const_iterator; - struct operand_iterator : public std::vector::const_iterator { - using Base = std::vector::const_iterator; - operand_iterator(const Base &iter) : Base(iter) {} - using value_type = Value *; - value_type operator->() { return Base::operator*().getValue(); } - value_type operator*() { return Base::operator*().getValue(); } - }; - // struct const_operand_iterator : std::vector::const_iterator { - // using Base = std::vector::const_iterator; - // const_operand_iterator(const Base &iter) : Base(iter) {} - // using value_type = Value *; - // value_type operator->() { return operator*().getValue(); } - // }; - -public: - int getNumOperands() const { return operands.size(); } - operand_iterator operand_begin() const { return operands.begin(); } - operand_iterator operand_end() const { return operands.end(); } - auto getOperands() const { - return make_range(operand_begin(), operand_end()); - } - Value *getOperand(int index) const { return operands[index].getValue(); } - void addOperand(Value *value) { - operands.emplace_back(operands.size(), this, value); - value->addUse(&operands.back()); - } - template void addOperands(const ContainerT &operands) { - for (auto value : operands) - addOperand(value); - } - void replaceOperand(int index, Value *value); - void setOperand(int index, Value *value); -}; // class User - -/*! - * Base of all concrete instruction types. - */ -class Instruction : public User { -public: - // enum Kind : uint64_t { - // kInvalid = 0x0UL, - // // Binary - // kAdd = 0x1UL << 0, - // kSub = 0x1UL << 1, - // kMul = 0x1UL << 2, - // kDiv = 0x1UL << 3, - // kRem = 0x1UL << 4, - // kICmpEQ = 0x1UL << 5, - // kICmpNE = 0x1UL << 6, - // kICmpLT = 0x1UL << 7, - // kICmpGT = 0x1UL << 8, - // kICmpLE = 0x1UL << 9, - // kICmpGE = 0x1UL << 10, - // kFAdd = 0x1UL << 14, - // kFSub = 0x1UL << 15, - // kFMul = 0x1UL << 16, - // kFDiv = 0x1UL << 17, - // kFRem = 0x1UL << 18, - // kFCmpEQ = 0x1UL << 19, - // kFCmpNE = 0x1UL << 20, - // kFCmpLT = 0x1UL << 21, - // kFCmpGT = 0x1UL << 22, - // kFCmpLE = 0x1UL << 23, - // kFCmpGE = 0x1UL << 24, - // // Unary - // kNeg = 0x1UL << 25, - // kNot = 0x1UL << 26, - // kFNeg = 0x1UL << 27, - // kFtoI = 0x1UL << 28, - // kIToF = 0x1UL << 29, - // // call - // kCall = 0x1UL << 30, - // // terminator - // kCondBr = 0x1UL << 31, - // kBr = 0x1UL << 32, - // kReturn = 0x1UL << 33, - // // mem op - // kAlloca = 0x1UL << 34, - // kLoad = 0x1UL << 35, - // kStore = 0x1UL << 36, - // // constant - // // kConstant = 0x1UL << 37, - // }; - -protected: - Kind kind; - BasicBlock *parent; - -protected: - Instruction(Kind kind, Type *type, BasicBlock *parent = nullptr, - const std::string &name = ""); - -public: - static bool classof(const Value *value) { - return value->getKind() >= kFirstInst and value->getKind() <= kLastInst; - } - -public: - Kind getKind() const { return kind; } - BasicBlock *getParent() const { return parent; } - Function *getFunction() const { return parent->getParent(); } - void setParent(BasicBlock *bb) { parent = bb; } - - bool isBinary() const { - static constexpr uint64_t BinaryOpMask = - (kAdd | kSub | kMul | kDiv | kRem) | - (kICmpEQ | kICmpNE | kICmpLT | kICmpGT | kICmpLE | kICmpGE) | - (kFAdd | kFSub | kFMul | kFDiv | kFRem) | - (kFCmpEQ | kFCmpNE | kFCmpLT | kFCmpGT | kFCmpLE | kFCmpGE); - return kind & BinaryOpMask; - } - bool isUnary() const { - static constexpr uint64_t UnaryOpMask = kNeg | kNot | kFNeg | kFtoI | kIToF; - return kind & UnaryOpMask; - } - bool isMemory() const { - static constexpr uint64_t MemoryOpMask = kAlloca | kLoad | kStore; - return kind & MemoryOpMask; - } - bool isTerminator() const { - static constexpr uint64_t TerminatorOpMask = kCondBr | kBr | kReturn; - return kind & TerminatorOpMask; - } - bool isCmp() const { - static constexpr uint64_t CmpOpMask = - (kICmpEQ | kICmpNE | kICmpLT | kICmpGT | kICmpLE | kICmpGE) | - (kFCmpEQ | kFCmpNE | kFCmpLT | kFCmpGT | kFCmpLE | kFCmpGE); - return kind & CmpOpMask; - } - bool isBranch() const { - static constexpr uint64_t BranchOpMask = kBr | kCondBr; - return kind & BranchOpMask; - } - bool isCommutative() const { - static constexpr uint64_t CommutativeOpMask = - kAdd | kMul | kICmpEQ | kICmpNE | kFAdd | kFMul | kFCmpEQ | kFCmpNE; - return kind & CommutativeOpMask; - } - bool isUnconditional() const { return kind == kBr; } - bool isConditional() const { return kind == kCondBr; } -}; // class Instruction - -class Function; -//! Function call. -class CallInst : public Instruction { - friend class IRBuilder; - -protected: - CallInst(Function *callee, const std::vector &args = {}, - BasicBlock *parent = nullptr, const std::string &name = ""); - -public: - static bool classof(const Value *value) { return value->getKind() == kCall; } - -public: - Function *getCallee() const; - auto getArguments() const { - return make_range(std::next(operand_begin()), operand_end()); - } - -public: - void print(std::ostream &os) const override; -}; // class CallInst - -//! Unary instruction, includes '!', '-' and type conversion. -class UnaryInst : public Instruction { - friend class IRBuilder; - -protected: - UnaryInst(Kind kind, Type *type, Value *operand, BasicBlock *parent = nullptr, - const std::string &name = "") - : Instruction(kind, type, parent, name) { - addOperand(operand); - } - -public: - static bool classof(const Value *value) { - return Instruction::classof(value) and - static_cast(value)->isUnary(); - } - -public: - Value *getOperand() const { return User::getOperand(0); } - -public: - void print(std::ostream &os) const override; -}; // class UnaryInst - -//! Binary instruction, e.g., arithmatic, relation, logic, etc. -class BinaryInst : public Instruction { - friend class IRBuilder; - -protected: - BinaryInst(Kind kind, Type *type, Value *lhs, Value *rhs, BasicBlock *parent, - const std::string &name = "") - : Instruction(kind, type, parent, name) { - addOperand(lhs); - addOperand(rhs); - } - -public: - static bool classof(const Value *value) { - return Instruction::classof(value) and - static_cast(value)->isBinary(); - } - -public: - Value *getLhs() const { return getOperand(0); } - Value *getRhs() const { return getOperand(1); } - -public: - void print(std::ostream &os) const override; -}; // class BinaryInst - -//! The return statement -class ReturnInst : public Instruction { - friend class IRBuilder; - -protected: - ReturnInst(Value *value = nullptr, BasicBlock *parent = nullptr) - : Instruction(kReturn, Type::getVoidType(), parent, "") { - if (value) - addOperand(value); - } - -public: - static bool classof(const Value *value) { - return value->getKind() == kReturn; - } - -public: - bool hasReturnValue() const { return not operands.empty(); } - Value *getReturnValue() const { - return hasReturnValue() ? getOperand(0) : nullptr; - } - -public: - void print(std::ostream &os) const override; -}; // class ReturnInst - -//! Unconditional branch -class UncondBrInst : public Instruction { - friend class IRBuilder; - -protected: - UncondBrInst(BasicBlock *block, std::vector args, - BasicBlock *parent = nullptr) - : Instruction(kCondBr, Type::getVoidType(), parent, "") { - assert(block->getNumArguments() == args.size()); - addOperand(block); - addOperands(args); - } - -public: - static bool classof(const Value *value) { return value->getKind() == kBr; } - -public: - BasicBlock *getBlock() const { return dyncast(getOperand(0)); } - auto getArguments() const { - return make_range(std::next(operand_begin()), operand_end()); - } - -public: - void print(std::ostream &os) const override; -}; // class UncondBrInst - -//! Conditional branch -class CondBrInst : public Instruction { - friend class IRBuilder; - -protected: - CondBrInst(Value *condition, BasicBlock *thenBlock, BasicBlock *elseBlock, - const std::vector &thenArgs, - const std::vector &elseArgs, BasicBlock *parent = nullptr) - : Instruction(kCondBr, Type::getVoidType(), parent, "") { - assert(thenBlock->getNumArguments() == thenArgs.size() and - elseBlock->getNumArguments() == elseArgs.size()); - addOperand(condition); - addOperand(thenBlock); - addOperand(elseBlock); - addOperands(thenArgs); - addOperands(elseArgs); - } - -public: - static bool classof(const Value *value) { - return value->getKind() == kCondBr; - } - -public: - Value *getCondition() const { return getOperand(0); } - BasicBlock *getThenBlock() const { - return dyncast(getOperand(1)); - } - BasicBlock *getElseBlock() const { - return dyncast(getOperand(2)); - } - auto getThenArguments() const { - auto begin = std::next(operand_begin(), 3); - auto end = std::next(begin, getThenBlock()->getNumArguments()); - return make_range(begin, end); - } - auto getElseArguments() const { - auto begin = - std::next(operand_begin(), 3 + getThenBlock()->getNumArguments()); - auto end = operand_end(); - return make_range(begin, end); - } - -public: - void print(std::ostream &os) const override; -}; // class CondBrInst - -//! Allocate memory for stack variables, used for non-global variable declartion -class AllocaInst : public Instruction { - friend class IRBuilder; - -protected: - AllocaInst(Type *type, const std::vector &dims = {}, - BasicBlock *parent = nullptr, const std::string &name = "") - : Instruction(kAlloca, type, parent, name) { - addOperands(dims); - } - -public: - static bool classof(const Value *value) { - return value->getKind() == kAlloca; - } - -public: - int getNumDims() const { return getNumOperands(); } - auto getDims() const { return getOperands(); } - Value *getDim(int index) { return getOperand(index); } - -public: - void print(std::ostream &os) const override; -}; // class AllocaInst - -//! Load a value from memory address specified by a pointer value -class LoadInst : public Instruction { - friend class IRBuilder; - -protected: - LoadInst(Value *pointer, const std::vector &indices = {}, - BasicBlock *parent = nullptr, const std::string &name = "") - : Instruction(kLoad, pointer->getType()->as()->getBaseType(), - parent, name) { - addOperand(pointer); - addOperands(indices); - } - -public: - static bool classof(const Value *value) { return value->getKind() == kLoad; } - -public: - int getNumIndices() const { return getNumOperands() - 1; } - Value *getPointer() const { return getOperand(0); } - auto getIndices() const { - return make_range(std::next(operand_begin()), operand_end()); - } - Value *getIndex(int index) const { return getOperand(index + 1); } - -public: - void print(std::ostream &os) const override; -}; // class LoadInst - -//! Store a value to memory address specified by a pointer value -class StoreInst : public Instruction { - friend class IRBuilder; - -protected: - StoreInst(Value *value, Value *pointer, - const std::vector &indices = {}, - BasicBlock *parent = nullptr, const std::string &name = "") - : Instruction(kStore, Type::getVoidType(), parent, name) { - addOperand(value); - addOperand(pointer); - addOperands(indices); - } - -public: - static bool classof(const Value *value) { return value->getKind() == kStore; } - -public: - int getNumIndices() const { return getNumOperands() - 2; } - Value *getValue() const { return getOperand(0); } - Value *getPointer() const { return getOperand(1); } - auto getIndices() const { - return make_range(std::next(operand_begin(), 2), operand_end()); - } - Value *getIndex(int index) const { return getOperand(index + 2); } - -public: - void print(std::ostream &os) const override; -}; // class StoreInst - -class Module; -//! Function definition -class Function : public Value { - friend class Module; - -protected: - Function(Module *parent, Type *type, const std::string &name) - : Value(kFunction, type, name), parent(parent), variableID(0), blocks() { - blocks.emplace_back(new BasicBlock(this, "entry")); - } - -public: - static bool classof(const Value *value) { - return value->getKind() == kFunction; - } - -public: - using block_list = std::list>; - -protected: - Module *parent; - int variableID; - int blockID; - block_list blocks; - -public: - Type *getReturnType() const { - return getType()->as()->getReturnType(); - } - auto getParamTypes() const { - return getType()->as()->getParamTypes(); - } - auto getBasicBlocks() const { return make_range(blocks); } - BasicBlock *getEntryBlock() const { return blocks.front().get(); } - BasicBlock *addBasicBlock(const std::string &name = "") { - blocks.emplace_back(new BasicBlock(this, name)); - return blocks.back().get(); - } - void removeBasicBlock(BasicBlock *block) { - blocks.remove_if([&](std::unique_ptr &b) -> bool { - return block == b.get(); - }); - } - int allocateVariableID() { return variableID++; } - int allocateblockID() { return blockID++; } - -public: - void print(std::ostream &os) const override; -}; // class Function - -// class ArrayValue : public User { -// protected: -// ArrayValue(Type *type, const std::vector &values = {}) -// : User(type, "") { -// addOperands(values); -// } - -// public: -// static ArrayValue *get(Type *type, const std::vector &values); -// static ArrayValue *get(const std::vector &values); -// static ArrayValue *get(const std::vector &values); - -// public: -// auto getValues() const { return getOperands(); } - -// public: -// void print(std::ostream &os) const override{}; -// }; // class ConstantArray - -//! Global value declared at file scope -class GlobalValue : public User { - friend class Module; - -protected: - Module *parent; - bool hasInit; - bool isConst; - -protected: - GlobalValue(Module *parent, Type *type, const std::string &name, - const std::vector &dims = {}, Value *init = nullptr) - : User(kGlobal, type, name), parent(parent), hasInit(init) { - assert(type->isPointer()); - addOperands(dims); - if (init) - addOperand(init); - } - -public: - static bool classof(const Value *value) { - return value->getKind() == kGlobal; - } - -public: - Value *init() const { return hasInit ? operands.back().getValue() : nullptr; } - int getNumDims() const { return getNumOperands() - (hasInit ? 1 : 0); } - Value *getDim(int index) { return getOperand(index); } - -public: - void print(std::ostream &os) const override{}; -}; // class GlobalValue - -//! IR unit for representing a SysY compile unit -class Module { -protected: - std::vector> children; - std::map functions; - std::map globals; - -public: - Module() = default; - -public: - Function *createFunction(const std::string &name, Type *type) { - if (functions.count(name)) - return nullptr; - auto func = new Function(this, type, name); - assert(func); - children.emplace_back(func); - functions.emplace(name, func); - return func; - }; - GlobalValue *createGlobalValue(const std::string &name, Type *type, - const std::vector &dims = {}, - Value *init = nullptr) { - if (globals.count(name)) - return nullptr; - auto global = new GlobalValue(this, type, name, dims, init); - assert(global); - children.emplace_back(global); - globals.emplace(name, global); - return global; - } - Function *getFunction(const std::string &name) const { - auto result = functions.find(name); - if (result == functions.end()) - return nullptr; - return result->second; - } - GlobalValue *getGlobalValue(const std::string &name) const { - auto result = globals.find(name); - if (result == globals.end()) - return nullptr; - return result->second; - } - - std::map *getFunctions(){ - return &functions; - } - std::map *getGlobalValues(){ - return &globals; - } - -public: - void print(std::ostream &os) const; -}; // class Module - -/*! - * @} - */ -inline std::ostream &operator<<(std::ostream &os, const Type &type) { - type.print(os); - return os; -} - -inline std::ostream &operator<<(std::ostream &os, const Value &value) { - value.print(os); - return os; -} - -} // namespace sysy \ No newline at end of file diff --git a/src/IRBuilder.h b/src/IRBuilder.h deleted file mode 100644 index 60cb092..0000000 --- a/src/IRBuilder.h +++ /dev/null @@ -1,232 +0,0 @@ -#pragma once - -#include "IR.h" -#include -#include - -namespace sysy { - -class IRBuilder { -private: - BasicBlock *block; - BasicBlock::iterator position; - -public: - IRBuilder() = default; - IRBuilder(BasicBlock *block) : block(block), position(block->end()) {} - IRBuilder(BasicBlock *block, BasicBlock::iterator position) - : block(block), position(position) {} - -public: - BasicBlock *getBasicBlock() const { return block; } - BasicBlock::iterator getPosition() const { return position; } - void setPosition(BasicBlock *block, BasicBlock::iterator position) { - this->block = block; - this->position = position; - } - void setPosition(BasicBlock::iterator position) { this->position = position; } - -public: - CallInst *createCallInst(Function *callee, - const std::vector &args = {}, - const std::string &name = "") { - auto inst = new CallInst(callee, args, block, name); - assert(inst); - block->getInstructions().emplace(position, inst); - return inst; - } - UnaryInst *createUnaryInst(Instruction::Kind kind, Type *type, Value *operand, - const std::string &name = "") { - - auto inst = new UnaryInst(kind, type, operand, block, name); - assert(inst); - block->getInstructions().emplace(position, inst); - return inst; - } - UnaryInst *createNegInst(Value *operand, const std::string &name = "") { - return createUnaryInst(Instruction::kNeg, Type::getIntType(), operand, - name); - } - UnaryInst *createNotInst(Value *operand, const std::string &name = "") { - return createUnaryInst(Instruction::kNot, Type::getIntType(), operand, - name); - } - UnaryInst *createFtoIInst(Value *operand, const std::string &name = "") { - return createUnaryInst(Instruction::kFtoI, Type::getIntType(), operand, - name); - } - UnaryInst *createFNegInst(Value *operand, const std::string &name = "") { - return createUnaryInst(Instruction::kFNeg, Type::getFloatType(), operand, - name); - } - UnaryInst *createIToFInst(Value *operand, const std::string &name = "") { - return createUnaryInst(Instruction::kIToF, Type::getFloatType(), operand, - name); - } - BinaryInst *createBinaryInst(Instruction::Kind kind, Type *type, Value *lhs, - Value *rhs, const std::string &name = "") { - auto inst = new BinaryInst(kind, type, lhs, rhs, block, name); - assert(inst); - block->getInstructions().emplace(position, inst); - return inst; - } - BinaryInst *createAddInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kAdd, Type::getIntType(), lhs, rhs, - name); - } - BinaryInst *createSubInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kSub, Type::getIntType(), lhs, rhs, - name); - } - BinaryInst *createMulInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kMul, Type::getIntType(), lhs, rhs, - name); - } - BinaryInst *createDivInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kDiv, Type::getIntType(), lhs, rhs, - name); - } - BinaryInst *createRemInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kRem, Type::getIntType(), lhs, rhs, - name); - } - BinaryInst *createICmpEQInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kICmpEQ, Type::getIntType(), lhs, rhs, - name); - } - BinaryInst *createICmpNEInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kICmpNE, Type::getIntType(), lhs, rhs, - name); - } - BinaryInst *createICmpLTInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kICmpLT, Type::getIntType(), lhs, rhs, - name); - } - BinaryInst *createICmpLEInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kICmpLE, Type::getIntType(), lhs, rhs, - name); - } - BinaryInst *createICmpGTInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kICmpGT, Type::getIntType(), lhs, rhs, - name); - } - BinaryInst *createICmpGEInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kICmpGE, Type::getIntType(), lhs, rhs, - name); - } - BinaryInst *createFAddInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kFAdd, Type::getFloatType(), lhs, rhs, - name); - } - BinaryInst *createFSubInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kFSub, Type::getFloatType(), lhs, rhs, - name); - } - BinaryInst *createFMulInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kFMul, Type::getFloatType(), lhs, rhs, - name); - } - BinaryInst *createFDivInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kFDiv, Type::getFloatType(), lhs, rhs, - name); - } - BinaryInst *createFRemInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kFRem, Type::getFloatType(), lhs, rhs, - name); - } - BinaryInst *createFCmpEQInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kFCmpEQ, Type::getFloatType(), lhs, - rhs, name); - } - BinaryInst *createFCmpNEInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kFCmpNE, Type::getFloatType(), lhs, - rhs, name); - } - BinaryInst *createFCmpLTInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kFCmpLT, Type::getFloatType(), lhs, - rhs, name); - } - BinaryInst *createFCmpLEInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kFCmpLE, Type::getFloatType(), lhs, - rhs, name); - } - BinaryInst *createFCmpGTInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kFCmpGT, Type::getFloatType(), lhs, - rhs, name); - } - BinaryInst *createFCmpGEInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kFCmpGE, Type::getFloatType(), lhs, - rhs, name); - } - ReturnInst *createReturnInst(Value *value = nullptr) { - auto inst = new ReturnInst(value); - assert(inst); - block->getInstructions().emplace(position, inst); - return inst; - } - UncondBrInst *createUncondBrInst(BasicBlock *block, - std::vector args) { - auto inst = new UncondBrInst(block, args, block); - assert(inst); - block->getInstructions().emplace(position, inst); - return inst; - } - CondBrInst *createCondBrInst(Value *condition, BasicBlock *thenBlock, - BasicBlock *elseBlock, - const std::vector &thenArgs, - const std::vector &elseArgs) { - auto inst = new CondBrInst(condition, thenBlock, elseBlock, thenArgs, - elseArgs, block); - assert(inst); - block->getInstructions().emplace(position, inst); - return inst; - } - AllocaInst *createAllocaInst(Type *type, - const std::vector &dims = {}, - const std::string &name = "") { - auto inst = new AllocaInst(type, dims, block, name); - assert(inst); - block->getInstructions().emplace(position, inst); - return inst; - } - LoadInst *createLoadInst(Value *pointer, - const std::vector &indices = {}, - const std::string &name = "") { - auto inst = new LoadInst(pointer, indices, block, name); - assert(inst); - block->getInstructions().emplace(position, inst); - return inst; - } - StoreInst *createStoreInst(Value *value, Value *pointer, - const std::vector &indices = {}, - const std::string &name = "") { - auto inst = new StoreInst(value, pointer, indices, block, name); - assert(inst); - block->getInstructions().emplace(position, inst); - return inst; - } -}; - -} // namespace sysy \ No newline at end of file diff --git a/src/LLVMIRGenerator.cpp b/src/LLVMIRGenerator.cpp new file mode 100644 index 0000000..0d42ce5 --- /dev/null +++ b/src/LLVMIRGenerator.cpp @@ -0,0 +1,674 @@ +// LLVMIRGenerator.cpp +// TODO:类型转换及其检查 +// TODO:sysy库函数处理 +// TODO:数组处理 +// TODO:对while、continue、break的测试 +#include "LLVMIRGenerator.h" +#include +using namespace std; +namespace sysy { +std::string LLVMIRGenerator::generateIR(SysYParser::CompUnitContext* unit) { + // 初始化自定义IR数据结构 + irModule = std::make_unique(); + irBuilder = sysy::IRBuilder(); // 初始化IR构建器 + tempCounter = 0; + symbolTable.clear(); + tmpTable.clear(); + globalVars.clear(); + inFunction = false; + + visitCompUnit(unit); + return irStream.str(); +} + +std::string LLVMIRGenerator::getNextTemp() { + std::string ret = "%." + std::to_string(tempCounter++); + tmpTable[ret] = "void"; + return ret; +} + +std::string LLVMIRGenerator::getLLVMType(const std::string& type) { + if (type == "int") return "i32"; + if (type == "float") return "float"; + if (type.find("[]") != std::string::npos) + return getLLVMType(type.substr(0, type.size()-2)) + "*"; + return "i32"; +} + +sysy::Type* LLVMIRGenerator::getSysYType(const std::string& typeStr) { + if (typeStr == "int") return sysy::Type::getIntType(); + if (typeStr == "float") return sysy::Type::getFloatType(); + if (typeStr == "void") return sysy::Type::getVoidType(); + // 处理指针类型等 + return sysy::Type::getIntType(); +} + +std::any LLVMIRGenerator::visitCompUnit(SysYParser::CompUnitContext* ctx) { + auto type_i32 = Type::getIntType(); + auto type_f32 = Type::getFloatType(); + auto type_void = Type::getVoidType(); + auto type_i32p = Type::getPointerType(type_i32); + auto type_f32p = Type::getPointerType(type_f32); + + // 创建运行时库函数 + irModule->createFunction("getint", sysy::FunctionType::get(type_i32, {})); + irModule->createFunction("getch", sysy::FunctionType::get(type_i32, {})); + irModule->createFunction("getfloat", sysy::FunctionType::get(type_f32, {})); + //TODO: 添加更多运行时库函数 + irStream << "declare i32 @getint()\n"; + irStream << "declare i32 @getch()\n"; + irStream << "declare float @getfloat()\n"; + //TODO: 添加更多运行时库函数的文本IR + + for (auto decl : ctx->decl()) { + decl->accept(this); + } + for (auto funcDef : ctx->funcDef()) { + inFunction = true; // 进入函数定义 + funcDef->accept(this); + inFunction = false; // 离开函数定义 + } + return nullptr; +} + +std::any LLVMIRGenerator::visitVarDecl(SysYParser::VarDeclContext* ctx) { + // TODO:数组初始化 + std::string type = ctx->bType()->getText(); + currentVarType = getLLVMType(type); + + for (auto varDef : ctx->varDef()) { + if (!inFunction) { + // 全局变量声明 + std::string varName = varDef->Ident()->getText(); + std::string llvmType = getLLVMType(type); + std::string value = "0"; // 默认值为 0 + + if (varDef->ASSIGN()) { + value = std::any_cast(varDef->initVal()->accept(this)); + } else { + std::cout << "[WR-Release-01]Warning: Global variable '" << varName + << "' is declared without initialization, defaulting to 0.\n"; + } + irStream << "@" << varName << " = dso_local global " << llvmType << " " << value << ", align 4\n"; + globalVars.push_back(varName); // 记录全局变量 + } else { + // 局部变量声明 + varDef->accept(this); + } + } + return nullptr; +} + +std::any LLVMIRGenerator::visitConstDecl(SysYParser::ConstDeclContext* ctx) { + // TODO:数组初始化 + std::string type = ctx->bType()->getText(); + for (auto constDef : ctx->constDef()) { + if (!inFunction) { + // 全局常量声明 + std::string varName = constDef->Ident()->getText(); + std::string llvmType = getLLVMType(type); + std::string value = "0"; // 默认值为 0 + + try { + value = std::any_cast(constDef->constInitVal()->accept(this)); + } catch (...) { + throw std::runtime_error("[ERR-Release-01]Const value must be initialized upon definition."); + } + // 如果是 float 类型,转换为十六进制表示 + if (llvmType == "float") { + try { + double floatValue = std::stod(value); + uint64_t hexValue = reinterpret_cast(floatValue); + std::stringstream ss; + ss << "0x" << std::hex << std::uppercase << hexValue; + value = ss.str(); + } catch (...) { + throw std::runtime_error("[ERR-Release-02]Invalid float literal: " + value); + } + } + + irStream << "@" << varName << " = dso_local constant " << llvmType << " " << value << ", align 4\n"; + globalVars.push_back(varName); // 记录全局变量 + } else { + // 局部常量声明 + std::string varName = constDef->Ident()->getText(); + std::string llvmType = getLLVMType(type); + std::string allocaName = getNextTemp(); + std::string value = "0"; // 默认值为 0 + + try { + value = std::any_cast(constDef->constInitVal()->accept(this)); + } catch (...) { + throw std::runtime_error("Const value must be initialized upon definition."); + } + + irStream << " " << allocaName << " = alloca " << llvmType << ", align 4\n"; + if (llvmType == "float") { + try { + double floatValue = std::stod(value); + uint64_t hexValue = reinterpret_cast(floatValue); + std::stringstream ss; + ss << "0x" << std::hex << std::uppercase << hexValue; + value = ss.str(); + } catch (...) { + throw std::runtime_error("Invalid float literal: " + value); + } + } + irStream << " store " << llvmType << " " << value << ", " << llvmType + << "* " << allocaName << ", align 4\n"; + + symbolTable[varName] = {allocaName, llvmType}; + tmpTable[allocaName] = llvmType; + } + } + return nullptr; +} + +std::any LLVMIRGenerator::visitVarDef(SysYParser::VarDefContext* ctx) { + // TODO:数组初始化 + std::string varName = ctx->Ident()->getText(); + std::string type = currentVarType; + std::string llvmType = getLLVMType(type); + std::string allocaName = getNextTemp(); + + + irStream << " " << allocaName << " = alloca " << llvmType << ", align 4\n"; + + if (ctx->ASSIGN()) { + std::string value = std::any_cast(ctx->initVal()->accept(this)); + + if (llvmType == "float") { + try { + double floatValue = std::stod(value); + uint64_t hexValue = reinterpret_cast(floatValue); + std::stringstream ss; + ss << "0x" << std::hex << std::uppercase << (hexValue & (0xffffffffUL << 32)); + value = ss.str(); + } catch (...) { + throw std::runtime_error("Invalid float literal: " + value); + } + } + irStream << " store " << llvmType << " " << value << ", " << llvmType + << "* " << allocaName << ", align 4\n"; + } + symbolTable[varName] = {allocaName, llvmType}; + tmpTable[allocaName] = llvmType; + return nullptr; +} + +std::any LLVMIRGenerator::visitFuncDef(SysYParser::FuncDefContext* ctx) { + currentFunction = ctx->Ident()->getText(); + currentReturnType = getLLVMType(ctx->funcType()->getText()); + symbolTable.clear(); + tmpTable.clear(); + tempCounter = 0; + hasReturn = false; + + irStream << "define dso_local " << currentReturnType << " @" << currentFunction << "("; + if (ctx->funcFParams()) { + auto params = ctx->funcFParams()->funcFParam(); + tempCounter += params.size(); + for (size_t i = 0; i < params.size(); ++i) { + if (i > 0) irStream << ", "; + std::string paramType = getLLVMType(params[i]->bType()->getText()); + irStream << paramType << " noundef %" << i; + symbolTable[params[i]->Ident()->getText()] = {"%" + std::to_string(i), paramType}; + tmpTable["%" + std::to_string(i)] = paramType; + } + } + tempCounter++; + irStream << ") #0 {\n"; + + if (ctx->funcFParams()) { + auto params = ctx->funcFParams()->funcFParam(); + for (size_t i = 0; i < params.size(); ++i) { + std::string varName = params[i]->Ident()->getText(); + std::string type = params[i]->bType()->getText(); + std::string llvmType = getLLVMType(type); + std::string allocaName = getNextTemp(); + tmpTable[allocaName] = llvmType; + + irStream << " " << allocaName << " = alloca " << llvmType << ", align 4\n"; + irStream << " store " << llvmType << " " << symbolTable[varName].first << ", " << llvmType + << "* " << allocaName << ", align 4\n"; + + symbolTable[varName] = {allocaName, llvmType}; + } + } + ctx->blockStmt()->accept(this); + + if (!hasReturn) { + if (currentReturnType == "void") { + irStream << " ret void\n"; + } else { + irStream << " ret " << currentReturnType << " 0\n"; + } + } + irStream << "}\n"; + return nullptr; +} + +std::any LLVMIRGenerator::visitBlockStmt(SysYParser::BlockStmtContext* ctx) { + for (auto item : ctx->blockItem()) { + item->accept(this); + } + return nullptr; +} +std::any LLVMIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) +{ + std::string lhsAlloca = std::any_cast(ctx->lValue()->accept(this)); + std::string lhsType = symbolTable[ctx->lValue()->Ident()->getText()].second; + std::string rhs = std::any_cast(ctx->exp()->accept(this)); + + if (lhsType == "float") { + try { + double floatValue = std::stod(rhs); + uint64_t hexValue = reinterpret_cast(floatValue); + std::stringstream ss; + ss << "0x" << std::hex << std::uppercase << (hexValue & (0xffffffffUL << 32)); + rhs = ss.str(); + } catch (...) { + throw std::runtime_error("Invalid float literal: " + rhs); + } + } + + irStream << " store " << lhsType << " " << rhs << ", " << lhsType + << "* " << lhsAlloca << ", align 4\n"; + return nullptr; +} + +std::any LLVMIRGenerator::visitIfStmt(SysYParser::IfStmtContext *ctx) +{ + std::string cond = std::any_cast(ctx->cond()->accept(this)); + std::string trueLabel = "if.then." + std::to_string(tempCounter); + std::string falseLabel = "if.else." + std::to_string(tempCounter); + std::string mergeLabel = "if.end." + std::to_string(tempCounter++); + + irStream << " br i1 " << cond << ", label %" << trueLabel << ", label %" << falseLabel << "\n"; + + irStream << trueLabel << ":\n"; + ctx->stmt(0)->accept(this); + irStream << " br label %" << mergeLabel << "\n"; + + irStream << falseLabel << ":\n"; + if (ctx->ELSE()) { + ctx->stmt(1)->accept(this); + } + irStream << " br label %" << mergeLabel << "\n"; + + irStream << mergeLabel << ":\n"; + return nullptr; +} + +std::any LLVMIRGenerator::visitWhileStmt(SysYParser::WhileStmtContext *ctx) +{ + std::string loop_cond = "while.cond." + std::to_string(tempCounter); + std::string loop_body = "while.body." + std::to_string(tempCounter); + std::string loop_end = "while.end." + std::to_string(tempCounter++); + + loopStack.push({loop_end, loop_cond}); + irStream << " br label %" << loop_cond << "\n"; + irStream << loop_cond << ":\n"; + + std::string cond = std::any_cast(ctx->cond()->accept(this)); + irStream << " br i1 " << cond << ", label %" << loop_body << ", label %" << loop_end << "\n"; + irStream << loop_body << ":\n"; + ctx->stmt()->accept(this); + irStream << " br label %" << loop_cond << "\n"; + irStream << loop_end << ":\n"; + + loopStack.pop(); + return nullptr; +} + +std::any LLVMIRGenerator::visitBreakStmt(SysYParser::BreakStmtContext *ctx) +{ + if (loopStack.empty()) { + throw std::runtime_error("Break statement outside of a loop."); + } + irStream << " br label %" << loopStack.top().breakLabel << "\n"; + return nullptr; +} + +std::any LLVMIRGenerator::visitContinueStmt(SysYParser::ContinueStmtContext *ctx) +{ + if (loopStack.empty()) { + throw std::runtime_error("Continue statement outside of a loop."); + } + irStream << " br label %" << loopStack.top().continueLabel << "\n"; + return nullptr; +} + +std::any LLVMIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext *ctx) +{ + hasReturn = true; + if (ctx->exp()) { + std::string value = std::any_cast(ctx->exp()->accept(this)); + irStream << " ret " << currentReturnType << " " << value << "\n"; + } else { + irStream << " ret void\n"; + } + return nullptr; +} + +// std::any LLVMIRGenerator::visitStmt(SysYParser::StmtContext* ctx) { +// if (ctx->lValue() && ctx->ASSIGN()) { +// std::string lhsAlloca = std::any_cast(ctx->lValue()->accept(this)); +// std::string lhsType = symbolTable[ctx->lValue()->Ident()->getText()].second; +// std::string rhs = std::any_cast(ctx->exp()->accept(this)); +// if (lhsType == "float") { +// try { +// double floatValue = std::stod(rhs); +// uint64_t hexValue = reinterpret_cast(floatValue); +// std::stringstream ss; +// ss << "0x" << std::hex << std::uppercase << (hexValue & (0xffffffffUL << 32)); +// rhs = ss.str(); +// } catch (...) { +// throw std::runtime_error("Invalid float literal: " + rhs); +// } +// } +// irStream << " store " << lhsType << " " << rhs << ", " << lhsType +// << "* " << lhsAlloca << ", align 4\n"; +// } else if (ctx->RETURN()) { +// hasReturn = true; +// if (ctx->exp()) { +// std::string value = std::any_cast(ctx->exp()->accept(this)); +// irStream << " ret " << currentReturnType << " " << value << "\n"; +// } else { +// irStream << " ret void\n"; +// } +// } else if (ctx->IF()) { +// std::string cond = std::any_cast(ctx->cond()->accept(this)); +// std::string trueLabel = "if.then." + std::to_string(tempCounter); +// std::string falseLabel = "if.else." + std::to_string(tempCounter); +// std::string mergeLabel = "if.end." + std::to_string(tempCounter++); + +// irStream << " br i1 " << cond << ", label %" << trueLabel << ", label %" << falseLabel << "\n"; + +// irStream << trueLabel << ":\n"; +// ctx->stmt(0)->accept(this); +// irStream << " br label %" << mergeLabel << "\n"; + +// irStream << falseLabel << ":\n"; +// if (ctx->ELSE()) { +// ctx->stmt(1)->accept(this); +// } +// irStream << " br label %" << mergeLabel << "\n"; + +// irStream << mergeLabel << ":\n"; +// } else if (ctx->WHILE()) { +// std::string loop_cond = "while.cond." + std::to_string(tempCounter); +// std::string loop_body = "while.body." + std::to_string(tempCounter); +// std::string loop_end = "while.end." + std::to_string(tempCounter++); + +// loopStack.push({loop_end, loop_cond}); +// irStream << " br label %" << loop_cond << "\n"; +// irStream << loop_cond << ":\n"; + +// std::string cond = std::any_cast(ctx->cond()->accept(this)); +// irStream << " br i1 " << cond << ", label %" << loop_body << ", label %" << loop_end << "\n"; +// irStream << loop_body << ":\n"; +// ctx->stmt(0)->accept(this); +// irStream << " br label %" << loop_cond << "\n"; +// irStream << loop_end << ":\n"; + +// loopStack.pop(); + +// } else if (ctx->BREAK()) { +// if (loopStack.empty()) { +// throw std::runtime_error("Break statement outside of a loop."); +// } +// irStream << " br label %" << loopStack.top().breakLabel << "\n"; +// } else if (ctx->CONTINUE()) { +// if (loopStack.empty()) { +// throw std::runtime_error("Continue statement outside of a loop."); +// } +// irStream << " br label %" << loopStack.top().continueLabel << "\n"; +// } else if (ctx->blockStmt()) { +// ctx->blockStmt()->accept(this); +// } else if (ctx->exp()) { +// ctx->exp()->accept(this); +// } +// return nullptr; +// } + +std::any LLVMIRGenerator::visitLValue(SysYParser::LValueContext* ctx) { + std::string varName = ctx->Ident()->getText(); + return symbolTable[varName].first; +} + +// std::any LLVMIRGenerator::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { +// if (ctx->lValue()) { +// std::string allocaPtr = std::any_cast(ctx->lValue()->accept(this)); +// std::string varName = ctx->lValue()->Ident()->getText(); +// std::string type = symbolTable[varName].second; +// std::string temp = getNextTemp(); +// irStream << " " << temp << " = load " << type << ", " << type << "* " << allocaPtr << ", align 4\n"; +// tmpTable[temp] = type; +// return temp; +// } else if (ctx->exp()) { +// return ctx->exp()->accept(this); +// } else { +// return ctx->number()->accept(this); +// } +// } + + +std::any LLVMIRGenerator::visitPrimExp(SysYParser::PrimExpContext *ctx){ + // irStream << "visitPrimExp\n"; + // std::cout << "Type name: " << typeid(*(ctx->primaryExp())).name() << std::endl; + SysYParser::PrimaryExpContext* pExpCtx = ctx->primaryExp(); + if (auto* lvalCtx = dynamic_cast(pExpCtx)) { + std::string allocaPtr = std::any_cast(lvalCtx->lValue()->accept(this)); + std::string varName = lvalCtx->lValue()->Ident()->getText(); + std::string type = symbolTable[varName].second; + std::string temp = getNextTemp(); + irStream << " " << temp << " = load " << type << ", " << type << "* " << allocaPtr << ", align 4\n"; + tmpTable[temp] = type; + return temp; + } else if (auto* expCtx = dynamic_cast(pExpCtx)) { + return expCtx->exp()->accept(this); + } else if (auto* strCtx = dynamic_cast(pExpCtx)) { + return strCtx->string()->accept(this); + } else if (auto* numCtx = dynamic_cast(pExpCtx)) { + return numCtx->number()->accept(this); + } else { + // 没有成功转换,说明 ctx->primaryExp() 不是 NumContext 或其他已知类型 + // 可能是其他类型的表达式,或者是一个空的 PrimaryExpContext + std::cout << "Unknown primary expression type." << std::endl; + throw std::runtime_error("Unknown primary expression type."); + } + // return visitChildren(ctx); +} + +std::any LLVMIRGenerator::visitParenExp(SysYParser::ParenExpContext* ctx) { + return ctx->exp()->accept(this); +} + +std::any LLVMIRGenerator::visitNumber(SysYParser::NumberContext* ctx) { + if (ctx->ILITERAL()) { + return ctx->ILITERAL()->getText(); + } else if (ctx->FLITERAL()) { + return ctx->FLITERAL()->getText(); + } + return ""; +} + +std::any LLVMIRGenerator::visitString(SysYParser::StringContext *ctx) +{ + if (ctx->STRING()) { + // 处理字符串常量 + std::string str = ctx->STRING()->getText(); + // 去掉引号 + str = str.substr(1, str.size() - 2); + // 转义处理 + std::string escapedStr; + for (char c : str) { + if (c == '\\') { + escapedStr += "\\\\"; + } else if (c == '"') { + escapedStr += "\\\""; + } else { + escapedStr += c; + } + } + return "\"" + escapedStr + "\""; + } + return ctx->STRING()->getText(); +} + +std::any LLVMIRGenerator::visitUnExp(SysYParser::UnExpContext* ctx) { + if (ctx->unaryOp()) { + std::string operand = std::any_cast(ctx->unaryExp()->accept(this)); + std::string op = ctx->unaryOp()->getText(); + std::string temp = getNextTemp(); + std::string type = operand.substr(0, operand.find(' ')); + tmpTable[temp] = type; + if (op == "-") { + irStream << " " << temp << " = sub " << type << " 0, " << operand << "\n"; + } else if (op == "!") { + irStream << " " << temp << " = xor " << type << " " << operand << ", 1\n"; + } + return temp; + } + return ctx->unaryExp()->accept(this); +} + +std::any LLVMIRGenerator::visitCall(SysYParser::CallContext *ctx) +{ + std::string funcName = ctx->Ident()->getText(); + std::vector args; + if (ctx->funcRParams()) { + for (auto argCtx : ctx->funcRParams()->exp()) { + args.push_back(std::any_cast(argCtx->accept(this))); + } + } + std::string temp = getNextTemp(); + std::string argList = ""; + for (size_t i = 0; i < args.size(); ++i) { + if (i > 0) argList += ", "; + argList +=tmpTable[args[i]] + " noundef " + args[i]; + } + irStream << " " << temp << " = call " << currentReturnType << " @" << funcName << "(" << argList << ")\n"; + tmpTable[temp] = currentReturnType; + return temp; +} + +std::any LLVMIRGenerator::visitMulExp(SysYParser::MulExpContext* ctx) { + auto unaryExps = ctx->unaryExp(); + std::string left = std::any_cast(unaryExps[0]->accept(this)); + for (size_t i = 1; i < unaryExps.size(); ++i) { + std::string right = std::any_cast(unaryExps[i]->accept(this)); + std::string op = ctx->children[2*i-1]->getText(); + std::string temp = getNextTemp(); + std::string type = tmpTable[left]; + if (op == "*") { + irStream << " " << temp << " = mul nsw " << type << " " << left << ", " << right << "\n"; + } else if (op == "/") { + irStream << " " << temp << " = sdiv " << type << " " << left << ", " << right << "\n"; + } else if (op == "%") { + irStream << " " << temp << " = srem " << type << " " << left << ", " << right << "\n"; + } + left = temp; + tmpTable[temp] = type; + } + return left; +} + +std::any LLVMIRGenerator::visitAddExp(SysYParser::AddExpContext* ctx) { + auto mulExps = ctx->mulExp(); + std::string left = std::any_cast(mulExps[0]->accept(this)); + for (size_t i = 1; i < mulExps.size(); ++i) { + std::string right = std::any_cast(mulExps[i]->accept(this)); + std::string op = ctx->children[2*i-1]->getText(); + std::string temp = getNextTemp(); + std::string type = tmpTable[left]; + if (op == "+") { + irStream << " " << temp << " = add nsw " << type << " " << left << ", " << right << "\n"; + } else if (op == "-") { + irStream << " " << temp << " = sub nsw " << type << " " << left << ", " << right << "\n"; + } + left = temp; + tmpTable[temp] = type; + } + return left; +} + +std::any LLVMIRGenerator::visitRelExp(SysYParser::RelExpContext* ctx) { + auto addExps = ctx->addExp(); + std::string left = std::any_cast(addExps[0]->accept(this)); + for (size_t i = 1; i < addExps.size(); ++i) { + std::string right = std::any_cast(addExps[i]->accept(this)); + std::string op = ctx->children[2*i-1]->getText(); + std::string temp = getNextTemp(); + std::string type = tmpTable[left]; + if (op == "<") { + irStream << " " << temp << " = icmp slt " << type << " " << left << ", " << right << "\n"; + } else if (op == ">") { + irStream << " " << temp << " = icmp sgt " << type << " " << left << ", " << right << "\n"; + } else if (op == "<=") { + irStream << " " << temp << " = icmp sle " << type << " " << left << ", " << right << "\n"; + } else if (op == ">=") { + irStream << " " << temp << " = icmp sge " << type << " " << left << ", " << right << "\n"; + } + left = temp; + } + return left; +} + +std::any LLVMIRGenerator::visitEqExp(SysYParser::EqExpContext* ctx) { + auto relExps = ctx->relExp(); + std::string left = std::any_cast(relExps[0]->accept(this)); + for (size_t i = 1; i < relExps.size(); ++i) { + std::string right = std::any_cast(relExps[i]->accept(this)); + std::string op = ctx->children[2*i-1]->getText(); + std::string temp = getNextTemp(); + std::string type = tmpTable[left]; + if (op == "==") { + irStream << " " << temp << " = icmp eq " << type << " " << left << ", " << right << "\n"; + } else if (op == "!=") { + irStream << " " << temp << " = icmp ne " << type << " " << left << ", " << right << "\n"; + } + left = temp; + } + return left; +} + +std::any LLVMIRGenerator::visitLAndExp(SysYParser::LAndExpContext* ctx) { + auto eqExps = ctx->eqExp(); + std::string left = std::any_cast(eqExps[0]->accept(this)); + for (size_t i = 1; i < eqExps.size(); ++i) { + std::string falseLabel = "land.false." + std::to_string(tempCounter); + std::string endLabel = "land.end." + std::to_string(tempCounter++); + std::string temp = getNextTemp(); + + irStream << " br label %" << falseLabel << "\n"; + irStream << falseLabel << ":\n"; + std::string right = std::any_cast(eqExps[i]->accept(this)); + irStream << " " << temp << " = and i1 " << left << ", " << right << "\n"; + irStream << " br label %" << endLabel << "\n"; + irStream << endLabel << ":\n"; + left = temp; + } + return left; +} + +std::any LLVMIRGenerator::visitLOrExp(SysYParser::LOrExpContext* ctx) { + auto lAndExps = ctx->lAndExp(); + std::string left = std::any_cast(lAndExps[0]->accept(this)); + for (size_t i = 1; i < lAndExps.size(); ++i) { + std::string trueLabel = "lor.true." + std::to_string(tempCounter); + std::string endLabel = "lor.end." + std::to_string(tempCounter++); + std::string temp = getNextTemp(); + + irStream << " br label %" << trueLabel << "\n"; + irStream << trueLabel << ":\n"; + std::string right = std::any_cast(lAndExps[i]->accept(this)); + irStream << " " << temp << " = or i1 " << left << ", " << right << "\n"; + irStream << " br label %" << endLabel << "\n"; + irStream << endLabel << ":\n"; + left = temp; + } + return left; +} +} \ No newline at end of file diff --git a/src/LLVMIRGenerator_1.cpp b/src/LLVMIRGenerator_1.cpp new file mode 100644 index 0000000..515b5a2 --- /dev/null +++ b/src/LLVMIRGenerator_1.cpp @@ -0,0 +1,859 @@ +// LLVMIRGenerator.cpp +// TODO:类型转换及其检查 +// TODO:sysy库函数处理 +// TODO:数组处理 +// TODO:对while、continue、break的测试 +#include "LLVMIRGenerator_1.h" +#include +#include +#include + +// namespace sysy { + +std::string LLVMIRGenerator::generateIR(SysYParser::CompUnitContext* unit) { + // 初始化 SysY IR 模块 + module = std::make_unique(); + // 清空符号表和临时变量表 + symbolTable.clear(); + tmpTable.clear(); + irSymbolTable.clear(); + irTmpTable.clear(); + tempCounter = 0; + globalVars.clear(); + hasReturn = false; + loopStack = std::stack(); + inFunction = false; + + // 访问编译单元 + visitCompUnit(unit); + return irStream.str(); +} + +std::string LLVMIRGenerator::getNextTemp() { + std::string ret = "%." + std::to_string(tempCounter++); + tmpTable[ret] = "void"; + return ret; +} + +std::string LLVMIRGenerator::getIRTempName() { + return "%" + std::to_string(tempCounter++); +} + +std::string LLVMIRGenerator::getLLVMType(const std::string& type) { + if (type == "int") return "i32"; + if (type == "float") return "float"; + if (type.find("[]") != std::string::npos) + return getLLVMType(type.substr(0, type.size() - 2)) + "*"; + return "i32"; +} + +sysy::Type* LLVMIRGenerator::getIRType(const std::string& type) { + if (type == "int") return sysy::Type::getIntType(); + if (type == "float") return sysy::Type::getFloatType(); + if (type == "void") return sysy::Type::getVoidType(); + if (type.find("[]") != std::string::npos) { + std::string baseType = type.substr(0, type.size() - 2); + return sysy::Type::getPointerType(getIRType(baseType)); + } + return sysy::Type::getIntType(); // 默认 int +} + +void LLVMIRGenerator::setIRPosition(sysy::BasicBlock* block) { + currentIRBlock = block; +} + +std::any LLVMIRGenerator::visitCompUnit(SysYParser::CompUnitContext* ctx) { + for (auto decl : ctx->decl()) { + decl->accept(this); + } + for (auto funcDef : ctx->funcDef()) { + inFunction = true; + funcDef->accept(this); + inFunction = false; + } + return nullptr; +} + + +std::any LLVMIRGenerator::visitVarDecl(SysYParser::VarDeclContext* ctx) { + // TODO:数组初始化 + std::string type = ctx->bType()->getText(); + currentVarType = getLLVMType(type); + sysy::Type* irType = sysy::Type::getPointerType(getIRType(type)); + + for (auto varDef : ctx->varDef()) { + if (!inFunction) { + // 全局变量(文本 IR) + std::string varName = varDef->Ident()->getText(); + std::string llvmType = getLLVMType(type); + std::string value = "0"; + sysy::Value* initValue = nullptr; + + if (varDef->ASSIGN()) { + value = std::any_cast(varDef->initVal()->accept(this)); + if (irTmpTable.find(value) != irTmpTable.end() && isa(irTmpTable[value])) { + initValue = irTmpTable[value]; + } + } + + if (llvmType == "float" && initValue) { + try { + double floatValue = std::stod(value); + uint64_t hexValue = reinterpret_cast(floatValue); + std::stringstream ss; + ss << "0x" << std::hex << std::uppercase << hexValue; + value = ss.str(); + } catch (...) { + throw std::runtime_error("[ERR-Release-02]Invalid float literal: " + value); + } + } + irStream << "@" << varName << " = dso_local global " << llvmType << " " << value << ", align 4\n"; + globalVars.push_back(varName); + + // 全局变量(SysY IR) + auto globalValue = module->createGlobalValue(varName, irType, {}, initValue); + irSymbolTable[varName] = globalValue; + } else { + varDef->accept(this); + } + } + return nullptr; +} + +std::any LLVMIRGenerator::visitConstDecl(SysYParser::ConstDeclContext* ctx) { + // TODO:数组初始化 + std::string type = ctx->bType()->getText(); + currentVarType = getLLVMType(type); + sysy::Type* irType = sysy::Type::getPointerType(getIRType(type)); // 全局变量为指针类型 + + for (auto constDef : ctx->constDef()) { + std::string varName = constDef->Ident()->getText(); + std::string llvmType = getLLVMType(type); + std::string value = "0"; + sysy::Value* initValue = nullptr; + + try { + value = std::any_cast(constDef->constInitVal()->accept(this)); + if (isa(irTmpTable[value])) { + initValue = irTmpTable[value]; + } + } catch (...) { + throw std::runtime_error("Const value must be initialized upon definition."); + } + + if (!inFunction) { + // 全局常量(文本 IR) + if (llvmType == "float") { + try { + double floatValue = std::stod(value); + uint64_t hexValue = reinterpret_cast(floatValue); + std::stringstream ss; + ss << "0x" << std::hex << std::uppercase << hexValue; + value = ss.str(); + } catch (...) { + throw std::runtime_error("[ERR-Release-03]Invalid float literal: " + value); + } + } + irStream << "@" << varName << " = dso_local constant " << llvmType << " " << value << ", align 4\n"; + globalVars.push_back(varName); + + // 全局常量(SysY IR) + auto globalValue = module->createGlobalValue(varName, irType, {}, initValue); + irSymbolTable[varName] = globalValue; + } else { + // 局部常量(文本 IR) + std::string allocaName = getNextTemp(); + if (llvmType == "float") { + try { + double floatValue = std::stod(value); + uint64_t hexValue = reinterpret_cast(floatValue); + std::stringstream ss; + ss << "0x" << std::hex << std::uppercase << hexValue; + value = ss.str(); + } catch (...) { + throw std::runtime_error("Invalid float literal: " + value); + } + } + irStream << " " << allocaName << " = alloca " << llvmType << ", align 4\n"; + irStream << " store " << llvmType << " " << value << ", " << llvmType + << "* " << allocaName << ", align 4\n"; + symbolTable[varName] = {allocaName, llvmType}; + tmpTable[allocaName] = llvmType; + + // 局部常量(SysY IR)TODO:这里可能有bug,AI在犯蠢 + sysy::IRBuilder builder(currentIRBlock); + auto allocaInst = builder.createAllocaInst(irType, {}, varName); + builder.createStoreInst(initValue, allocaInst); + irSymbolTable[varName] = allocaInst; + irTmpTable[allocaName] = allocaInst; + } + } + return nullptr; +} + +std::any LLVMIRGenerator::visitVarDef(SysYParser::VarDefContext* ctx) { + // TODO:数组初始化 + std::string varName = ctx->Ident()->getText(); + std::string llvmType = currentVarType; + sysy::Type* irType = sysy::Type::getPointerType(getIRType(currentVarType == "i32" ? "int" : "float")); + std::string allocaName = getNextTemp(); + + // 局部变量(文本 IR) + irStream << " " << allocaName << " = alloca " << llvmType << ", align 4\n"; + + // 局部变量(SysY IR) + sysy::IRBuilder builder(currentIRBlock); + auto allocaInst = builder.createAllocaInst(irType, {}, varName); + sysy::Value* initValue = nullptr; + + if (ctx->ASSIGN()) { + std::string value = std::any_cast(ctx->initVal()->accept(this)); + if (llvmType == "float") { + try { + double floatValue = std::stod(value); + uint64_t hexValue = reinterpret_cast(floatValue); + std::stringstream ss; + ss << "0x" << std::hex << std::uppercase << (hexValue & (0xffffffffUL << 32)); + value = ss.str(); + } catch (...) { + throw std::runtime_error("Invalid float literal: " + value); + } + } + irStream << " store " << llvmType << " " << value << ", " << llvmType + << "* " << allocaName << ", align 4\n"; + + if (irTmpTable.find(value) != irTmpTable.end()) { + initValue = irTmpTable[value]; + } + builder.createStoreInst(initValue, allocaInst); + } + + symbolTable[varName] = {allocaName, llvmType}; + tmpTable[allocaName] = llvmType; + irSymbolTable[varName] = allocaInst;//TODO:这里没看懂在干嘛 + irTmpTable[allocaName] = allocaInst;//TODO:这里没看懂在干嘛 + builder.createStoreInst(initValue, allocaInst);//TODO:这里没看懂在干嘛 + return nullptr; +} + +std::any LLVMIRGenerator::visitFuncDef(SysYParser::FuncDefContext* ctx) { + currentFunction = ctx->Ident()->getText(); + currentReturnType = getLLVMType(ctx->funcType()->getText()); + sysy::Type* irReturnType = getIRType(ctx->funcType()->getText()); + std::vector paramTypes; + + // 清空符号表 + symbolTable.clear(); + tmpTable.clear(); + irSymbolTable.clear(); + irTmpTable.clear(); + tempCounter = 0; + hasReturn = false; + + // 处理函数参数(文本 IR 和 SysY IR) + if (ctx->funcFParams()) { + auto params = ctx->funcFParams()->funcFParam(); + for (size_t i = 0; i < params.size(); ++i) { + std::string paramType = getLLVMType(params[i]->bType()->getText()); + if (i > 0) irStream << ", "; + irStream << paramType << " noundef %" << i; + symbolTable[params[i]->Ident()->getText()] = {"%" + std::to_string(i), paramType}; + tmpTable["%" + std::to_string(i)] = paramType; + paramTypes.push_back(getIRType(params[i]->bType()->getText())); + } + tempCounter += params.size(); + } + tempCounter++; + + // 文本 IR 函数定义 + irStream << "define dso_local " << currentReturnType << " @" << currentFunction << "("; + irStream << ") #0 {\n"; + + // SysY IR 函数定义 + sysy::Type* funcType = sysy::Type::getFunctionType(irReturnType, paramTypes); + currentIRFunction = module->createFunction(currentFunction, funcType); + setIRPosition(currentIRFunction->getEntryBlock()); + + // 处理函数参数分配 + if (ctx->funcFParams()) { + auto params = ctx->funcFParams()->funcFParam(); + for (size_t i = 0; i < params.size(); ++i) { + std::string varName = params[i]->Ident()->getText(); + std::string llvmType = getLLVMType(params[i]->bType()->getText()); + sysy::Type* irType = getIRType(params[i]->bType()->getText()); + std::string allocaName = getNextTemp(); + tmpTable[allocaName] = llvmType; + + // 文本 IR 分配 + irStream << " " << allocaName << " = alloca " << llvmType << ", align 4\n"; + irStream << " store " << llvmType << " %" << i << ", " << llvmType + << "* " << allocaName << ", align 4\n"; + + // SysY IR 分配 + sysy::IRBuilder builder(currentIRBlock); + auto arg = currentIRBlock->createArgument(irType, varName); + auto allocaInst = builder.createAllocaInst(sysy::Type::getPointerType(irType), {}, varName); + builder.createStoreInst(arg, allocaInst); + symbolTable[varName] = {allocaName, llvmType}; + irSymbolTable[varName] = allocaInst; + irTmpTable[allocaName] = allocaInst; + } + } + + ctx->blockStmt()->accept(this); + + if (!hasReturn) { + if (currentReturnType == "void") { + irStream << " ret void\n"; + sysy::IRBuilder builder(currentIRBlock); + builder.createReturnInst(); + } else { + irStream << " ret " << currentReturnType << " 0\n"; + sysy::IRBuilder builder(currentIRBlock); + builder.createReturnInst(sysy::ConstantValue::get(0)); + } + } + irStream << "}\n"; + currentIRFunction = nullptr; + currentIRBlock = nullptr; + return nullptr; +} + +std::any LLVMIRGenerator::visitBlockStmt(SysYParser::BlockStmtContext* ctx) { + for (auto item : ctx->blockItem()) { + item->accept(this); + } + return nullptr; +} + +std::any LLVMIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext* ctx) { + std::string lhsAlloca = std::any_cast(ctx->lValue()->accept(this)); + std::string lhsType = symbolTable[ctx->lValue()->Ident()->getText()].second; + std::string rhs = std::any_cast(ctx->exp()->accept(this)); + sysy::Value* rhsValue = irTmpTable[rhs]; + + // 文本 IR + if (lhsType == "float") { + try { + double floatValue = std::stod(rhs); + uint64_t hexValue = reinterpret_cast(floatValue); + std::stringstream ss; + ss << "0x" << std::hex << std::uppercase << (hexValue & (0xffffffffUL << 32)); + rhs = ss.str(); + } catch (...) { + // 如果 rhs 不是字面量,假设已正确处理 + throw std::runtime_error("Invalid float literal: " + rhs); + } + } + irStream << " store " << lhsType << " " << rhs << ", " << lhsType + << "* " << lhsAlloca << ", align 4\n"; + + // SysY IR + sysy::IRBuilder builder(currentIRBlock); + builder.createStoreInst(rhsValue, irSymbolTable[ctx->lValue()->Ident()->getText()]); + return nullptr; +} + +std::any LLVMIRGenerator::visitIfStmt(SysYParser::IfStmtContext* ctx) { + std::string cond = std::any_cast(ctx->cond()->accept(this)); + sysy::Value* condValue = irTmpTable[cond]; + std::string trueLabel = "if.then." + std::to_string(tempCounter); + std::string falseLabel = "if.else." + std::to_string(tempCounter); + std::string mergeLabel = "if.end." + std::to_string(tempCounter++); + + // SysY IR 基本块 + sysy::BasicBlock* thenBlock = currentIRFunction->addBasicBlock(trueLabel); + sysy::BasicBlock* elseBlock = ctx->ELSE() ? currentIRFunction->addBasicBlock(falseLabel) : nullptr; + sysy::BasicBlock* mergeBlock = currentIRFunction->addBasicBlock(mergeLabel); + + // 文本 IR + irStream << " br i1 " << cond << ", label %" << trueLabel << ", label %" + << (ctx->ELSE() ? falseLabel : mergeLabel) << "\n"; + + // SysY IR 条件分支 + sysy::IRBuilder builder(currentIRBlock); + builder.createCondBrInst(condValue, thenBlock, ctx->ELSE() ? elseBlock : mergeBlock, {}, {}); + + // 处理 then 分支 + setIRPosition(thenBlock); + irStream << trueLabel << ":\n"; + ctx->stmt(0)->accept(this); + irStream << " br label %" << mergeLabel << "\n"; + builder.setPosition(thenBlock, thenBlock->end()); + builder.createUncondBrInst(mergeBlock, {}); + + // 处理 else 分支 + if (ctx->ELSE()) { + setIRPosition(elseBlock); + irStream << falseLabel << ":\n"; + ctx->stmt(1)->accept(this); + irStream << " br label %" << mergeLabel << "\n"; + builder.setPosition(elseBlock, elseBlock->end()); + builder.createUncondBrInst(mergeBlock, {}); + } + + // 合并点 + setIRPosition(mergeBlock); + irStream << mergeLabel << ":\n"; + return nullptr; +} + +std::any LLVMIRGenerator::visitWhileStmt(SysYParser::WhileStmtContext* ctx) { + std::string loopCond = "while.cond." + std::to_string(tempCounter); + std::string loopBody = "while.body." + std::to_string(tempCounter); + std::string loopEnd = "while.end." + std::to_string(tempCounter++); + + // SysY IR 基本块 + sysy::BasicBlock* condBlock = currentIRFunction->addBasicBlock(loopCond); + sysy::BasicBlock* bodyBlock = currentIRFunction->addBasicBlock(loopBody); + sysy::BasicBlock* endBlock = currentIRFunction->addBasicBlock(loopEnd); + + loopStack.push({loopEnd, loopCond, endBlock, condBlock}); + + // 跳转到条件块 + sysy::IRBuilder builder(currentIRBlock); + builder.createUncondBrInst(condBlock, {}); + irStream << " br label %" << loopCond << "\n"; + + // 条件块 + setIRPosition(condBlock); + irStream << loopCond << ":\n"; + std::string cond = std::any_cast(ctx->cond()->accept(this)); + sysy::Value* condValue = irTmpTable[cond]; + irStream << " br i1 " << cond << ", label %" << loopBody << ", label %" << loopEnd << "\n"; + builder.setPosition(condBlock, condBlock->end()); + builder.createCondBrInst(condValue, bodyBlock, endBlock, {}, {}); + + // 循环体 + setIRPosition(bodyBlock); + irStream << loopBody << ":\n"; + ctx->stmt()->accept(this); + irStream << " br label %" << loopCond << "\n"; + builder.setPosition(bodyBlock, bodyBlock->end()); + builder.createUncondBrInst(condBlock, {}); + + // 结束块 + setIRPosition(endBlock); + irStream << loopEnd << ":\n"; + loopStack.pop(); + return nullptr; +} + +std::any LLVMIRGenerator::visitBreakStmt(SysYParser::BreakStmtContext* ctx) { + if (loopStack.empty()) { + throw std::runtime_error("Break statement outside of a loop."); + } + irStream << " br label %" << loopStack.top().breakLabel << "\n"; + sysy::IRBuilder builder(currentIRBlock); + builder.createUncondBrInst(loopStack.top().irBreakBlock, {}); + return nullptr; +} + +std::any LLVMIRGenerator::visitContinueStmt(SysYParser::ContinueStmtContext* ctx) { + if (loopStack.empty()) { + throw std::runtime_error("Continue statement outside of a loop."); + } + irStream << " br label %" << loopStack.top().continueLabel << "\n"; + sysy::IRBuilder builder(currentIRBlock); + builder.createUncondBrInst(loopStack.top().irContinueBlock, {}); + return nullptr; +} + +std::any LLVMIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext* ctx) { + hasReturn = true; + sysy::IRBuilder builder(currentIRBlock); + if (ctx->exp()) { + std::string value = std::any_cast(ctx->exp()->accept(this)); + sysy::Value* irValue = irTmpTable[value]; + irStream << " ret " << currentReturnType << " " << value << "\n"; + builder.createReturnInst(irValue); + } else { + irStream << " ret void\n"; + builder.createReturnInst(); + } + return nullptr; +} + +std::any LLVMIRGenerator::visitLValue(SysYParser::LValueContext* ctx) { + std::string varName = ctx->Ident()->getText(); + if (irSymbolTable.find(varName) == irSymbolTable.end()) { + throw std::runtime_error("Undefined variable: " + varName); + } + // 对于 LValue,返回分配的指针(文本 IR 和 SysY IR 一致) + return symbolTable[varName].first; +} + +std::any LLVMIRGenerator::visitPrimExp(SysYParser::PrimExpContext* ctx) { + SysYParser::PrimaryExpContext* pExpCtx = ctx->primaryExp(); + if (auto* lvalCtx = dynamic_cast(pExpCtx)) { + std::string allocaPtr = std::any_cast(lvalCtx->lValue()->accept(this)); + std::string varName = lvalCtx->lValue()->Ident()->getText(); + std::string type = symbolTable[varName].second; + std::string temp = getNextTemp(); + sysy::Type* irType = getIRType(type == "i32" ? "int" : "float"); + + // 文本 IR + irStream << " " << temp << " = load " << type << ", " << type << "* " << allocaPtr << ", align 4\n"; + tmpTable[temp] = type; + + // SysY IR + sysy::IRBuilder builder(currentIRBlock); + auto loadInst = builder.createLoadInst(irSymbolTable[varName], {}); + irTmpTable[temp] = loadInst; + return temp; + } else if (auto* expCtx = dynamic_cast(pExpCtx)) { + return expCtx->exp()->accept(this); + } else if (auto* strCtx = dynamic_cast(pExpCtx)) { + return strCtx->string()->accept(this); + } else if (auto* numCtx = dynamic_cast(pExpCtx)) { + return numCtx->number()->accept(this); + } else { + // 没有成功转换,说明 ctx->primaryExp() 不是 NumContext 或其他已知类型 + // 可能是其他类型的表达式,或者是一个空的 PrimaryExpContext + std::cout << "Unknown primary expression type." << std::endl; + throw std::runtime_error("Unknown primary expression type."); + } +} + +std::any LLVMIRGenerator::visitParenExp(SysYParser::ParenExpContext* ctx) { + return ctx->exp()->accept(this); +} + +std::any LLVMIRGenerator::visitNumber(SysYParser::NumberContext* ctx) { + std::string value; + sysy::Value* irValue = nullptr; + if (ctx->ILITERAL()) { + value = ctx->ILITERAL()->getText(); + irValue = sysy::ConstantValue::get(std::stoi(value)); + } else if (ctx->FLITERAL()) { + value = ctx->FLITERAL()->getText(); + irValue = sysy::ConstantValue::get(std::stof(value)); + } else { + value = ""; + } + std::string temp = getNextTemp(); + tmpTable[temp] = ctx->ILITERAL() ? "i32" : "float"; + irTmpTable[temp] = irValue; + return value; +} + +std::any LLVMIRGenerator::visitString(SysYParser::StringContext* ctx) { + if (ctx->STRING()) { + std::string str = ctx->STRING()->getText(); + str = str.substr(1, str.size() - 2); + std::string escapedStr; + for (char c : str) { + if (c == '\\') { + escapedStr += "\\\\"; + } else if (c == '"') { + escapedStr += "\\\""; + } else { + escapedStr += c; + } + } + // TODO: SysY IR 暂不支持字符串常量,返回文本 IR 结果 + return "\"" + escapedStr + "\""; + } + return ctx->STRING()->getText(); +} + + + +std::any LLVMIRGenerator::visitUnExp(SysYParser::UnExpContext* ctx) { + if (ctx->unaryOp()) { + std::string operand = std::any_cast(ctx->unaryExp()->accept(this)); + sysy::Value* irOperand = irTmpTable[operand]; + std::string op = ctx->unaryOp()->getText(); + std::string temp = getNextTemp(); + std::string type = tmpTable[operand]; + sysy::Type* irType = getIRType(type == "i32" ? "int" : "float"); + tmpTable[temp] = type; + + // 文本 IR + if (op == "-") { + irStream << " " << temp << " = sub " << type << " 0, " << operand << "\n"; + } else if (op == "!") { + irStream << " " << temp << " = xor " << type << " " << operand << ", 1\n"; + } + + // SysY IR + sysy::IRBuilder builder(currentIRBlock); + sysy::Instruction::Kind kind = (op == "-") ? (type == "i32" ? sysy::Instruction::kNeg : sysy::Instruction::kFNeg) + : sysy::Instruction::kNot; + auto unaryInst = builder.createUnaryInst(kind, irType, irOperand, temp); + irTmpTable[temp] = unaryInst; + return temp; + } + return ctx->unaryExp()->accept(this); +} + +std::any LLVMIRGenerator::visitCall(SysYParser::CallContext* ctx) { + std::string funcName = ctx->Ident()->getText(); + std::vector args; + std::vector irArgs; + if (ctx->funcRParams()) { + for (auto argCtx : ctx->funcRParams()->exp()) { + std::string arg = std::any_cast(argCtx->accept(this)); + args.push_back(arg); + irArgs.push_back(irTmpTable[arg]); + } + } + std::string temp = getNextTemp(); + std::string argList; + for (size_t i = 0; i < args.size(); ++i) { + if (i > 0) argList += ", "; + argList += tmpTable[args[i]] + " noundef " + args[i]; + } + + // 文本 IR + irStream << " " << temp << " = call " << currentReturnType << " @" << funcName << "(" << argList << ")\n"; + tmpTable[temp] = currentReturnType; + + // SysY IR + sysy::IRBuilder builder(currentIRBlock); + sysy::Function* callee = module->getFunction(funcName); + if (!callee) { + throw std::runtime_error("Undefined function: " + funcName); + } + auto callInst = builder.createCallInst(callee, irArgs, temp); + irTmpTable[temp] = callInst; + return temp; +} + +std::any LLVMIRGenerator::visitMulExp(SysYParser::MulExpContext* ctx) { + auto unaryExps = ctx->unaryExp(); + std::string left = std::any_cast(unaryExps[0]->accept(this)); + sysy::Value* irLeft = irTmpTable[left]; + sysy::Type* irType = irLeft->getType(); + + for (size_t i = 1; i < unaryExps.size(); ++i) { + std::string right = std::any_cast(unaryExps[i]->accept(this)); + sysy::Value* irRight = irTmpTable[right]; + std::string op = ctx->children[2 * i - 1]->getText(); + std::string temp = getNextTemp(); + std::string type = tmpTable[left]; + tmpTable[temp] = type; + + // 文本 IR + if (op == "*") { + irStream << " " << temp << " = mul nsw " << type << " " << left << ", " << right << "\n"; + } else if (op == "/") { + irStream << " " << temp << " = sdiv " << type << " " << left << ", " << right << "\n"; + } else if (op == "%") { + irStream << " " << temp << " = srem " << type << " " << left << ", " << right << "\n"; + } + + // SysY IR + sysy::IRBuilder builder(currentIRBlock); + sysy::Instruction::Kind kind; + if (type == "i32") { + if (op == "*") kind = sysy::Instruction::kMul; + else if (op == "/") kind = sysy::Instruction::kDiv; + else kind = sysy::Instruction::kRem; + } else { + if (op == "*") kind = sysy::Instruction::kFMul; + else if (op == "/") kind = sysy::Instruction::kFDiv; + else kind = sysy::Instruction::kFRem; + } + auto binaryInst = builder.createBinaryInst(kind, irType, irLeft, irRight, temp); + irTmpTable[temp] = binaryInst; + left = temp; + irLeft = binaryInst; + } + return left; +} + +std::any LLVMIRGenerator::visitAddExp(SysYParser::AddExpContext* ctx) { + auto mulExps = ctx->mulExp(); + std::string left = std::any_cast(mulExps[0]->accept(this)); + sysy::Value* irLeft = irTmpTable[left]; + sysy::Type* irType = irLeft->getType(); + + for (size_t i = 1; i < mulExps.size(); ++i) { + std::string right = std::any_cast(mulExps[i]->accept(this)); + sysy::Value* irRight = irTmpTable[right]; + std::string op = ctx->children[2 * i - 1]->getText(); + std::string temp = getNextTemp(); + std::string type = tmpTable[left]; + tmpTable[temp] = type; + + // 文本 IR + if (op == "+") { + irStream << " " << temp << " = add nsw " << type << " " << left << ", " << right << "\n"; + } else if (op == "-") { + irStream << " " << temp << " = sub nsw " << type << " " << left << ", " << right << "\n"; + } + + // SysY IR + sysy::IRBuilder builder(currentIRBlock); + sysy::Instruction::Kind kind = (type == "i32") ? (op == "+" ? sysy::Instruction::kAdd : sysy::Instruction::kSub) + : (op == "+" ? sysy::Instruction::kFAdd : sysy::Instruction::kFSub); + auto binaryInst = builder.createBinaryInst(kind, irType, irLeft, irRight, temp); + irTmpTable[temp] = binaryInst; + left = temp; + irLeft = binaryInst; + } + return left; +} + +std::any LLVMIRGenerator::visitRelExp(SysYParser::RelExpContext* ctx) { + auto addExps = ctx->addExp(); + std::string left = std::any_cast(addExps[0]->accept(this)); + sysy::Value* irLeft = irTmpTable[left]; + sysy::Type* irType = sysy::Type::getIntType(); // 比较结果为 i1 + + for (size_t i = 1; i < addExps.size(); ++i) { + std::string right = std::any_cast(addExps[i]->accept(this)); + sysy::Value* irRight = irTmpTable[right]; + std::string op = ctx->children[2 * i - 1]->getText(); + std::string temp = getNextTemp(); + std::string type = tmpTable[left]; + tmpTable[temp] = "i1"; + + // 文本 IR + if (op == "<") { + irStream << " " << temp << " = icmp slt " << type << " " << left << ", " << right << "\n"; + } else if (op == ">") { + irStream << " " << temp << " = icmp sgt " << type << " " << left << ", " << right << "\n"; + } else if (op == "<=") { + irStream << " " << temp << " = icmp sle " << type << " " << left << ", " << right << "\n"; + } else if (op == ">=") { + irStream << " " << temp << " = icmp sge " << type << " " << left << ", " << right << "\n"; + } + + // SysY IR + sysy::IRBuilder builder(currentIRBlock); + sysy::Instruction::Kind kind; + if (type == "i32") { + if (op == "<") kind = sysy::Instruction::kICmpLT; + else if (op == ">") kind = sysy::Instruction::kICmpGT; + else if (op == "<=") kind = sysy::Instruction::kICmpLE; + else kind = sysy::Instruction::kICmpGE; + } else { + if (op == "<") kind = sysy::Instruction::kFCmpLT; + else if (op == ">") kind = sysy::Instruction::kFCmpGT; + else if (op == "<=") kind = sysy::Instruction::kFCmpLE; + else kind = sysy::Instruction::kFCmpGE; + } + auto cmpInst = builder.createBinaryInst(kind, irType, irLeft, irRight, temp); + irTmpTable[temp] = cmpInst; + left = temp; + irLeft = cmpInst; + } + return left; +} + +std::any LLVMIRGenerator::visitEqExp(SysYParser::EqExpContext* ctx) { + auto relExps = ctx->relExp(); + std::string left = std::any_cast(relExps[0]->accept(this)); + sysy::Value* irLeft = irTmpTable[left]; + sysy::Type* irType = sysy::Type::getIntType(); // 比较结果为 i1 + + for (size_t i = 1; i < relExps.size(); ++i) { + std::string right = std::any_cast(relExps[i]->accept(this)); + sysy::Value* irRight = irTmpTable[right]; + std::string op = ctx->children[2 * i - 1]->getText(); + std::string temp = getNextTemp(); + std::string type = tmpTable[left]; + tmpTable[temp] = "i1"; + + // 文本 IR + if (op == "==") { + irStream << " " << temp << " = icmp eq " << type << " " << left << ", " << right << "\n"; + } else if (op == "!=") { + irStream << " " << temp << " = icmp ne " << type << " " << left << ", " << right << "\n"; + } + + // SysY IR + sysy::IRBuilder builder(currentIRBlock); + sysy::Instruction::Kind kind = (type == "i32") ? (op == "==" ? sysy::Instruction::kICmpEQ : sysy::Instruction::kICmpNE) + : (op == "==" ? sysy::Instruction::kFCmpEQ : sysy::Instruction::kFCmpNE); + auto cmpInst = builder.createBinaryInst(kind, irType, irLeft, irRight, temp); + irTmpTable[temp] = cmpInst; + left = temp; + irLeft = cmpInst; + } + return left; +} + +std::any LLVMIRGenerator::visitLAndExp(SysYParser::LAndExpContext* ctx) { + auto eqExps = ctx->eqExp(); + std::string left = std::any_cast(eqExps[0]->accept(this)); + sysy::Value* irLeft = irTmpTable[left]; + + for (size_t i = 1; i < eqExps.size(); ++i) { + std::string falseLabel = "land.false." + std::to_string(tempCounter); + std::string endLabel = "land.end." + std::to_string(tempCounter++); + sysy::BasicBlock* falseBlock = currentIRFunction->addBasicBlock(falseLabel); + sysy::BasicBlock* endBlock = currentIRFunction->addBasicBlock(endLabel); + std::string temp = getNextTemp(); + tmpTable[temp] = "i1"; + + // 文本 IR + irStream << " br i1 " << left << ", label %" << falseLabel << ", label %" << endLabel << "\n"; + irStream << falseLabel << ":\n"; + + // SysY IR + sysy::IRBuilder builder(currentIRBlock); + builder.createCondBrInst(irLeft, falseBlock, endBlock, {}, {}); + setIRPosition(falseBlock); + + std::string right = std::any_cast(eqExps[i]->accept(this)); + sysy::Value* irRight = irTmpTable[right]; + irStream << " " << temp << " = and i1 " << left << ", " << right << "\n"; + irStream << " br label %" << endLabel << "\n"; + irStream << endLabel << ":\n"; + + // SysY IR 逻辑与(通过基本块实现短路求值) + builder.setPosition(falseBlock, falseBlock->end()); + auto andInst = builder.createBinaryInst(sysy::Instruction::kICmpEQ, sysy::Type::getIntType(), irLeft, irRight, temp); + builder.createUncondBrInst(endBlock, {}); + irTmpTable[temp] = andInst; + left = temp; + irLeft = andInst; + setIRPosition(endBlock); + } + return left; +} + +std::any LLVMIRGenerator::visitLOrExp(SysYParser::LOrExpContext* ctx) { + auto lAndExps = ctx->lAndExp(); + std::string left = std::any_cast(lAndExps[0]->accept(this)); + sysy::Value* irLeft = irTmpTable[left]; + + for (size_t i = 1; i < lAndExps.size(); ++i) { + std::string trueLabel = "lor.true." + std::to_string(tempCounter); + std::string endLabel = "lor.end." + std::to_string(tempCounter++); + sysy::BasicBlock* trueBlock = currentIRFunction->addBasicBlock(trueLabel); + sysy::BasicBlock* endBlock = currentIRFunction->addBasicBlock(endLabel); + std::string temp = getNextTemp(); + tmpTable[temp] = "i1"; + + // 文本 IR + irStream << " br i1 " << left << ", label %" << trueLabel << ", label %" << endLabel << "\n"; + irStream << trueLabel << ":\n"; + + // SysY IR + sysy::IRBuilder builder(currentIRBlock); + builder.createCondBrInst(irLeft, trueBlock, endBlock, {}, {}); + setIRPosition(trueBlock); + + std::string right = std::any_cast(lAndExps[i]->accept(this)); + sysy::Value* irRight = irTmpTable[right]; + irStream << " " << temp << " = or i1 " << left << ", " << right << "\n"; + irStream << " br label %" << endLabel << "\n"; + irStream << endLabel << ":\n"; + + // SysY IR 逻辑或(通过基本块实现短路求值) + builder.setPosition(trueBlock, trueBlock->end()); + auto orInst = builder.createBinaryInst(sysy::Instruction::kICmpEQ, sysy::Type::getIntType(), irLeft, irRight, temp); + builder.createUncondBrInst(endBlock, {}); + irTmpTable[temp] = orInst; + left = temp; + irLeft = orInst; + setIRPosition(endBlock); + } + return left; +} + +// } // namespace sysy \ No newline at end of file diff --git a/src/SysY.g4 b/src/SysY.g4 index b3ed583..ad74a0c 100644 --- a/src/SysY.g4 +++ b/src/SysY.g4 @@ -101,7 +101,10 @@ BLOCKCOMMENT: '/*' .*? '*/' -> skip; // CompUnit: (CompUnit)? (decl |funcDef); -compUnit: (decl |funcDef)+; +compUnit: (globalDecl |funcDef)+; + +globalDecl: constDecl # globalConstDecl + | varDecl # globalVarDecl; decl: constDecl | varDecl; @@ -111,16 +114,16 @@ bType: INT | FLOAT; constDef: Ident (LBRACK constExp RBRACK)* ASSIGN constInitVal; -constInitVal: constExp - | LBRACE (constInitVal (COMMA constInitVal)*)? RBRACE; +constInitVal: constExp # constScalarInitValue + | LBRACE (constInitVal (COMMA constInitVal)*)? RBRACE # constArrayInitValue; varDecl: bType varDef (COMMA varDef)* SEMICOLON; varDef: Ident (LBRACK constExp RBRACK)* | Ident (LBRACK constExp RBRACK)* ASSIGN initVal; -initVal: exp - | LBRACE (initVal (COMMA initVal)*)? RBRACE; +initVal: exp # scalarInitValue + | LBRACE (initVal (COMMA initVal)*)? RBRACE # arrayInitValue; funcType: VOID | INT | FLOAT; @@ -150,15 +153,16 @@ cond: lOrExp; lValue: Ident (LBRACK exp RBRACK)*; // 为了方便测试 primaryExp 可以是一个string -primaryExp: LPAREN exp RPAREN #parenExp - | lValue #lVal - | number #num - | string #str; +primaryExp: LPAREN exp RPAREN + | lValue + | number + | string; number: ILITERAL | FLITERAL; -unaryExp: primaryExp #primExp - | Ident LPAREN (funcRParams)? RPAREN #call - | unaryOp unaryExp #unExp; +call: Ident LPAREN (funcRParams)? RPAREN; +unaryExp: primaryExp + | call + | unaryOp unaryExp; unaryOp: ADD|SUB|NOT; funcRParams: exp (COMMA exp)*; diff --git a/src/SysYIRAnalyser.cpp b/src/SysYIRAnalyser.cpp new file mode 100644 index 0000000..e69de29 diff --git a/src/SysYIRGenerator.cpp b/src/SysYIRGenerator.cpp index a2962c7..1718aba 100644 --- a/src/SysYIRGenerator.cpp +++ b/src/SysYIRGenerator.cpp @@ -6,15 +6,20 @@ #include "IR.h" #include #include +#include +#include +#include +#include using namespace std; #include "SysYIRGenerator.h" + namespace sysy { /* * @brief: visit compUnit * @details: - * compUnit: (decl | funcDef)+; + * compUnit: (globalDecl | funcDef)+; */ std::any SysYIRGenerator::visitCompUnit(SysYParser::CompUnitContext *ctx) { // create the IR module @@ -22,1016 +27,1153 @@ std::any SysYIRGenerator::visitCompUnit(SysYParser::CompUnitContext *ctx) { assert(pModule); module.reset(pModule); - SymbolTable::ModuleScope scope(symbols_table); + // SymbolTable::ModuleScope scope(symbols_table); - // 待添加运行时库函数getint等 - // generates globals and functions - auto type_i32 = Type::getIntType(); - auto type_f32 = Type::getFloatType(); - auto type_void = Type::getVoidType(); - auto type_i32p = Type::getPointerType(type_i32); - auto type_f32p = Type::getPointerType(type_f32); - - //runtime library - module->createFunction("getint", Type::getFunctionType(type_i32, {})); - module->createFunction("getch", Type::getFunctionType(type_i32, {})); - module->createFunction("getfloat", Type::getFunctionType(type_f32, {})); - symbols_table.insert("getint", module->getFunction("getint")); - symbols_table.insert("getch", module->getFunction("getch")); - symbols_table.insert("getfloat", module->getFunction("getfloat")); - - module->createFunction("getarray", Type::getFunctionType(type_i32, {type_i32p})); - module->createFunction("getfarray", Type::getFunctionType(type_i32, {type_f32p})); - symbols_table.insert("getarray", module->getFunction("getarray")); - symbols_table.insert("getfarray", module->getFunction("getfarray")); + Utils::initExternalFunction(pModule, &builder); - module->createFunction("putint", Type::getFunctionType(type_void, {type_i32})); - module->createFunction("putch", Type::getFunctionType(type_void, {type_i32})); - module->createFunction("putfloat", Type::getFunctionType(type_void, {type_f32})); - symbols_table.insert("putint", module->getFunction("putint")); - symbols_table.insert("putch", module->getFunction("putch")); - symbols_table.insert("putfloat", module->getFunction("putfloat")); - - module->createFunction("putarray", Type::getFunctionType(type_void, {type_i32, type_i32p})); - module->createFunction("putfarray", Type::getFunctionType(type_void, {type_i32, type_f32p})); - symbols_table.insert("putarray", module->getFunction("putarray")); - symbols_table.insert("putfarray", module->getFunction("putfarray")); - - module->createFunction("putf", Type::getFunctionType(type_void, {})); - symbols_table.insert("putf", module->getFunction("putf")); - - module->createFunction("starttime", Type::getFunctionType(type_void, {type_i32})); - module->createFunction("stoptime", Type::getFunctionType(type_void, {type_i32})); - symbols_table.insert("starttime", module->getFunction("starttime")); - symbols_table.insert("stoptime", module->getFunction("stoptime")); - - // visit all decls and funcDefs - for(auto decl:ctx->decl()){ - decl->accept(this); - } - for(auto funcDef:ctx->funcDef()){ - builder = IRBuilder(); - printf("entry funcDef\n"); - funcDef->accept(this); - } - // return the IR module ? + pModule->enterNewScope(); + visitChildren(ctx); + pModule->leaveScope(); return pModule; } -/* - * @brief: visit decl - * @details: - * decl: constDecl | varDecl; - * constDecl: CONST bType constDef (COMMA constDef)* SEMI; - * varDecl: bType varDef (COMMA varDef)* SEMI; - * constDecl and varDecl shares similar syntax structure - * we consider them together? not sure - */ -std::any SysYIRGenerator::visitDecl(SysYParser::DeclContext *ctx) { - if(ctx->constDecl()) - return visitConstDecl(ctx->constDecl()); - else if(ctx->varDecl()) - return visitVarDecl(ctx->varDecl()); - std::cerr << "error unkown decl" << ctx->getText() << std::endl; - return std::any(); -} +std::any SysYIRGenerator::visitGlobalConstDecl(SysYParser::GlobalConstDeclContext *ctx){ + auto constDecl = ctx->constDecl(); + Type* type = std::any_cast(visitBType(constDecl->bType())); + for (const auto &constDef : constDecl->constDef()) { + std::vector dims = {}; + std::string name = constDef->Ident()->getText(); + auto constExps = constDef->constExp(); + if (!constExps.empty()) { + for (const auto &constExp : constExps) { + dims.push_back(std::any_cast(visitConstExp(constExp))); + } + } -/* - * @brief: visit constdecl - * @details: - * constDecl: CONST bType constDef (COMMA constDef)* SEMI; - * constDef: Ident (LBRACK constExp RBRACK)* (ASSIGN constInitVal)?; - */ -std::any SysYIRGenerator::visitConstDecl(SysYParser::ConstDeclContext *ctx) { - cout << "visitconstDecl" << endl; - current_type = any_cast(ctx->bType()->accept(this)); - for(auto constDef:ctx->constDef()){ - constDef->accept(this); + ArrayValueTree* root = std::any_cast(constDef->constInitVal()->accept(this)); + ValueCounter values; + Utils::tree2Array(type, root, dims, dims.size(), values, &builder); + delete root; + // 创建全局常量变量,并更新符号表 + module->createConstVar(name, Type::getPointerType(type), values, dims); } return std::any(); } -/* - * @brief: visit btype - * @details: - * bType: INT | FLOAT; - */ + +std::any SysYIRGenerator::visitGlobalVarDecl(SysYParser::GlobalVarDeclContext *ctx) { + auto varDecl = ctx->varDecl(); + Type* type = std::any_cast(visitBType(varDecl->bType())); + for (const auto &varDef : varDecl->varDef()) { + std::vector dims = {}; + std::string name = varDef->Ident()->getText(); + auto constExps = varDef->constExp(); + if (!constExps.empty()) { + for (const auto &constExp : constExps) { + dims.push_back(std::any_cast(visitConstExp(constExp))); + } + } + + ArrayValueTree* root = std::any_cast(varDef->initVal()->accept(this)); + ValueCounter values; + Utils::tree2Array(type, root, dims, dims.size(), values, &builder); + delete root; + // 创建全局变量,并更新符号表 + module->createGlobalValue(name, Type::getPointerType(type), dims, values); + } + return std::any(); +} + +std::any SysYIRGenerator::visitConstDecl(SysYParser::ConstDeclContext *ctx){ + Type* type = std::any_cast(visitBType(ctx->bType())); + for (const auto constDef : ctx->constDef()) { + std::vector dims = {}; + std::string name = constDef->Ident()->getText(); + auto constExps = constDef->constExp(); + if (!constExps.empty()) { + for (const auto constExp : constExps) { + dims.push_back(std::any_cast(visitConstExp(constExp))); + } + } + + ArrayValueTree* root = std::any_cast(constDef->constInitVal()->accept(this)); + ValueCounter values; + Utils::tree2Array(type, root, dims, dims.size(), values, &builder); + delete root; + + module->createConstVar(name, Type::getPointerType(type), values, dims); + } + return 0; +} + +std::any SysYIRGenerator::visitVarDecl(SysYParser::VarDeclContext *ctx) { + Type* type = std::any_cast(visitBType(ctx->bType())); + for (const auto varDef : ctx->varDef()) { + std::vector dims = {}; + std::string name = varDef->Ident()->getText(); + auto constExps = varDef->constExp(); + if (!constExps.empty()) { + for (const auto &constExp : constExps) { + dims.push_back(std::any_cast(visitConstExp(constExp))); + } + } + + AllocaInst* alloca = + builder.createAllocaInst(Type::getPointerType(type), dims, name); + + if (varDef->initVal() != nullptr) { + ValueCounter values; + // 这里的varDef->initVal()可能是ScalarInitValue或ArrayInitValue + ArrayValueTree* root = std::any_cast(varDef->initVal()->accept(this)); + Utils::tree2Array(type, root, dims, dims.size(), values, &builder); + delete root; + if (dims.empty()) { + builder.createStoreInst(values.getValue(0), alloca); + } else { + // 对于多维数组,使用memset初始化 + // 计算每个维度的大小 + // 这里的values.getNumbers()返回的是每个维度的大小 + // 这里的values.getValues()返回的是每个维度对应的值 + // 例如:对于一个二维数组,values.getNumbers()可能是[3, 4],表示3行4列 + // values.getValues()可能是[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] + // 对于每个维度,使用memset将对应的值填充到数组中 + // 这里的alloca是一个指向数组的指针 + const std::vector & counterNumbers = values.getNumbers(); + const std::vector & counterValues = values.getValues(); + unsigned begin = 0; + for (size_t i = 0; i < counterNumbers.size(); i++) { + + builder.createMemsetInst( + alloca, ConstantValue::get(static_cast(begin)), + ConstantValue::get(static_cast(counterNumbers[i])), + counterValues[i]); + begin += counterNumbers[i]; + } + } + } + module->addVariable(name, alloca); + } + return std::any(); +} + std::any SysYIRGenerator::visitBType(SysYParser::BTypeContext *ctx) { - if(ctx->INT()) - return Type::getPointerType(Type::getIntType()); - else if(ctx->FLOAT()) - return Type::getPointerType(Type::getFloatType()); - std::cerr << "error: unknown type" << ctx->getText() << std::endl; - return std::any(); + return ctx->INT() != nullptr ? Type::getIntType() : Type::getFloatType(); } -/* - * @brief: visit constDef - * @details: - * constDef: Ident (LBRACK constExp RBRACK)* (ASSIGN constInitVal)?; - * constInitVal: constExp | LBRACE (constInitVal (COMMA constInitVal)*)? RBRACE; - */ -std::any SysYIRGenerator::visitConstDef(SysYParser::ConstDefContext *ctx){ - auto name = ctx->Ident()->getText(); - Type* type = current_type; - Value* init = ctx->constInitVal() ? any_cast(ctx->constInitVal()->accept(this)) : (Value *)nullptr; - - if (ctx->constExp().empty()){ - //scalar - if(init){ - if(symbols_table.isModuleScope()){ - assert(init->isConstant() && "global must be initialized by constant"); - Value* global = module->createGlobalValue(name, type, {}, init); - symbols_table.insert(name, global); - cout << "add module const " << name ; - if(init){ - cout << " inited by " ; - init->print(cout); - } - cout << '\n'; - } - else{ - Value* alloca = builder.createAllocaInst(type, {}, name); - Value* store = builder.createStoreInst(init, alloca); - symbols_table.insert(name, alloca); - cout << "add local const " << name ; - if(init){ - cout << " inited by " ; - init->print(cout); - } - cout << '\n'; - } - } - else{ - assert(false && "const without initialization"); - } - } - else{ - //array - std::cerr << "array constDef not implemented yet" << std::endl; - } - printf("visitConstDef %s\n",name.c_str()); - return std::any(); +std::any SysYIRGenerator::visitScalarInitValue(SysYParser::ScalarInitValueContext *ctx) { + Value* value = std::any_cast(visitExp(ctx->exp())); + ArrayValueTree* result = new ArrayValueTree(); + result->setValue(value); + return result; } - -/* - * @brief: visit constInitVal - * @details: - * constInitVal: constExp - * | LBRACE (constInitVal (COMMA constInitVal)*)? RBRACE; - */ -std::any SysYIRGenerator::visitConstInitVal(SysYParser::ConstInitValContext *ctx){ - Value* initvalue; - if(ctx->constExp()) - initvalue = any_cast(ctx->constExp()->accept(this)); - else{ - //还未实现数组初始化等功能待验证 - std::cerr << "array initvalue not implemented yet" << std::endl; - // auto numConstInitVals = ctx->constInitVal().size(); - // vector initvalues; - // for(int i = 0; i < numConstInitVals; i++) - // initvalues.push_back(any_cast(ctx->constInitVal(i)->accept(this))); - // initvalue = ConstantValue::getArray(initvalues); - } - return initvalue; +std::any SysYIRGenerator::visitArrayInitValue(SysYParser::ArrayInitValueContext *ctx) { + std::vector children; + for (const auto &initVal : ctx->initVal()) + children.push_back(std::any_cast(initVal->accept(this))); + ArrayValueTree* result = new ArrayValueTree(); + result->addChildren(children); + return result; } -/* - * @brief: visit function type - * @details: - * funcType: VOID | INT | FLOAT; - */ -std::any SysYIRGenerator::visitFuncType(SysYParser::FuncTypeContext* ctx){ - if(ctx->INT()) +std::any SysYIRGenerator::visitConstScalarInitValue(SysYParser::ConstScalarInitValueContext *ctx) { + Value* value = std::any_cast(visitConstExp(ctx->constExp())); + ArrayValueTree* result = new ArrayValueTree(); + result->setValue(value); + return result; +} + +std::any SysYIRGenerator::visitConstArrayInitValue(SysYParser::ConstArrayInitValueContext *ctx) { + std::vector children; + for (const auto &constInitVal : ctx->constInitVal()) + children.push_back(std::any_cast(constInitVal->accept(this))); + ArrayValueTree* result = new ArrayValueTree(); + result->addChildren(children); + return result; +} + +std::any SysYIRGenerator::visitFuncType(SysYParser::FuncTypeContext *ctx) { + if (ctx->INT() != nullptr) return Type::getIntType(); - else if(ctx->FLOAT()) + if (ctx->FLOAT() != nullptr) return Type::getFloatType(); - else if(ctx->VOID()) - return Type::getVoidType(); - std::cerr << "invalid function type: " << ctx->getText() << std::endl; - return std::any(); + return Type::getVoidType(); } -/* - * @brief: visit function define - * @details: - * funcDef: funcType Ident LPAREN funcFParams? RPAREN blockStmt; - * funcFParams: funcFParam (COMMA funcFParam)*; - * funcFParam: bType Ident (LBRACK RBRACK (LBRACK exp RBRACK)*)?; - * entry -> next -> others -> exit - * entry: allocas, br - * next: retval, params, br - * other: blockStmt init block - * exit: load retval, ret - */ -std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext* ctx){ +std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext *ctx){ + // 更新作用域 + module->enterNewScope(); - auto funcName = ctx->Ident()->getText(); - auto returnType = any_cast(ctx->funcType()->accept(this)); - auto func = module->getFunction(funcName); - - cout << "func: "; - returnType->print(cout); - cout << ' '<< funcName.c_str() << endl; - - vector paramTypes; - vector paramNames; - - if(ctx->funcFParams()){ - for(auto funcParam:ctx->funcFParams()->funcFParam()){ - Type* paramType = any_cast(funcParam->bType()->accept(this)); - paramTypes.push_back(paramType); - paramNames.push_back(funcParam->Ident()->getText()); - } - } - - - auto funcType = FunctionType::get(returnType, paramTypes); - auto function = module->createFunction(funcName, funcType); - - SymbolTable::FunctionScope scope(symbols_table); - BasicBlock* entryblock = function->getEntryBlock(); - for(size_t i = 0; i < paramTypes.size(); i++) - entryblock->createArgument(paramTypes[i], paramNames[i]); - - for(auto& arg: entryblock->getArguments()) - symbols_table.insert(arg->getName(), (Value *)arg.get()); - - cout << "setposition entryblock" << endl; - builder.setPosition(entryblock, entryblock->end()); - - ctx->blockStmt()->accept(this); - - return std::any(); -} - -/* - * @brief: visit varDecl - * @details: - * varDecl: bType varDef (COMMA varDef)* SEMI; - */ -std::any SysYIRGenerator::visitVarDecl(SysYParser::VarDeclContext *ctx){ - cout << "visitVarDecl" << endl; - current_type = any_cast(ctx->bType()->accept(this)); - for(auto varDef:ctx->varDef()){ - varDef->accept(this); - } - return std::any(); -} - -/* - * @brief: visit varDef - * @details: - * varDef: Ident (LBRACK constExp RBRACK)* (ASSIGN initVal)?; - */ -std::any SysYIRGenerator::visitVarDef(SysYParser::VarDefContext *ctx){ - const std::string name = ctx->Ident()->getText(); - Type* type = current_type; - Value* init = ctx->initVal() ? any_cast(ctx->initVal()->accept(this)) : nullptr; - // const std::vector dims = {}; - - cout << "vardef: "; - current_type->print(cout); - cout << ' ' << name << endl; - - if(ctx->constExp().empty()){ - //scalar - if(symbols_table.isModuleScope()){ - - if(init) - assert(init->isConstant() && "global must be initialized by constant"); - Value* global = module->createGlobalValue(name, type, {}, init); - symbols_table.insert(name, global); - cout << "add module var " << name ; - // if(init){ - // cout << " inited by " ; - // init->print(cout); - // } - cout << '\n'; - } - else{ - - Value* alloca = builder.createAllocaInst(type, {}, name); - cout << "creatalloca" << endl; - alloca->print(cout); - Value* store = (StoreInst *)nullptr; - if(init != nullptr) - store = builder.createStoreInst(alloca, init, {}, name); - symbols_table.insert(name, alloca); - // alloca->setName(name); - cout << "add local var " ; - alloca->print(cout); - // if(init){ - // cout << " inited by " ; - // init->print(cout); - // } - cout << '\n'; - } - } - else{ - //array - std::cerr << "array varDef not implemented yet" << std::endl; - } - return std::any(); -} - -/* - * @brief: visit initVal - * @details: - * initVal: exp | LBRACE (initVal (COMMA initVal)*)? RBRACE; - */ -std::any SysYIRGenerator::visitInitVal(SysYParser::InitValContext *ctx) { - Value* initvalue = nullptr; - if(ctx->exp()) - initvalue = any_cast(ctx->exp()->accept(this)); - else{ - //还未实现数组初始化等功能待验证 - std::cerr << "array initvalue not implemented yet" << std::endl; - // auto numConstInitVals = ctx->constInitVal().size(); - // vector initvalues; - // for(int i = 0; i < numConstInitVals; i++) - // initvalues.push_back(any_cast(ctx->constInitVal(i)->accept(this))); - // initvalue = ConstantValue::getArray(initvalues); - } - return initvalue; -} - -// std::any SysYIRGenerator::visitFuncFParams(SysYParser::FuncFParamsContext* ctx){ -// return visitChildren(ctx); -// } -// std::any SysYIRGenerator::visitFuncFParam(SysYParser::FuncFParamContext *ctx) { -// return visitChildren(ctx); -// } - -/* - * @brief: visit blockStmt - * @details: - * blockStmt: LBRACE blockItem* RBRACE; - * blockItem: decl | stmt; - */ -std::any SysYIRGenerator::visitBlockStmt(SysYParser::BlockStmtContext* ctx){ - SymbolTable::BlockScope scope(symbols_table); - for (auto item : ctx->blockItem()){ - item->accept(this); - // if(builder.getBasicBlock()->isTerminal()){ - // break; - // } - } - return std::any(); -} - -/* - * @brief: visit ifstmt - * @details: - * ifStmt: IF LPAREN cond RPAREN stmt (ELSE stmt)?; - */ -std::any SysYIRGenerator::visitIfStmt(SysYParser::IfStmtContext *ctx) { - auto condition = any_cast(ctx->cond()->accept(this)); - - auto thenBlock = builder.getBasicBlock()->getParent()->addBasicBlock("then"); - auto elseBlock = builder.getBasicBlock()->getParent()->addBasicBlock("else"); - - auto condbr = builder.createCondBrInst(condition,thenBlock,elseBlock,{},{}); - - builder.setPosition(thenBlock, thenBlock->end()); - ctx->stmt(0)->accept(this); - - if(ctx->ELSE()){ - builder.setPosition(elseBlock, elseBlock->end()); - ctx->stmt(1)->accept(this); - } - //无条件跳转到下一个基本块 - // builder.createUncondBrInst(builder.getBasicBlock()->getParent()->addBasicBlock("next"),{}); - return std::any(); -} - -/* - * @brief: visit whilestmt - * @details: - * whileStmt: WHILE LPAREN cond RPAREN stmt; - */ - - std::any SysYIRGenerator::visitWhileStmt(SysYParser::WhileStmtContext* ctx) { - //需要解决一个函数多个循环的命名问题 - auto header = builder.getBasicBlock()->getParent()->addBasicBlock("header"); - auto body = builder.getBasicBlock()->getParent()->addBasicBlock("body"); - auto exit = builder.getBasicBlock()->getParent()->addBasicBlock("exit"); - - SymbolTable::BlockScope scope(symbols_table); - - { // visit header block - builder.setPosition(header, header->end()); - auto cond = any_cast(ctx->cond()->accept(this)); - auto condbr = builder.createCondBrInst(cond, body, exit, {}, {}); - } - - { // visit body block - builder.setPosition(body, body->end()); - ctx->stmt()->accept(this); - auto uncondbr = builder.createUncondBrInst(header, {}); - } - - // visit exit block - builder.setPosition(exit, exit->end()); - //无条件跳转到下一个基本块以及一些参数传递 - - return std::any(); -} - -/* - * @brief: visit breakstmt - * @details: - * breakStmt: BREAK SEMICOLON; - */ -std::any SysYIRGenerator::visitBreakStmt(SysYParser::BreakStmtContext* ctx) { - //如何获取break所在body对应的header块 - return std::any(); -} -/* - * @brief Visit ReturnStmt - * returnStmt: RETURN exp? SEMICOLON; - */ -std::any SysYIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext* ctx) { - cout << "visitReturnStmt" << endl; - // auto value = ctx->exp() ? any_cast_Value(visit(ctx->exp())) : nullptr; - Value* value = ctx->exp() ? any_cast(ctx->exp()->accept(this)) : nullptr; - - const auto func = builder.getBasicBlock()->getParent(); - - assert(func && "ret stmt block parent err!"); - - // 匹配 返回值类型 与 函数定义类型 - if (func->getReturnType()->isVoid()) { - if (ctx->exp()) - assert(false && "the returned value is not matching the function"); - - auto ret = builder.createReturnInst(); - return std::any(); - } - - assert(ctx->exp() && "the returned value is not matching the function"); - - auto ret = builder.createReturnInst(value); - - //需要增加无条件跳转吗 - return std::any(); -} - -/* - * @brief: visit continuestmt - * @details: - * continueStmt: CONTINUE SEMICOLON; - */ -std::any SysYIRGenerator::visitContinueStmt(SysYParser::ContinueStmtContext* ctx) { - //如何获取continue所在body对应的header块 - return std::any(); -} - - -/* - * @brief visit assign stmt - * @details: - * assignStmt: lValue ASSIGN exp SEMICOLON - */ -std::any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext* ctx) { - cout << "visitassignstme :\n"; - auto lvalue = any_cast(ctx->lValue()->accept(this)); - cout << "getlval" << endl;//lvalue->print(cout);cout << ')'; - auto rvalue = any_cast(ctx->exp()->accept(this)); - //可能要考虑类型转换例如int a = 1.0 - cout << "getrval" << endl;//rvalue->print(cout);cout << ")\n"; - builder.createStoreInst(rvalue, lvalue, {}, {}); - return std::any(); -} - -/* - * @brief: visit lValue - * @details: - * lValue: Ident (LBRACK exp RBRACK)*; - */ -std::any SysYIRGenerator::visitLValue(SysYParser::LValueContext* ctx) { - cout << "visitLValue" << endl; auto name = ctx->Ident()->getText(); - Value* value = symbols_table.lookup(name); + std::vector paramTypes; + std::vector paramNames; + std::vector> paramDims; - assert(value && "lvalue not found"); + if (ctx->funcFParams() != nullptr) { + auto params = ctx->funcFParams()->funcFParam(); + for (const auto ¶m : params) { + paramTypes.push_back(std::any_cast(visitBType(param->bType()))); + paramNames.push_back(param->Ident()->getText()); + std::vector dims = {}; + if (!param->LBRACK().empty()) { + dims.push_back(ConstantValue::get(-1)); // 第一个维度不确定 + for (const auto &exp : param->exp()) { + dims.push_back(std::any_cast(visitExp(exp))); + } + } + paramDims.emplace_back(dims); + } + } - if(ctx->exp().size() == 0){ - //scalar - cout << "lvalue: " << name << endl; - return value; + Type* returnType = std::any_cast(visitFuncType(ctx->funcType())); + Type* funcType = Type::getFunctionType(returnType, paramTypes); + Function* function = module->createFunction(name, funcType); + BasicBlock* entry = function->getEntryBlock(); + builder.setPosition(entry, entry->end()); + + for (size_t i = 0; i < paramTypes.size(); ++i) { + AllocaInst* alloca = builder.createAllocaInst(Type::getPointerType(paramTypes[i]), + paramDims[i], paramNames[i]); + entry->insertArgument(alloca); + module->addVariable(paramNames[i], alloca); } - else{ - //array - std::cerr << "array lvalue not implemented yet" << std::endl; + + for (auto item : ctx->blockStmt()->blockItem()) { + visitBlockItem(item); } - std::cerr << "error lvalue" << ctx->getText() << std::endl; + + module->leaveScope(); + return std::any(); } -std::any SysYIRGenerator::visitPrimExp(SysYParser::PrimExpContext *ctx){ - cout << "visitPrimExp" << endl; - return visitChildren(ctx); +std::any SysYIRGenerator::visitBlockStmt(SysYParser::BlockStmtContext *ctx) { + module->enterNewScope(); + for (auto item : ctx->blockItem()) + visitBlockItem(item); + module->leaveScope(); + return 0; +} + +std::any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) { + auto lVal = ctx->lValue(); + std::string name = lVal->Ident()->getText(); + std::vector dims; + for (const auto &exp : lVal->exp()) { + dims.push_back(std::any_cast(visitExp(exp))); + } + + User* variable = module->getVariable(name); + Value* value = std::any_cast(visitExp(ctx->exp())); + Type* variableType = dynamic_cast(variable->getType())->getBaseType(); + + // 左值右值类型不同处理 + if (variableType != value->getType()) { + ConstantValue * constValue = dynamic_cast(value); + if (constValue != nullptr) { + if (variableType == Type::getFloatType()) { + value = ConstantValue::get(static_cast(constValue->getInt())); + } else { + value = ConstantValue::get(static_cast(constValue->getFloat())); + } + } else { + if (variableType == Type::getFloatType()) { + value = builder.createIToFInst(value); + } else { + value = builder.createFtoIInst(value); + } + } + } + builder.createStoreInst(value, variable, dims, variable->getName()); + + return std::any(); +} + + +std::any SysYIRGenerator::visitIfStmt(SysYParser::IfStmtContext *ctx) { + // labels string stream + + std::stringstream labelstring; + Function * function = builder.getBasicBlock()->getParent(); + + BasicBlock* thenBlock = new BasicBlock(function); + BasicBlock* exitBlock = new BasicBlock(function); + + if (ctx->stmt().size() > 1) { + BasicBlock* elseBlock = new BasicBlock(function); + + builder.pushTrueBlock(thenBlock); + builder.pushFalseBlock(elseBlock); + // 访问条件表达式 + visitCond(ctx->cond()); + builder.popTrueBlock(); + builder.popFalseBlock(); + + labelstring << "then.L" << builder.getLabelIndex(); + thenBlock->setName(labelstring.str()); + labelstring.str(""); + function->addBasicBlock(thenBlock); + builder.setPosition(thenBlock, thenBlock->end()); + + auto block = dynamic_cast(ctx->stmt(0)); + // 如果是块语句,直接访问 + // 否则访问语句 + if (block != nullptr) { + visitBlockStmt(block); + } else { + module->enterNewScope(); + ctx->stmt(0)->accept(this); + module->leaveScope(); + } + builder.createUncondBrInst(exitBlock, {}); + BasicBlock::conectBlocks(builder.getBasicBlock(), exitBlock); + + labelstring << "else.L" << builder.getLabelIndex(); + elseBlock->setName(labelstring.str()); + labelstring.str(""); + function->addBasicBlock(elseBlock); + builder.setPosition(elseBlock, elseBlock->end()); + + block = dynamic_cast(ctx->stmt(1)); + if (block != nullptr) { + visitBlockStmt(block); + } else { + module->enterNewScope(); + ctx->stmt(1)->accept(this); + module->leaveScope(); + } + BasicBlock::conectBlocks(builder.getBasicBlock(), exitBlock); + + labelstring << "exit.L" << builder.getLabelIndex(); + exitBlock->setName(labelstring.str()); + labelstring.str(""); + function->addBasicBlock(exitBlock); + builder.setPosition(exitBlock, exitBlock->end()); + + } else { + builder.pushTrueBlock(thenBlock); + builder.pushFalseBlock(exitBlock); + visitCond(ctx->cond()); + builder.popTrueBlock(); + builder.popFalseBlock(); + + labelstring << "then.L" << builder.getLabelIndex(); + thenBlock->setName(labelstring.str()); + labelstring.str(""); + function->addBasicBlock(thenBlock); + builder.setPosition(thenBlock, thenBlock->end()); + + auto block = dynamic_cast(ctx->stmt(0)); + if (block != nullptr) { + visitBlockStmt(block); + } else { + module->enterNewScope(); + ctx->stmt(0)->accept(this); + module->leaveScope(); + } + BasicBlock::conectBlocks(builder.getBasicBlock(), exitBlock); + + labelstring << "exit.L" << builder.getLabelIndex(); + exitBlock->setName(labelstring.str()); + labelstring.str(""); + function->addBasicBlock(exitBlock); + builder.setPosition(exitBlock, exitBlock->end()); + } + return std::any(); +} + +std::any SysYIRGenerator::visitWhileStmt(SysYParser::WhileStmtContext *ctx) { + // while structure: + // curblock -> headBlock -> bodyBlock -> exitBlock + BasicBlock* curBlock = builder.getBasicBlock(); + Function* function = builder.getBasicBlock()->getParent(); + + std::stringstream labelstring; + labelstring << "head.L" << builder.getLabelIndex(); + BasicBlock *headBlock = function->addBasicBlock(labelstring.str()); + labelstring.str(""); + BasicBlock::conectBlocks(curBlock, headBlock); + builder.setPosition(headBlock, headBlock->end()); + + BasicBlock* bodyBlock = new BasicBlock(function); + BasicBlock* exitBlock = new BasicBlock(function); + + builder.pushTrueBlock(bodyBlock); + builder.pushFalseBlock(exitBlock); + // 访问条件表达式 + visitCond(ctx->cond()); + builder.popTrueBlock(); + builder.popFalseBlock(); + + labelstring << "body.L" << builder.getLabelIndex(); + bodyBlock->setName(labelstring.str()); + labelstring.str(""); + function->addBasicBlock(bodyBlock); + builder.setPosition(bodyBlock, bodyBlock->end()); + + builder.pushBreakBlock(exitBlock); + builder.pushContinueBlock(headBlock); + + auto block = dynamic_cast(ctx->stmt()); + + if( block != nullptr) { + visitBlockStmt(block); + } else { + module->enterNewScope(); + ctx->stmt()->accept(this); + module->leaveScope(); + } + + builder.createUncondBrInst(headBlock, {}); + BasicBlock::conectBlocks(builder.getBasicBlock(), exitBlock); + builder.popBreakBlock(); + builder.popContinueBlock(); + + labelstring << "exit.L" << builder.getLabelIndex(); + exitBlock->setName(labelstring.str()); + labelstring.str(""); + function->addBasicBlock(exitBlock); + builder.setPosition(exitBlock, exitBlock->end()); + + return std::any(); +} + +std::any SysYIRGenerator::visitBreakStmt(SysYParser::BreakStmtContext *ctx) { + BasicBlock* breakBlock = builder.getBreakBlock(); + builder.pushBreakBlock(breakBlock); + BasicBlock::conectBlocks(builder.getBasicBlock(), breakBlock); + return std::any(); +} + +std::any SysYIRGenerator::visitContinueStmt(SysYParser::ContinueStmtContext *ctx) { + BasicBlock* continueBlock = builder.getContinueBlock(); + builder.createUncondBrInst(continueBlock, {}); + return std::any(); +} + +std::any SysYIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext *ctx) { + Value* returnValue = nullptr; + if (ctx->exp() != nullptr) { + returnValue = std::any_cast(visitExp(ctx->exp())); + } + + Type* funcType = builder.getBasicBlock()->getParent()->getType(); + if (funcType!= returnValue->getType() && returnValue != nullptr) { + ConstantValue * constValue = dynamic_cast(returnValue); + if (constValue != nullptr) { + if (funcType == Type::getFloatType()) { + returnValue = ConstantValue::get(static_cast(constValue->getInt())); + } else { + returnValue = ConstantValue::get(static_cast(constValue->getFloat())); + } + } else { + if (funcType == Type::getFloatType()) { + returnValue = builder.createIToFInst(returnValue); + } else { + returnValue = builder.createFtoIInst(returnValue); + } + } + } + builder.createReturnInst(returnValue); + return std::any(); +} + + +std::any SysYIRGenerator::visitLValue(SysYParser::LValueContext *ctx) { + std::string name = ctx->Ident()->getText(); + User* variable = module->getVariable(name); + + Value* value = nullptr; + std::vector dims; + for (const auto &exp : ctx->exp()) { + dims.push_back(std::any_cast(visitExp(exp))); + } + + if (variable == nullptr) { + throw std::runtime_error("Variable " + name + " not found."); + } + + bool indicesConstant = true; + for (const auto &dim : dims) { + if (dynamic_cast(dim) == nullptr) { + indicesConstant = false; + break; + } + } + + ConstantVariable* constVar = dynamic_cast(variable); + GlobalValue* globalVar = dynamic_cast(variable); + AllocaInst* localVar = dynamic_cast(variable); + if (constVar != nullptr && indicesConstant) { + // 如果是常量变量,且索引是常量,则直接获取子数组 + value = constVar->getByIndices(dims); + } else if (module->isInGlobalArea() && (globalVar != nullptr)) { + assert(indicesConstant); + value = globalVar->getByIndices(dims); + } else { + if ((globalVar != nullptr && globalVar->getNumDims() > dims.size()) || + (localVar != nullptr && localVar->getNumDims() > dims.size()) || + (constVar != nullptr && constVar->getNumDims() > dims.size())) { + // value = builder.createLaInst(variable, indices); + // 如果变量是全局变量或局部变量,且索引数量小于维度数量,则创建createGetSubArray获取子数组 + auto getArrayInst = + builder.createGetSubArray(dynamic_cast(variable), dims); + value = getArrayInst->getChildArray(); + } else { + value = builder.createLoadInst(variable, dims); + } + } + + return value; +} + +std::any SysYIRGenerator::visitPrimaryExp(SysYParser::PrimaryExpContext *ctx) { + if (ctx->exp() != nullptr) + return visitExp(ctx->exp()); + if (ctx->lValue() != nullptr) + return visitLValue(ctx->lValue()); + if (ctx->number() != nullptr) + return visitNumber(ctx->number()); + if (ctx->string() != nullptr) { + cout << "String literal not supported in SysYIRGenerator." << endl; + } + return visitNumber(ctx->number()); } -// std::any SysYIRGenerator::visitExp(SysYParser::ExpContext* ctx) { -// cout << "visitExp" << endl; -// return ctx->addExp()->accept(this); -// } std::any SysYIRGenerator::visitNumber(SysYParser::NumberContext *ctx) { - cout << "visitNumber" << endl; - Value* res = nullptr; - if (auto iLiteral = ctx->ILITERAL()) { - /* 基数 (8, 10, 16) */ - const auto text = iLiteral->getText(); - int base = 10; - if (text.find("0x") == 0 || text.find("0X") == 0) { - base = 16; - } else if (text.find("0b") == 0 || text.find("0B") == 0) { - base = 2; - } else if (text.find("0") == 0) { - base = 8; - } - res = ConstantValue::get((int)std::stol(text, 0, base)); - } else if (auto fLiteral = ctx->FLITERAL()) { - const auto text = fLiteral->getText(); - res = ConstantValue::get((float)std::stof(text)); + if (ctx->ILITERAL() != nullptr) { + int value = std::stol(ctx->ILITERAL()->getText(), nullptr, 0); + return static_cast(ConstantValue::get(value)); + } else if (ctx->FLITERAL() != nullptr) { + float value = std::stof(ctx->FLITERAL()->getText()); + return static_cast(ConstantValue::get(value)); } - cout << "number: "; - res->print(cout); - cout << endl; - - return res; + throw std::runtime_error("Unknown number type."); + return std::any(); // 不会到达这里 } -/* - * @brief: visit call - * @details: - * call: Ident LPAREN funcRParams? RPAREN; - */ -std::any SysYIRGenerator::visitCall(SysYParser::CallContext* ctx) { - cout << "visitCall" << endl; - auto funcName = ctx->Ident()->getText(); - auto func = module->getFunction(funcName); - assert(func && "function not found"); - - //需要做类型检查和转换 - std::vector args; - if(ctx->funcRParams()){ - for(auto exp:ctx->funcRParams()->exp()){ - args.push_back(any_cast(exp->accept(this))); +std::any SysYIRGenerator::visitCall(SysYParser::CallContext *ctx) { + std::string funcName = ctx->Ident()->getText(); + Function *function = module->getFunction(funcName); + if (function == nullptr) { + function = module->getExternalFunction(funcName); + if (function == nullptr) { + std::cout << "The function " << funcName << " no defined." << std::endl; + assert(function); } } - Value* call = builder.createCallInst(func, args); - return call; + + std::vector args = {}; + if (funcName == "starttime" || funcName == "stoptime") { + // 如果是starttime或stoptime函数 + // TODO: 这里需要处理starttime和stoptime函数的参数 + // args.emplace_back() + } else { + if (ctx->funcRParams() != nullptr) { + args = std::any_cast>(visitFuncRParams(ctx->funcRParams())); + } + + auto params = function->getEntryBlock()->getArguments(); + for (size_t i = 0; i < args.size(); i++) { + // 参数类型转换 + if (params[i]->getType() != args[i]->getType() && + (params[i]->getNumDims() != 0 || + params[i]->getType()->as()->getBaseType() != args[i]->getType())) { + ConstantValue * constValue = dynamic_cast(args[i]); + if (constValue != nullptr) { + if (params[i]->getType() == Type::getPointerType(Type::getFloatType())) { + args[i] = ConstantValue::get(static_cast(constValue->getInt())); + } else { + args[i] = ConstantValue::get(static_cast(constValue->getFloat())); + } + } else { + if (params[i]->getType() == Type::getPointerType(Type::getFloatType())) { + args[i] = builder.createIToFInst(args[i]); + } else { + args[i] = builder.createFtoIInst(args[i]); + } + } + } + } + } + + return static_cast(builder.createCallInst(function, args)); } -/* - * @brief: visit unexp - * @details: - * unExp: unaryOp unaryExp - */ -std::any SysYIRGenerator::visitUnExp(SysYParser::UnExpContext *ctx) { - cout << "visitUnExp" << endl; - Value* res = nullptr; - auto op = ctx->unaryOp()->getText(); - auto exp = any_cast(ctx->unaryExp()->accept(this)); - if(ctx->unaryOp()->ADD()){ - res = exp; +std::any SysYIRGenerator::visitUnaryExp(SysYParser::UnaryExpContext *ctx) { + if (ctx->primaryExp() != nullptr) + return visitPrimaryExp(ctx->primaryExp()); + if (ctx->call() != nullptr) + return visitCall(ctx->call()); + + Value* value = std::any_cast(visitUnaryExp(ctx->unaryExp())); + Value* result = value; + if (ctx->unaryOp()->SUB() != nullptr) { + ConstantValue * constValue = dynamic_cast(value); + if (constValue != nullptr) { + if (constValue->isFloat()) { + result = ConstantValue::get(-constValue->getFloat()); + } else { + result = ConstantValue::get(-constValue->getInt()); + } + } else if (value != nullptr) { + if (value->getType() == Type::getIntType()) { + result = builder.createNegInst(value); + } else { + result = builder.createFNegInst(value); + } + } else { + std::cout << "UnExp: value is nullptr." << std::endl; + assert(false); + } + } else if (ctx->unaryOp()->NOT() != nullptr) { + auto constValue = dynamic_cast(value); + if (constValue != nullptr) { + if (constValue->isFloat()) { + result = + ConstantValue::get(1 - (constValue->getFloat() != 0.0F ? 1 : 0)); + } else { + result = ConstantValue::get(1 - (constValue->getInt() != 0 ? 1 : 0)); + } + } else if (value != nullptr) { + if (value->getType() == Type::getIntType()) { + result = builder.createNotInst(value); + } else { + result = builder.createFNotInst(value); + } + } else { + std::cout << "UnExp: value is nullptr." << std::endl; + assert(false); + } } - else if(ctx->unaryOp()->SUB()){ - res = builder.createNegInst(exp, exp->getName()); - } - else if(ctx->unaryOp()->NOT()){ - //not将非零值转换为0,零值转换为1 - res = builder.createNotInst(exp, exp->getName()); - } - return res; + return result; } -/* - * @brief: visit mulexp - * @details: - * mulExp: unaryExp ((MUL | DIV | MOD) unaryExp)* - */ +std::any SysYIRGenerator::visitFuncRParams(SysYParser::FuncRParamsContext *ctx) { + std::vector params; + for (const auto &exp : ctx->exp()) + params.push_back(std::any_cast(visitExp(exp))); + return params; +} + + std::any SysYIRGenerator::visitMulExp(SysYParser::MulExpContext *ctx) { - cout << "visitMulExp" << endl; - Value* res = nullptr; - cout << "mulExplhsin\n"; - Value* lhs = any_cast(ctx->unaryExp(0)->accept(this)); - cout << "mulExplhsout\n"; - if(ctx->unaryExp().size() == 1){ - cout << "unaryExp().size() = 1\n"; - res = lhs; - } - else{ - cout << "unaryExp().size() > 1\n"; - for(size_t i = 1; i < ctx->unaryExp().size(); i++){ - Value* rhs = any_cast(ctx->unaryExp(i)->accept(this)); - auto opNode = dynamic_cast(ctx->children[2 * i - 1]); - - if(opNode->getText() == "*"){ - res = builder.createMulInst(lhs, rhs, lhs->getName() + "*" + rhs->getName()); + Value * result = std::any_cast(visitUnaryExp(ctx->unaryExp(0))); + + for (size_t i = 1; i < ctx->unaryExp().size(); i++) { + auto opNode = dynamic_cast(ctx->children[2*i-1]); + int opType = opNode->getSymbol()->getType(); + + Value* operand = std::any_cast(visitUnaryExp(ctx->unaryExp(i))); + + Type* resultType = result->getType(); + Type* operandType = operand->getType(); + Type* floatType = Type::getFloatType(); + + if (resultType == floatType || operandType == floatType) { + // 如果有一个操作数是浮点数,则将两个操作数都转换为浮点数 + if (operandType != floatType) { + ConstantValue * constValue = dynamic_cast(operand); + if (constValue != nullptr) + operand = ConstantValue::get(static_cast(constValue->getInt())); + else + operand = builder.createIToFInst(operand); + } else if (resultType != floatType) { + ConstantValue* constResult = dynamic_cast(result); + if (constResult != nullptr) + result = ConstantValue::get(static_cast(constResult->getInt())); + else + result = builder.createIToFInst(result); } - else if(opNode->getText() == "/"){ - res = builder.createDivInst(lhs, rhs, lhs->getName() + "/" + rhs->getName()); + + ConstantValue* constResult = dynamic_cast(result); + ConstantValue* constOperand = dynamic_cast(operand); + if (opType == SysYParser::MUL) { + if ((constOperand != nullptr) && (constResult != nullptr)) { + result = ConstantValue::get(constResult->getFloat() * + constOperand->getFloat()); + } else { + result = builder.createFMulInst(result, operand); + } + } else if (opType == SysYParser::DIV) { + if ((constOperand != nullptr) && (constResult != nullptr)) { + result = ConstantValue::get(constResult->getFloat() / + constOperand->getFloat()); + } else { + result = builder.createFDivInst(result, operand); + } + } else { + // float类型的取模操作不允许 + std::cout << "MulExp: float type mod operation is not allowed." << std::endl; + assert(false); + } + } else { + ConstantValue * constResult = dynamic_cast(result); + ConstantValue * constOperand = dynamic_cast(operand); + if (opType == SysYParser::MUL) { + if ((constOperand != nullptr) && (constResult != nullptr)) + result = ConstantValue::get(constResult->getInt() * constOperand->getInt()); + else + result = builder.createMulInst(result, operand); + } else if (opType == SysYParser::DIV) { + if ((constOperand != nullptr) && (constResult != nullptr)) + result = ConstantValue::get(constResult->getInt() / constOperand->getInt()); + else + result = builder.createDivInst(result, operand); + } else { + if ((constOperand != nullptr) && (constResult != nullptr)) + result = ConstantValue::get(constResult->getInt() % constOperand->getInt()); + else + result = builder.createRemInst(result, operand); } - else if(opNode->getText() == "%"){ - std::cerr << "mod not implemented yet" << std::endl; - // res = builder.createModInst(lhs, rhs, lhs->getName() + "%" + rhs->getName()); - } } } - return res; + + return result; } -/* - * @brief: visit addexp - * @details: - * addExp: mulExp ((ADD | SUB) mulExp)* - */ + std::any SysYIRGenerator::visitAddExp(SysYParser::AddExpContext *ctx) { - cout << "visitAddExp" << endl; - Value* res = nullptr; - Value* lhs = any_cast(ctx->mulExp(0)->accept(this)); - if(ctx->mulExp().size() == 1){ - cout << "ctx->mulExp().size() = 1\n"; - res = lhs; - } - else{ - for(size_t i = 1; i < ctx->mulExp().size(); i++){ - cout << "i = " << i << "\n"; - Value* rhs = any_cast(ctx->mulExp(i)->accept(this)); - auto opNode = dynamic_cast(ctx->children[2 * i - 1]); - - if(opNode->getText() == "+"){ - res = builder.createAddInst(lhs, rhs, lhs->getName() + "+" + rhs->getName()); + Value* result = std::any_cast(visitMulExp(ctx->mulExp(0))); + + for (size_t i = 1; i < ctx->mulExp().size(); i++) { + auto opNode = dynamic_cast(ctx->children[2*i-1]); + int opType = opNode->getSymbol()->getType(); + + Value* operand = std::any_cast(visitMulExp(ctx->mulExp(i))); + Type* resultType = result->getType(); + Type* operandType = operand->getType(); + Type* floatType = Type::getFloatType(); + + if (resultType == floatType || operandType == floatType) { + // 类型转换 + if (operandType != floatType) { + ConstantValue * constOperand = dynamic_cast(operand); + if (constOperand != nullptr) + operand = ConstantValue::get(static_cast(constOperand->getInt())); + else + operand = builder.createIToFInst(operand); + } else if (resultType != floatType) { + ConstantValue * constResult = dynamic_cast(result); + if (constResult != nullptr) + result = ConstantValue::get(static_cast(constResult->getInt())); + else + result = builder.createIToFInst(result); } - else if(opNode->getText() == "-"){ - res = builder.createSubInst(lhs, rhs, lhs->getName() + "-" + rhs->getName()); + + ConstantValue * constResult = dynamic_cast(result); + ConstantValue * constOperand = dynamic_cast(operand); + if (opType == SysYParser::ADD) { + if ((constResult != nullptr) && (constOperand != nullptr)) + result = ConstantValue::get(constResult->getFloat() + constOperand->getFloat()); + else + result = builder.createFAddInst(result, operand); + } else { + if ((constResult != nullptr) && (constOperand != nullptr)) + result = ConstantValue::get(constResult->getFloat() - constOperand->getFloat()); + else + result = builder.createFSubInst(result, operand); + } + } else { + ConstantValue * constResult = dynamic_cast(result); + ConstantValue * constOperand = dynamic_cast(operand); + if (opType == SysYParser::ADD) { + if ((constResult != nullptr) && (constOperand != nullptr)) + result = ConstantValue::get(constResult->getInt() + constOperand->getInt()); + else + result = builder.createAddInst(result, operand); + } else { + if ((constResult != nullptr) && (constOperand != nullptr)) + result = ConstantValue::get(constResult->getInt() - constOperand->getInt()); + else + result = builder.createSubInst(result, operand); } } - lhs = res; } - return res; + return result; } -/* - * @brief: visit relexp - * @details: - * relExp: addExp ((LT | GT | LE | GE) addExp)* - */ std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) { - cout << "visitRelExp" << endl; - Value* res = nullptr; - Value* lhs = any_cast(ctx->addExp(0)->accept(this)); - if(ctx->addExp().size() == 1){ - res = lhs; - } - else{ - for(size_t i = 1; i < ctx->addExp().size(); i++){ - Value* rhs = any_cast(ctx->addExp(i)->accept(this)); - auto opNode = dynamic_cast(ctx->children[2 * i - 1]); - if(lhs->getType() != rhs->getType()){ - std::cerr << "type mismatch:type check not implemented" << std::endl; - } - Type* type = lhs->getType(); - if(opNode->getText() == "<"){ - if(type->isInt()) - res = builder.createICmpLTInst(lhs, rhs, lhs->getName() + "<" + rhs->getName()); - else if(type->isFloat()) - res = builder.createFCmpLTInst(lhs, rhs, lhs->getName() + "<" + rhs->getName()); - } - else if(opNode->getText() == ">"){ - if(type->isInt()) - res = builder.createICmpGTInst(lhs, rhs, lhs->getName() + ">" + rhs->getName()); - else if(type->isFloat()) - res = builder.createFCmpGTInst(lhs, rhs, lhs->getName() + ">" + rhs->getName()); - } - else if(opNode->getText() == "<="){ - if(type->isInt()) - res = builder.createICmpLEInst(lhs, rhs, lhs->getName() + "<=" + rhs->getName()); - else if(type->isFloat()) - res = builder.createFCmpLEInst(lhs, rhs, lhs->getName() + "<=" + rhs->getName()); - } - else if(opNode->getText() == ">="){ - if(type->isInt()) - res = builder.createICmpGEInst(lhs, rhs, lhs->getName() + ">=" + rhs->getName()); - else if(type->isFloat()) - res = builder.createFCmpGEInst(lhs, rhs, lhs->getName() + ">=" + rhs->getName()); + Value* result = std::any_cast(visitAddExp(ctx->addExp(0))); + + for (size_t i = 1; i < ctx->addExp().size(); i++) { + auto opNode = dynamic_cast(ctx->children[2*i-1]); + int opType = opNode->getSymbol()->getType(); + + Value* operand = std::any_cast(visitAddExp(ctx->addExp(i))); + + Type* resultType = result->getType(); + Type* operandType = operand->getType(); + + ConstantValue* constResult = dynamic_cast(result); + ConstantValue* constOperand = dynamic_cast(operand); + + // 常量比较 + if ((constResult != nullptr) && (constOperand != nullptr)) { + auto operand1 = constResult->isFloat() ? constResult->getFloat() + : constResult->getInt(); + auto operand2 = constOperand->isFloat() ? constOperand->getFloat() + : constOperand->getInt(); + + if (opType == SysYParser::LT) result = ConstantValue::get(operand1 < operand2 ? 1 : 0); + else if (opType == SysYParser::GT) result = ConstantValue::get(operand1 > operand2 ? 1 : 0); + else if (opType == SysYParser::LE) result = ConstantValue::get(operand1 <= operand2 ? 1 : 0); + else if (opType == SysYParser::GE) result = ConstantValue::get(operand1 >= operand2 ? 1 : 0); + else assert(false); + + } else { + Type* resultType = result->getType(); + Type* operandType = operand->getType(); + Type* floatType = Type::getFloatType(); + + // 浮点数处理 + if (resultType == floatType || operandType == floatType) { + if (resultType != floatType) { + if (constResult != nullptr) + result = ConstantValue::get(static_cast(constResult->getInt())); + else + result = builder.createIToFInst(result); + + } + if (operandType != floatType) { + if (constOperand != nullptr) + operand = ConstantValue::get(static_cast(constOperand->getInt())); + else + operand = builder.createIToFInst(operand); + + } + + if (opType == SysYParser::LT) result = builder.createFCmpLTInst(result, operand); + else if (opType == SysYParser::GT) result = builder.createFCmpGTInst(result, operand); + else if (opType == SysYParser::LE) result = builder.createFCmpLEInst(result, operand); + else if (opType == SysYParser::GE) result = builder.createFCmpGEInst(result, operand); + else assert(false); + + } else { + // 整数处理 + if (opType == SysYParser::LT) result = builder.createICmpLTInst(result, operand); + else if (opType == SysYParser::GT) result = builder.createICmpGTInst(result, operand); + else if (opType == SysYParser::LE) result = builder.createICmpLEInst(result, operand); + else if (opType == SysYParser::GE) result = builder.createICmpGEInst(result, operand); + else assert(false); + } } } - return res; + + return result; } -/* - * @brief: visit eqexp - * @details: - * eqExp: relExp ((EQ | NEQ) relExp)* - */ -std::any SysYIRGenerator::visitEqExp(SysYParser::EqExpContext* ctx) { - cout << "visitEqExp" << endl; - Value* res = nullptr; - Value* lhs = any_cast(ctx->relExp(0)->accept(this)); - if(ctx->relExp().size() == 1){ - res = lhs; - } - else{ - for(size_t i = 1; i < ctx->relExp().size(); i++){ - Value* rhs = any_cast(ctx->relExp(i)->accept(this)); - auto opNode = dynamic_cast(ctx->children[2 * i - 1]); - if(lhs->getType() != rhs->getType()){ - std::cerr << "type mismatch:type check not implemented" << std::endl; - } - Type* type = lhs->getType(); - if(opNode->getText() == "=="){ - if(type->isInt()) - res = builder.createICmpEQInst(lhs, rhs, lhs->getName() + "==" + rhs->getName()); - else if(type->isFloat()) - res = builder.createFCmpEQInst(lhs, rhs, lhs->getName() + "==" + rhs->getName()); - } - else if(opNode->getText() == "!="){ - if(type->isInt()) - res = builder.createICmpNEInst(lhs, rhs, lhs->getName() + "!=" + rhs->getName()); - else if(type->isFloat()) - res = builder.createFCmpNEInst(lhs, rhs, lhs->getName() + "!=" + rhs->getName()); + +std::any SysYIRGenerator::visitEqExp(SysYParser::EqExpContext *ctx) { + Value * result = std::any_cast(visitRelExp(ctx->relExp(0))); + + for (size_t i = 1; i < ctx->relExp().size(); i++) { + auto opNode = dynamic_cast(ctx->children[2*i-1]); + int opType = opNode->getSymbol()->getType(); + + Value * operand = std::any_cast(visitRelExp(ctx->relExp(i))); + + ConstantValue* constResult = dynamic_cast(result); + ConstantValue* constOperand = dynamic_cast(operand); + + if ((constResult != nullptr) && (constOperand != nullptr)) { + auto operand1 = constResult->isFloat() ? constResult->getFloat() + : constResult->getInt(); + auto operand2 = constOperand->isFloat() ? constOperand->getFloat() + : constOperand->getInt(); + + if (opType == SysYParser::EQ) result = ConstantValue::get(operand1 == operand2 ? 1 : 0); + else if (opType == SysYParser::NE) result = ConstantValue::get(operand1 != operand2 ? 1 : 0); + else assert(false); + + } else { + Type* resultType = result->getType(); + Type* operandType = operand->getType(); + Type* floatType = Type::getFloatType(); + + if (resultType == floatType || operandType == floatType) { + if (resultType != floatType) { + if (constResult != nullptr) + result = ConstantValue::get(static_cast(constResult->getInt())); + else + result = builder.createIToFInst(result); + } + if (operandType != floatType) { + if (constOperand != nullptr) + operand = ConstantValue::get(static_cast(constOperand->getInt())); + else + operand = builder.createIToFInst(operand); + } + + if (opType == SysYParser::EQ) result = builder.createFCmpEQInst(result, operand); + else if (opType == SysYParser::NE) result = builder.createFCmpNEInst(result, operand); + else assert(false); + + } else { + + if (opType == SysYParser::EQ) result = builder.createICmpEQInst(result, operand); + else if (opType == SysYParser::NE) result = builder.createICmpNEInst(result, operand); + else assert(false); + } } } - return res; + + if (ctx->relExp().size() == 1) { + ConstantValue * constResult = dynamic_cast(result); + // 如果只有一个关系表达式,则将结果转换为0或1 + if (constResult != nullptr) { + if (constResult->isFloat()) + result = ConstantValue::get(constResult->getFloat() != 0.0F ? 1 : 0); + else + result = ConstantValue::get(constResult->getInt() != 0 ? 1 : 0); + } + } + + return result; } -/* - * @brief: visit lAndexp - * @details: - * lAndExp: eqExp (AND eqExp)* - */ -std::any SysYIRGenerator::visitLAndExp(SysYParser::LAndExpContext* ctx) { - cout << "visitLAndExp" << endl; - auto currentBlock = builder.getBasicBlock(); - // auto trueBlock = currentBlock->getParent()->addBasicBlock("trueland" + std::to_string(++trueBlockNum)); - auto falseBlock = currentBlock->getParent()->addBasicBlock("falseland" + std::to_string(++falseBlockNum)); - Value* value = any_cast(ctx->eqExp(0)->accept(this)); - auto trueBlock = currentBlock; - for(size_t i = 1; i < ctx->eqExp().size(); i++){ - trueBlock = trueBlock->getParent()->addBasicBlock("trueland" + std::to_string(++trueBlockNum)); - builder.createCondBrInst(value, currentBlock, falseBlock, {}, {}); - builder.setPosition(trueBlock, trueBlock->end()); - value = any_cast(ctx->eqExp(i)->accept(this)); +std::any SysYIRGenerator::visitLAndExp(SysYParser::LAndExpContext *ctx){ + std::stringstream labelstring; + BasicBlock *curBlock = builder.getBasicBlock(); + Function *function = builder.getBasicBlock()->getParent(); + BasicBlock *trueBlock = builder.getTrueBlock(); + BasicBlock *falseBlock = builder.getFalseBlock(); + auto conds = ctx->eqExp(); + + for (size_t i = 0; i < conds.size() - 1; i++) { + + labelstring << "AND.L" << builder.getLabelIndex(); + BasicBlock *newtrueBlock = function->addBasicBlock(labelstring.str()); + labelstring.str(""); + + auto cond = std::any_cast(visitEqExp(ctx->eqExp(i))); + builder.createCondBrInst(cond, newtrueBlock, falseBlock, {}, {}); + + BasicBlock::conectBlocks(curBlock, newtrueBlock); + BasicBlock::conectBlocks(curBlock, falseBlock); + + curBlock = newtrueBlock; + builder.setPosition(curBlock, curBlock->end()); } - //结构trueblk条件跳转到falseblk - // - trueBlock1->trueBlock2->trueBlock3->...->trueBlockn->nextblk - //entry-| - // -falseBlock->nextblk - //需要在最后一个trueblock的末尾加上无条件跳转到下一个基本块的指令 - // builder.createCondBrInst(value, trueBlock, falseBlock, {}, {}); + + auto cond = std::any_cast(visitEqExp(conds.back())); + builder.createCondBrInst(cond, trueBlock, falseBlock, {}, {}); + + BasicBlock::conectBlocks(curBlock, trueBlock); + BasicBlock::conectBlocks(curBlock, falseBlock); + return std::any(); } -/* - * @brief: visit lOrexp - * @details: - * lOrExp: lAndExp (OR lAndExp)* - */ -std::any SysYIRGenerator::visitLOrExp(SysYParser::LOrExpContext* ctx) { - cout << "visitLOrExp" << endl; - auto currentBlock = builder.getBasicBlock(); - auto trueBlock = currentBlock->getParent()->addBasicBlock("trueland" + std::to_string(++trueBlockNum)); - Value* value = any_cast(ctx->lAndExp(0)->accept(this)); - auto falseBlock = currentBlock; - for(size_t i = 1; i < ctx->lAndExp().size(); i++){ - falseBlock = currentBlock->getParent()->addBasicBlock("falseland" + std::to_string(++falseBlockNum)); - builder.createCondBrInst(value, trueBlock, falseBlock, {}, {}); - builder.setPosition(falseBlock, falseBlock->end()); - value = any_cast(ctx->lAndExp(i)->accept(this)); +auto SysYIRGenerator::visitLOrExp(SysYParser::LOrExpContext *ctx) -> std::any { + std::stringstream labelstring; + BasicBlock *curBlock = builder.getBasicBlock(); + Function *function = curBlock->getParent(); + auto conds = ctx->lAndExp(); + + for (size_t i = 0; i < conds.size() - 1; i++) { + labelstring << "OR.L" << builder.getLabelIndex(); + BasicBlock *newFalseBlock = function->addBasicBlock(labelstring.str()); + labelstring.str(""); + + builder.pushFalseBlock(newFalseBlock); + visitLAndExp(ctx->lAndExp(i)); + builder.popFalseBlock(); + + builder.setPosition(newFalseBlock, newFalseBlock->end()); } - //结构trueblk条件跳转到falseblk - // - falseBlock1->falseBlock2->falseBlock3->...->falseBlockn->nextblk - //entry-| - // -trueBlock->nextblk - //需要在最后一个falseblock的末尾加上无条件跳转到下一个基本块的指令 - // builder.createCondBrInst(value, trueBlock, falseBlock, {}, {}); + + visitLAndExp(conds.back()); + return std::any(); } -/* - * @brief: visit constexp - * @details: - * constExp: addExp; - */ -std::any SysYIRGenerator::visitConstExp(SysYParser::ConstExpContext* ctx) { - cout << "visitConstExp" << endl; - ConstantValue* res = nullptr; - Value* value = any_cast(ctx->addExp()->accept(this)); - if(isa(value)){ - res = dyncast(value); - } - else{ - std::cerr << "error constexp" << ctx->getText() << std::endl; - } - return res; -} - -/* begin -std::any SysYIRGenerator::visitConstGlobalDecl(SysYParser::ConstDeclContext *ctx, Type* type) { - std::vector values; - for (auto constDef : ctx->constDef()) { - - auto name = constDef->Ident()->getText(); - // get its dimensions - vector dims; - for (auto dim : constDef->constExp()) - dims.push_back(any_cast(dim->accept(this))); - - if (dims.size() == 0) { - auto init = constDef->ASSIGN() ? any_cast((constDef->constInitVal()->constExp()->accept(this))) - : nullptr; - if (init && isa(init)){ - Type *btype = type->as()->getBaseType(); - if (btype->isInt() && init->getType()->isFloat()) - init = ConstantValue::get((int)dynamic_cast(init)->getFloat()); - else if (btype->isFloat() && init->getType()->isInt()) - init = ConstantValue::get((float)dynamic_cast(init)->getInt()); - } - - auto global_value = module->createGlobalValue(name, type, dims, init); - - symbols_table.insert(name, global_value); - values.push_back(global_value); - } - else{ - auto init = constDef->ASSIGN() ? any_cast(dims[0]) - : nullptr; - auto global_value = module->createGlobalValue(name, type, dims, init); - if (constDef->ASSIGN()) { - d = 0; - n = 0; - path.clear(); - path = vector(dims.size(), 0); - isalloca = false; - current_type = global_value->getType()->as()->getBaseType(); - current_global = global_value; - numdims = global_value->getNumDims(); - for (auto init : constDef->constInitVal()->constInitVal()) - init->accept(this); - // visitConstInitValue(init); - } - symbols_table.insert(name, global_value); - values.push_back(global_value); - } - } - return values; -} - -std::any SysYIRGenerator::visitVarGlobalDecl(SysYParser::VarDeclContext *ctx, Type* type){ - std::vector values; - for (auto varDef : ctx->varDef()) { - - auto name = varDef->Ident()->getText(); - // get its dimensions - vector dims; - for (auto dim : varDef->constExp()) - dims.push_back(any_cast(dim->accept(this))); - - if (dims.size() == 0) { - auto init = varDef->ASSIGN() ? any_cast((varDef->initVal()->exp()->accept(this))) - : nullptr; - if (init && isa(init)){ - Type *btype = type->as()->getBaseType(); - if (btype->isInt() && init->getType()->isFloat()) - init = ConstantValue::get((int)dynamic_cast(init)->getFloat()); - else if (btype->isFloat() && init->getType()->isInt()) - init = ConstantValue::get((float)dynamic_cast(init)->getInt()); - } - - auto global_value = module->createGlobalValue(name, type, dims, init); - - symbols_table.insert(name, global_value); - values.push_back(global_value); - } - else{ - auto init = varDef->ASSIGN() ? any_cast(dims[0]) - : nullptr; - auto global_value = module->createGlobalValue(name, type, dims, init); - if (varDef->ASSIGN()) { - d = 0; - n = 0; - path.clear(); - path = vector(dims.size(), 0); - isalloca = false; - current_type = global_value->getType()->as()->getBaseType(); - current_global = global_value; - numdims = global_value->getNumDims(); - for (auto init : varDef->initVal()->initVal()) - init->accept(this); - // visitInitValue(init); - } - symbols_table.insert(name, global_value); - values.push_back(global_value); - } - } - return values; -} - -std::any SysYIRGenerator::visitConstLocalDecl(SysYParser::ConstDeclContext *ctx, Type* type){ - std::vector values; - // handle variables - for (auto constDef : ctx->constDef()) { - - auto name = constDef->Ident()->getText(); - vector dims; - for (auto dim : constDef->constExp()) - dims.push_back(any_cast(dim->accept(this))); - auto alloca = builder.createAllocaInst(type, dims, name); - symbols_table.insert(name, alloca); - - if (constDef->ASSIGN()) { - if (alloca->getNumDims() == 0) { - - auto value = any_cast(constDef->constInitVal()->constExp()->accept(this)); +void Utils::tree2Array(Type *type, ArrayValueTree *root, + const std::vector &dims, unsigned numDims, + ValueCounter &result, IRBuilder *builder) { + Value* value = root->getValue(); + auto &children = root->getChildren(); + if (value != nullptr) { + if (type == value->getType()) { + result.push_back(value); + } else { + if (type == Type::getFloatType()) { + ConstantValue* constValue = dynamic_cast(value); + if (constValue != nullptr) + result.push_back(ConstantValue::get(static_cast(constValue->getInt()))); + else + result.push_back(builder->createIToFInst(value)); - if (isa(value)) { - if (ctx->bType()->INT() && dynamic_cast(value)->isFloat()) - value = ConstantValue::get((int)dynamic_cast(value)->getFloat()); - else if (ctx->bType()->FLOAT() && dynamic_cast(value)->isInt()) - value = ConstantValue::get((float)dynamic_cast(value)->getInt()); - } - else if (alloca->getType()->as()->getBaseType()->isInt() && value->getType()->isFloat()) - value = builder.createFtoIInst(value); - else if (alloca->getType()->as()->getBaseType()->isFloat() && value->getType()->isInt()) - value = builder.createIToFInst(value); + } else { + ConstantValue* constValue = dynamic_cast(value); + if (constValue != nullptr) + result.push_back(ConstantValue::get(static_cast(constValue->getFloat()))); + else + result.push_back(builder->createFtoIInst(value)); - auto store = builder.createStoreInst(value, alloca); - } - else{ - d = 0; - n = 0; - path.clear(); - path = vector(alloca->getNumDims(), 0); - isalloca = true; - current_alloca = alloca; - current_type = alloca->getType()->as()->getBaseType(); - numdims = alloca->getNumDims(); - for (auto init : constDef->constInitVal()->constInitVal()) - init->accept(this); } } - - values.push_back(alloca); + return; } - return values; -} -std::any SysYIRGenerator::visitVarLocalDecl(SysYParser::VarDeclContext *ctx, Type* type){ - std::vector values; - for (auto varDef : ctx->varDef()) { - - auto name = varDef->Ident()->getText(); - vector dims; - for (auto dim : varDef->constExp()) - dims.push_back(any_cast(dim->accept(this))); - auto alloca = builder.createAllocaInst(type, dims, name); - symbols_table.insert(name, alloca); - - if (varDef->ASSIGN()) { - if (alloca->getNumDims() == 0) { - - auto value = any_cast(varDef->initVal()->exp()->accept(this)); - - if (isa(value)) { - if (ctx->bType()->INT() && dynamic_cast(value)->isFloat()) - value = ConstantValue::get((int)dynamic_cast(value)->getFloat()); - else if (ctx->bType()->FLOAT() && dynamic_cast(value)->isInt()) - value = ConstantValue::get((float)dynamic_cast(value)->getInt()); - } - else if (alloca->getType()->as()->getBaseType()->isInt() && value->getType()->isFloat()) - value = builder.createFtoIInst(value); - else if (alloca->getType()->as()->getBaseType()->isFloat() && value->getType()->isInt()) - value = builder.createIToFInst(value); - - auto store = builder.createStoreInst(value, alloca); - } - else{ - d = 0; - n = 0; - path.clear(); - path = vector(alloca->getNumDims(), 0); - isalloca = true; - current_alloca = alloca; - current_type = alloca->getType()->as()->getBaseType(); - numdims = alloca->getNumDims(); - for (auto init : varDef->initVal()->initVal()) - init->accept(this); + auto beforeSize = result.size(); + for (const auto &child : children) { + int begin = result.size(); + int newNumDims = 0; + for (unsigned i = 0; i < numDims - 1; i++) { + auto dim = dynamic_cast(*(dims.rbegin() + i))->getInt(); + if (begin % dim == 0) { + newNumDims += 1; + begin /= dim; + } else { + break; } } - - values.push_back(alloca); + tree2Array(type, child.get(), dims, newNumDims, result, builder); + } + auto afterSize = result.size(); + + int blockSize = 1; + for (unsigned i = 0; i < numDims; i++) { + blockSize *= dynamic_cast(*(dims.rbegin() + i))->getInt(); + } + + int num = blockSize - afterSize + beforeSize; + if (num > 0) { + if (type == Type::getFloatType()) + result.push_back(ConstantValue::get(0.0F), num); + else + result.push_back(ConstantValue::get(0), num); } - return values; } - end -*/ + +void Utils::createExternalFunction( + const std::vector ¶mTypes, + const std::vector ¶mNames, + const std::vector> ¶mDims, Type *returnType, + const std::string &funcName, Module *pModule, IRBuilder *pBuilder) { + auto funcType = Type::getFunctionType(returnType, paramTypes); + auto function = pModule->createExternalFunction(funcName, funcType); + auto entry = function->getEntryBlock(); + pBuilder->setPosition(entry, entry->end()); + + for (size_t i = 0; i < paramTypes.size(); ++i) { + auto alloca = pBuilder->createAllocaInst( + Type::getPointerType(paramTypes[i]), paramDims[i], paramNames[i]); + entry->insertArgument(alloca); + // pModule->addVariable(paramNames[i], alloca); + } +} + +void Utils::initExternalFunction(Module *pModule, IRBuilder *pBuilder) { + std::vector paramTypes; + std::vector paramNames; + std::vector> paramDims; + Type *returnType; + std::string funcName; + + returnType = Type::getIntType(); + funcName = "getint"; + Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, + funcName, pModule, pBuilder); + funcName = "getch"; + Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, + funcName, pModule, pBuilder); + paramTypes.push_back(Type::getIntType()); + paramNames.emplace_back("x"); + paramDims.push_back(std::vector{ConstantValue::get(-1)}); + funcName = "getarray"; + Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, + funcName, pModule, pBuilder); + + returnType = Type::getFloatType(); + paramTypes.clear(); + paramNames.clear(); + paramDims.clear(); + funcName = "getfloat"; + Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, + funcName, pModule, pBuilder); + + returnType = Type::getIntType(); + paramTypes.push_back(Type::getFloatType()); + paramNames.emplace_back("x"); + paramDims.push_back(std::vector{ConstantValue::get(-1)}); + funcName = "getfarray"; + Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, + funcName, pModule, pBuilder); + + returnType = Type::getVoidType(); + + paramTypes.clear(); + paramTypes.push_back(Type::getIntType()); + paramDims.clear(); + paramDims.emplace_back(); + funcName = "putint"; + Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, + funcName, pModule, pBuilder); + + funcName = "putch"; + Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, + funcName, pModule, pBuilder); + + paramTypes.clear(); + paramTypes.push_back(Type::getIntType()); + paramTypes.push_back(Type::getIntType()); + paramDims.clear(); + paramDims.emplace_back(); + paramDims.push_back(std::vector{ConstantValue::get(-1)}); + paramNames.clear(); + paramNames.emplace_back("n"); + paramNames.emplace_back("a"); + funcName = "putarray"; + Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, + funcName, pModule, pBuilder); + + paramTypes.clear(); + paramTypes.push_back(Type::getFloatType()); + paramDims.clear(); + paramDims.emplace_back(); + paramNames.clear(); + paramNames.emplace_back("a"); + funcName = "putfloat"; + Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, + funcName, pModule, pBuilder); + + paramTypes.clear(); + paramTypes.push_back(Type::getIntType()); + paramTypes.push_back(Type::getFloatType()); + paramDims.clear(); + paramDims.emplace_back(); + paramDims.push_back(std::vector{ConstantValue::get(-1)}); + paramNames.clear(); + paramNames.emplace_back("n"); + paramNames.emplace_back("a"); + funcName = "putfarray"; + Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, + funcName, pModule, pBuilder); + + paramTypes.clear(); + paramTypes.push_back(Type::getIntType()); + paramDims.clear(); + paramDims.emplace_back(); + paramNames.clear(); + paramNames.emplace_back("__LINE__"); + funcName = "starttime"; + Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, + funcName, pModule, pBuilder); + + paramTypes.clear(); + paramTypes.push_back(Type::getIntType()); + paramDims.clear(); + paramDims.emplace_back(); + paramNames.clear(); + paramNames.emplace_back("__LINE__"); + funcName = "stoptime"; + Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, + funcName, pModule, pBuilder); + +} } // namespace sysy \ No newline at end of file diff --git a/src/SysYIRGenerator.h b/src/SysYIRGenerator.h deleted file mode 100644 index 3c89ce0..0000000 --- a/src/SysYIRGenerator.h +++ /dev/null @@ -1,149 +0,0 @@ -#pragma once -#include "IR.h" -#include "IRBuilder.h" -#include "SysYBaseVisitor.h" -#include "SysYParser.h" -#include -#include -#include - -namespace sysy { - -class SymbolTable{ -private: - enum Kind - { - kModule, - kFunction, - kBlock, - }; - - std::forward_list>> Scopes; - -public: - struct ModuleScope { - SymbolTable& tables_ref; - ModuleScope(SymbolTable& tables) : tables_ref(tables) { - tables.enter(kModule); - } - ~ModuleScope() { tables_ref.exit(); } - }; - struct FunctionScope { - SymbolTable& tables_ref; - FunctionScope(SymbolTable& tables) : tables_ref(tables) { - tables.enter(kFunction); - } - ~FunctionScope() { tables_ref.exit(); } - }; - struct BlockScope { - SymbolTable& tables_ref; - BlockScope(SymbolTable& tables) : tables_ref(tables) { - tables.enter(kBlock); - } - ~BlockScope() { tables_ref.exit(); } - }; - - SymbolTable() = default; - - bool isModuleScope() const { return Scopes.front().first == kModule; } - bool isFunctionScope() const { return Scopes.front().first == kFunction; } - bool isBlockScope() const { return Scopes.front().first == kBlock; } - Value *lookup(const std::string &name) const { - for (auto &scope : Scopes) { - auto iter = scope.second.find(name); - if (iter != scope.second.end()) - return iter->second; - } - return nullptr; - } - auto insert(const std::string &name, Value *value) { - assert(not Scopes.empty()); - return Scopes.front().second.emplace(name, value); - } -private: - void enter(Kind kind) { - Scopes.emplace_front(); - Scopes.front().first = kind; - } - void exit() { - Scopes.pop_front(); - } - -}; - -class SysYIRGenerator : public SysYBaseVisitor { -private: - std::unique_ptr module; - IRBuilder builder; - SymbolTable symbols_table; - - int trueBlockNum = 0, falseBlockNum = 0; - - int d = 0, n = 0; - vector path; - bool isalloca; - AllocaInst *current_alloca; - GlobalValue *current_global; - Type* current_type; - int numdims = 0; - -public: - SysYIRGenerator() = default; - -public: - Module *get() const { return module.get(); } - -public: - std::any visitCompUnit(SysYParser::CompUnitContext *ctx) override; - std::any visitDecl(SysYParser::DeclContext *ctx) override; - std::any visitConstDecl(SysYParser::ConstDeclContext *ctx) override; - std::any visitBType(SysYParser::BTypeContext *ctx) override; - std::any visitConstDef(SysYParser::ConstDefContext *ctx) override; - std::any visitConstInitVal(SysYParser::ConstInitValContext *ctx) override; - std::any visitFuncType(SysYParser::FuncTypeContext* ctx) override; - std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override; - std::any visitVarDecl(SysYParser::VarDeclContext *ctx) override; - std::any visitVarDef(SysYParser::VarDefContext *ctx) override; - std::any visitInitVal(SysYParser::InitValContext *ctx) override; - // std::any visitFuncFParam(SysYParser::FuncFParamContext *ctx) override; - // std::any visitFuncFParams(SysYParser::FuncFParamsContext *ctx) override; - std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override; - // std::any visitStmt(SysYParser::StmtContext *ctx) override; - std::any visitAssignStmt(SysYParser::AssignStmtContext *ctx) override; - std::any visitIfStmt(SysYParser::IfStmtContext *ctx) override; - std::any visitWhileStmt(SysYParser::WhileStmtContext *ctx) override; - std::any visitBreakStmt(SysYParser::BreakStmtContext *ctx) override; - std::any visitContinueStmt(SysYParser::ContinueStmtContext *ctx) override; - std::any visitReturnStmt(SysYParser::ReturnStmtContext *ctx) override; - // std::any visitExp(SysYParser::ExpContext *ctx) override; - std::any visitLValue(SysYParser::LValueContext *ctx) override; - std::any visitPrimExp(SysYParser::PrimExpContext *ctx) override; - // std::any visitParenExp(SysYParser::ParenExpContext *ctx) override; - std::any visitNumber(SysYParser::NumberContext *ctx) override; - // std::any visitString(SysYParser::StringContext *ctx) override; - std::any visitCall(SysYParser::CallContext *ctx) override; - // std::any visitUnaryExp(SysYParser::UnaryExpContext *ctx) override; - // std::any visitUnaryOp(SysYParser::UnaryOpContext *ctx) override; - std::any visitUnExp(SysYParser::UnExpContext *ctx) override; - // std::any visitFuncRParams(SysYParser::FuncRParamsContext *ctx) override; - std::any visitMulExp(SysYParser::MulExpContext *ctx) override; - std::any visitAddExp(SysYParser::AddExpContext *ctx) override; - std::any visitRelExp(SysYParser::RelExpContext *ctx) override; - std::any visitEqExp(SysYParser::EqExpContext *ctx) override; - std::any visitLAndExp(SysYParser::LAndExpContext *ctx) override; - std::any visitLOrExp(SysYParser::LOrExpContext *ctx) override; - std::any visitConstExp(SysYParser::ConstExpContext *ctx) override; - -private: - std::any visitConstGlobalDecl(SysYParser::ConstDeclContext *ctx, Type* type); - std::any visitVarGlobalDecl(SysYParser::VarDeclContext *ctx, Type* type); - std::any visitConstLocalDecl(SysYParser::ConstDeclContext *ctx, Type* type); - std::any visitVarLocalDecl(SysYParser::VarDeclContext *ctx, Type* type); - Type *getArithmeticResultType(Type *lhs, Type *rhs) { - assert(lhs->isIntOrFloat() and rhs->isIntOrFloat()); - return lhs == rhs ? lhs : Type::getFloatType(); - } - -}; // class SysYIRGenerator - -} // namespace sysy \ No newline at end of file diff --git a/src/Backend.h b/src/include/Backend.h similarity index 100% rename from src/Backend.h rename to src/include/Backend.h diff --git a/src/include/IR.h b/src/include/IR.h new file mode 100644 index 0000000..3182a9a --- /dev/null +++ b/src/include/IR.h @@ -0,0 +1,1711 @@ +#pragma once + +#include "range.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace sysy { +/** + * \defgroup type Types + * @brief Sysy的类型系统 + * + * 1. 基类`Type` 用来表示所有的原始标量类型, + * 包括 `int`, `float`, `void`, 和表示跳转目标的标签类型。 + * 2. `PointerType` 和 `FunctionType` 派生自`Type` 并且分别表示指针和函数类型。 + * + * \note `Type`和它的派生类的构造函数声明为'protected'. + * 用户必须使用Type::getXXXType()获得`Type` 指针。 + * @{ + */ + +/** + * + * `Type`用来表示所有的原始标量类型, + * 包括`int`, `float`, `void`, 和表示跳转目标的标签类型。 + */ + +class Type { + public: + /// 定义了原始标量类型种类的枚举类型 + enum Kind { + kInt, + kFloat, + kVoid, + kLabel, + kPointer, + kFunction, + }; + + Kind kind; ///< 表示具体类型的变量 + + protected: + explicit Type(Kind kind) : kind(kind) {} + virtual ~Type() = default; + + public: + static Type* getIntType(); ///< 返回表示Int类型的Type指针 + static Type* getFloatType(); ///< 返回表示Float类型的Type指针 + static Type* getVoidType(); ///< 返回表示Void类型的Type指针 + static Type* getLabelType(); ///< 返回表示Label类型的Type指针 + static Type* getPointerType(Type *baseType); ///< 返回表示指向baseType类型的Pointer类型的Type指针 + static Type* getFunctionType(Type *returnType, const std::vector ¶mTypes = {}); + ///< 返回表示返回类型为returnType,形参类型列表为paramTypes的函数类型的Type指针 + + public: + Kind getKind() const { return kind; } ///< 返回Type对象代表原始标量类型 + bool isInt() const { return kind == kInt; } ///< 判定是否为Int类型 + bool isFloat() const { return kind == kFloat; } ///< 判定是否为Float类型 + bool isVoid() const { return kind == kVoid; } ///< 判定是否为Void类型 + bool isLabel() const { return kind == kLabel; } ///< 判定是否为Label类型 + bool isPointer() const { return kind == kPointer; } ///< 判定是否为Pointer类型 + bool isFunction() const { return kind == kFunction; } ///< 判定是否为Function类型 + unsigned getSize() const; ///< 返回类型所占的空间大小(字节) + /// 尝试将一个变量转换为给定的Type及其派生类类型的变量 + template + auto as() const -> std::enable_if_t, T *> { + return dynamic_cast(const_cast(this)); + } +}; + +class PointerType : public Type { + protected: + Type *baseType; ///< 所指向的类型 + + protected: + explicit PointerType(Type *baseType) : Type(kPointer), baseType(baseType) {} + + public: + static PointerType* get(Type *baseType); ///< 获取指向baseType的Pointer类型 + + public: + Type* getBaseType() const { return baseType; } ///< 获取指向的类型 +}; + +class FunctionType : public Type { + private: + Type *returnType; ///< 返回值类型 + std::vector paramTypes; ///< 形参类型列表 + + protected: + explicit FunctionType(Type *returnType, std::vector paramTypes = {}) + : Type(kFunction), returnType(returnType), paramTypes(std::move(paramTypes)) {} + + public: + /// 获取返回值类型为returnType, 形参类型列表为paramTypes的Function类型 + static FunctionType* get(Type *returnType, const std::vector ¶mTypes = {}); + + public: + Type* getReturnType() const { return returnType; } ///< 获取返回值类信息 + auto getParamTypes() const { return make_range(paramTypes); } ///< 获取形参类型列表 + unsigned getNumParams() const { return paramTypes.size(); } ///< 获取形参数量 +}; + +/*! + * @} + */ + +/** + * \defgroup ir IR + * + * TSysY IR 是一种指令级别的语言. 它被组织为四层树型结构,如下所示: + * + * \dot IR Structure + * digraph IRStructure{ + * node [shape="box"] + * a [label="Module"] + * b [label="GlobalValue"] + * c [label="Function"] + * d [label="BasicBlock"] + * e [label="Instruction"] + * a->{b,c} + * c->d->e + * } + * + * \enddot + * + * - `Module` 对应顶层"CompUnit"语法结构 + * - `GlobalValue`对应"globalDecl"语法结构 + * - `Function`对应"FuncDef"语法结构 + * - `BasicBlock` 是一连串没有分支指令的指令。一个 `Function` + * 由一个或多个`BasicBlock`组成 + * - `Instruction` 表示一个原始指令,例如, add 或 sub + * + * SysY IR中基础的数据概念是`Value`。一个 `Value` 像 + * 一个寄存器。它充当`Instruction`的输入输出操作数。每个value + * 都有一个与之相联系的`Type`,以此说明Value所拥有值的类型。 + * + * 大多数`Instruction`具有三地址代码的结构, 例如, 最多拥有两个输入操作数和一个输出操作数。 + * + * SysY IR采用了Static-Single-Assignment (SSA)设计。`Value`作为一个输出操作数 + * 被一些指令所定义, 并被另一些指令当作输入操作数使用。尽管一个Value可以被多个指令使用,其 + * 定义只能发生一次。这导致一个value个定义它的指令存在一一对应关系。换句话说,任何定义一个Value的指令 + * 都可以被看作被定义的指令本身。故在SysY IR中,`Instruction` 也是一个`Value`。查看 `Value` 以获取其继承 + * 关系。 + * + * @{ + */ + + +class User; +class Value; +class AllocaInst; + +//! `Use` 表示`Value`和它的`User`之间的使用关系。 + +class Use { + private: + /** + * value在User操作数中的位置,例如, + * user->getOperands[index] == value + */ + unsigned index; + User *user; ///< 使用者 + Value *value; ///< 被使用的值 + + public: + Use() = default; + Use(unsigned index, User *user, Value *value) : index(index), user(user), value(value) {} + + public: + unsigned getIndex() const { return index; } ///< 返回value在User操作数中的位置 + User* getUser() const { return user; } ///< 返回使用者 + Value* getValue() const { return value; } ///< 返回被使用的值 + void setValue(Value *newValue) { value = newValue; } ///< 将被使用的值设置为newValue +}; + +//! The base class of all value types + +class Value { + protected: + Type *type; ///< 值的类型 + std::string name; ///< 值的名字 + std::list> uses; ///< 值的使用关系列表 + + protected: + explicit Value(Type *type, std::string name = "") : type(type), name(std::move(name)) {} + virtual ~Value() = default; + + public: + void setName(const std::string &newName) { name = newName; } ///< 设置名字 + const std::string& getName() const { return name; } ///< 获取名字 + Type* getType() const { return type; } ///< 返回值的类型 + bool isInt() const { return type->isInt(); } ///< 判定是否为Int类型 + bool isFloat() const { return type->isFloat(); } ///< 判定是否为Float类型 + bool isPointer() const { return type->isPointer(); } ///< 判定是否为Pointer类型 + std::list>& getUses() { return uses; } ///< 获取使用关系列表 + void addUse(const std::shared_ptr &use) { uses.push_back(use); } ///< 添加使用关系 + void replaceAllUsesWith(Value *value); ///< 将原来使用该value的使用者全变为使用给定参数value并修改相应use关系 + void removeUse(const std::shared_ptr &use) { uses.remove(use); } ///< 删除使用关系use +}; + +/** + * ValueCounter 需要理解为一个Value *的计数器。 + * 它的主要目的是为了节省存储空间和方便Memset指令的创建。 + * ValueCounter记录了一列Value *的互异元素和每个元素的重复数量。 + * 例如,假设有一列Value *为{v1, v1, v2, v3, v3, v3, v4}, + * 那么ValueCounter将记录为: + * - __counterValues: {v1, v2, v3, v4} + * - __counterNumbers: {2, 1, 3, 1} + * - __size: 7 + * 使得存储空间得到节省,方便Memset指令的创建。 + */ +class ValueCounter { + private: + unsigned __size{}; ///< 总的Value数量 + std::vector __counterValues; ///< 记录的Value *列表(无重复元素) + std::vector __counterNumbers; ///< 记录的Value *重复数量列表 + + public: + ValueCounter() = default; + + public: + unsigned size() const { return __size; } ///< 返回总的Value数量 + Value* getValue(unsigned index) const { + if (index >= __size) { + return nullptr; + } + + unsigned num = 0; + for (size_t i = 0; i < __counterNumbers.size(); i++) { + if (num <= index && index < num + __counterNumbers[i]) { + return __counterValues[i]; + } + num += __counterNumbers[i]; + } + + return nullptr; + } ///< 根据位置index获取Value * + const std::vector& getValues() const { return __counterValues; } ///< 获取互异Value *列表 + const std::vector& getNumbers() const { return __counterNumbers; } ///< 获取Value *重复数量列表 + void push_back(Value *value, unsigned num = 1) { + if (__size != 0 && __counterValues.back() == value) { + *(__counterNumbers.end() - 1) += num; + } else { + __counterValues.push_back(value); + __counterNumbers.push_back(num); + } + __size += num; + } ///< 向后插入num个value + void clear() { + __size = 0; + __counterValues.clear(); + __counterNumbers.clear(); + } ///< 清空ValueCounter +}; + +/*! + * Static constants known at compile time. + * + * `ConstantValue`s are not defined by instructions, and do not use any other + * `Value`s. It's type is either `int` or `float`. + * `ConstantValue`并不由指令定义, 也不使用任何Value。它的类型为int/float。 + */ + + +class ConstantValue : public Value { + protected: + /// 定义字面量类型的聚合类型 + union { + int iScalar; + float fScalar; + }; + + protected: + explicit ConstantValue(int value, const std::string &name = "") : Value(Type::getIntType(), name), iScalar(value) {} + explicit ConstantValue(float value, const std::string &name = "") + : Value(Type::getFloatType(), name), fScalar(value) {} + + public: + static ConstantValue* get(int value); ///< 获取一个int类型的ConstValue *,其值为value + static ConstantValue* get(float value); ///< 获取一个float类型的ConstValue *,其值为value + + public: + int getInt() const { + assert(isInt()); + return iScalar; + } ///< 返回int类型的值 + float getFloat() const { + assert(isFloat()); + return fScalar; + } ///< 返回float类型的值 + template + T getValue() const { + if (std::is_same::value && isInt()) { + return getInt(); + } + if (std::is_same::value && isFloat()) { + return getFloat(); + } + throw std::bad_cast(); // 或者其他适当的异常处理 + } ///< 返回值,getInt和getFloat统一化,整数返回整形,浮点返回浮点型 +}; + +class Instruction; +class Function; +class Loop; +class BasicBlock; + +/*! + * The container for `Instruction` sequence. + * + * `BasicBlock` maintains a list of `Instruction`s, with the last one being + * a terminator (branch or return). Besides, `BasicBlock` stores its arguments + * and records its predecessor and successor `BasicBlock`s. + */ + + class BasicBlock : public Value { + friend class Function; + + public: + using inst_list = std::list>; + using iterator = inst_list::iterator; + using arg_list = std::vector; + using block_list = std::vector; + using block_set = std::unordered_set; + + protected: + Function *parent; ///< 从属的函数 + inst_list instructions; ///< 拥有的指令序列 + arg_list arguments; ///< 分配空间后的形式参数列表 + block_list successors; ///< 前驱列表 + block_list predecessors; ///< 后继列表 + BasicBlock *idom = nullptr; ///< 直接支配结点,即支配树前驱,唯一,默认nullptr + block_list sdoms; ///< 支配树后继,可以有多个 + block_set dominants; ///< 必经结点集合 + block_set dominant_frontiers; ///< 支配边界 + bool reachable = false; ///< 用于表示该节点是否可达,默认不可达 + Loop *loopbelong = nullptr; ///< 用来表示该块属于哪个循环,唯一,默认nullptr + int loopdepth = 0; /// < 用来表示其归属循环的深度,默认0 + + public: + explicit BasicBlock(Function *parent, const std::string &name = "") + : Value(Type::getLabelType(), name), parent(parent) {} + + ~BasicBlock() override { + for (auto pre : predecessors) { + pre->removeSuccessor(this); + } + + for (auto suc : successors) { + suc->removePredecessor(this); + } + } ///< 基本块的析构函数,同时删除其前驱后继关系 + + public: + unsigned getNumInstructions() const { return instructions.size(); } ///< 获取指令数量 + unsigned getNumArguments() const { return arguments.size(); } ///< 获取形式参数数量 + unsigned getNumPredecessors() const { return predecessors.size(); } ///< 获取前驱数量 + unsigned getNumSuccessors() const { return successors.size(); } ///< 获取后继数量 + Function* getParent() const { return parent; } ///< 获取父函数 + void setParent(Function *func) { parent = func; } ///< 设置父函数 + inst_list& getInstructions() { return instructions; } ///< 获取指令列表 + arg_list& getArguments() { return arguments; } ///< 获取分配空间后的形式参数列表 + const block_list& getPredecessors() const { return predecessors; } ///< 获取前驱列表 + block_list& getSuccessors() { return successors; } ///< 获取后继列表 + block_set& getDominants() { return dominants; } + BasicBlock* getIdom() { return idom; } + block_list& getSdoms() { return sdoms; } + block_set& getDFs() { return dominant_frontiers; } + iterator begin() { return instructions.begin(); } ///< 返回指向指令列表开头的迭代器 + iterator end() { return instructions.end(); } ///< 返回指向指令列表末尾的迭代器 + iterator terminator() { return std::prev(end()); } ///< 基本块最后的IR + void insertArgument(AllocaInst *inst) { arguments.push_back(inst); } ///< 插入分配空间后的形式参数 + void addPredecessor(BasicBlock *block) { + if (std::find(predecessors.begin(), predecessors.end(), block) == predecessors.end()) { + predecessors.push_back(block); + } + } ///< 添加前驱 + void addSuccessor(BasicBlock *block) { + if (std::find(successors.begin(), successors.end(), block) == successors.end()) { + successors.push_back(block); + } + } ///< 添加后继 + void addPredecessor(const block_list &blocks) { + for (auto block : blocks) { + addPredecessor(block); + } + } ///< 添加多个前驱 + void addSuccessor(const block_list &blocks) { + for (auto block : blocks) { + addSuccessor(block); + } + } ///< 添加多个后继 + void setIdom(BasicBlock *block) { idom = block; } + void addSdoms(BasicBlock *block) { sdoms.push_back(block); } + void clearSdoms() { sdoms.clear(); } + // 重载1,参数为 BasicBlock* + void addDominants(BasicBlock *block) { dominants.emplace(block); } + // 重载2,参数为 block_set + void addDominants(const block_set &blocks) { dominants.insert(blocks.begin(), blocks.end()); } + void setDominants(BasicBlock *block) { + dominants.clear(); + addDominants(block); + } + void setDominants(const block_set &doms) { + dominants.clear(); + addDominants(doms); + } + void setDFs(const block_set &df) { + dominant_frontiers.clear(); + for (auto elem : df) { + dominant_frontiers.emplace(elem); + } + } + void removePredecessor(BasicBlock *block) { + auto iter = std::find(predecessors.begin(), predecessors.end(), block); + if (iter != predecessors.end()) { + predecessors.erase(iter); + } else { + assert(false); + } + } ///< 删除前驱 + void removeSuccessor(BasicBlock *block) { + auto iter = std::find(successors.begin(), successors.end(), block); + if (iter != successors.end()) { + successors.erase(iter); + } else { + assert(false); + } + } ///< 删除后继 + void replacePredecessor(BasicBlock *oldBlock, BasicBlock *newBlock) { + for (auto &predecessor : predecessors) { + if (predecessor == oldBlock) { + predecessor = newBlock; + break; + } + } + } ///< 替换前驱 + // 获取支配树中该块的所有子节点,包括子节点的子节点等,迭代实现 + block_list getChildren() { + std::queue q; + block_list children; + for (auto sdom : sdoms) { + q.push(sdom); + children.push_back(sdom); + } + while (!q.empty()) { + auto block = q.front(); + q.pop(); + for (auto sdom : block->sdoms) { + q.push(sdom); + children.push_back(sdom); + } + } + + return children; + } + + void setreachableTrue() { reachable = true; } ///< 设置可达 + void setreachableFalse() { reachable = false; } ///< 设置不可达 + bool getreachable() { return reachable; } ///< 返回可达状态 + + static void conectBlocks(BasicBlock *prev, BasicBlock *next) { + prev->addSuccessor(next); + next->addPredecessor(prev); + } ///< 连接两个块,即设置两个基本块的前驱后继关系 + void setLoop(Loop *loop2set) { loopbelong = loop2set; } ///< 设置所属循环 + Loop* getLoop() { return loopbelong; } ///< 获得所属循环 + void setLoopDepth(int loopdepth2set) { loopdepth = loopdepth2set; } ///< 设置循环深度 + int getLoopDepth() { return loopdepth; } ///< 获得其在循环的深度 + void removeInst(iterator pos) { instructions.erase(pos); } ///< 删除指令 + iterator moveInst(iterator sourcePos, iterator targetPos, BasicBlock *block); ///< 移动指令 +}; + +//! User is the abstract base type of `Value` types which use other `Value` as +//! operands. Currently, there are two kinds of `User`s, `Instruction` and +//! `GlobalValue`. +// User是`Value`的派生类型,其使用其他的`Value`作为操作数 + + +class User : public Value { + protected: + std::vector> operands; ///< 操作数/使用关系 + + protected: + explicit User(Type *type, const std::string &name = "") : Value(type, name) {} + + public: + unsigned getNumOperands() const { return operands.size(); } ///< 获取操作数数量 + auto operand_begin() const { return operands.begin(); } ///< 返回操作数列表的开头迭代器 + auto operand_end() const { return operands.end(); } ///< 返回操作数列表的结尾迭代器 + auto getOperands() const { return make_range(operand_begin(), operand_end()); } ///< 获取操作数列表 + Value* getOperand(unsigned index) const { return operands[index]->getValue(); } ///< 获取位置为index的操作数 + void addOperand(Value *value) { + operands.emplace_back(std::make_shared(operands.size(), this, value)); + value->addUse(operands.back()); + } ///< 增加操作数 + void removeOperand(unsigned index) { + auto value = getOperand(index); + value->removeUse(operands[index]); + operands.erase(operands.begin() + index); + } ///< 移除操作数 + template + void addOperands(const ContainerT &newoperands) { + for (auto value : newoperands) { + addOperand(value); + } + } ///< 增加多个操作数 + void replaceOperand(unsigned index, Value *value); ///< 替换操作数 + void setOperand(unsigned index, Value *value); ///< 设置操作数 +}; + +class GetSubArrayInst; +/** + * 左值 具有地址的对象 + */ +class LVal { + friend class GetSubArrayInst; + + protected: + LVal *fatherLVal{}; ///< 父左值 + std::list> childrenLVals; ///< 子左值 + GetSubArrayInst *defineInst{}; /// 定义该左值的GetSubArray指令 + + protected: + LVal() = default; + + public: + virtual ~LVal() = default; + virtual std::vector getLValDims() const = 0; ///< 获取左值的维度 + virtual unsigned getLValNumDims() const = 0; ///< 获取左值的维度数量 + + public: + LVal* getFatherLVal() const { return fatherLVal; } ///< 获取父左值 + const std::list>& getChildrenLVals() const { + return childrenLVals; + } ///< 获取子左值列表 + LVal* getAncestorLVal() const { + auto curLVal = const_cast(this); + while (curLVal->getFatherLVal() != nullptr) { + curLVal = curLVal->getFatherLVal(); + } + return curLVal; + } ///< 获取祖先左值 + void setFatherLVal(LVal *father) { fatherLVal = father; } ///< 设置父左值 + void setDefineInst(GetSubArrayInst *inst) { defineInst = inst; } ///< 设置定义指令 + void addChild(LVal *child) { childrenLVals.emplace_back(child); } ///< 添加子左值 + void removeChild(LVal *child) { + auto iter = std::find_if(childrenLVals.begin(), childrenLVals.end(), + [child](const std::unique_ptr &ptr) { return ptr.get() == child; }); + childrenLVals.erase(iter); + } ///< 移除子左值 + GetSubArrayInst* getDefineInst() const { return defineInst; } ///< 获取定义指令 +}; + +/*! + * Base of all concrete instruction types. + */ +class Instruction : public User { + public: + /// 指令标识码 + enum Kind : uint64_t { + kInvalid = 0x0UL, + // Binary + kAdd = 0x1UL << 0, + kSub = 0x1UL << 1, + kMul = 0x1UL << 2, + kDiv = 0x1UL << 3, + kRem = 0x1UL << 4, + kICmpEQ = 0x1UL << 5, + kICmpNE = 0x1UL << 6, + kICmpLT = 0x1UL << 7, + kICmpGT = 0x1UL << 8, + kICmpLE = 0x1UL << 9, + kICmpGE = 0x1UL << 10, + kFAdd = 0x1UL << 11, + kFSub = 0x1UL << 12, + kFMul = 0x1UL << 13, + kFDiv = 0x1UL << 14, + kFCmpEQ = 0x1UL << 15, + kFCmpNE = 0x1UL << 16, + kFCmpLT = 0x1UL << 17, + kFCmpGT = 0x1UL << 18, + kFCmpLE = 0x1UL << 19, + kFCmpGE = 0x1UL << 20, + kAnd = 0x1UL << 21, + kOr = 0x1UL << 22, + // Unary + kNeg = 0x1UL << 23, + kNot = 0x1UL << 24, + kFNeg = 0x1UL << 25, + kFNot = 0x1UL << 26, + kFtoI = 0x1UL << 27, + kItoF = 0x1UL << 28, + // call + kCall = 0x1UL << 29, + // terminator + kCondBr = 0x1UL << 30, + kBr = 0x1UL << 31, + kReturn = 0x1UL << 32, + // mem op + kAlloca = 0x1UL << 33, + kLoad = 0x1UL << 34, + kStore = 0x1UL << 35, + kLa = 0x1UL << 36, + kMemset = 0x1UL << 37, + kGetSubArray = 0x1UL << 38, + // constant + kConstant = 0x1UL << 37, + // phi + kPhi = 0x1UL << 39, + kBitItoF = 0x1UL << 40, + kBitFtoI = 0x1UL << 41 + }; + +protected: + Kind kind; + BasicBlock *parent; + +protected: + Instruction(Kind kind, Type *type, BasicBlock *parent = nullptr, const std::string &name = "") + : User(type, name), kind(kind), parent(parent) {} + +public: + +public: + Kind getKind() const { return kind; } + std::string getKindString() const{ + switch (kind) { + case kInvalid: + return "Invalid"; + case kAdd: + return "Add"; + case kSub: + return "Sub"; + case kMul: + return "Mul"; + case kDiv: + return "Div"; + case kRem: + return "Rem"; + case kICmpEQ: + return "ICmpEQ"; + case kICmpNE: + return "ICmpNE"; + case kICmpLT: + return "ICmpLT"; + case kICmpGT: + return "ICmpGT"; + case kICmpLE: + return "ICmpLE"; + case kICmpGE: + return "ICmpGE"; + case kFAdd: + return "FAdd"; + case kFSub: + return "FSub"; + case kFMul: + return "FMul"; + case kFDiv: + return "FDiv"; + case kFCmpEQ: + return "FCmpEQ"; + case kFCmpNE: + return "FCmpNE"; + case kFCmpLT: + return "FCmpLT"; + case kFCmpGT: + return "FCmpGT"; + case kFCmpLE: + return "FCmpLE"; + case kFCmpGE: + return "FCmpGE"; + case kAnd: + return "And"; + case kOr: + return "Or"; + case kNeg: + return "Neg"; + case kNot: + return "Not"; + case kFNeg: + return "FNeg"; + case kFNot: + return "FNot"; + case kFtoI: + return "FtoI"; + case kItoF: + return "IToF"; + case kCall: + return "Call"; + case kCondBr: + return "CondBr"; + case kBr: + return "Br"; + case kReturn: + return "Return"; + case kAlloca: + return "Alloca"; + case kLoad: + return "Load"; + case kStore: + return "Store"; + case kLa: + return "La"; + case kMemset: + return "Memset"; + case kPhi: + return "Phi"; + case kGetSubArray: + return "GetSubArray"; + default: + return "Unknown"; + } + } ///< 根据指令标识码获取字符串 + + BasicBlock* getParent() const { return parent; } + Function* getFunction() const { return parent->getParent(); } + void setParent(BasicBlock *bb) { parent = bb; } + + bool isBinary() const { + static constexpr uint64_t BinaryOpMask = + (kAdd | kSub | kMul | kDiv | kRem | kAnd | kOr) | + (kICmpEQ | kICmpNE | kICmpLT | kICmpGT | kICmpLE | kICmpGE) | + (kFAdd | kFSub | kFMul | kFDiv) | + (kFCmpEQ | kFCmpNE | kFCmpLT | kFCmpGT | kFCmpLE | kFCmpGE); + return kind & BinaryOpMask; + } + bool isUnary() const { + static constexpr uint64_t UnaryOpMask = + kNeg | kNot | kFNeg | kFNot | kFtoI | kItoF | kBitFtoI | kBitItoF; + return kind & UnaryOpMask; + } + bool isMemory() const { + static constexpr uint64_t MemoryOpMask = + kAlloca | kLoad | kStore; + return kind & MemoryOpMask; + } + bool isTerminator() const { + static constexpr uint64_t TerminatorOpMask = kCondBr | kBr | kReturn; + return kind & TerminatorOpMask; + } + bool isCmp() const { + static constexpr uint64_t CmpOpMask = + (kICmpEQ | kICmpNE | kICmpLT | kICmpGT | kICmpLE | kICmpGE) | + (kFCmpEQ | kFCmpNE | kFCmpLT | kFCmpGT | kFCmpLE | kFCmpGE); + return kind & CmpOpMask; + } + bool isBranch() const { + static constexpr uint64_t BranchOpMask = kBr | kCondBr; + return kind & BranchOpMask; + } + bool isCommutative() const { + static constexpr uint64_t CommutativeOpMask = + kAdd | kMul | kICmpEQ | kICmpNE | kFAdd | kFMul | kFCmpEQ | kFCmpNE | kAnd | kOr; + return kind & CommutativeOpMask; + } + bool isUnconditional() const { return kind == kBr; } + bool isConditional() const { return kind == kCondBr; } + bool isPhi() const { return kind == kPhi; } + bool isAlloca() const { return kind == kAlloca; } + bool isLoad() const { return kind == kLoad; } + bool isStore() const { return kind == kStore; } + bool isLa() const { return kind == kLa; } + bool isMemset() const { return kind == kMemset; } + bool isGetSubArray() const { return kind == kGetSubArray; } + bool isCall() const { return kind == kCall; } + bool isReturn() const { return kind == kReturn; } + bool isDefine() const { + static constexpr uint64_t DefineOpMask = kAlloca | kStore | kPhi; + return (kind & DefineOpMask) != 0U; + } +}; // class Instruction + +class Function; +//! Function call. + +class LaInst : public Instruction { + friend class Function; + friend class IRBuilder; + + protected: + explicit LaInst(Value *pointer, const std::vector &indices = {}, BasicBlock *parent = nullptr, + const std::string &name = "") + : Instruction(Kind::kLa, pointer->getType(), parent, name) { + assert(pointer); + addOperand(pointer); + addOperands(indices); + } + + public: + unsigned getNumIndices() const { return getNumOperands() - 1; } ///< 获取索引长度 + Value* getPointer() const { return getOperand(0); } ///< 获取目标变量的Value指针 + auto getIndices() const { return make_range(std::next(operand_begin()), operand_end()); } ///< 获取索引列表 + Value* getIndex(unsigned index) const { return getOperand(index + 1); } ///< 获取位置为index的索引分量 +}; + +class PhiInst : public Instruction { + friend class IRBuilder; + friend class Function; + friend class SysySSA; + + protected: + Value *map_val; // Phi的旧值 + + PhiInst(Type *type, Value *lhs, const std::vector &rhs, Value *mval, BasicBlock *parent, + const std::string &name = "") + : Instruction(Kind::kPhi, type, parent, name) { + map_val = mval; + addOperand(lhs); + addOperands(rhs); + } + + public: + Value* getMapVal() { return map_val; } + Value* getPointer() const { return getOperand(0); } + auto getValues() { return make_range(std::next(operand_begin()), operand_end()); } + Value* getValue(unsigned index) const { return getOperand(index + 1); } +}; + + +class CallInst : public Instruction { + friend class Function; + friend class IRBuilder; + +protected: + CallInst(Function *callee, const std::vector &args = {}, + BasicBlock *parent = nullptr, const std::string &name = ""); + + +public: + Function* getCallee() const; + auto getArguments() const { + return make_range(std::next(operand_begin()), operand_end()); + } + +}; // class CallInst + +//! Unary instruction, includes '!', '-' and type conversion. +class UnaryInst : public Instruction { + friend class Function; + friend class IRBuilder; + +protected: + UnaryInst(Kind kind, Type *type, Value *operand, BasicBlock *parent = nullptr, + const std::string &name = "") + : Instruction(kind, type, parent, name) { + addOperand(operand); + } + + +public: + Value* getOperand() const { return User::getOperand(0); } + +}; // class UnaryInst + +//! Binary instruction, e.g., arithmatic, relation, logic, etc. +class BinaryInst : public Instruction { + friend class IRBuilder; + friend class Function; + + protected: + BinaryInst(Kind kind, Type *type, Value *lhs, Value *rhs, BasicBlock *parent, const std::string &name = "") + : Instruction(kind, type, parent, name) { + addOperand(lhs); + addOperand(rhs); + } + +public: + Value* getLhs() const { return getOperand(0); } + Value* getRhs() const { return getOperand(1); } + template + T eval(T lhs, T rhs) { + switch (getKind()) { + case kAdd: + return lhs + rhs; + case kSub: + return lhs - rhs; + case kMul: + return lhs * rhs; + case kDiv: + return lhs / rhs; + case kRem: + if constexpr (std::is_floating_point::value) { + throw std::runtime_error("Remainder operation not supported for floating point types."); + } else { + return lhs % rhs; + } + case kICmpEQ: + return lhs == rhs; + case kICmpNE: + return lhs != rhs; + case kICmpLT: + return lhs < rhs; + case kICmpGT: + return lhs > rhs; + case kICmpLE: + return lhs <= rhs; + case kICmpGE: + return lhs >= rhs; + case kFAdd: + return lhs + rhs; + case kFSub: + return lhs - rhs; + case kFMul: + return lhs * rhs; + case kFDiv: + return lhs / rhs; + case kFCmpEQ: + return lhs == rhs; + case kFCmpNE: + return lhs != rhs; + case kFCmpLT: + return lhs < rhs; + case kFCmpGT: + return lhs > rhs; + case kFCmpLE: + return lhs <= rhs; + case kFCmpGE: + return lhs >= rhs; + case kAnd: + return lhs && rhs; + case kOr: + return lhs || rhs; + default: + assert(false); + } + } ///< 根据指令类型进行二元计算,eval template模板实现 +}; // class BinaryInst + +//! The return statement +class ReturnInst : public Instruction { + friend class IRBuilder; + friend class Function; + + protected: + explicit ReturnInst(Value *value = nullptr, BasicBlock *parent = nullptr, const std::string &name = "") + : Instruction(kReturn, Type::getVoidType(), parent, name) { + if (value != nullptr) { + addOperand(value); + } + } + + public: + bool hasReturnValue() const { return not operands.empty(); } + Value* getReturnValue() const { + return hasReturnValue() ? getOperand(0) : nullptr; + } +}; + +//! Unconditional branch +class UncondBrInst : public Instruction { + friend class IRBuilder; + friend class Function; + +protected: + UncondBrInst(BasicBlock *block, std::vector args, + BasicBlock *parent = nullptr) + : Instruction(kBr, Type::getVoidType(), parent, "") { + // assert(block->getNumArguments() == args.size()); + addOperand(block); + addOperands(args); + } + +public: + BasicBlock* getBlock() const { return dynamic_cast(getOperand(0)); } + auto getArguments() const { + return make_range(std::next(operand_begin()), operand_end()); + } + +}; // class UncondBrInst + +//! Conditional branch +class CondBrInst : public Instruction { + friend class IRBuilder; + friend class Function; + +protected: + CondBrInst(Value *condition, BasicBlock *thenBlock, BasicBlock *elseBlock, + const std::vector &thenArgs, + const std::vector &elseArgs, BasicBlock *parent = nullptr) + : Instruction(kCondBr, Type::getVoidType(), parent, "") { + // assert(thenBlock->getNumArguments() == thenArgs.size() and + // elseBlock->getNumArguments() == elseArgs.size()); + addOperand(condition); + addOperand(thenBlock); + addOperand(elseBlock); + addOperands(thenArgs); + addOperands(elseArgs); + } +public: + Value* getCondition() const { return getOperand(0); } + BasicBlock* getThenBlock() const { + return dynamic_cast(getOperand(1)); + } + BasicBlock* getElseBlock() const { + return dynamic_cast(getOperand(2)); + } + auto getThenArguments() const { + auto begin = std::next(operand_begin(), 3); + auto end = std::next(begin, getThenBlock()->getNumArguments()); + return make_range(begin, end); + } + auto getElseArguments() const { + auto begin = + std::next(operand_begin(), 3 + getThenBlock()->getNumArguments()); + auto end = operand_end(); + return make_range(begin, end); + } + +}; // class CondBrInst + +//! Allocate memory for stack variables, used for non-global variable declartion +class AllocaInst : public Instruction , public LVal { + friend class IRBuilder; + friend class Function; +protected: + AllocaInst(Type *type, const std::vector &dims = {}, + BasicBlock *parent = nullptr, const std::string &name = "") + : Instruction(kAlloca, type, parent, name) { + addOperands(dims); + } + +public: + std::vector getLValDims() const override { + std::vector dims; + for (const auto &dim : getOperands()) { + dims.emplace_back(dim->getValue()); + } + return dims; + } ///< 获取作为左值的维度数组 + unsigned getLValNumDims() const override { return getNumOperands(); } + + int getNumDims() const { return getNumOperands(); } + auto getDims() const { return getOperands(); } + Value* getDim(int index) { return getOperand(index); } + +}; // class AllocaInst + + +class GetSubArrayInst : public Instruction { + friend class IRBuilder; + friend class Function; + + public: + GetSubArrayInst(LVal *fatherArray, LVal *childArray, const std::vector &indices, + BasicBlock *parent = nullptr, const std::string &name = "") + : Instruction(Kind::kGetSubArray, Type::getVoidType(), parent, name) { + auto predicate = [childArray](const std::unique_ptr &child) -> bool { return child.get() == childArray; }; + if (std::find_if(fatherArray->childrenLVals.begin(), fatherArray->childrenLVals.end(), predicate) == + fatherArray->childrenLVals.end()) { + fatherArray->childrenLVals.emplace_back(childArray); + } + childArray->fatherLVal = fatherArray; + childArray->defineInst = this; + auto fatherArrayValue = dynamic_cast(fatherArray); + auto childArrayValue = dynamic_cast(childArray); + assert(fatherArrayValue); + assert(childArrayValue); + addOperand(fatherArrayValue); + addOperand(childArrayValue); + addOperands(indices); + } + + public: + Value* getFatherArray() const { return getOperand(0); } ///< 获取父数组 + Value* getChildArray() const { return getOperand(1); } ///< 获取子数组 + LVal* getFatherLVal() const { return dynamic_cast(getOperand(0)); } ///< 获取父左值 + LVal* getChildLVal() const { return dynamic_cast(getOperand(1)); } ///< 获取子左值 + auto getIndices() const { return make_range(std::next(operand_begin(), 2), operand_end()); } ///< 获取索引 + unsigned getNumIndices() const { return getNumOperands() - 2; } ///< 获取索引数量 +}; + +//! Load a value from memory address specified by a pointer value +class LoadInst : public Instruction { + friend class IRBuilder; + friend class Function; + +protected: + LoadInst(Value *pointer, const std::vector &indices = {}, + BasicBlock *parent = nullptr, const std::string &name = "") + : Instruction(kLoad, pointer->getType()->as()->getBaseType(), + parent, name) { + addOperand(pointer); + addOperands(indices); + } + +public: + int getNumIndices() const { return getNumOperands() - 1; } + Value* getPointer() const { return getOperand(0); } + auto getIndices() const { + return make_range(std::next(operand_begin()), operand_end()); + } + Value* getIndex(int index) const { return getOperand(index + 1); } + std::list getAncestorIndices() const { + std::list indices; + for (const auto &index : getIndices()) { + indices.emplace_back(index->getValue()); + } + auto curPointer = dynamic_cast(getPointer()); + while (curPointer->getFatherLVal() != nullptr) { + auto inserter = std::next(indices.begin()); + for (const auto &index : curPointer->getDefineInst()->getIndices()) { + indices.insert(inserter, index->getValue()); + } + curPointer = curPointer->getFatherLVal(); + } + + return indices; + } ///< 获取相对于祖先数组的索引列表 +}; // class LoadInst + +//! Store a value to memory address specified by a pointer value +class StoreInst : public Instruction { + friend class IRBuilder; + friend class Function; + +protected: + StoreInst(Value *value, Value *pointer, + const std::vector &indices = {}, + BasicBlock *parent = nullptr, const std::string &name = "") + : Instruction(kStore, Type::getVoidType(), parent, name) { + addOperand(value); + addOperand(pointer); + addOperands(indices); + } + +public: + int getNumIndices() const { return getNumOperands() - 2; } + Value* getValue() const { return getOperand(0); } + Value* getPointer() const { return getOperand(1); } + auto getIndices() const { + return make_range(std::next(operand_begin(), 2), operand_end()); + } + Value* getIndex(int index) const { return getOperand(index + 2); } + std::list getAncestorIndices() const { + std::list indices; + for (const auto &index : getIndices()) { + indices.emplace_back(index->getValue()); + } + auto curPointer = dynamic_cast(getPointer()); + while (curPointer->getFatherLVal() != nullptr) { + auto inserter = std::next(indices.begin()); + for (const auto &index : curPointer->getDefineInst()->getIndices()) { + indices.insert(inserter, index->getValue()); + } + curPointer = curPointer->getFatherLVal(); + } + + return indices; + } ///< 获取相对于祖先数组的索引列表 + +}; // class StoreInst + +//! Memset instruction +class MemsetInst : public Instruction { + friend class IRBuilder; + friend class Function; + +protected: + //! Create a memset instruction. + //! \param pointer The pointer to the memory location to be set. + //! \param begin The starting address of the memory region to be set. + //! \param size The size of the memory region to be set. + //! \param value The value to set the memory region to. + //! \param parent The parent basic block of this instruction. + //! \param name The name of this instruction. + MemsetInst(Value *pointer, Value *begin, Value *size, Value *value, + BasicBlock *parent = nullptr, const std::string &name = "") + : Instruction(kMemset, Type::getVoidType(), parent, name) { + addOperand(pointer); + addOperand(begin); + addOperand(size); + addOperand(value); + } + +public: + Value* getPointer() const { return getOperand(0); } + Value* getBegin() const { return getOperand(1); } + Value* getSize() const { return getOperand(2); } + Value* getValue() const { return getOperand(3); } + +}; + +class GlobalValue; + +// 循环类 +class Loop { +public: + using block_list = std::vector; + using block_set = std::unordered_set; + using Loop_list = std::vector; + +protected: + Function *parent; // 所属函数 + block_list blocksInLoop; // 循环内的基本块 + BasicBlock *preheaderBlock = nullptr; // 前驱块 + BasicBlock *headerBlock = nullptr; // 循环头 + block_list latchBlock; // 回边块 + block_set exitingBlocks; // 退出块 + block_set exitBlocks; // 退出目标块 + Loop *parentloop = nullptr; // 父循环 + Loop_list subLoops; // 子循环 + size_t loopID; // 循环ID + unsigned loopDepth; // 循环深度 + + Instruction *indCondVar = nullptr; // 循环条件变量 + Instruction::Kind IcmpKind; // 比较类型 + Value *indEnd = nullptr; // 循环结束值 + AllocaInst *IndPhi = nullptr; // 循环变量 + + ConstantValue *indBegin = nullptr; // 循环起始值 + ConstantValue *indStep = nullptr; // 循环步长 + + std::set GlobalValuechange; // 循环内改变的全局变量 + + int StepType = 0; // 循环步长类型 + bool parallelable = false; // 是否可并行 + +public: + explicit Loop(BasicBlock *header, const std::string &name = "") + : headerBlock(header) { + blocksInLoop.push_back(header); + } + + void setloopID() { + static unsigned loopCount = 0; + loopCount = loopCount + 1; + loopID = loopCount; + } + ConstantValue* getindBegin() { return indBegin; } ///< 获得循环开始值 + ConstantValue* getindStep() { return indStep; } ///< 获得循环步长 + void setindBegin(ConstantValue *indBegin2set) { indBegin = indBegin2set; } ///< 设置循环开始值 + void setindStep(ConstantValue *indStep2set) { indStep = indStep2set; } ///< 设置循环步长 + void setStepType(int StepType2Set) { StepType = StepType2Set; } ///< 设置循环变量规则 + int getStepType() { return StepType; } ///< 获得循环变量规则 + size_t getLoopID() { return loopID; } + + BasicBlock* getHeader() const { return headerBlock; } + BasicBlock* getPreheaderBlock() const { return preheaderBlock; } + block_list& getLatchBlocks() { return latchBlock; } + block_set& getExitingBlocks() { return exitingBlocks; } + block_set& getExitBlocks() { return exitBlocks; } + Loop* getParentLoop() const { return parentloop; } + void setParentLoop(Loop *parent) { parentloop = parent; } + void addBasicBlock(BasicBlock *bb) { blocksInLoop.push_back(bb); } + void addSubLoop(Loop *loop) { subLoops.push_back(loop); } + void setLoopDepth(unsigned depth) { loopDepth = depth; } + block_list& getBasicBlocks() { return blocksInLoop; } + Loop_list& getSubLoops() { return subLoops; } + unsigned getLoopDepth() const { return loopDepth; } + + bool isLoopContainsBasicBlock(BasicBlock *bb) const { + return std::find(blocksInLoop.begin(), blocksInLoop.end(), bb) != blocksInLoop.end(); + } ///< 判断输入块是否在该循环内 + + void addExitingBlock(BasicBlock *bb) { exitingBlocks.insert(bb); } + void addExitBlock(BasicBlock *bb) { exitBlocks.insert(bb); } + void addLatchBlock(BasicBlock *bb) { latchBlock.push_back(bb); } + void setPreheaderBlock(BasicBlock *bb) { preheaderBlock = bb; } + + void setIndexCondInstr(Instruction *instr) { indCondVar = instr; } + void setIcmpKind(Instruction::Kind kind) { IcmpKind = kind; } + Instruction::Kind getIcmpKind() const { return IcmpKind; } + + bool isSimpleLoopInvariant(Value *value) ; ///< 判断是否为简单循环不变量,若其在loop中,则不是。 + + void setIndEnd(Value *value) { indEnd = value; } + void setIndPhi(AllocaInst *phi) { IndPhi = phi; } + Value* getIndEnd() const { return indEnd; } + AllocaInst* getIndPhi() const { return IndPhi; } + Instruction* getIndCondVar() const { return indCondVar; } + + void addGlobalValuechange(GlobalValue *globalvaluechange2add) { + GlobalValuechange.insert(globalvaluechange2add); + } ///<添加在循环中改变的全局变量 + std::set& getGlobalValuechange() { + return GlobalValuechange; + } ///<获得在循环中改变的所有全局变量 + + void setParallelable(bool flag) { parallelable = flag; } + bool isParallelable() const { return parallelable; } +}; + +class Module; +//! Function definition +class Function : public Value { + friend class Module; + +protected: + Function(Module *parent, Type *type, const std::string &name) : Value(type, name), parent(parent) { + blocks.emplace_back(new BasicBlock(this)); + } + +public: + using block_list = std::list>; + using Loop_list = std::list>; + + // 函数优化属性标识符 + enum FunctionAttribute : uint64_t { + PlaceHolder = 0x0UL, + Pure = 0x1UL << 0, + SelfRecursive = 0x1UL << 1, + SideEffect = 0x1UL << 2, + NoPureCauseMemRead = 0x1UL << 3 + }; + +protected: + Module *parent; ///< 函数的父模块 + block_list blocks; ///< 函数包含的基本块列表 + Loop_list loops; ///< 函数包含的循环列表 + Loop_list topLoops; ///< 函数所包含的顶层循环; + std::list> indirectAllocas; ///< 函数中mem2reg引入的间接分配的内存 + + FunctionAttribute attribute = PlaceHolder; ///< 函数属性 + std::set callees; ///< 函数调用的函数集合 + + std::unordered_map basicblock2Loop; + std::unordered_map value2AllocBlocks; ///< value -- alloc block mapping + std::unordered_map> + value2DefBlocks; //< value -- define blocks mapping + std::unordered_map> value2UseBlocks; //< value -- use blocks mapping + + public: + static unsigned getcloneIndex() { + static unsigned cloneIndex = 0; + cloneIndex += 1; + return cloneIndex - 1; + } + Function* clone(const std::string &suffix = "_" + std::to_string(getcloneIndex()) + "@") const; ///< 复制函数 + const std::set& getCallees() { return callees; } + void addCallee(Function *callee) { callees.insert(callee); } + void removeCallee(Function *callee) { callees.erase(callee); } + void clearCallees() { callees.clear(); } + std::set getCalleesWithNoExternalAndSelf(); + FunctionAttribute getAttribute() const { return attribute; } ///< 获取函数属性 + void setAttribute(FunctionAttribute attr) { + attribute = static_cast(attribute | attr); + } ///< 设置函数属性 + void clearAttribute() { attribute = PlaceHolder; } ///< 清楚所有函数属性,只保留PlaceHolder + Loop* getLoopOfBasicBlock(BasicBlock *bb) { + return basicblock2Loop.count(bb) != 0 ? basicblock2Loop[bb] : nullptr; + } ///< 获得块所在循环 + unsigned getLoopDepthByBlock(BasicBlock *basicblock2Check) { + if (getLoopOfBasicBlock(basicblock2Check) != nullptr) { + auto loop = getLoopOfBasicBlock(basicblock2Check); + return loop->getLoopDepth(); + } + return static_cast(0); + } ///< 通过块,获得其所在循环深度 + void addBBToLoop(BasicBlock *bb, Loop *LoopToadd) { basicblock2Loop[bb] = LoopToadd; } ///< 添加块与循环的映射 + std::unordered_map& getBBToLoopRef() { + return basicblock2Loop; + } ///< 获得块-循环映射表 + // auto getNewLoopPtr(BasicBlock *header) -> Loop * { return new Loop(header); } + Type* getReturnType() const { return getType()->as()->getReturnType(); } ///< 获取返回值类型 + auto getParamTypes() const { return getType()->as()->getParamTypes(); } ///< 获取形式参数类型列表 + auto getBasicBlocks() { return make_range(blocks); } ///< 获取基本块列表 + block_list& getBasicBlocks_NoRange() { return blocks; } + BasicBlock* getEntryBlock() { return blocks.front().get(); } ///< 获取入口块 + void removeBasicBlock(BasicBlock *blockToRemove) { + auto is_same_ptr = [blockToRemove](const std::unique_ptr &ptr) { return ptr.get() == blockToRemove; }; + blocks.remove_if(is_same_ptr); + // blocks.erase(std::remove_if(blocks.begin(), blocks.end(), is_same_ptr), blocks.end()); + } ///< 将该块从function的blocks中删除 + // auto getBasicBlocks_NoRange() -> block_list & { return blocks; } + BasicBlock* addBasicBlock(const std::string &name = "") { + blocks.emplace_back(new BasicBlock(this, name)); + return blocks.back().get(); + } ///< 添加新的基本块 + BasicBlock* addBasicBlock(BasicBlock *block) { + blocks.emplace_back(block); + return block; + } ///< 添加基本块到blocks中 + BasicBlock* addBasicBlockFront(BasicBlock *block) { + blocks.emplace_front(block); + return block; + } // 从前端插入新的基本块 + /** value -- alloc blocks mapping */ + void addValue2AllocBlocks(Value *value, BasicBlock *block) { + value2AllocBlocks[value] = block; + } ///< 添加value -- alloc block mapping + BasicBlock* getAllocBlockByValue(Value *value) { + if (value2AllocBlocks.count(value) > 0) { + return value2AllocBlocks[value]; + } + return nullptr; + } ///< 通过value获取alloc block + std::unordered_map& getValue2AllocBlocks() { + return value2AllocBlocks; + } ///< 获取所有value -- alloc block mappings + void removeValue2AllocBlock(Value *value) { + value2AllocBlocks.erase(value); + } ///< 删除value -- alloc block mapping + /** value -- define blocks mapping */ + void addValue2DefBlocks(Value *value, BasicBlock *block) { + ++value2DefBlocks[value][block]; + } ///< 添加value -- define block mapping + // keep in mind that the return is not a reference. + std::unordered_set getDefBlocksByValue(Value *value) { + std::unordered_set blocks; + if (value2DefBlocks.count(value) > 0) { + for (const auto &pair : value2DefBlocks[value]) { + blocks.insert(pair.first); + } + } + return blocks; + } ///< 通过value获取define blocks + std::unordered_map>& getValue2DefBlocks() { + return value2DefBlocks; + } ///< 获取所有value -- define blocks mappings + bool removeValue2DefBlock(Value *value, BasicBlock *block) { + bool changed = false; + if (--value2DefBlocks[value][block] == 0) { + value2DefBlocks[value].erase(block); + if (value2DefBlocks[value].empty()) { + value2DefBlocks.erase(value); + changed = true; + } + } + return changed; + } ///< 删除value -- define block mapping + std::unordered_set getValuesOfDefBlock() { + std::unordered_set values; + for (const auto &pair : value2DefBlocks) { + values.insert(pair.first); + } + return values; + } ///< 获取所有定义过的value + /** value -- use blocks mapping */ + void addValue2UseBlocks(Value *value, BasicBlock *block) { + ++value2UseBlocks[value][block]; + } ///< 添加value -- use block mapping + // keep in mind that the return is not a reference. + std::unordered_set getUseBlocksByValue(Value *value) { + std::unordered_set blocks; + if (value2UseBlocks.count(value) > 0) { + for (const auto &pair : value2UseBlocks[value]) { + blocks.insert(pair.first); + } + } + return blocks; + } ///< 通过value获取use blocks + std::unordered_map>& getValue2UseBlocks() { + return value2UseBlocks; + } ///< 获取所有value -- use blocks mappings + bool removeValue2UseBlock(Value *value, BasicBlock *block) { + bool changed = false; + if (--value2UseBlocks[value][block] == 0) { + value2UseBlocks[value].erase(block); + if (value2UseBlocks[value].empty()) { + value2UseBlocks.erase(value); + changed = true; + } + } + return changed; + } ///< 删除value -- use block mapping + void addIndirectAlloca(AllocaInst *alloca) { indirectAllocas.emplace_back(alloca); } ///< 添加间接分配 + std::list>& getIndirectAllocas() { + return indirectAllocas; + } ///< 获取间接分配列表 + + /** loop -- begin */ + + void addLoop(Loop *loop) { loops.emplace_back(loop); } ///< 添加循环(非顶层) + void addTopLoop(Loop *loop) { topLoops.emplace_back(loop); } ///< 添加顶层循环 + Loop_list& getLoops() { return loops; } ///< 获得循环(非顶层) + Loop_list& getTopLoops() { return topLoops; } ///< 获得顶层循环 + /** loop -- end */ + +}; // class Function + +//! Global value declared at file scope +class GlobalValue : public User, public LVal { + friend class Module; + +protected: + Module *parent; ///< 父模块 + unsigned numDims; ///< 维度数量 + ValueCounter initValues; ///< 初值 + +protected: + GlobalValue(Module *parent, Type *type, const std::string &name, + const std::vector &dims = {}, + ValueCounter init = {}) + : User(type, name), parent(parent) { + assert(type->isPointer()); + addOperands(dims); + numDims = dims.size(); + if (init.size() == 0) { + unsigned num = 1; + for (unsigned i = 0; i < numDims; i++) { + num *= dynamic_cast(dims[i])->getInt(); + } + if (dynamic_cast(type)->getBaseType() == Type::getFloatType()) { + init.push_back(ConstantValue::get(0.0F), num); + } else { + init.push_back(ConstantValue::get(0), num); + } + } + initValues = init; + } + +public: + unsigned getLValNumDims() const override { return numDims; } ///< 获取作为左值的维度数量 + std::vector getLValDims() const override { + std::vector dims; + for (const auto &dim : getOperands()) { + dims.emplace_back(dim->getValue()); + } + + return dims; + } ///< 获取作为左值的维度列表 + + unsigned getNumDims() const { return numDims; } ///< 获取维度数量 + Value* getDim(unsigned index) const { return getOperand(index); } ///< 获取位置为index的维度 + auto getDims() const { return getOperands(); } ///< 获取维度列表 + Value* getByIndex(unsigned index) const { + return initValues.getValue(index); + } ///< 通过一维偏移量index获取初始值 + Value* getByIndices(const std::vector &indices) const { + int index = 0; + for (size_t i = 0; i < indices.size(); i++) { + index = dynamic_cast(getDim(i))->getInt() * index + + dynamic_cast(indices[i])->getInt(); + } + return getByIndex(index); + } ///< 通过多维索引indices获取初始值 + const ValueCounter& getInitValues() const { return initValues; } +}; // class GlobalValue + + +class ConstantVariable : public User, public LVal { + friend class Module; + + protected: + Module *parent; ///< 父模块 + unsigned numDims; ///< 维度数量 + ValueCounter initValues; ///< 值 + + protected: + ConstantVariable(Module *parent, Type *type, const std::string &name, const ValueCounter &init, + const std::vector &dims = {}) + : User(type, name), parent(parent) { + assert(type->isPointer()); + numDims = dims.size(); + initValues = init; + addOperands(dims); + } + + public: + unsigned getLValNumDims() const override { return numDims; } ///< 获取作为左值的维度数量 + std::vector getLValDims() const override { + std::vector dims; + for (const auto &dim : getOperands()) { + dims.emplace_back(dim->getValue()); + } + + return dims; + } ///< 获取作为左值的维度列表 + Value* getByIndex(unsigned index) const { return initValues.getValue(index); } ///< 通过一维位置index获取值 + Value* getByIndices(const std::vector &indices) const { + int index = 0; + // 计算偏移量 + for (size_t i = 0; i < indices.size(); i++) { + index = dynamic_cast(getDim(i))->getInt() * index + + dynamic_cast(indices[i])->getInt(); + } + + return getByIndex(index); + } ///< 通过多维索引indices获取初始值 + unsigned getNumDims() const { return numDims; } ///< 获取维度数量 + Value* getDim(unsigned index) const { return getOperand(index); } ///< 获取位置为index的维度 + auto getDims() const { return getOperands(); } ///< 获取维度列表 + const ValueCounter& getInitValues() const { return initValues; } ///< 获取初始值 +}; + +using SymbolTableNode = struct SymbolTableNode { + SymbolTableNode *pNode; ///< 父节点 + std::vector children; ///< 子节点列表 + std::map varList; ///< 变量列表 +}; + + +class SymbolTable { + private: + SymbolTableNode *curNode{}; ///< 当前所在的作用域(符号表节点) + std::map variableIndex; ///< 变量命名索引表 + std::vector> globals; ///< 全局变量列表 + std::vector> consts; ///< 常量列表 + std::vector> nodeList; ///< 符号表节点列表 + + public: + SymbolTable() = default; + + User* getVariable(const std::string &name) const; ///< 根据名字name以及当前作用域获取变量 + User* addVariable(const std::string &name, User *variable); ///< 添加变量 + std::vector>& getGlobals(); ///< 获取全局变量列表 + const std::vector>& getConsts() const; ///< 获取常量列表 + void enterNewScope(); ///< 进入新的作用域 + void leaveScope(); ///< 离开作用域 + bool isInGlobalScope() const; ///< 是否位于全局作用域 + void enterGlobalScope(); ///< 进入全局作用域 + bool isCurNodeNull() { return curNode == nullptr; } +}; + +//! IR unit for representing a SysY compile unit +class Module { + protected: + std::map> externalFunctions; ///< 外部函数表 + std::map> functions; ///< 函数表 + SymbolTable variableTable; ///< 符号表 + + public: + Module() = default; + + public: + Function* createFunction(const std::string &name, Type *type) { + auto result = functions.try_emplace(name, new Function(this, type, name)); + if (!result.second) { + return nullptr; + } + return result.first->second.get(); + } ///< 创建函数 + Function* createExternalFunction(const std::string &name, Type *type) { + auto result = externalFunctions.try_emplace(name, new Function(this, type, name)); + if (!result.second) { + return nullptr; + } + return result.first->second.get(); + } ///< 创建外部函数 + ///< 变量创建伴随着符号表的更新 + GlobalValue* createGlobalValue(const std::string &name, Type *type, const std::vector &dims = {}, + const ValueCounter &init = {}) { + bool isFinished = variableTable.isCurNodeNull(); + if (isFinished) { + variableTable.enterGlobalScope(); + } + auto result = variableTable.addVariable(name, new GlobalValue(this, type, name, dims, init)); + if (isFinished) { + variableTable.leaveScope(); + } + if (result == nullptr) { + return nullptr; + } + return dynamic_cast(result); + } ///< 创建全局变量 + ConstantVariable* createConstVar(const std::string &name, Type *type, const ValueCounter &init, + const std::vector &dims = {}) { + auto result = variableTable.addVariable(name, new ConstantVariable(this, type, name, init, dims)); + if (result == nullptr) { + return nullptr; + } + return dynamic_cast(result); + } ///< 创建常量 + void addVariable(const std::string &name, AllocaInst *variable) { + variableTable.addVariable(name, variable); + } ///< 添加变量 + User* getVariable(const std::string &name) { + return variableTable.getVariable(name); + } ///< 根据名字name和当前作用域获取变量 + Function* getFunction(const std::string &name) const { + auto result = functions.find(name); + if (result == functions.end()) { + return nullptr; + } + return result->second.get(); + } ///< 获取函数 + Function* getExternalFunction(const std::string &name) const { + auto result = externalFunctions.find(name); + if (result == functions.end()) { + return nullptr; + } + return result->second.get(); + } ///< 获取外部函数 + std::map>& getFunctions() { return functions; } ///< 获取函数列表 + const std::map>& getExternalFunctions() const { + return externalFunctions; + } ///< 获取外部函数列表 + std::vector>& getGlobals() { + return variableTable.getGlobals(); + } ///< 获取全局变量列表 + const std::vector>& getConsts() const { + return variableTable.getConsts(); + } ///< 获取常量列表 + void enterNewScope() { variableTable.enterNewScope(); } ///< 进入新的作用域 + + void leaveScope() { variableTable.leaveScope(); } ///< 离开作用域 + + bool isInGlobalArea() const { return variableTable.isInGlobalScope(); } ///< 是否位于全局作用域 +}; + +/*! + * @} + */ + +} // namespace sysy diff --git a/src/include/IRBuilder.h b/src/include/IRBuilder.h new file mode 100644 index 0000000..9189cba --- /dev/null +++ b/src/include/IRBuilder.h @@ -0,0 +1,349 @@ +#pragma once + +#include +#include +#include +#include +#include +#include "IR.h" + +/** + * @file IRBuilder.h + * + * @brief 定义IR构建器的头文件 + */ +namespace sysy { + +/** + * @brief 中间IR的构建器 + * + */ +class IRBuilder { + private: + unsigned labelIndex; ///< 基本块标签编号 + unsigned tmpIndex; ///< 临时变量编号 + + BasicBlock *block; ///< 当前基本块 + BasicBlock::iterator position; ///< 当前基本块指令列表位置的迭代器 + + std::vector trueBlocks; ///< true分支基本块列表 + std::vector falseBlocks; ///< false分支基本块列表 + + std::vector breakBlocks; ///< break目标块列表 + std::vector continueBlocks; ///< continue目标块列表 + + public: + IRBuilder() : labelIndex(0), tmpIndex(0), block(nullptr) {} + explicit IRBuilder(BasicBlock *block) : labelIndex(0), tmpIndex(0), block(block), position(block->end()) {} + IRBuilder(BasicBlock *block, BasicBlock::iterator position) + : labelIndex(0), tmpIndex(0), block(block), position(position) {} + + public: + unsigned getLabelIndex() { + labelIndex += 1; + return labelIndex - 1; + } ///< 获取基本块标签编号 + unsigned getTmpIndex() { + tmpIndex += 1; + return tmpIndex - 1; + } ///< 获取临时变量编号 + BasicBlock * getBasicBlock() const { return block; } ///< 获取当前基本块 + BasicBlock * getBreakBlock() const { return breakBlocks.back(); } ///< 获取break目标块 + BasicBlock * popBreakBlock() { + auto result = breakBlocks.back(); + breakBlocks.pop_back(); + return result; + } ///< 弹出break目标块 + BasicBlock * getContinueBlock() const { return continueBlocks.back(); } ///< 获取continue目标块 + BasicBlock * popContinueBlock() { + auto result = continueBlocks.back(); + continueBlocks.pop_back(); + return result; + } ///< 弹出continue目标块 + + BasicBlock * getTrueBlock() const { return trueBlocks.back(); } ///< 获取true分支基本块 + BasicBlock * getFalseBlock() const { return falseBlocks.back(); } ///< 获取false分支基本块 + BasicBlock * popTrueBlock() { + auto result = trueBlocks.back(); + trueBlocks.pop_back(); + return result; + } ///< 弹出true分支基本块 + BasicBlock * popFalseBlock() { + auto result = falseBlocks.back(); + falseBlocks.pop_back(); + return result; + } ///< 弹出false分支基本块 + BasicBlock::iterator getPosition() const { return position; } ///< 获取当前基本块指令列表位置的迭代器 + void setPosition(BasicBlock *block, BasicBlock::iterator position) { + this->block = block; + this->position = position; + } ///< 设置基本块和基本块指令列表位置的迭代器 + void setPosition(BasicBlock::iterator position) { + this->position = position; + } ///< 设置当前基本块指令列表位置的迭代器 + void pushBreakBlock(BasicBlock *block) { breakBlocks.push_back(block); } ///< 压入break目标基本块 + void pushContinueBlock(BasicBlock *block) { continueBlocks.push_back(block); } ///< 压入continue目标基本块 + void pushTrueBlock(BasicBlock *block) { trueBlocks.push_back(block); } ///< 压入true分支基本块 + void pushFalseBlock(BasicBlock *block) { falseBlocks.push_back(block); } ///< 压入false分支基本块 + + public: + Instruction * insertInst(Instruction *inst) { + assert(inst); + block->getInstructions().emplace(position, inst); + return inst; + } ///< 插入指令 + UnaryInst * createUnaryInst(Instruction::Kind kind, Type *type, Value *operand, const std::string &name = "") { + std::string newName; + if (name.empty()) { + std::stringstream ss; + ss << "%" << tmpIndex; + newName = ss.str(); + tmpIndex++; + } else { + newName = name; + } + + auto inst = new UnaryInst(kind, type, operand, block, newName); + assert(inst); + block->getInstructions().emplace(position, inst); + return inst; + } ///< 创建一元指令 + UnaryInst * createNegInst(Value *operand, const std::string &name = "") { + return createUnaryInst(Instruction::kNeg, Type::getIntType(), operand, name); + } ///< 创建取反指令 + UnaryInst * createNotInst(Value *operand, const std::string &name = "") { + return createUnaryInst(Instruction::kNot, Type::getIntType(), operand, name); + } ///< 创建取非指令 + UnaryInst * createFtoIInst(Value *operand, const std::string &name = "") { + return createUnaryInst(Instruction::kFtoI, Type::getIntType(), operand, name); + } ///< 创建浮点转整型指令 + UnaryInst * createBitFtoIInst(Value *operand, const std::string &name = "") { + return createUnaryInst(Instruction::kBitFtoI, Type::getIntType(), operand, name); + } ///< 创建按位浮点转整型指令 + UnaryInst * createFNegInst(Value *operand, const std::string &name = "") { + return createUnaryInst(Instruction::kFNeg, Type::getFloatType(), operand, name); + } ///< 创建浮点取反指令 + UnaryInst * createFNotInst(Value *operand, const std::string &name = "") { + return createUnaryInst(Instruction::kFNot, Type::getIntType(), operand, name); + } ///< 创建浮点取非指令 + UnaryInst * createIToFInst(Value *operand, const std::string &name = "") { + return createUnaryInst(Instruction::kItoF, Type::getFloatType(), operand, name); + } ///< 创建整型转浮点指令 + UnaryInst * createBitItoFInst(Value *operand, const std::string &name = "") { + return createUnaryInst(Instruction::kBitItoF, Type::getFloatType(), operand, name); + } ///< 创建按位整型转浮点指令 + BinaryInst * createBinaryInst(Instruction::Kind kind, Type *type, Value *lhs, Value *rhs, const std::string &name = "") { + std::string newName; + if (name.empty()) { + std::stringstream ss; + ss << "%" << tmpIndex; + newName = ss.str(); + tmpIndex++; + } else { + newName = name; + } + + auto inst = new BinaryInst(kind, type, lhs, rhs, block, newName); + assert(inst); + block->getInstructions().emplace(position, inst); + return inst; + } ///< 创建二元指令 + BinaryInst * createAddInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kAdd, Type::getIntType(), lhs, rhs, name); + } ///< 创建加法指令 + BinaryInst * createSubInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kSub, Type::getIntType(), lhs, rhs, name); + } ///< 创建减法指令 + BinaryInst * createMulInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kMul, Type::getIntType(), lhs, rhs, name); + } ///< 创建乘法指令 + BinaryInst * createDivInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kDiv, Type::getIntType(), lhs, rhs, name); + } ///< 创建除法指令 + BinaryInst * createRemInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kRem, Type::getIntType(), lhs, rhs, name); + } ///< 创建取余指令 + BinaryInst * createICmpEQInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kICmpEQ, Type::getIntType(), lhs, rhs, name); + } ///< 创建相等设置指令 + BinaryInst * createICmpNEInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kICmpNE, Type::getIntType(), lhs, rhs, name); + } ///< 创建不相等设置指令 + BinaryInst * createICmpLTInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kICmpLT, Type::getIntType(), lhs, rhs, name); + } ///< 创建小于设置指令 + BinaryInst * createICmpLEInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kICmpLE, Type::getIntType(), lhs, rhs, name); + } ///< 创建小于等于设置指令 + BinaryInst * createICmpGTInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kICmpGT, Type::getIntType(), lhs, rhs, name); + } ///< 创建大于设置指令 + BinaryInst * createICmpGEInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kICmpGE, Type::getIntType(), lhs, rhs, name); + } ///< 创建大于等于设置指令 + BinaryInst * createFAddInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kFAdd, Type::getFloatType(), lhs, rhs, name); + } ///< 创建浮点加法指令 + BinaryInst * createFSubInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kFSub, Type::getFloatType(), lhs, rhs, name); + } ///< 创建浮点减法指令 + BinaryInst * createFMulInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kFMul, Type::getFloatType(), lhs, rhs, name); + } ///< 创建浮点乘法指令 + BinaryInst * createFDivInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kFDiv, Type::getFloatType(), lhs, rhs, name); + } ///< 创建浮点除法指令 + BinaryInst * createFCmpEQInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kFCmpEQ, Type::getIntType(), lhs, rhs, name); + } ///< 创建浮点相等设置指令 + BinaryInst * createFCmpNEInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kFCmpNE, Type::getIntType(), lhs, rhs, name); + } ///< 创建浮点不相等设置指令 + BinaryInst * createFCmpLTInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kFCmpLT, Type::getIntType(), lhs, rhs, name); + } ///< 创建浮点小于设置指令 + BinaryInst * createFCmpLEInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kFCmpLE, Type::getIntType(), lhs, rhs, name); + } ///< 创建浮点小于等于设置指令 + BinaryInst * createFCmpGTInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kFCmpGT, Type::getIntType(), lhs, rhs, name); + } ///< 创建浮点大于设置指令 + BinaryInst * createFCmpGEInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kFCmpGE, Type::getIntType(), lhs, rhs, name); + } ///< 创建浮点相大于等于设置指令 + BinaryInst * createAndInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kAnd, Type::getIntType(), lhs, rhs, name); + } ///< 创建按位且指令 + BinaryInst * createOrInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kOr, Type::getIntType(), lhs, rhs, name); + } ///< 创建按位或指令 + CallInst * createCallInst(Function *callee, const std::vector &args, const std::string &name = "") { + std::string newName; + if (name.empty() && callee->getReturnType() != Type::getVoidType()) { + std::stringstream ss; + ss << "%" << tmpIndex; + newName = ss.str(); + tmpIndex++; + } else { + newName = name; + } + + auto inst = new CallInst(callee, args, block, newName); + assert(inst); + block->getInstructions().emplace(position, inst); + return inst; + } ///< 创建Call指令 + ReturnInst * createReturnInst(Value *value = nullptr, const std::string &name = "") { + auto inst = new ReturnInst(value, block, name); + assert(inst); + block->getInstructions().emplace(position, inst); + return inst; + } ///< 创建return指令 + UncondBrInst * createUncondBrInst(BasicBlock *thenBlock, const std::vector &args) { + auto inst = new UncondBrInst(thenBlock, args, block); + assert(inst); + block->getInstructions().emplace(position, inst); + return inst; + } ///< 创建无条件指令 + CondBrInst * createCondBrInst(Value *condition, BasicBlock *thenBlock, BasicBlock *elseBlock, + const std::vector &thenArgs, const std::vector &elseArgs) { + auto inst = new CondBrInst(condition, thenBlock, elseBlock, thenArgs, elseArgs, block); + assert(inst); + block->getInstructions().emplace(position, inst); + return inst; + } ///< 创建条件跳转指令 + AllocaInst * createAllocaInst(Type *type, const std::vector &dims = {}, const std::string &name = "") { + auto inst = new AllocaInst(type, dims, block, name); + assert(inst); + block->getInstructions().emplace(position, inst); + return inst; + } ///< 创建分配指令 + AllocaInst * createAllocaInstWithoutInsert(Type *type, const std::vector &dims = {}, BasicBlock *parent = nullptr, + const std::string &name = "") { + auto inst = new AllocaInst(type, dims, parent, name); + assert(inst); + return inst; + } ///< 创建不插入指令列表的分配指令 + LoadInst * createLoadInst(Value *pointer, const std::vector &indices = {}, const std::string &name = "") { + std::string newName; + if (name.empty()) { + std::stringstream ss; + ss << "%" << tmpIndex; + newName = ss.str(); + tmpIndex++; + } else { + newName = name; + } + + auto inst = new LoadInst(pointer, indices, block, newName); + assert(inst); + block->getInstructions().emplace(position, inst); + return inst; + } ///< 创建load指令 + LaInst * createLaInst(Value *pointer, const std::vector &indices = {}, const std::string &name = "") { + std::string newName; + if (name.empty()) { + std::stringstream ss; + ss << "%" << tmpIndex; + newName = ss.str(); + tmpIndex++; + } else { + newName = name; + } + + auto inst = new LaInst(pointer, indices, block, newName); + assert(inst); + block->getInstructions().emplace(position, inst); + return inst; + } ///< 创建la指令 + GetSubArrayInst * createGetSubArray(LVal *fatherArray, const std::vector &indices, const std::string &name = "") { + assert(fatherArray->getLValNumDims() > indices.size()); + std::vector subDims; + auto dims = fatherArray->getLValDims(); + auto iter = std::next(dims.begin(), indices.size()); + while (iter != dims.end()) { + subDims.emplace_back(*iter); + iter++; + } + + std::string childArrayName; + std::stringstream ss; + ss << "A" + << "%" << tmpIndex; + childArrayName = ss.str(); + tmpIndex++; + + auto fatherArrayValue = dynamic_cast(fatherArray); + auto childArray = new AllocaInst(fatherArrayValue->getType(), subDims, block, childArrayName); + auto inst = new GetSubArrayInst(fatherArray, childArray, indices, block, name); + assert(inst); + block->getInstructions().emplace(position, inst); + return inst; + } ///< 创建获取部分数组指令 + MemsetInst * createMemsetInst(Value *pointer, Value *begin, Value *size, Value *value, const std::string &name = "") { + auto inst = new MemsetInst(pointer, begin, size, value, block, name); + assert(inst); + block->getInstructions().emplace(position, inst); + return inst; + } ///< 创建memset指令 + StoreInst * createStoreInst(Value *value, Value *pointer, const std::vector &indices = {}, + const std::string &name = "") { + auto inst = new StoreInst(value, pointer, indices, block, name); + assert(inst); + block->getInstructions().emplace(position, inst); + return inst; + } ///< 创建store指令 + PhiInst * createPhiInst(Type *type, Value *lhs, BasicBlock *parent, const std::string &name = "") { + auto predNum = parent->getNumPredecessors(); + std::vector rhs; + for (size_t i = 0; i < predNum; i++) { + rhs.push_back(lhs); + } + auto inst = new PhiInst(type, lhs, rhs, lhs, parent, name); + assert(inst); + parent->getInstructions().emplace(parent->begin(), inst); + return inst; + } ///< 创建Phi指令 +}; + +} // namespace sysy diff --git a/src/include/LLVMIRGenerator.h b/src/include/LLVMIRGenerator.h new file mode 100644 index 0000000..e330a4f --- /dev/null +++ b/src/include/LLVMIRGenerator.h @@ -0,0 +1,78 @@ +#pragma once +#include "SysYBaseVisitor.h" +#include "SysYParser.h" +#include "IR.h" +#include "IRBuilder.h" +#include +#include +#include +#include + +class LLVMIRGenerator : public SysYBaseVisitor { +public: + sysy::Module* getIRModule() const { return irModule.get(); } + + std::string generateIR(SysYParser::CompUnitContext* unit); + std::string getIR() const { return irStream.str(); } + +private: + std::unique_ptr irModule; // IR数据结构 + std::stringstream irStream; // 文本输出流 + sysy::IRBuilder irBuilder; // IR构建器 + int tempCounter = 0; + std::string currentVarType; + // std::map symbolTable; + std::map> symbolTable; + std::map tmpTable; + std::vector globalVars; + std::string currentFunction; + std::string currentReturnType; + std::vector breakStack; + std::vector continueStack; + bool hasReturn = false; + + struct LoopLabels { + std::string breakLabel; // break跳转的目标标签 + std::string continueLabel; // continue跳转的目标标签 + }; + std::stack loopStack; // 用于管理循环的break和continue标签 + std::string getNextTemp(); + std::string getLLVMType(const std::string&); + sysy::Type* getSysYType(const std::string&); + + bool inFunction = false; // 标识当前是否处于函数内部 + + // 访问方法 + std::any visitCompUnit(SysYParser::CompUnitContext* ctx); + std::any visitConstDecl(SysYParser::ConstDeclContext* ctx); + std::any visitVarDecl(SysYParser::VarDeclContext* ctx); + std::any visitVarDef(SysYParser::VarDefContext* ctx); + std::any visitFuncDef(SysYParser::FuncDefContext* ctx); + std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx); + // std::any visitStmt(SysYParser::StmtContext* ctx); + std::any visitLValue(SysYParser::LValueContext* ctx); + std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx); + std::any visitPrimExp(SysYParser::PrimExpContext* ctx); + std::any visitParenExp(SysYParser::ParenExpContext* ctx); + std::any visitNumber(SysYParser::NumberContext* ctx); + std::any visitString(SysYParser::StringContext* ctx); + std::any visitCall(SysYParser::CallContext *ctx); + std::any visitUnExp(SysYParser::UnExpContext* ctx); + std::any visitMulExp(SysYParser::MulExpContext* ctx); + std::any visitAddExp(SysYParser::AddExpContext* ctx); + std::any visitRelExp(SysYParser::RelExpContext* ctx); + std::any visitEqExp(SysYParser::EqExpContext* ctx); + std::any visitLAndExp(SysYParser::LAndExpContext* ctx); + std::any visitLOrExp(SysYParser::LOrExpContext* ctx); + std::any visitAssignStmt(SysYParser::AssignStmtContext *ctx) override; + std::any visitIfStmt(SysYParser::IfStmtContext *ctx) override; + std::any visitWhileStmt(SysYParser::WhileStmtContext *ctx) override; + std::any visitBreakStmt(SysYParser::BreakStmtContext *ctx) override; + std::any visitContinueStmt(SysYParser::ContinueStmtContext *ctx) override; + std::any visitReturnStmt(SysYParser::ReturnStmtContext *ctx) override; + + // 统一创建二元操作(同时生成数据结构和文本) + sysy::Value* createBinaryOp(SysYParser::ExpContext* lhs, + SysYParser::ExpContext* rhs, + sysy::Instruction::Kind opKind); +}; \ No newline at end of file diff --git a/src/SysYFormatter.h b/src/include/SysYFormatter.h similarity index 100% rename from src/SysYFormatter.h rename to src/include/SysYFormatter.h diff --git a/src/include/SysYIRAnalyser.h b/src/include/SysYIRAnalyser.h new file mode 100644 index 0000000..e69de29 diff --git a/src/include/SysYIRGenerator.h b/src/include/SysYIRGenerator.h new file mode 100644 index 0000000..445a856 --- /dev/null +++ b/src/include/SysYIRGenerator.h @@ -0,0 +1,138 @@ +#pragma once +#include "IR.h" +#include "IRBuilder.h" +#include "SysYBaseVisitor.h" +#include "SysYParser.h" +#include +#include +#include + +namespace sysy { + + +// @brief 用于存储数组值的树结构 +// 多位数组本质上是一维数组的嵌套可以用树来表示。 +class ArrayValueTree { +private: + Value *value = nullptr; /// 该节点存储的value + std::vector> children; /// 子节点列表 + +public: + ArrayValueTree() = default; + +public: + auto getValue() const -> Value * { return value; } + auto getChildren() const + -> const std::vector> & { + return children; + } + + void setValue(Value *newValue) { value = newValue; } + void addChild(ArrayValueTree *newChild) { children.emplace_back(newChild); } + void addChildren(const std::vector &newChildren) { + for (const auto &child : newChildren) { + children.emplace_back(child); + } + } +}; + + +class Utils { +public: + // transform a tree of ArrayValueTree to a ValueCounter + static void tree2Array(Type *type, ArrayValueTree *root, + const std::vector &dims, unsigned numDims, + ValueCounter &result, IRBuilder *builder); + static void + createExternalFunction(const std::vector ¶mTypes, + const std::vector ¶mNames, + const std::vector> ¶mDims, + Type *returnType, const std::string &funcName, + Module *pModule, IRBuilder *pBuilder); + + static void initExternalFunction(Module *pModule, IRBuilder *pBuilder); +}; + +class SysYIRGenerator : public SysYBaseVisitor { + +private: + std::unique_ptr module; + IRBuilder builder; + +public: + SysYIRGenerator() = default; + +public: + Module *get() const { return module.get(); } + IRBuilder *getBuilder(){ return &builder; } +public: + std::any visitCompUnit(SysYParser::CompUnitContext *ctx) override; + + std::any visitGlobalConstDecl(SysYParser::GlobalConstDeclContext *ctx) override; + std::any visitGlobalVarDecl(SysYParser::GlobalVarDeclContext *ctx) override; + + // std::any visitDecl(SysYParser::DeclContext *ctx) override; + std::any visitConstDecl(SysYParser::ConstDeclContext *ctx) override; + std::any visitVarDecl(SysYParser::VarDeclContext *ctx) override; + + std::any visitBType(SysYParser::BTypeContext *ctx) override; + + // std::any visitConstDef(SysYParser::ConstDefContext *ctx) override; + // std::any visitVarDef(SysYParser::VarDefContext *ctx) override; + + std::any visitScalarInitValue(SysYParser::ScalarInitValueContext *ctx) override; + std::any visitArrayInitValue(SysYParser::ArrayInitValueContext *ctx) override; + + std::any visitConstScalarInitValue(SysYParser::ConstScalarInitValueContext *ctx) override; + std::any visitConstArrayInitValue(SysYParser::ConstArrayInitValueContext *ctx) override; + + // std::any visitConstInitVal(SysYParser::ConstInitValContext *ctx) override; + std::any visitFuncType(SysYParser::FuncTypeContext* ctx) override; + std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override; + // std::any visitInitVal(SysYParser::InitValContext *ctx) override; + // std::any visitFuncFParam(SysYParser::FuncFParamContext *ctx) override; + // std::any visitFuncFParams(SysYParser::FuncFParamsContext *ctx) override; + + std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override; + // std::any visitStmt(SysYParser::StmtContext *ctx) override; + std::any visitAssignStmt(SysYParser::AssignStmtContext *ctx) override; + // std::any visitExpStmt(SysYParser::ExpStmtContext *ctx) override; + // std::any visitBlkStmt(SysYParser::BlkStmtContext *ctx) override; + std::any visitIfStmt(SysYParser::IfStmtContext *ctx) override; + std::any visitWhileStmt(SysYParser::WhileStmtContext *ctx) override; + std::any visitBreakStmt(SysYParser::BreakStmtContext *ctx) override; + std::any visitContinueStmt(SysYParser::ContinueStmtContext *ctx) override; + std::any visitReturnStmt(SysYParser::ReturnStmtContext *ctx) override; + + // std::any visitExp(SysYParser::ExpContext *ctx) override; + // std::any visitCond(SysYParser::CondContext *ctx) override; + + std::any visitLValue(SysYParser::LValueContext *ctx) override; + + std::any visitPrimaryExp(SysYParser::PrimaryExpContext *ctx) override; + + // std::any visitParenExp(SysYParser::ParenExpContext *ctx) override; + std::any visitNumber(SysYParser::NumberContext *ctx) override; + // std::any visitString(SysYParser::StringContext *ctx) override; + + std::any visitCall(SysYParser::CallContext *ctx) override; + + std::any visitUnaryExp(SysYParser::UnaryExpContext *ctx) override; + // std::any visitUnaryOp(SysYParser::UnaryOpContext *ctx) override; + + // std::any visitUnExp(SysYParser::UnExpContext *ctx) override; + + std::any visitFuncRParams(SysYParser::FuncRParamsContext *ctx) override; + std::any visitMulExp(SysYParser::MulExpContext *ctx) override; + std::any visitAddExp(SysYParser::AddExpContext *ctx) override; + std::any visitRelExp(SysYParser::RelExpContext *ctx) override; + std::any visitEqExp(SysYParser::EqExpContext *ctx) override; + std::any visitLAndExp(SysYParser::LAndExpContext *ctx) override; + std::any visitLOrExp(SysYParser::LOrExpContext *ctx) override; + + // std::any visitConstExp(SysYParser::ConstExpContext *ctx) override; + + +}; // class SysYIRGenerator + +} // namespace sysy \ No newline at end of file diff --git a/src/range.h b/src/include/range.h similarity index 100% rename from src/range.h rename to src/include/range.h diff --git a/src/sysyc.cpp b/src/sysyc.cpp index 6222499..085960e 100644 --- a/src/sysyc.cpp +++ b/src/sysyc.cpp @@ -6,8 +6,7 @@ using namespace std; #include "SysYLexer.h" #include "SysYParser.h" using namespace antlr4; -#include "ASTPrinter.h" -#include "Backend.h" +// #include "Backend.h" #include "SysYIRGenerator.h" #include "RISCv32Backend.h" using namespace sysy; @@ -70,12 +69,6 @@ int main(int argc, char **argv) { return EXIT_SUCCESS; } - // pretty format the input file - if (argFormat) { - ASTPrinter printer; - printer.visitCompUnit(moduleAST); - return EXIT_SUCCESS; - } // visit AST to generate IR SysYIRGenerator generator;