From 1aa785efc3210f4cc51c43fab8e5f4ed0fc3d7e9 Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Mon, 16 Jun 2025 20:56:32 +0800 Subject: [PATCH 01/13] add arraytype def --- src/IR.cpp | 2 ++ src/IR.h | 49 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/src/IR.cpp b/src/IR.cpp index 318b8c8..529e903 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -91,6 +91,8 @@ int Type::getSize() const { return 8; case kVoid: return 0; + case kArray: + return asArrayType()->getArraySize(); } return 0; } diff --git a/src/IR.h b/src/IR.h index 184c7f4..1483395 100644 --- a/src/IR.h +++ b/src/IR.h @@ -42,6 +42,7 @@ public: kLabel, kPointer, kFunction, + kArray, }; Kind kind; @@ -57,6 +58,9 @@ public: static Type *getPointerType(Type *baseType); static Type *getFunctionType(Type *returnType, const std::vector ¶mTypes = {}); + static Type *getArrayType(Type *elementType, const std::vector &dims = {}) { + return ArrayType::get(elementType, dims); + } public: Kind getKind() const { return kind; } @@ -66,8 +70,12 @@ public: bool isLabel() const { return kind == kLabel; } bool isPointer() const { return kind == kPointer; } bool isFunction() const { return kind == kFunction; } + bool isArray() const { return kind == kArray; } bool isIntOrFloat() const { return kind == kInt or kind == kFloat; } int getSize() const; + ArrayType *asArrayType() const { + return isArray() ? static_cast(const_cast(this)) : nullptr; + } template std::enable_if_t, T *> as() const { return dynamic_cast(const_cast(this)); @@ -110,6 +118,47 @@ public: int getNumParams() const { return paramTypes.size(); } }; // class FunctionType +class ArrayType : public Type { +private: + Type *elementType; // 数组元素类型 + std::vector dimensions; // 维度信息(空向量表示未知大小) + +protected: + ArrayType(Type *elemType, const std::vector &dims = {}) + : Type(kArray), elementType(elemType), dimensions(dims) { + // 确保元素类型有效 + assert(elemType && "Array element type cannot be null"); + assert(!elemType->isVoid() && "Cannot have array of void"); + assert(!elemType->isLabel() && "Cannot have array of labels"); + } + +public: + // 获取数组类型(带缓存机制) + static ArrayType *get(Type *elemType, const std::vector &dims = {}) { + // 实现类型缓存池(避免重复创建) + static std::map>, ArrayType*> cache; + + auto key = std::make_pair(elemType, dims); + if (cache.find(key) == cache.end()) { + cache[key] = new ArrayType(elemType, dims); + } + return cache[key]; + } + + Type *getElementType() const { return elementType; } + const std::vector& getDimensions() const { return dimensions; } + size_t getNumDimensions() const { return dimensions.size(); } + + int getArraySize() const { + int size = elementType->getSize(); + for (int dim : dimensions) { + size *= dim; + } + return size; + } + +};//class ArrayType + /*! * @} */ From 1de8c0e7d79710acf3ee9ac6952e403e2a4db108 Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Thu, 19 Jun 2025 00:18:58 +0800 Subject: [PATCH 02/13] =?UTF-8?q?=E5=BC=95=E5=85=A5=E4=BA=86=E5=B8=B8?= =?UTF-8?q?=E9=87=8F=E6=B1=A0=E4=BC=98=E5=8C=96=EF=BC=8C=E4=BF=AE=E6=94=B9?= =?UTF-8?q?constvalue=E7=B1=BB=E5=B9=B6=E5=AF=B9IR=E7=94=9F=E6=88=90?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=EF=BC=8C=E8=83=BD=E5=A4=9F=E7=BC=96=E8=AF=91?= =?UTF-8?q?=E9=80=9A=E8=BF=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/IR.cpp | 99 ++++++++++++++++++++------ src/IR.h | 145 +++++++++++++++++++++++++++++--------- src/LLVMIRGenerator_1.cpp | 10 +-- src/SysYIRGenerator.cpp | 4 +- 4 files changed, 194 insertions(+), 64 deletions(-) diff --git a/src/IR.cpp b/src/IR.cpp index 529e903..f961d53 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -13,6 +13,8 @@ #include #include #include +#include +#include using namespace std; namespace sysy { @@ -80,6 +82,15 @@ Type *Type::getFunctionType(Type *returnType, return FunctionType::get(returnType, paramTypes); } +Type *Type::getArrayType(Type *elementType, const vector &dims) { + // forward to ArrayType + return ArrayType::get(elementType, dims); +} + +ArrayType* Type::asArrayType() const { + return isArray() ? dynamic_cast(const_cast(this)) : nullptr; +} + int Type::getSize() const { switch (kind) { case kInt: @@ -177,33 +188,75 @@ bool Value::isConstant() const { return false; } -ConstantValue *ConstantValue::get(int value) { - static std::map> intConstants; - auto iter = intConstants.find(value); - if (iter != intConstants.end()) - return iter->second.get(); - auto constant = new ConstantValue(value); - assert(constant); - auto result = intConstants.emplace(value, constant); - return result.first->second.get(); + +// 定义静态常量池 +std::unordered_map ConstantValue::constantPool; + +// 常量池实现 +ConstantValue* ConstantValue::get(Type* type, int32_t value) { + ConstantValueKey key = {type, ConstantValVariant(value)}; + + if (auto it = constantPool.find(key); it != constantPool.end()) { + return it->second; + } + + ConstantValue* constant = new ConstantInt(type, value); + constantPool[key] = constant; + return constant; } -ConstantValue *ConstantValue::get(float value) { - static std::map> floatConstants; - auto iter = floatConstants.find(value); - if (iter != floatConstants.end()) - return iter->second.get(); - auto constant = new ConstantValue(value); - assert(constant); - auto result = floatConstants.emplace(value, constant); - return result.first->second.get(); +ConstantValue* ConstantValue::get(Type* type, float value) { + ConstantValueKey key = {type, ConstantValVariant(value)}; + + if (auto it = constantPool.find(key); it != constantPool.end()) { + return it->second; + } + + ConstantValue* constant = new ConstantFloat(type, value); + constantPool[key] = constant; + return constant; } -void ConstantValue::print(ostream &os) const { - if (isInt()) - os << getInt(); - else - os << getFloat(); +ConstantValue* ConstantValue::getInt32(int32_t value) { + return get(Type::getIntType(), value); +} + +ConstantValue* ConstantValue::getFloat32(float value) { + return get(Type::getFloatType(), value); +} + +ConstantValue* ConstantValue::getTrue() { + return get(Type::getIntType(), 1); +} + +ConstantValue* ConstantValue::getFalse() { + return get(Type::getIntType(), 0); +} + + + +void ConstantValue::print(std::ostream &os) const { + // 根据类型调用相应的打印实现 + if (auto intConst = dynamic_cast(this)) { + intConst->print(os); + } + else if (auto floatConst = dynamic_cast(this)) { + floatConst->print(os); + } + else { + os << "???"; // 未知常量类型 + } +} + +void ConstantInt::print(std::ostream &os) const { + os << value; +} +void ConstantFloat::print(std::ostream &os) const { + if (value == static_cast(value)) { + os << value << ".0"; // 确保输出带小数点 + } else { + os << std::fixed << std::setprecision(6) << value; + } } Argument::Argument(Type *type, BasicBlock *block, int index, diff --git a/src/IR.h b/src/IR.h index 1483395..7d228ad 100644 --- a/src/IR.h +++ b/src/IR.h @@ -11,6 +11,9 @@ #include #include #include +#include +#include +#include namespace sysy { @@ -33,6 +36,9 @@ namespace sysy { * include `int`, `float`, `void`, and the label type representing branch * targets */ + +class ArrayType; + class Type { public: enum Kind { @@ -58,9 +64,7 @@ public: static Type *getPointerType(Type *baseType); static Type *getFunctionType(Type *returnType, const std::vector ¶mTypes = {}); - static Type *getArrayType(Type *elementType, const std::vector &dims = {}) { - return ArrayType::get(elementType, dims); - } + static Type *getArrayType(Type *elementType, const std::vector &dims = {}); public: Kind getKind() const { return kind; } @@ -73,9 +77,9 @@ public: bool isArray() const { return kind == kArray; } bool isIntOrFloat() const { return kind == kInt or kind == kFloat; } int getSize() const; - ArrayType *asArrayType() const { - return isArray() ? static_cast(const_cast(this)) : nullptr; - } + + ArrayType* asArrayType() const; + template std::enable_if_t, T *> as() const { return dynamic_cast(const_cast(this)); @@ -335,41 +339,114 @@ public: * `ConstantValue`s are not defined by instructions, and do not use any other * `Value`s. It's type is either `int` or `float`. */ + +class ConstantInt; +class ConstantFloat; +//常量池优化 + +using ConstantValVariant = std::variant; +using ConstantValueKey = std::pair; + class ConstantValue : public Value { protected: - union { - int iScalar; - float fScalar; - }; - -protected: - ConstantValue(int value) - : Value(kConstant, Type::getIntType(), ""), iScalar(value) {} - ConstantValue(float value) - : Value(kConstant, Type::getFloatType(), ""), fScalar(value) {} - + ConstantValue(Type* type) + : Value(kConstant, type, "") {} public: - static ConstantValue *get(int value); - static ConstantValue *get(float value); - -public: - static bool classof(const Value *value) { + struct ConstantValueHash; + struct ConstantValueEqual; + + static std::unordered_map constantPool; + + virtual ~ConstantValue() = default; + + static ConstantValue* get(Type* type, int32_t value); + static ConstantValue* get(Type* type, float value); + + static bool classof(const Value* value) { return value->getKind() == kConstant; } + + virtual int32_t getInt() const = 0; + virtual float getFloat() const = 0; + virtual bool isZero() const = 0; + virtual bool isOne() const = 0; + + + static ConstantValue* getInt32(int32_t value); + static ConstantValue* getFloat32(float value); + static ConstantValue* getTrue() ; + static ConstantValue* getFalse(); -public: - int getInt() const { - assert(isInt()); - return iScalar; - } - float getFloat() const { - assert(isFloat()); - return fScalar; - } - -public: void print(std::ostream &os) const override; -}; // class ConstantValue +}; + +struct ConstantValue::ConstantValueHash { + std::size_t operator()(const ConstantValueKey& key) const { + std::size_t typeHash = std::hash{}(key.first); + std::size_t valHash = 0; + if (key.first->isInt()) { + valHash = std::hash{}(std::get(key.second)); + } else if (key.first->isFloat()) { + // 修复5: 确保float哈希正确 + valHash = std::hash{}(std::get(key.second)); + } + return typeHash ^ (valHash << 1); + } +}; + +struct ConstantValue::ConstantValueEqual { + bool operator()(const ConstantValueKey& lhs, const ConstantValueKey& rhs) const { + if (lhs.first != rhs.first) return false; + if (lhs.first->isInt()) { + return std::get(lhs.second) == std::get(rhs.second); + } else if (lhs.first->isFloat()) { + // 修复6: 使用浮点比较容差 + const float eps = 1e-6; + return fabs(std::get(lhs.second) - std::get(rhs.second)) < eps; + } + return false; + } +}; + +class ConstantInt : public ConstantValue { + int32_t value; + friend class ConstantValue; + +protected: + ConstantInt(Type* type, int32_t value) + : ConstantValue(type), value(value) { + assert(type->isInt() && "Invalid type for ConstantInt"); + } +public: + static ConstantInt* get(Type* type, int32_t value); + + int32_t getInt() const override { return value; } + float getFloat() const override { return static_cast(value); } + bool isZero() const override { return value == 0; } + bool isOne() const override { return value == 1; } + + void print(std::ostream& os) const override ; +}; + +class ConstantFloat : public ConstantValue { + float value; + friend class ConstantValue; + +protected: + ConstantFloat(Type* type, float value) + : ConstantValue(type), value(value) { + assert(type->isFloat() && "Invalid type for ConstantFloat"); + } +public: + static ConstantFloat* get(Type* type, float value); + + int32_t getInt() const override { return static_cast(value); } + float getFloat() const override { return value; } + bool isZero() const override { return value == 0.0f; } + bool isOne() const override { return value == 1.0f; } + + void print(std::ostream& os) const override; +}; class BasicBlock; /*! diff --git a/src/LLVMIRGenerator_1.cpp b/src/LLVMIRGenerator_1.cpp index 515b5a2..965eb82 100644 --- a/src/LLVMIRGenerator_1.cpp +++ b/src/LLVMIRGenerator_1.cpp @@ -91,7 +91,7 @@ std::any LLVMIRGenerator::visitVarDecl(SysYParser::VarDeclContext* ctx) { if (varDef->ASSIGN()) { value = std::any_cast(varDef->initVal()->accept(this)); - if (irTmpTable.find(value) != irTmpTable.end() && isa(irTmpTable[value])) { + if (irTmpTable.find(value) != irTmpTable.end() && sysy::isa(irTmpTable[value])) { initValue = irTmpTable[value]; } } @@ -134,7 +134,7 @@ std::any LLVMIRGenerator::visitConstDecl(SysYParser::ConstDeclContext* ctx) { try { value = std::any_cast(constDef->constInitVal()->accept(this)); - if (isa(irTmpTable[value])) { + if (sysy::isa(irTmpTable[value])) { initValue = irTmpTable[value]; } } catch (...) { @@ -310,7 +310,7 @@ std::any LLVMIRGenerator::visitFuncDef(SysYParser::FuncDefContext* ctx) { } else { irStream << " ret " << currentReturnType << " 0\n"; sysy::IRBuilder builder(currentIRBlock); - builder.createReturnInst(sysy::ConstantValue::get(0)); + builder.createReturnInst(sysy::ConstantValue::get(getIRType("int"),0)); } } irStream << "}\n"; @@ -524,10 +524,10 @@ std::any LLVMIRGenerator::visitNumber(SysYParser::NumberContext* ctx) { sysy::Value* irValue = nullptr; if (ctx->ILITERAL()) { value = ctx->ILITERAL()->getText(); - irValue = sysy::ConstantValue::get(std::stoi(value)); + irValue = sysy::ConstantValue::get(getIRType("int"), std::stoi(value)); } else if (ctx->FLITERAL()) { value = ctx->FLITERAL()->getText(); - irValue = sysy::ConstantValue::get(std::stof(value)); + irValue = sysy::ConstantValue::get(getIRType("float"), std::stof(value)); } else { value = ""; } diff --git a/src/SysYIRGenerator.cpp b/src/SysYIRGenerator.cpp index a2962c7..1844a53 100644 --- a/src/SysYIRGenerator.cpp +++ b/src/SysYIRGenerator.cpp @@ -552,10 +552,10 @@ std::any SysYIRGenerator::visitNumber(SysYParser::NumberContext *ctx) { } else if (text.find("0") == 0) { base = 8; } - res = ConstantValue::get((int)std::stol(text, 0, base)); + res = ConstantValue::get(Type::getIntType() ,(int)std::stol(text, 0, base)); } else if (auto fLiteral = ctx->FLITERAL()) { const auto text = fLiteral->getText(); - res = ConstantValue::get((float)std::stof(text)); + res = ConstantValue::get(Type::getFloatType(), (float)std::stof(text)); } cout << "number: "; res->print(cout); From c54543bff3b33b16700c8a20d2045ae85469c31c Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Fri, 20 Jun 2025 22:46:04 +0800 Subject: [PATCH 03/13] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E7=9B=AE=E5=BD=95?= =?UTF-8?q?=E7=BB=93=E6=9E=84=EF=BC=8C=E4=BF=AE=E6=94=B9IR=E7=BB=93?= =?UTF-8?q?=E6=9E=84=EF=BC=8C=E9=83=A8=E5=88=86=E4=BF=AE=E5=A4=8DIR?= =?UTF-8?q?=E7=94=9F=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- TODO.md | 86 ++ src/CMakeLists.txt | 2 +- src/IR.cpp | 832 ++++++---------- src/IR.h | 1120 --------------------- src/SysYIRAnalyser.cpp | 0 src/SysYIRGenerator.cpp | 200 ---- src/{ => include}/ASTPrinter.h | 0 src/{ => include}/Backend.h | 0 src/include/IR.h | 1319 +++++++++++++++++++++++++ src/{ => include}/IRBuilder.h | 21 +- src/{ => include}/LLVMIRGenerator.h | 0 src/{ => include}/LLVMIRGenerator_1.h | 0 src/{ => include}/SysYFormatter.h | 0 src/include/SysYIRAnalyser.h | 0 src/{ => include}/SysYIRGenerator.h | 9 - src/{ => include}/range.h | 0 16 files changed, 1699 insertions(+), 1890 deletions(-) create mode 100644 TODO.md delete mode 100644 src/IR.h create mode 100644 src/SysYIRAnalyser.cpp rename src/{ => include}/ASTPrinter.h (100%) rename src/{ => include}/Backend.h (100%) create mode 100644 src/include/IR.h rename src/{ => include}/IRBuilder.h (91%) rename src/{ => include}/LLVMIRGenerator.h (100%) rename src/{ => include}/LLVMIRGenerator_1.h (100%) rename src/{ => include}/SysYFormatter.h (100%) create mode 100644 src/include/SysYIRAnalyser.h rename src/{ => include}/SysYIRGenerator.h (90%) rename src/{ => include}/range.h (100%) diff --git a/TODO.md b/TODO.md new file mode 100644 index 0000000..c386b0d --- /dev/null +++ b/TODO.md @@ -0,0 +1,86 @@ +要打通从SysY到RISC-V的完整编译流程,以下是必须实现的核心模块和关键步骤(按编译流程顺序)。在你们当前IR生成阶段,可以优先实现这些基础模块来快速获得可工作的RISC-V汇编输出: + +### 1. **前端必须模块** +- **词法/语法分析**(已完成): + - `SysYLexer`/`SysYParser`:ANTLR生成的解析器 +- **IR生成核心**: + - `SysYIRGenerator`:将AST转换为中间表示(IR) + - `IRBuilder`:构建指令和基本块的工具类(你们正在实现的部分) + +### 2. **中端必要优化(最小集合)** +| 优化阶段 | 关键作用 | 是否必须 | +|-------------------|----------------------------------|----------| +| `Mem2Reg` | 消除冗余内存访问,转换为SSA形式 | ✅ 核心 | +| `DCE` (死代码消除) | 移除无用指令 | ✅ 必要 | +| `DFE` (死函数消除) | 移除未使用的函数 | ✅ 必要 | +| `FuncAnalysis` | 函数调用关系分析 | ✅ 基础 | +| `Global2Local` | 全局变量降级为局部变量 | ✅ 重要 | + +### 3. **后端核心流程(必须实现)** +```mermaid +graph LR + A[IR指令选择] --> B[寄存器分配] + B --> C[指令调度] + C --> D[汇编生成] +``` + +1. **指令选择**(关键步骤): + - `DAGBuilder`:将IR转换为有向无环图(DAG) + - `DAGCoverage`:DAG到目标指令的映射 + - `Mid2End`:IR到机器指令的转换接口 + +2. **寄存器分配**: + - `RegisterAlloc`:基础寄存器分配器(可先实现简单算法如线性扫描) + +3. **汇编生成**: + - `RiscvPrinter`:将机器指令输出为RISC-V汇编 + - 实现基础指令集:`add`/`sub`/`lw`/`sw`/`beq`/`jal`等 + +### 4. **最小可工作流程** +```cpp +// 精简版编译流程(跳过复杂优化) +int main() { + // 1. 前端解析 + auto module = sysy::SysYIRGenerator().genIR(input); + + // 2. 关键中端优化 + sysy::Mem2Reg(module).run(); // 必须 + sysy::Global2Local(module).run(); // 必须 + sysy::DCE(module).run(); // 推荐 + + // 3. 后端代码生成 + auto backendModule = mid2end::CodeGenerater().run(module); + riscv::RiscvPrinter().print("output.s", backendModule); +} +``` + +### 5. **当前开发优先级建议** +1. **完成IR生成**: + - 确保能构建基本块、函数、算术/内存/控制流指令 + - 实现`createCall`/`createLoad`/`createStore`等核心方法 + +2. **实现Mem2Reg**: + - 插入Phi节点 + - 变量重命名(关键算法) + +3. **构建基础后端**: + - 指令选择:实现IR到RISC-V的简单映射(例如:`IRAdd` → `add`) + - 寄存器分配:使用无限寄存器方案(后期替换为真实分配) + - 汇编打印:支持基础指令输出 + +> **注意**:循环优化、函数内联、高级寄存器分配等可在基础流程打通后逐步添加。初期可跳过复杂优化。 + +### 6. 调试建议 +- 添加IR打印模块(`SysYPrinter`)验证前端输出 +- 使用简化测试用例: + ```c + int main() { + int a = 1; + int b = a + 2; + return b; + } + ``` +- 逐步扩展支持: + 1. 算术运算 → 2. 条件分支 → 3. 函数调用 → 4. 数组访问 + +通过聚焦这些核心模块,你们可以快速打通从SysY到RISC-V的基础编译流程,后续再逐步添加优化传递提升代码质量。 \ No newline at end of file diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index d130bd8..7e8572c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -20,7 +20,7 @@ add_executable(sysyc # LLVMIRGenerator.cpp LLVMIRGenerator_1.cpp ) -target_include_directories(sysyc PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) +target_include_directories(sysyc PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/include) target_compile_options(sysyc PRIVATE -frtti) target_link_libraries(sysyc PRIVATE SysYParser) diff --git a/src/IR.cpp b/src/IR.cpp index f961d53..1564415 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -1,586 +1,310 @@ +#pragma once + #include "IR.h" -#include "range.h" -#include #include -#include -#include -#include -#include -#include #include -#include -#include +#include #include -#include #include -#include -#include -using namespace std; namespace sysy { -template -ostream &interleave(ostream &os, const T &container, const string sep = ", ") { - auto b = container.begin(), e = container.end(); - if (b == e) - return os; - os << *b; - for (b = next(b); b != e; b = next(b)) - os << sep << *b; - return os; -} -static inline ostream &printVarName(ostream &os, const Value *var) { - return os << (dyncast(var) ? '@' : '%') - << var->getName(); -} -static inline ostream &printBlockName(ostream &os, const BasicBlock *block) { - return os << '^' << block->getName(); -} -static inline ostream &printFunctionName(ostream &os, const Function *fn) { - return os << '@' << fn->getName(); -} -static inline ostream &printOperand(ostream &os, const Value *value) { - auto constant = dyncast(value); - if (constant) { - constant->print(os); - return os; - } - return printVarName(os, value); -} -//===----------------------------------------------------------------------===// -// Types -//===----------------------------------------------------------------------===// +class IRBuilder { +private: + unsigned labelIndex; ///< 基本块标签编号 + unsigned tmpIndex; ///< 临时变量编号 -Type *Type::getIntType() { - static Type intType(kInt); - return &intType; -} + BasicBlock *block; ///< 当前基本块 + BasicBlock::iterator position; ///< 当前基本块指令列表位置的迭代器 -Type *Type::getFloatType() { - static Type floatType(kFloat); - return &floatType; -} + std::vector trueBlocks; ///< true分支基本块列表 + std::vector falseBlocks; ///< false分支基本块列表 -Type *Type::getVoidType() { - static Type voidType(kVoid); - return &voidType; -} + std::vector breakBlocks; ///< break目标块列表 + std::vector continueBlocks; ///< continue目标块列表 -Type *Type::getLabelType() { - static Type labelType(kLabel); - return &labelType; -} +public: + IRBuilder() : labelIndex(0), tmpIndex(0), block(nullptr) {} + explicit IRBuilder(BasicBlock *block) : labelIndex(0), tmpIndex(0), block(block), position(block->end()) {} + IRBuilder(BasicBlock *block, BasicBlock::iterator position) + : labelIndex(0), tmpIndex(0), block(block), position(position) {} -Type *Type::getPointerType(Type *baseType) { - // forward to PointerType - return PointerType::get(baseType); -} - -Type *Type::getFunctionType(Type *returnType, - const vector ¶mTypes) { - // forward to FunctionType - return FunctionType::get(returnType, paramTypes); -} - -Type *Type::getArrayType(Type *elementType, const vector &dims) { - // forward to ArrayType - return ArrayType::get(elementType, dims); -} - -ArrayType* Type::asArrayType() const { - return isArray() ? dynamic_cast(const_cast(this)) : nullptr; -} - -int Type::getSize() const { - switch (kind) { - case kInt: - case kFloat: - return 4; - case kLabel: - case kPointer: - case kFunction: - return 8; - case kVoid: - return 0; - case kArray: - return asArrayType()->getArraySize(); - } - return 0; -} - -void Type::print(ostream &os) const { - auto kind = getKind(); - switch (kind) { - case kInt: - os << "int"; - break; - case kFloat: - os << "float"; - break; - case kVoid: - os << "void"; - break; - case kPointer: - static_cast(this)->getBaseType()->print(os); - os << "*"; - break; - case kFunction: - static_cast(this)->getReturnType()->print(os); - os << "("; - interleave(os, static_cast(this)->getParamTypes()); - os << ")"; - break; - case kLabel: - default: - cerr << "Unexpected type!\n"; - break; - } -} - -PointerType *PointerType::get(Type *baseType) { - static std::map> pointerTypes; - auto iter = pointerTypes.find(baseType); - if (iter != pointerTypes.end()) - return iter->second.get(); - auto type = new PointerType(baseType); - assert(type); - auto result = pointerTypes.emplace(baseType, type); - return result.first->second.get(); -} - -FunctionType *FunctionType::get(Type *returnType, - const std::vector ¶mTypes) { - static std::set> functionTypes; - auto iter = - std::find_if(functionTypes.begin(), functionTypes.end(), - [&](const std::unique_ptr &type) -> bool { - if (returnType != type->getReturnType() or - paramTypes.size() != type->getParamTypes().size()) - return false; - return std::equal(paramTypes.begin(), paramTypes.end(), - type->getParamTypes().begin()); - }); - if (iter != functionTypes.end()) - return iter->get(); - auto type = new FunctionType(returnType, paramTypes); - assert(type); - auto result = functionTypes.emplace(type); - return result.first->get(); -} - -void Value::replaceAllUsesWith(Value *value) { - for (auto &use : uses) - use->getUser()->setOperand(use->getIndex(), value); - uses.clear(); -} - -bool Value::isConstant() const { - if (dyncast(this)) - return true; - if (dyncast(this) or - dyncast(this)) - return true; - // if (auto array = dyncast(this)) { - // auto elements = array->getValues(); - // return all_of(elements.begin(), elements.end(), - // [](Value *v) -> bool { return v->isConstant(); }); - // } - return false; -} - - -// 定义静态常量池 -std::unordered_map ConstantValue::constantPool; - -// 常量池实现 -ConstantValue* ConstantValue::get(Type* type, int32_t value) { - ConstantValueKey key = {type, ConstantValVariant(value)}; +public: + unsigned getLabelIndex() { return labelIndex++; } + unsigned getTmpIndex() { return tmpIndex++; } - if (auto it = constantPool.find(key); it != constantPool.end()) { - return it->second; - } + BasicBlock *getBasicBlock() const { return block; } + BasicBlock::iterator getPosition() const { return position; } - ConstantValue* constant = new ConstantInt(type, value); - constantPool[key] = constant; - return constant; -} - -ConstantValue* ConstantValue::get(Type* type, float value) { - ConstantValueKey key = {type, ConstantValVariant(value)}; - - if (auto it = constantPool.find(key); it != constantPool.end()) { - return it->second; + void setPosition(BasicBlock *block, BasicBlock::iterator position) { + this->block = block; + this->position = position; } - - ConstantValue* constant = new ConstantFloat(type, value); - constantPool[key] = constant; - return constant; -} + void setPosition(BasicBlock::iterator position) { this->position = position; } -ConstantValue* ConstantValue::getInt32(int32_t value) { - return get(Type::getIntType(), value); -} - -ConstantValue* ConstantValue::getFloat32(float value) { - return get(Type::getFloatType(), value); -} - -ConstantValue* ConstantValue::getTrue() { - return get(Type::getIntType(), 1); -} - -ConstantValue* ConstantValue::getFalse() { - return get(Type::getIntType(), 0); -} - - - -void ConstantValue::print(std::ostream &os) const { - // 根据类型调用相应的打印实现 - if (auto intConst = dynamic_cast(this)) { - intConst->print(os); - } - else if (auto floatConst = dynamic_cast(this)) { - floatConst->print(os); + // 控制流管理函数 + BasicBlock *getBreakBlock() const { return breakBlocks.back(); } + BasicBlock *popBreakBlock() { + auto result = breakBlocks.back(); + breakBlocks.pop_back(); + return result; } - else { - os << "???"; // 未知常量类型 + BasicBlock *getContinueBlock() const { return continueBlocks.back(); } + BasicBlock *popContinueBlock() { + auto result = continueBlocks.back(); + continueBlocks.pop_back(); + return result; } -} - -void ConstantInt::print(std::ostream &os) const { - os << value; -} -void ConstantFloat::print(std::ostream &os) const { - if (value == static_cast(value)) { - os << value << ".0"; // 确保输出带小数点 - } else { - os << std::fixed << std::setprecision(6) << value; + BasicBlock *getTrueBlock() const { return trueBlocks.back(); } + BasicBlock *getFalseBlock() const { return falseBlocks.back(); } + BasicBlock *popTrueBlock() { + auto result = trueBlocks.back(); + trueBlocks.pop_back(); + return result; } -} + BasicBlock *popFalseBlock() { + auto result = falseBlocks.back(); + falseBlocks.pop_back(); + return result; + } + void pushBreakBlock(BasicBlock *block) { breakBlocks.push_back(block); } + void pushContinueBlock(BasicBlock *block) { continueBlocks.push_back(block); } + void pushTrueBlock(BasicBlock *block) { trueBlocks.push_back(block); } + void pushFalseBlock(BasicBlock *block) { falseBlocks.push_back(block); } -Argument::Argument(Type *type, BasicBlock *block, int index, - const std::string &name) - : Value(kArgument, type, name), block(block), index(index) { - if (not hasName()) - setName(to_string(block->getParent()->allocateVariableID())); -} +public: + // 指令创建函数 + Instruction *insertInst(Instruction *inst) { + assert(inst); + block->getInstructions().emplace(position, inst); + return inst; + } -void Argument::print(std::ostream &os) const { - assert(hasName()); - printVarName(os, this) << ": " << *getType(); -} + UnaryInst *createUnaryInst(Instruction::Kind kind, Type *type, Value *operand, + const std::string &name = "") { + auto inst = new UnaryInst(kind, type, operand, block, name); + return static_cast(insertInst(inst)); + } -BasicBlock::BasicBlock(Function *parent, const std::string &name) - : Value(kBasicBlock, Type::getLabelType(), name), parent(parent), - instructions(), arguments(), successors(), predecessors() { - if (not hasName()) - setName("bb" + to_string(getParent()->allocateblockID())); -} + UnaryInst *createNegInst(Value *operand, const std::string &name = "") { + return createUnaryInst(Instruction::kNeg, Type::getIntType(), operand, name); + } -void BasicBlock::print(std::ostream &os) const { - assert(hasName()); - os << " "; - printBlockName(os, this); - auto args = getArguments(); - auto b = args.begin(), e = args.end(); - if (b != e) { - os << '('; - printVarName(os, b->get()) << ": " << *b->get()->getType(); - for (auto &arg : make_range(std::next(b), e)) { - os << ", "; - printVarName(os, arg.get()) << ": " << *arg->getType(); + UnaryInst *createNotInst(Value *operand, const std::string &name = "") { + return createUnaryInst(Instruction::kNot, Type::getIntType(), operand, name); + } + + UnaryInst *createFtoIInst(Value *operand, const std::string &name = "") { + return createUnaryInst(Instruction::kFtoI, Type::getIntType(), operand, name); + } + + UnaryInst *createBitFtoIInst(Value *operand, const std::string &name = "") { + return createUnaryInst(Instruction::kBitFtoI, Type::getIntType(), operand, name); + } + + UnaryInst *createFNegInst(Value *operand, const std::string &name = "") { + return createUnaryInst(Instruction::kFNeg, Type::getFloatType(), operand, name); + } + + UnaryInst *createFNotInst(Value *operand, const std::string &name = "") { + return createUnaryInst(Instruction::kFNot, Type::getIntType(), operand, name); + } + + UnaryInst *createIToFInst(Value *operand, const std::string &name = "") { + return createUnaryInst(Instruction::kItoF, Type::getFloatType(), operand, name); + } + + UnaryInst *createBitItoFInst(Value *operand, const std::string &name = "") { + return createUnaryInst(Instruction::kBitItoF, Type::getFloatType(), operand, name); + } + + BinaryInst *createBinaryInst(Instruction::Kind kind, Type *type, Value *lhs, + Value *rhs, const std::string &name = "") { + auto inst = new BinaryInst(kind, type, lhs, rhs, block, name); + return static_cast(insertInst(inst)); + } + + BinaryInst *createAddInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kAdd, Type::getIntType(), lhs, rhs, name); + } + + BinaryInst *createSubInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kSub, Type::getIntType(), lhs, rhs, name); + } + + BinaryInst *createMulInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kMul, Type::getIntType(), lhs, rhs, name); + } + + BinaryInst *createDivInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kDiv, Type::getIntType(), lhs, rhs, name); + } + + BinaryInst *createRemInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kRem, Type::getIntType(), lhs, rhs, name); + } + + BinaryInst *createICmpEQInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kICmpEQ, Type::getIntType(), lhs, rhs, name); + } + + BinaryInst *createICmpNEInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kICmpNE, Type::getIntType(), lhs, rhs, name); + } + + BinaryInst *createICmpLTInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kICmpLT, Type::getIntType(), lhs, rhs, name); + } + + BinaryInst *createICmpLEInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kICmpLE, Type::getIntType(), lhs, rhs, name); + } + + BinaryInst *createICmpGTInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kICmpGT, Type::getIntType(), lhs, rhs, name); + } + + BinaryInst *createICmpGEInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kICmpGE, Type::getIntType(), lhs, rhs, name); + } + + BinaryInst *createFAddInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kFAdd, Type::getFloatType(), lhs, rhs, name); + } + + BinaryInst *createFSubInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kFSub, Type::getFloatType(), lhs, rhs, name); + } + + BinaryInst *createFMulInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kFMul, Type::getFloatType(), lhs, rhs, name); + } + + BinaryInst *createFDivInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kFDiv, Type::getFloatType(), lhs, rhs, name); + } + + BinaryInst *createFCmpEQInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kFCmpEQ, Type::getIntType(), lhs, rhs, name); + } + + BinaryInst *createFCmpNEInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kFCmpNE, Type::getIntType(), lhs, rhs, name); + } + + BinaryInst *createFCmpLTInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kFCmpLT, Type::getIntType(), lhs, rhs, name); + } + + BinaryInst *createFCmpLEInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kFCmpLE, Type::getIntType(), lhs, rhs, name); + } + + BinaryInst *createFCmpGTInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kFCmpGT, Type::getIntType(), lhs, rhs, name); + } + + BinaryInst *createFCmpGEInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kFCmpGE, Type::getIntType(), lhs, rhs, name); + } + + BinaryInst *createAndInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kAnd, Type::getIntType(), lhs, rhs, name); + } + + BinaryInst *createOrInst(Value *lhs, Value *rhs, const std::string &name = "") { + return createBinaryInst(Instruction::kOr, Type::getIntType(), lhs, rhs, name); + } + + CallInst *createCallInst(Function *callee, const std::vector &args, + const std::string &name = "") { + auto inst = new CallInst(callee, args, block, name); + return static_cast(insertInst(inst)); + } + + ReturnInst *createReturnInst(Value *value = nullptr) { + auto inst = new ReturnInst(value, block); + return static_cast(insertInst(inst)); + } + + UncondBrInst *createUncondBrInst(BasicBlock *thenBlock, + const std::vector &args) { + auto inst = new UncondBrInst(thenBlock, args, block); + return static_cast(insertInst(inst)); + } + + CondBrInst *createCondBrInst(Value *condition, BasicBlock *thenBlock, + BasicBlock *elseBlock, + const std::vector &thenArgs, + const std::vector &elseArgs) { + auto inst = new CondBrInst(condition, thenBlock, elseBlock, thenArgs, elseArgs, block); + return static_cast(insertInst(inst)); + } + + AllocaInst *createAllocaInst(Type *type, const std::vector &dims = {}, + const std::string &name = "") { + auto inst = new AllocaInst(type, dims, block, name); + return static_cast(insertInst(inst)); + } + + AllocaInst *createAllocaInstWithoutInsert(Type *type, + const std::vector &dims = {}, + BasicBlock *parent = nullptr, + const std::string &name = "") { + return new AllocaInst(type, dims, parent, name); + } + + LoadInst *createLoadInst(Value *pointer, const std::vector &indices = {}, + const std::string &name = "") { + auto inst = new LoadInst(pointer, indices, block, name); + return static_cast(insertInst(inst)); + } + + LaInst *createLaInst(Value *pointer, const std::vector &indices = {}, + const std::string &name = "") { + auto inst = new LaInst(pointer, indices, block, name); + return static_cast(insertInst(inst)); + } + + GetSubArrayInst *createGetSubArray(LVal *fatherArray, + const std::vector &indices, + const std::string &name = "") { + assert(fatherArray->getLValNumDims() > indices.size()); + std::vector subDims; + auto dims = fatherArray->getLValDims(); + auto iter = std::next(dims.begin(), indices.size()); + while (iter != dims.end()) { + subDims.emplace_back(*iter); + iter++; } - os << ')'; + + auto fatherArrayValue = dynamic_cast(fatherArray); + AllocaInst * childArray = new AllocaInst(fatherArrayValue->getType(), subDims, block); + auto inst = new GetSubArrayInst(fatherArray, childArray, indices, block ,name); + return static_cast(insertInst(inst)); } - os << ":\n"; - for (auto &inst : instructions) { - os << " " << *inst << '\n'; + + MemsetInst *createMemsetInst(Value *pointer, Value *begin, Value *size, + Value *value, const std::string &name = "") { + auto inst = new MemsetInst(pointer, begin, size, value, block, name); + return static_cast(insertInst(inst)); } -} -Instruction::Instruction(Kind kind, Type *type, BasicBlock *parent, - const std::string &name) - : User(kind, type, name), kind(kind), parent(parent) { - if (not type->isVoid() and not hasName()) - setName(to_string(getFunction()->allocateVariableID())); -} - -void CallInst::print(std::ostream &os) const { - if (not getType()->isVoid()) - printVarName(os, this) << " = call "; - printFunctionName(os, getCallee()) << '('; - auto args = getArguments(); - auto b = args.begin(), e = args.end(); - if (b != e) { - printOperand(os, *b); - for (auto arg : make_range(std::next(b), e)) { - os << ", "; - printOperand(os, arg); - } + StoreInst *createStoreInst(Value *value, Value *pointer, + const std::vector &indices = {}, + const std::string &name = "") { + auto inst = new StoreInst(value, pointer, indices, block, name); + return static_cast(insertInst(inst)); } - os << ") : " << *getType(); -} -void UnaryInst::print(std::ostream &os) const { - printVarName(os, this) << " = "; - switch (getKind()) { - case kNeg: - os << "neg"; - break; - case kNot: - os << "not"; - break; - case kFNeg: - os << "fneg"; - break; - case kFtoI: - os << "ftoi"; - break; - case kIToF: - os << "itof"; - break; - default: - assert(false); + PhiInst *createPhiInst(Type *type, Value *lhs, BasicBlock *parent, + const std::string &name = "") { + auto predNum = parent->getNumPredecessors(); + std::vector rhs(predNum, lhs); + auto inst = new PhiInst(type, lhs, rhs, lhs, parent, name); + parent->getInstructions().emplace(parent->begin(), inst); + return inst; } - printOperand(os, getOperand()) << " : " << *getType(); -} +}; -void BinaryInst::print(std::ostream &os) const { - printVarName(os, this) << " = "; - switch (getKind()) { - case kAdd: - os << "add"; - break; - case kSub: - os << "sub"; - break; - case kMul: - os << "mul"; - break; - case kDiv: - os << "div"; - break; - case kRem: - os << "rem"; - break; - case kICmpEQ: - os << "icmpeq"; - break; - case kICmpNE: - os << "icmpne"; - break; - case kICmpLT: - os << "icmplt"; - break; - case kICmpGT: - os << "icmpgt"; - break; - case kICmpLE: - os << "icmple"; - break; - case kICmpGE: - os << "icmpge"; - break; - case kFAdd: - os << "fadd"; - break; - case kFSub: - os << "fsub"; - break; - case kFMul: - os << "fmul"; - break; - case kFDiv: - os << "fdiv"; - break; - case kFRem: - os << "frem"; - break; - case kFCmpEQ: - os << "fcmpeq"; - break; - case kFCmpNE: - os << "fcmpne"; - break; - case kFCmpLT: - os << "fcmplt"; - break; - case kFCmpGT: - os << "fcmpgt"; - break; - case kFCmpLE: - os << "fcmple"; - break; - case kFCmpGE: - os << "fcmpge"; - break; - default: - assert(false); - } - os << ' '; - printOperand(os, getLhs()) << ", "; - printOperand(os, getRhs()) << " : " << *getType(); -} - -void ReturnInst::print(std::ostream &os) const { - os << "return"; - if (auto value = getReturnValue()) { - os << ' '; - printOperand(os, value) << " : " << *value->getType(); - } -} - -void UncondBrInst::print(std::ostream &os) const { - os << "br "; - printBlockName(os, getBlock()); - auto args = getArguments(); - auto b = args.begin(), e = args.end(); - if (b != e) { - os << '('; - printOperand(os, *b); - for (auto arg : make_range(std::next(b), e)) { - os << ", "; - printOperand(os, arg); - } - os << ')'; - } -} - -void CondBrInst::print(std::ostream &os) const { - os << "condbr "; - printOperand(os, getCondition()) << ", "; - printBlockName(os, getThenBlock()); - { - auto args = getThenArguments(); - auto b = args.begin(), e = args.end(); - if (b != e) { - os << '('; - printOperand(os, *b); - for (auto arg : make_range(std::next(b), e)) { - os << ", "; - printOperand(os, arg); - } - os << ')'; - } - } - os << ", "; - printBlockName(os, getElseBlock()); - { - auto args = getElseArguments(); - auto b = args.begin(), e = args.end(); - if (b != e) { - os << '('; - printOperand(os, *b); - for (auto arg : make_range(std::next(b), e)) { - os << ", "; - printOperand(os, arg); - } - os << ')'; - } - } -} - -void AllocaInst::print(std::ostream &os) const { - if (getNumDims()) - cerr << "not implemented yet\n"; - printVarName(os, this) << " = "; - os << "alloca " - << *static_cast(getType())->getBaseType(); - os << " : " << *getType(); -} - -void LoadInst::print(std::ostream &os) const { - if (getNumIndices()) - cerr << "not implemented yet\n"; - printVarName(os, this) << " = "; - os << "load "; - printOperand(os, getPointer()) << " : " << *getType(); -} - -void StoreInst::print(std::ostream &os) const { - if (getNumIndices()) - cerr << "not implemented yet\n"; - os << "store "; - printOperand(os, getValue()) << ", "; - printOperand(os, getPointer()) << " : " << *getValue()->getType(); -} - -void Function::print(std::ostream &os) const { - auto returnType = getReturnType(); - auto paramTypes = getParamTypes(); - os << *returnType << ' '; - printFunctionName(os, this) << '('; - auto b = paramTypes.begin(), e = paramTypes.end(); - if (b != e) { - os << *(*b); - for (auto type : make_range(std::next(b), e)) - os << ", " << *type; - } - os << ") {\n"; - for (auto &bb : getBasicBlocks()) { - os << *bb << '\n'; - } - os << "}"; -} - -void Module::print(std::ostream &os) const { - for (auto &value : children) - os << *value << '\n'; -} - -// ArrayValue *ArrayValue::get(Type *type, const vector &values) { -// static map, unique_ptr> arrayConstants; -// hash hasher; -// auto key = make_pair( -// type, hasher(string(reinterpret_cast(values.data()), -// values.size() * sizeof(Value *)))); - -// auto iter = arrayConstants.find(key); -// if (iter != arrayConstants.end()) -// return iter->second.get(); -// auto constant = new ArrayValue(type, values); -// assert(constant); -// auto result = arrayConstants.emplace(key, constant); -// return result.first->second.get(); -// } - -// ArrayValue *ArrayValue::get(const std::vector &values) { -// vector vals(values.size(), nullptr); -// std::transform(values.begin(), values.end(), vals.begin(), -// [](int v) { return ConstantValue::get(v); }); -// return get(Type::getIntType(), vals); -// } - -// ArrayValue *ArrayValue::get(const std::vector &values) { -// vector vals(values.size(), nullptr); -// std::transform(values.begin(), values.end(), vals.begin(), -// [](float v) { return ConstantValue::get(v); }); -// return get(Type::getFloatType(), vals); -// } - -void User::setOperand(int index, Value *value) { - assert(index < getNumOperands()); - operands[index].setValue(value); -} - -void User::replaceOperand(int index, Value *value) { - assert(index < getNumOperands()); - auto &use = operands[index]; - use.getValue()->removeUse(&use); - use.setValue(value); -} - -CallInst::CallInst(Function *callee, const std::vector &args, - BasicBlock *parent, const std::string &name) - : Instruction(kCall, callee->getReturnType(), parent, name) { - addOperand(callee); - for (auto arg : args) - addOperand(arg); -} - -Function *CallInst::getCallee() const { - return dyncast(getOperand(0)); -} - -} // namespace sysy \ No newline at end of file +} // namespace sysy diff --git a/src/IR.h b/src/IR.h deleted file mode 100644 index 7d228ad..0000000 --- a/src/IR.h +++ /dev/null @@ -1,1120 +0,0 @@ -#pragma once - -#include "range.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace sysy { - -/*! - * \defgroup type Types - * The SysY type system is quite simple. - * 1. The base class `Type` is used to represent all primitive scalar types, - * include `int`, `float`, `void`, and the label type representing branch - * targets. - * 2. `PointerType` and `FunctionType` derive from `Type` and represent pointer - * type and function type, respectively. - * - * NOTE `Type` and its derived classes have their ctors declared as 'protected'. - * Users must use Type::getXXXType() methods to obtain `Type` pointers. - * @{ - */ - -/*! - * `Type` is used to represent all primitive scalar types, - * include `int`, `float`, `void`, and the label type representing branch - * targets - */ - -class ArrayType; - -class Type { -public: - enum Kind { - kInt, - kFloat, - kVoid, - kLabel, - kPointer, - kFunction, - kArray, - }; - Kind kind; - -protected: - Type(Kind kind) : kind(kind) {} - virtual ~Type() = default; - -public: - static Type *getIntType(); - static Type *getFloatType(); - static Type *getVoidType(); - static Type *getLabelType(); - static Type *getPointerType(Type *baseType); - static Type *getFunctionType(Type *returnType, - const std::vector ¶mTypes = {}); - static Type *getArrayType(Type *elementType, const std::vector &dims = {}); - -public: - Kind getKind() const { return kind; } - bool isInt() const { return kind == kInt; } - bool isFloat() const { return kind == kFloat; } - bool isVoid() const { return kind == kVoid; } - bool isLabel() const { return kind == kLabel; } - bool isPointer() const { return kind == kPointer; } - bool isFunction() const { return kind == kFunction; } - bool isArray() const { return kind == kArray; } - bool isIntOrFloat() const { return kind == kInt or kind == kFloat; } - int getSize() const; - - ArrayType* asArrayType() const; - - template - std::enable_if_t, T *> as() const { - return dynamic_cast(const_cast(this)); - } - void print(std::ostream &os) const; -}; // class Type - -//! Pointer type -class PointerType : public Type { -protected: - Type *baseType; - -protected: - PointerType(Type *baseType) : Type(kPointer), baseType(baseType) {} - -public: - static PointerType *get(Type *baseType); - -public: - Type *getBaseType() const { return baseType; } -}; // class PointerType - -//! Function type -class FunctionType : public Type { -private: - Type *returnType; - std::vector paramTypes; - -protected: - FunctionType(Type *returnType, const std::vector ¶mTypes = {}) - : Type(kFunction), returnType(returnType), paramTypes(paramTypes) {} - -public: - static FunctionType *get(Type *returnType, - const std::vector ¶mTypes = {}); - -public: - Type *getReturnType() const { return returnType; } - auto getParamTypes() const { return make_range(paramTypes); } - int getNumParams() const { return paramTypes.size(); } -}; // class FunctionType - -class ArrayType : public Type { -private: - Type *elementType; // 数组元素类型 - std::vector dimensions; // 维度信息(空向量表示未知大小) - -protected: - ArrayType(Type *elemType, const std::vector &dims = {}) - : Type(kArray), elementType(elemType), dimensions(dims) { - // 确保元素类型有效 - assert(elemType && "Array element type cannot be null"); - assert(!elemType->isVoid() && "Cannot have array of void"); - assert(!elemType->isLabel() && "Cannot have array of labels"); - } - -public: - // 获取数组类型(带缓存机制) - static ArrayType *get(Type *elemType, const std::vector &dims = {}) { - // 实现类型缓存池(避免重复创建) - static std::map>, ArrayType*> cache; - - auto key = std::make_pair(elemType, dims); - if (cache.find(key) == cache.end()) { - cache[key] = new ArrayType(elemType, dims); - } - return cache[key]; - } - - Type *getElementType() const { return elementType; } - const std::vector& getDimensions() const { return dimensions; } - size_t getNumDimensions() const { return dimensions.size(); } - - int getArraySize() const { - int size = elementType->getSize(); - for (int dim : dimensions) { - size *= dim; - } - return size; - } - -};//class ArrayType - -/*! - * @} - */ - -/*! - * \defgroup ir IR - * - * The SysY IR is an instruction level language. The IR is orgnized - * as a four-level tree structure, as shown below - * - * \dotfile ir-4level.dot IR Structure - * - * - `Module` corresponds to the top level "CompUnit" syntax structure - * - `GlobalValue` corresponds to the "Decl" syntax structure - * - `Function` corresponds to the "FuncDef" syntax structure - * - `BasicBlock` is a sequence of instructions without branching. A `Function` - * made up by one or more `BasicBlock`s. - * - `Instruction` represents a primitive operation on values, e.g., add or sub. - * - * The fundamental data concept in SysY IR is `Value`. A `Value` is like - * a register and is used by `Instruction`s as input/output operand. Each value - * has an associated `Type` indicating the data type held by the value. - * - * Most `Instruction`s have a three-address signature, i.e., there are at most 2 - * input values and at most 1 output value. - * - * The SysY IR adots a Static-Single-Assignment (SSA) design. That is, `Value` - * is defined (as the output operand ) by some instruction, and used (as the - * input operand) by other instructions. While a value can be used by multiple - * instructions, the `definition` occurs only once. As a result, there is a - * one-to-one relation between a value and the instruction defining it. In other - * words, any instruction defines a value can be viewed as the defined value - * itself. So `Instruction` is also a `Value` in SysY IR. See `Value` for the - * type hierachy. - * - * @{ - */ - -class User; -class Value; - -//! `Use` represents the relation between a `Value` and its `User` -class Use { -private: - //! the position of value in the user's operands, i.e., - //! user->getOperands[index] == value - int index; - User *user; - Value *value; - -public: - Use() = default; - Use(int index, User *user, Value *value) - : index(index), user(user), value(value) {} - -public: - int getIndex() const { return index; } - User *getUser() const { return user; } - Value *getValue() const { return value; } - void setValue(Value *value) { value = value; } -}; // class Use - -template -inline std::enable_if_t, bool> -isa(const Value *value) { - return T::classof(value); -} - -template -inline std::enable_if_t, T *> -dyncast(Value *value) { - return isa(value) ? static_cast(value) : nullptr; -} - -template -inline std::enable_if_t, const T *> -dyncast(const Value *value) { - return isa(value) ? static_cast(value) : nullptr; -} - -//! The base class of all value types -class Value { -public: - enum Kind : uint64_t { - kInvalid, - // Instructions - // Binary - kAdd = 0x1UL << 0, - kSub = 0x1UL << 1, - kMul = 0x1UL << 2, - kDiv = 0x1UL << 3, - kRem = 0x1UL << 4, - kICmpEQ = 0x1UL << 5, - kICmpNE = 0x1UL << 6, - kICmpLT = 0x1UL << 7, - kICmpGT = 0x1UL << 8, - kICmpLE = 0x1UL << 9, - kICmpGE = 0x1UL << 10, - kFAdd = 0x1UL << 14, - kFSub = 0x1UL << 15, - kFMul = 0x1UL << 16, - kFDiv = 0x1UL << 17, - kFRem = 0x1UL << 18, - kFCmpEQ = 0x1UL << 19, - kFCmpNE = 0x1UL << 20, - kFCmpLT = 0x1UL << 21, - kFCmpGT = 0x1UL << 22, - kFCmpLE = 0x1UL << 23, - kFCmpGE = 0x1UL << 24, - // Unary - kNeg = 0x1UL << 25, - kNot = 0x1UL << 26, - kFNeg = 0x1UL << 27, - kFtoI = 0x1UL << 28, - kIToF = 0x1UL << 29, - // call - kCall = 0x1UL << 30, - // terminator - kCondBr = 0x1UL << 31, - kBr = 0x1UL << 32, - kReturn = 0x1UL << 33, - // mem op - kAlloca = 0x1UL << 34, - kLoad = 0x1UL << 35, - kStore = 0x1UL << 36, - kFirstInst = kAdd, - kLastInst = kStore, - // others - kArgument = 0x1UL << 37, - kBasicBlock = 0x1UL << 38, - kFunction = 0x1UL << 39, - kConstant = 0x1UL << 40, - kGlobal = 0x1UL << 41, - }; - -protected: - Kind kind; - Type *type; - std::string name; - std::list uses; - -protected: - Value(Kind kind, Type *type, const std::string &name = "") - : kind(kind), type(type), name(name), uses() {} - -public: - virtual ~Value() = default; - -public: - Kind getKind() const { return kind; } - static bool classof(const Value *) { return true; } - -public: - Type *getType() const { return type; } - const std::string &getName() const { return name; } - void setName(const std::string &n) { name = n; } - bool hasName() const { return not name.empty(); } - bool isInt() const { return type->isInt(); } - bool isFloat() const { return type->isFloat(); } - bool isPointer() const { return type->isPointer(); } - const std::list &getUses() { return uses; } - void addUse(Use *use) { uses.push_back(use); } - void replaceAllUsesWith(Value *value); - void removeUse(Use *use) { uses.remove(use); } - bool isConstant() const; - -public: - virtual void print(std::ostream &os) const {}; -}; // class Value - -/*! - * Static constants known at compile time. - * - * `ConstantValue`s are not defined by instructions, and do not use any other - * `Value`s. It's type is either `int` or `float`. - */ - -class ConstantInt; -class ConstantFloat; -//常量池优化 - -using ConstantValVariant = std::variant; -using ConstantValueKey = std::pair; - -class ConstantValue : public Value { -protected: - ConstantValue(Type* type) - : Value(kConstant, type, "") {} -public: - struct ConstantValueHash; - struct ConstantValueEqual; - - static std::unordered_map constantPool; - - virtual ~ConstantValue() = default; - - static ConstantValue* get(Type* type, int32_t value); - static ConstantValue* get(Type* type, float value); - - static bool classof(const Value* value) { - return value->getKind() == kConstant; - } - - virtual int32_t getInt() const = 0; - virtual float getFloat() const = 0; - virtual bool isZero() const = 0; - virtual bool isOne() const = 0; - - - static ConstantValue* getInt32(int32_t value); - static ConstantValue* getFloat32(float value); - static ConstantValue* getTrue() ; - static ConstantValue* getFalse(); - - void print(std::ostream &os) const override; -}; - -struct ConstantValue::ConstantValueHash { - std::size_t operator()(const ConstantValueKey& key) const { - std::size_t typeHash = std::hash{}(key.first); - std::size_t valHash = 0; - if (key.first->isInt()) { - valHash = std::hash{}(std::get(key.second)); - } else if (key.first->isFloat()) { - // 修复5: 确保float哈希正确 - valHash = std::hash{}(std::get(key.second)); - } - return typeHash ^ (valHash << 1); - } -}; - -struct ConstantValue::ConstantValueEqual { - bool operator()(const ConstantValueKey& lhs, const ConstantValueKey& rhs) const { - if (lhs.first != rhs.first) return false; - if (lhs.first->isInt()) { - return std::get(lhs.second) == std::get(rhs.second); - } else if (lhs.first->isFloat()) { - // 修复6: 使用浮点比较容差 - const float eps = 1e-6; - return fabs(std::get(lhs.second) - std::get(rhs.second)) < eps; - } - return false; - } -}; - -class ConstantInt : public ConstantValue { - int32_t value; - friend class ConstantValue; - -protected: - ConstantInt(Type* type, int32_t value) - : ConstantValue(type), value(value) { - assert(type->isInt() && "Invalid type for ConstantInt"); - } -public: - static ConstantInt* get(Type* type, int32_t value); - - int32_t getInt() const override { return value; } - float getFloat() const override { return static_cast(value); } - bool isZero() const override { return value == 0; } - bool isOne() const override { return value == 1; } - - void print(std::ostream& os) const override ; -}; - -class ConstantFloat : public ConstantValue { - float value; - friend class ConstantValue; - -protected: - ConstantFloat(Type* type, float value) - : ConstantValue(type), value(value) { - assert(type->isFloat() && "Invalid type for ConstantFloat"); - } -public: - static ConstantFloat* get(Type* type, float value); - - int32_t getInt() const override { return static_cast(value); } - float getFloat() const override { return value; } - bool isZero() const override { return value == 0.0f; } - bool isOne() const override { return value == 1.0f; } - - void print(std::ostream& os) const override; -}; - -class BasicBlock; -/*! - * Arguments of `BasicBlock`s. - * - * SysY IR is an SSA language, however, it does not use PHI instructions as in - * LLVM IR. `Value`s from different predecessor blocks are passed explicitly as - * block arguments. This is also the approach used by MLIR. - * NOTE that `Function` does not own `Argument`s, function arguments are - * implemented as its entry block's arguments. - */ - -class Argument : public Value { -protected: - BasicBlock *block; - int index; - -public: - Argument(Type *type, BasicBlock *block, int index, - const std::string &name = ""); - -public: - static bool classof(const Value *value) { - return value->getKind() == kConstant; - } - -public: - BasicBlock *getParent() const { return block; } - int getIndex() const { return index; } - -public: - void print(std::ostream &os) const override; -}; - -class Instruction; -class Function; -/*! - * The container for `Instruction` sequence. - * - * `BasicBlock` maintains a list of `Instruction`s, with the last one being - * a terminator (branch or return). Besides, `BasicBlock` stores its arguments - * and records its predecessor and successor `BasicBlock`s. - */ -class BasicBlock : public Value { - friend class Function; - -public: - using inst_list = std::list>; - using iterator = inst_list::iterator; - using arg_list = std::vector>; - using block_list = std::vector; - -protected: - Function *parent; - inst_list instructions; - arg_list arguments; - block_list successors; - block_list predecessors; - -protected: - explicit BasicBlock(Function *parent, const std::string &name = ""); - -public: - static bool classof(const Value *value) { - return value->getKind() == kBasicBlock; - } - -public: - int getNumInstructions() const { return instructions.size(); } - int getNumArguments() const { return arguments.size(); } - int getNumPredecessors() const { return predecessors.size(); } - int getNumSuccessors() const { return successors.size(); } - Function *getParent() const { return parent; } - inst_list &getInstructions() { return instructions; } - auto getArguments() const { return make_range(arguments); } - block_list &getPredecessors() { return predecessors; } - block_list &getSuccessors() { return successors; } - iterator begin() { return instructions.begin(); } - iterator end() { return instructions.end(); } - iterator terminator() { return std::prev(end()); } - Argument *createArgument(Type *type, const std::string &name = "") { - auto arg = new Argument(type, this, arguments.size(), name); - assert(arg); - arguments.emplace_back(arg); - return arguments.back().get(); - }; - -public: - void print(std::ostream &os) const override; -}; // class BasicBlock - -//! User is the abstract base type of `Value` types which use other `Value` as -//! operands. Currently, there are two kinds of `User`s, `Instruction` and -//! `GlobalValue`. -class User : public Value { -protected: - std::vector operands; - -protected: - User(Kind kind, Type *type, const std::string &name = "") - : Value(kind, type, name), operands() {} - -public: - using use_iterator = std::vector::const_iterator; - struct operand_iterator : public std::vector::const_iterator { - using Base = std::vector::const_iterator; - operand_iterator(const Base &iter) : Base(iter) {} - using value_type = Value *; - value_type operator->() { return Base::operator*().getValue(); } - value_type operator*() { return Base::operator*().getValue(); } - }; - // struct const_operand_iterator : std::vector::const_iterator { - // using Base = std::vector::const_iterator; - // const_operand_iterator(const Base &iter) : Base(iter) {} - // using value_type = Value *; - // value_type operator->() { return operator*().getValue(); } - // }; - -public: - int getNumOperands() const { return operands.size(); } - operand_iterator operand_begin() const { return operands.begin(); } - operand_iterator operand_end() const { return operands.end(); } - auto getOperands() const { - return make_range(operand_begin(), operand_end()); - } - Value *getOperand(int index) const { return operands[index].getValue(); } - void addOperand(Value *value) { - operands.emplace_back(operands.size(), this, value); - value->addUse(&operands.back()); - } - template void addOperands(const ContainerT &operands) { - for (auto value : operands) - addOperand(value); - } - void replaceOperand(int index, Value *value); - void setOperand(int index, Value *value); -}; // class User - -/*! - * Base of all concrete instruction types. - */ -class Instruction : public User { -public: - // enum Kind : uint64_t { - // kInvalid = 0x0UL, - // // Binary - // kAdd = 0x1UL << 0, - // kSub = 0x1UL << 1, - // kMul = 0x1UL << 2, - // kDiv = 0x1UL << 3, - // kRem = 0x1UL << 4, - // kICmpEQ = 0x1UL << 5, - // kICmpNE = 0x1UL << 6, - // kICmpLT = 0x1UL << 7, - // kICmpGT = 0x1UL << 8, - // kICmpLE = 0x1UL << 9, - // kICmpGE = 0x1UL << 10, - // kFAdd = 0x1UL << 14, - // kFSub = 0x1UL << 15, - // kFMul = 0x1UL << 16, - // kFDiv = 0x1UL << 17, - // kFRem = 0x1UL << 18, - // kFCmpEQ = 0x1UL << 19, - // kFCmpNE = 0x1UL << 20, - // kFCmpLT = 0x1UL << 21, - // kFCmpGT = 0x1UL << 22, - // kFCmpLE = 0x1UL << 23, - // kFCmpGE = 0x1UL << 24, - // // Unary - // kNeg = 0x1UL << 25, - // kNot = 0x1UL << 26, - // kFNeg = 0x1UL << 27, - // kFtoI = 0x1UL << 28, - // kIToF = 0x1UL << 29, - // // call - // kCall = 0x1UL << 30, - // // terminator - // kCondBr = 0x1UL << 31, - // kBr = 0x1UL << 32, - // kReturn = 0x1UL << 33, - // // mem op - // kAlloca = 0x1UL << 34, - // kLoad = 0x1UL << 35, - // kStore = 0x1UL << 36, - // // constant - // // kConstant = 0x1UL << 37, - // }; - -protected: - Kind kind; - BasicBlock *parent; - -protected: - Instruction(Kind kind, Type *type, BasicBlock *parent = nullptr, - const std::string &name = ""); - -public: - static bool classof(const Value *value) { - return value->getKind() >= kFirstInst and value->getKind() <= kLastInst; - } - -public: - Kind getKind() const { return kind; } - BasicBlock *getParent() const { return parent; } - Function *getFunction() const { return parent->getParent(); } - void setParent(BasicBlock *bb) { parent = bb; } - - bool isBinary() const { - static constexpr uint64_t BinaryOpMask = - (kAdd | kSub | kMul | kDiv | kRem) | - (kICmpEQ | kICmpNE | kICmpLT | kICmpGT | kICmpLE | kICmpGE) | - (kFAdd | kFSub | kFMul | kFDiv | kFRem) | - (kFCmpEQ | kFCmpNE | kFCmpLT | kFCmpGT | kFCmpLE | kFCmpGE); - return kind & BinaryOpMask; - } - bool isUnary() const { - static constexpr uint64_t UnaryOpMask = kNeg | kNot | kFNeg | kFtoI | kIToF; - return kind & UnaryOpMask; - } - bool isMemory() const { - static constexpr uint64_t MemoryOpMask = kAlloca | kLoad | kStore; - return kind & MemoryOpMask; - } - bool isTerminator() const { - static constexpr uint64_t TerminatorOpMask = kCondBr | kBr | kReturn; - return kind & TerminatorOpMask; - } - bool isCmp() const { - static constexpr uint64_t CmpOpMask = - (kICmpEQ | kICmpNE | kICmpLT | kICmpGT | kICmpLE | kICmpGE) | - (kFCmpEQ | kFCmpNE | kFCmpLT | kFCmpGT | kFCmpLE | kFCmpGE); - return kind & CmpOpMask; - } - bool isBranch() const { - static constexpr uint64_t BranchOpMask = kBr | kCondBr; - return kind & BranchOpMask; - } - bool isCommutative() const { - static constexpr uint64_t CommutativeOpMask = - kAdd | kMul | kICmpEQ | kICmpNE | kFAdd | kFMul | kFCmpEQ | kFCmpNE; - return kind & CommutativeOpMask; - } - bool isUnconditional() const { return kind == kBr; } - bool isConditional() const { return kind == kCondBr; } -}; // class Instruction - -class Function; -//! Function call. -class CallInst : public Instruction { - friend class IRBuilder; - -protected: - CallInst(Function *callee, const std::vector &args = {}, - BasicBlock *parent = nullptr, const std::string &name = ""); - -public: - static bool classof(const Value *value) { return value->getKind() == kCall; } - -public: - Function *getCallee() const; - auto getArguments() const { - return make_range(std::next(operand_begin()), operand_end()); - } - -public: - void print(std::ostream &os) const override; -}; // class CallInst - -//! Unary instruction, includes '!', '-' and type conversion. -class UnaryInst : public Instruction { - friend class IRBuilder; - -protected: - UnaryInst(Kind kind, Type *type, Value *operand, BasicBlock *parent = nullptr, - const std::string &name = "") - : Instruction(kind, type, parent, name) { - addOperand(operand); - } - -public: - static bool classof(const Value *value) { - return Instruction::classof(value) and - static_cast(value)->isUnary(); - } - -public: - Value *getOperand() const { return User::getOperand(0); } - -public: - void print(std::ostream &os) const override; -}; // class UnaryInst - -//! Binary instruction, e.g., arithmatic, relation, logic, etc. -class BinaryInst : public Instruction { - friend class IRBuilder; - -protected: - BinaryInst(Kind kind, Type *type, Value *lhs, Value *rhs, BasicBlock *parent, - const std::string &name = "") - : Instruction(kind, type, parent, name) { - addOperand(lhs); - addOperand(rhs); - } - -public: - static bool classof(const Value *value) { - return Instruction::classof(value) and - static_cast(value)->isBinary(); - } - -public: - Value *getLhs() const { return getOperand(0); } - Value *getRhs() const { return getOperand(1); } - -public: - void print(std::ostream &os) const override; -}; // class BinaryInst - -//! The return statement -class ReturnInst : public Instruction { - friend class IRBuilder; - -protected: - ReturnInst(Value *value = nullptr, BasicBlock *parent = nullptr) - : Instruction(kReturn, Type::getVoidType(), parent, "") { - if (value) - addOperand(value); - } - -public: - static bool classof(const Value *value) { - return value->getKind() == kReturn; - } - -public: - bool hasReturnValue() const { return not operands.empty(); } - Value *getReturnValue() const { - return hasReturnValue() ? getOperand(0) : nullptr; - } - -public: - void print(std::ostream &os) const override; -}; // class ReturnInst - -//! Unconditional branch -class UncondBrInst : public Instruction { - friend class IRBuilder; - -protected: - UncondBrInst(BasicBlock *block, std::vector args, - BasicBlock *parent = nullptr) - : Instruction(kCondBr, Type::getVoidType(), parent, "") { - assert(block->getNumArguments() == args.size()); - addOperand(block); - addOperands(args); - } - -public: - static bool classof(const Value *value) { return value->getKind() == kBr; } - -public: - BasicBlock *getBlock() const { return dyncast(getOperand(0)); } - auto getArguments() const { - return make_range(std::next(operand_begin()), operand_end()); - } - -public: - void print(std::ostream &os) const override; -}; // class UncondBrInst - -//! Conditional branch -class CondBrInst : public Instruction { - friend class IRBuilder; - -protected: - CondBrInst(Value *condition, BasicBlock *thenBlock, BasicBlock *elseBlock, - const std::vector &thenArgs, - const std::vector &elseArgs, BasicBlock *parent = nullptr) - : Instruction(kCondBr, Type::getVoidType(), parent, "") { - assert(thenBlock->getNumArguments() == thenArgs.size() and - elseBlock->getNumArguments() == elseArgs.size()); - addOperand(condition); - addOperand(thenBlock); - addOperand(elseBlock); - addOperands(thenArgs); - addOperands(elseArgs); - } - -public: - static bool classof(const Value *value) { - return value->getKind() == kCondBr; - } - -public: - Value *getCondition() const { return getOperand(0); } - BasicBlock *getThenBlock() const { - return dyncast(getOperand(1)); - } - BasicBlock *getElseBlock() const { - return dyncast(getOperand(2)); - } - auto getThenArguments() const { - auto begin = std::next(operand_begin(), 3); - auto end = std::next(begin, getThenBlock()->getNumArguments()); - return make_range(begin, end); - } - auto getElseArguments() const { - auto begin = - std::next(operand_begin(), 3 + getThenBlock()->getNumArguments()); - auto end = operand_end(); - return make_range(begin, end); - } - -public: - void print(std::ostream &os) const override; -}; // class CondBrInst - -//! Allocate memory for stack variables, used for non-global variable declartion -class AllocaInst : public Instruction { - friend class IRBuilder; - -protected: - AllocaInst(Type *type, const std::vector &dims = {}, - BasicBlock *parent = nullptr, const std::string &name = "") - : Instruction(kAlloca, type, parent, name) { - addOperands(dims); - } - -public: - static bool classof(const Value *value) { - return value->getKind() == kAlloca; - } - -public: - int getNumDims() const { return getNumOperands(); } - auto getDims() const { return getOperands(); } - Value *getDim(int index) { return getOperand(index); } - -public: - void print(std::ostream &os) const override; -}; // class AllocaInst - -//! Load a value from memory address specified by a pointer value -class LoadInst : public Instruction { - friend class IRBuilder; - -protected: - LoadInst(Value *pointer, const std::vector &indices = {}, - BasicBlock *parent = nullptr, const std::string &name = "") - : Instruction(kLoad, pointer->getType()->as()->getBaseType(), - parent, name) { - addOperand(pointer); - addOperands(indices); - } - -public: - static bool classof(const Value *value) { return value->getKind() == kLoad; } - -public: - int getNumIndices() const { return getNumOperands() - 1; } - Value *getPointer() const { return getOperand(0); } - auto getIndices() const { - return make_range(std::next(operand_begin()), operand_end()); - } - Value *getIndex(int index) const { return getOperand(index + 1); } - -public: - void print(std::ostream &os) const override; -}; // class LoadInst - -//! Store a value to memory address specified by a pointer value -class StoreInst : public Instruction { - friend class IRBuilder; - -protected: - StoreInst(Value *value, Value *pointer, - const std::vector &indices = {}, - BasicBlock *parent = nullptr, const std::string &name = "") - : Instruction(kStore, Type::getVoidType(), parent, name) { - addOperand(value); - addOperand(pointer); - addOperands(indices); - } - -public: - static bool classof(const Value *value) { return value->getKind() == kStore; } - -public: - int getNumIndices() const { return getNumOperands() - 2; } - Value *getValue() const { return getOperand(0); } - Value *getPointer() const { return getOperand(1); } - auto getIndices() const { - return make_range(std::next(operand_begin(), 2), operand_end()); - } - Value *getIndex(int index) const { return getOperand(index + 2); } - -public: - void print(std::ostream &os) const override; -}; // class StoreInst - -class Module; -//! Function definition -class Function : public Value { - friend class Module; - -protected: - Function(Module *parent, Type *type, const std::string &name) - : Value(kFunction, type, name), parent(parent), variableID(0), blocks() { - blocks.emplace_back(new BasicBlock(this, "entry")); - } - -public: - static bool classof(const Value *value) { - return value->getKind() == kFunction; - } - -public: - using block_list = std::list>; - -protected: - Module *parent; - int variableID; - int blockID; - block_list blocks; - -public: - Type *getReturnType() const { - return getType()->as()->getReturnType(); - } - auto getParamTypes() const { - return getType()->as()->getParamTypes(); - } - auto getBasicBlocks() const { return make_range(blocks); } - BasicBlock *getEntryBlock() const { return blocks.front().get(); } - BasicBlock *addBasicBlock(const std::string &name = "") { - blocks.emplace_back(new BasicBlock(this, name)); - return blocks.back().get(); - } - void removeBasicBlock(BasicBlock *block) { - blocks.remove_if([&](std::unique_ptr &b) -> bool { - return block == b.get(); - }); - } - int allocateVariableID() { return variableID++; } - int allocateblockID() { return blockID++; } - -public: - void print(std::ostream &os) const override; -}; // class Function - -// class ArrayValue : public User { -// protected: -// ArrayValue(Type *type, const std::vector &values = {}) -// : User(type, "") { -// addOperands(values); -// } - -// public: -// static ArrayValue *get(Type *type, const std::vector &values); -// static ArrayValue *get(const std::vector &values); -// static ArrayValue *get(const std::vector &values); - -// public: -// auto getValues() const { return getOperands(); } - -// public: -// void print(std::ostream &os) const override{}; -// }; // class ConstantArray - -//! Global value declared at file scope -class GlobalValue : public User { - friend class Module; - -protected: - Module *parent; - bool hasInit; - bool isConst; - -protected: - GlobalValue(Module *parent, Type *type, const std::string &name, - const std::vector &dims = {}, Value *init = nullptr) - : User(kGlobal, type, name), parent(parent), hasInit(init) { - assert(type->isPointer()); - addOperands(dims); - if (init) - addOperand(init); - } - -public: - static bool classof(const Value *value) { - return value->getKind() == kGlobal; - } - -public: - Value *init() const { return hasInit ? operands.back().getValue() : nullptr; } - int getNumDims() const { return getNumOperands() - (hasInit ? 1 : 0); } - Value *getDim(int index) { return getOperand(index); } - -public: - void print(std::ostream &os) const override{}; -}; // class GlobalValue - -//! IR unit for representing a SysY compile unit -class Module { -protected: - std::vector> children; - std::map functions; - std::map globals; - -public: - Module() = default; - -public: - Function *createFunction(const std::string &name, Type *type) { - if (functions.count(name)) - return nullptr; - auto func = new Function(this, type, name); - assert(func); - children.emplace_back(func); - functions.emplace(name, func); - return func; - }; - GlobalValue *createGlobalValue(const std::string &name, Type *type, - const std::vector &dims = {}, - Value *init = nullptr) { - if (globals.count(name)) - return nullptr; - auto global = new GlobalValue(this, type, name, dims, init); - assert(global); - children.emplace_back(global); - globals.emplace(name, global); - return global; - } - Function *getFunction(const std::string &name) const { - auto result = functions.find(name); - if (result == functions.end()) - return nullptr; - return result->second; - } - GlobalValue *getGlobalValue(const std::string &name) const { - auto result = globals.find(name); - if (result == globals.end()) - return nullptr; - return result->second; - } - - std::map *getFunctions(){ - return &functions; - } - std::map *getGlobalValues(){ - return &globals; - } - -public: - void print(std::ostream &os) const; -}; // class Module - -/*! - * @} - */ -inline std::ostream &operator<<(std::ostream &os, const Type &type) { - type.print(os); - return os; -} - -inline std::ostream &operator<<(std::ostream &os, const Value &value) { - value.print(os); - return os; -} - -} // namespace sysy \ No newline at end of file diff --git a/src/SysYIRAnalyser.cpp b/src/SysYIRAnalyser.cpp new file mode 100644 index 0000000..e69de29 diff --git a/src/SysYIRGenerator.cpp b/src/SysYIRGenerator.cpp index 1844a53..5a7ccc2 100644 --- a/src/SysYIRGenerator.cpp +++ b/src/SysYIRGenerator.cpp @@ -834,204 +834,4 @@ std::any SysYIRGenerator::visitConstExp(SysYParser::ConstExpContext* ctx) { return res; } -/* begin -std::any SysYIRGenerator::visitConstGlobalDecl(SysYParser::ConstDeclContext *ctx, Type* type) { - std::vector values; - for (auto constDef : ctx->constDef()) { - - auto name = constDef->Ident()->getText(); - // get its dimensions - vector dims; - for (auto dim : constDef->constExp()) - dims.push_back(any_cast(dim->accept(this))); - - if (dims.size() == 0) { - auto init = constDef->ASSIGN() ? any_cast((constDef->constInitVal()->constExp()->accept(this))) - : nullptr; - if (init && isa(init)){ - Type *btype = type->as()->getBaseType(); - if (btype->isInt() && init->getType()->isFloat()) - init = ConstantValue::get((int)dynamic_cast(init)->getFloat()); - else if (btype->isFloat() && init->getType()->isInt()) - init = ConstantValue::get((float)dynamic_cast(init)->getInt()); - } - - auto global_value = module->createGlobalValue(name, type, dims, init); - - symbols_table.insert(name, global_value); - values.push_back(global_value); - } - else{ - auto init = constDef->ASSIGN() ? any_cast(dims[0]) - : nullptr; - auto global_value = module->createGlobalValue(name, type, dims, init); - if (constDef->ASSIGN()) { - d = 0; - n = 0; - path.clear(); - path = vector(dims.size(), 0); - isalloca = false; - current_type = global_value->getType()->as()->getBaseType(); - current_global = global_value; - numdims = global_value->getNumDims(); - for (auto init : constDef->constInitVal()->constInitVal()) - init->accept(this); - // visitConstInitValue(init); - } - symbols_table.insert(name, global_value); - values.push_back(global_value); - } - } - return values; -} - -std::any SysYIRGenerator::visitVarGlobalDecl(SysYParser::VarDeclContext *ctx, Type* type){ - std::vector values; - for (auto varDef : ctx->varDef()) { - - auto name = varDef->Ident()->getText(); - // get its dimensions - vector dims; - for (auto dim : varDef->constExp()) - dims.push_back(any_cast(dim->accept(this))); - - if (dims.size() == 0) { - auto init = varDef->ASSIGN() ? any_cast((varDef->initVal()->exp()->accept(this))) - : nullptr; - if (init && isa(init)){ - Type *btype = type->as()->getBaseType(); - if (btype->isInt() && init->getType()->isFloat()) - init = ConstantValue::get((int)dynamic_cast(init)->getFloat()); - else if (btype->isFloat() && init->getType()->isInt()) - init = ConstantValue::get((float)dynamic_cast(init)->getInt()); - } - - auto global_value = module->createGlobalValue(name, type, dims, init); - - symbols_table.insert(name, global_value); - values.push_back(global_value); - } - else{ - auto init = varDef->ASSIGN() ? any_cast(dims[0]) - : nullptr; - auto global_value = module->createGlobalValue(name, type, dims, init); - if (varDef->ASSIGN()) { - d = 0; - n = 0; - path.clear(); - path = vector(dims.size(), 0); - isalloca = false; - current_type = global_value->getType()->as()->getBaseType(); - current_global = global_value; - numdims = global_value->getNumDims(); - for (auto init : varDef->initVal()->initVal()) - init->accept(this); - // visitInitValue(init); - } - symbols_table.insert(name, global_value); - values.push_back(global_value); - } - } - return values; -} - -std::any SysYIRGenerator::visitConstLocalDecl(SysYParser::ConstDeclContext *ctx, Type* type){ - std::vector values; - // handle variables - for (auto constDef : ctx->constDef()) { - - auto name = constDef->Ident()->getText(); - vector dims; - for (auto dim : constDef->constExp()) - dims.push_back(any_cast(dim->accept(this))); - auto alloca = builder.createAllocaInst(type, dims, name); - symbols_table.insert(name, alloca); - - if (constDef->ASSIGN()) { - if (alloca->getNumDims() == 0) { - - auto value = any_cast(constDef->constInitVal()->constExp()->accept(this)); - - if (isa(value)) { - if (ctx->bType()->INT() && dynamic_cast(value)->isFloat()) - value = ConstantValue::get((int)dynamic_cast(value)->getFloat()); - else if (ctx->bType()->FLOAT() && dynamic_cast(value)->isInt()) - value = ConstantValue::get((float)dynamic_cast(value)->getInt()); - } - else if (alloca->getType()->as()->getBaseType()->isInt() && value->getType()->isFloat()) - value = builder.createFtoIInst(value); - else if (alloca->getType()->as()->getBaseType()->isFloat() && value->getType()->isInt()) - value = builder.createIToFInst(value); - - auto store = builder.createStoreInst(value, alloca); - } - else{ - d = 0; - n = 0; - path.clear(); - path = vector(alloca->getNumDims(), 0); - isalloca = true; - current_alloca = alloca; - current_type = alloca->getType()->as()->getBaseType(); - numdims = alloca->getNumDims(); - for (auto init : constDef->constInitVal()->constInitVal()) - init->accept(this); - } - } - - values.push_back(alloca); - } - return values; -} - -std::any SysYIRGenerator::visitVarLocalDecl(SysYParser::VarDeclContext *ctx, Type* type){ - std::vector values; - for (auto varDef : ctx->varDef()) { - - auto name = varDef->Ident()->getText(); - vector dims; - for (auto dim : varDef->constExp()) - dims.push_back(any_cast(dim->accept(this))); - auto alloca = builder.createAllocaInst(type, dims, name); - symbols_table.insert(name, alloca); - - if (varDef->ASSIGN()) { - if (alloca->getNumDims() == 0) { - - auto value = any_cast(varDef->initVal()->exp()->accept(this)); - - if (isa(value)) { - if (ctx->bType()->INT() && dynamic_cast(value)->isFloat()) - value = ConstantValue::get((int)dynamic_cast(value)->getFloat()); - else if (ctx->bType()->FLOAT() && dynamic_cast(value)->isInt()) - value = ConstantValue::get((float)dynamic_cast(value)->getInt()); - } - else if (alloca->getType()->as()->getBaseType()->isInt() && value->getType()->isFloat()) - value = builder.createFtoIInst(value); - else if (alloca->getType()->as()->getBaseType()->isFloat() && value->getType()->isInt()) - value = builder.createIToFInst(value); - - auto store = builder.createStoreInst(value, alloca); - } - else{ - d = 0; - n = 0; - path.clear(); - path = vector(alloca->getNumDims(), 0); - isalloca = true; - current_alloca = alloca; - current_type = alloca->getType()->as()->getBaseType(); - numdims = alloca->getNumDims(); - for (auto init : varDef->initVal()->initVal()) - init->accept(this); - } - } - - values.push_back(alloca); - } - return values; -} - end -*/ - } // namespace sysy \ No newline at end of file diff --git a/src/ASTPrinter.h b/src/include/ASTPrinter.h similarity index 100% rename from src/ASTPrinter.h rename to src/include/ASTPrinter.h diff --git a/src/Backend.h b/src/include/Backend.h similarity index 100% rename from src/Backend.h rename to src/include/Backend.h diff --git a/src/include/IR.h b/src/include/IR.h new file mode 100644 index 0000000..1f3885b --- /dev/null +++ b/src/include/IR.h @@ -0,0 +1,1319 @@ +#pragma once + +#include "range.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace sysy { +/** + * \defgroup type Types + * @brief Sysy的类型系统 + * + * 1. 基类`Type` 用来表示所有的原始标量类型, + * 包括 `int`, `float`, `void`, 和表示跳转目标的标签类型。 + * 2. `PointerType` 和 `FunctionType` 派生自`Type` 并且分别表示指针和函数类型。 + * + * \note `Type`和它的派生类的构造函数声明为'protected'. + * 用户必须使用Type::getXXXType()获得`Type` 指针。 + * @{ + */ + +/** + * + * `Type`用来表示所有的原始标量类型, + * 包括`int`, `float`, `void`, 和表示跳转目标的标签类型。 + */ + +class Type { + public: + /// 定义了原始标量类型种类的枚举类型 + enum Kind { + kInt, + kFloat, + kVoid, + kLabel, + kPointer, + kFunction, + }; + + Kind kind; ///< 表示具体类型的变量 + + protected: + explicit Type(Kind kind) : kind(kind) {} + virtual ~Type() = default; + + public: + static auto getIntType() -> Type *; ///< 返回表示Int类型的Type指针 + static auto getFloatType() -> Type *; ///< 返回表示Float类型的Type指针 + static auto getVoidType() -> Type *; ///< 返回表示Void类型的Type指针 + static auto getLabelType() -> Type *; ///< 返回表示Label类型的Type指针 + static auto getPointerType(Type *baseType) -> Type *; ///< 返回表示指向baseType类型的Pointer类型的Type指针 + static auto getFunctionType(Type *returnType, const std::vector ¶mTypes = {}) -> Type *; + ///< 返回表示返回类型为returnType,形参类型列表为paramTypes的函数类型的Type指针 + + public: + auto getKind() const -> Kind { return kind; } ///< 返回Type对象代表原始标量类型 + auto isInt() const -> bool { return kind == kInt; } ///< 判定是否为Int类型 + auto isFloat() const -> bool { return kind == kFloat; } ///< 判定是否为Float类型 + auto isVoid() const -> bool { return kind == kVoid; } ///< 判定是否为Void类型 + auto isLabel() const -> bool { return kind == kLabel; } ///< 判定是否为Label类型 + auto isPointer() const -> bool { return kind == kPointer; } ///< 判定是否为Pointer类型 + auto isFunction() const -> bool { return kind == kFunction; } ///< 判定是否为Function类型 + auto getSize() const -> unsigned; ///< 返回类型所占的空间大小(字节) + /// 尝试将一个变量转换为给定的Type及其派生类类型的变量 + template + auto as() const -> std::enable_if_t, T *> { + return dynamic_cast(const_cast(this)); + } +}; + +class PointerType : public Type { + protected: + Type *baseType; ///< 所指向的类型 + + protected: + explicit PointerType(Type *baseType) : Type(kPointer), baseType(baseType) {} + + public: + static auto get(Type *baseType) -> PointerType *; ///< 获取指向baseType的Pointer类型 + + public: + auto getBaseType() const -> Type * { return baseType; } ///< 获取指向的类型 +}; + +class FunctionType : public Type { + private: + Type *returnType; ///< 返回值类型 + std::vector paramTypes; ///< 形参类型列表 + + protected: + explicit FunctionType(Type *returnType, std::vector paramTypes = {}) + : Type(kFunction), returnType(returnType), paramTypes(std::move(paramTypes)) {} + + public: + /// 获取返回值类型为returnType, 形参类型列表为paramTypes的Function类型 + static auto get(Type *returnType, const std::vector ¶mTypes = {}) -> FunctionType *; + + public: + auto getReturnType() const -> Type * { return returnType; } ///< 获取返回值类信息 + auto getParamTypes() const { return make_range(paramTypes); } ///< 获取形参类型列表 + auto getNumParams() const -> unsigned { return paramTypes.size(); } ///< 获取形参数量 +}; + +/*! + * @} + */ + +/** + * \defgroup ir IR + * + * TSysY IR 是一种指令级别的语言. 它被组织为四层树型结构,如下所示: + * + * \dot IR Structure + * digraph IRStructure{ + * node [shape="box"] + * a [label="Module"] + * b [label="GlobalValue"] + * c [label="Function"] + * d [label="BasicBlock"] + * e [label="Instruction"] + * a->{b,c} + * c->d->e + * } + * + * \enddot + * + * - `Module` 对应顶层"CompUnit"语法结构 + * - `GlobalValue`对应"globalDecl"语法结构 + * - `Function`对应"FuncDef"语法结构 + * - `BasicBlock` 是一连串没有分支指令的指令。一个 `Function` + * 由一个或多个`BasicBlock`组成 + * - `Instruction` 表示一个原始指令,例如, add 或 sub + * + * SysY IR中基础的数据概念是`Value`。一个 `Value` 像 + * 一个寄存器。它充当`Instruction`的输入输出操作数。每个value + * 都有一个与之相联系的`Type`,以此说明Value所拥有值的类型。 + * + * 大多数`Instruction`具有三地址代码的结构, 例如, 最多拥有两个输入操作数和一个输出操作数。 + * + * SysY IR采用了Static-Single-Assignment (SSA)设计。`Value`作为一个输出操作数 + * 被一些指令所定义, 并被另一些指令当作输入操作数使用。尽管一个Value可以被多个指令使用,其 + * 定义只能发生一次。这导致一个value个定义它的指令存在一一对应关系。换句话说,任何定义一个Value的指令 + * 都可以被看作被定义的指令本身。故在SysY IR中,`Instruction` 也是一个`Value`。查看 `Value` 以获取其继承 + * 关系。 + * + * @{ + */ + + +class User; +class Value; +class AllocaInst; + +//! `Use` 表示`Value`和它的`User`之间的使用关系。 + +class Use { + private: + /** + * value在User操作数中的位置,例如, + * user->getOperands[index] == value + */ + unsigned index; + User *user; ///< 使用者 + Value *value; ///< 被使用的值 + + public: + Use() = default; + Use(unsigned index, User *user, Value *value) : index(index), user(user), value(value) {} + + public: + auto getIndex() const -> unsigned { return index; } ///< 返回value在User操作数中的位置 + auto getUser() const -> User * { return user; } ///< 返回使用者 + auto getValue() const -> Value * { return value; } ///< 返回被使用的值 + void setValue(Value *newValue) { value = newValue; } ///< 将被使用的值设置为newValue +}; + +template +inline std::enable_if_t, bool> +isa(const Value *value) { + return T::classof(value); +} + +template +inline std::enable_if_t, T *> +dyncast(Value *value) { + return isa(value) ? static_cast(value) : nullptr; +} + +template +inline std::enable_if_t, const T *> +dyncast(const Value *value) { + return isa(value) ? static_cast(value) : nullptr; +} + +//! The base class of all value types + +class Value { + protected: + Type *type; ///< 值的类型 + std::string name; ///< 值的名字 + std::list> uses; ///< 值的使用关系列表 + + protected: + explicit Value(Type *type, std::string name = "") : type(type), name(std::move(name)) {} + virtual ~Value() = default; + + public: + void setName(const std::string &newName) { name = newName; } ///< 设置名字 + auto getName() const -> const std::string & { return name; } ///< 获取名字 + auto getType() const -> Type * { return type; } ///< 返回值的类型 + auto isInt() const -> bool { return type->isInt(); } ///< 判定是否为Int类型 + auto isFloat() const -> bool { return type->isFloat(); } ///< 判定是否为Float类型 + auto isPointer() const -> bool { return type->isPointer(); } ///< 判定是否为Pointer类型 + auto getUses() -> std::list> & { return uses; } ///< 获取使用关系列表 + void addUse(const std::shared_ptr &use) { uses.push_back(use); } ///< 添加使用关系 + void replaceAllUsesWith(Value *value); ///< 将原来使用该value的使用者全变为使用给定参数value并修改相应use关系 + void removeUse(const std::shared_ptr &use) { uses.remove(use); } ///< 删除使用关系use +}; + + + +/** + * ValueCounter 需要理解为一个Value *的计数器。 + * 它的主要目的是为了节省存储空间和方便Memset指令的创建。 + * ValueCounter记录了一列Value *的互异元素和每个元素的重复数量。 + * 例如,假设有一列Value *为{v1, v1, v2, v3, v3, v3, v4}, + * 那么ValueCounter将记录为: + * - __counterValues: {v1, v2, v3, v4} + * - __counterNumbers: {2, 1, 3, 1} + * - __size: 7 + * 使得存储空间得到节省,方便Memset指令的创建。 + */ +class ValueCounter { + private: + unsigned __size{}; ///< 总的Value数量 + std::vector __counterValues; ///< 记录的Value *列表(无重复元素) + std::vector __counterNumbers; ///< 记录的Value *重复数量列表 + + public: + ValueCounter() = default; + + public: + auto size() const -> unsigned { return __size; } ///< 返回总的Value数量 + auto getValue(unsigned index) const -> Value * { + if (index >= __size) { + return nullptr; + } + + unsigned num = 0; + for (size_t i = 0; i < __counterNumbers.size(); i++) { + if (num <= index && index < num + __counterNumbers[i]) { + return __counterValues[i]; + } + num += __counterNumbers[i]; + } + + return nullptr; + } ///< 根据位置index获取Value * + auto getValues() const -> const std::vector & { return __counterValues; } ///< 获取互异Value *列表 + auto getNumbers() const -> const std::vector & { return __counterNumbers; } ///< 获取Value *重复数量列表 + void push_back(Value *value, unsigned num = 1) { + if (__size != 0 && __counterValues.back() == value) { + *(__counterNumbers.end() - 1) += num; + } else { + __counterValues.push_back(value); + __counterNumbers.push_back(num); + } + __size += num; + } ///< 向后插入num个value + void clear() { + __size = 0; + __counterValues.clear(); + __counterNumbers.clear(); + } ///< 清空ValueCounter +}; + +/*! + * Static constants known at compile time. + * + * `ConstantValue`s are not defined by instructions, and do not use any other + * `Value`s. It's type is either `int` or `float`. + * `ConstantValue`并不由指令定义, 也不使用任何Value。它的类型为int/float。 + */ + + +class ConstantValue : public Value { + protected: + /// 定义字面量类型的聚合类型 + union { + int iScalar; + float fScalar; + }; + + protected: + explicit ConstantValue(int value, const std::string &name = "") : Value(Type::getIntType(), name), iScalar(value) {} + explicit ConstantValue(float value, const std::string &name = "") + : Value(Type::getFloatType(), name), fScalar(value) {} + + public: + static auto get(int value) -> ConstantValue *; ///< 获取一个int类型的ConstValue *,其值为value + static auto get(float value) -> ConstantValue *; ///< 获取一个float类型的ConstValue *,其值为value + + public: + auto getInt() const -> int { + assert(isInt()); + return iScalar; + } ///< 返回int类型的值 + auto getFloat() const -> float { + assert(isFloat()); + return fScalar; + } ///< 返回float类型的值 + template + auto getValue() const -> T { + if (std::is_same::value && isInt()) { + return getInt(); + } + if (std::is_same::value && isFloat()) { + return getFloat(); + } + throw std::bad_cast(); // 或者其他适当的异常处理 + } ///< 返回值,getInt和getFloat统一化,整数返回整形,浮点返回浮点型 +}; + +class Instruction; +class Function; +class Loop; +class BasicBlock; + +/*! + * Arguments of `BasicBlock`s. + * + * SysY IR is an SSA language, however, it does not use PHI instructions as in + * LLVM IR. `Value`s from different predecessor blocks are passed explicitly as + * block arguments. This is also the approach used by MLIR. + * NOTE that `Function` does not own `Argument`s, function arguments are + * implemented as its entry block's arguments. + */ + +class Argument : public Value { +protected: + BasicBlock *block; + int index; + +public: + Argument(Type *type, BasicBlock *block, int index, + const std::string &name = ""); + +public: + static bool classof(const Value *value) { + return value->getKind() == kArgument; + } + +public: + BasicBlock *getParent() const { return block; } + int getIndex() const { return index; } + +public: + void print(std::ostream &os) const override; +}; + +class Instruction; +class Function; +class Loop; +/*! + * The container for `Instruction` sequence. + * + * `BasicBlock` maintains a list of `Instruction`s, with the last one being + * a terminator (branch or return). Besides, `BasicBlock` stores its arguments + * and records its predecessor and successor `BasicBlock`s. + */ +class BasicBlock : public Value { + friend class Function; + +public: + using inst_list = std::list>; + using iterator = inst_list::iterator; + using arg_list = std::vector>; + using block_list = std::vector; + using block_set = std::unordered_set; + +protected: + Function *parent; + inst_list instructions; + arg_list arguments; + block_list successors; + block_list predecessors; + BasicBlock *idom = nullptr; // 直接支配节点 + block_list sdoms; // 支配树后继 + block_set dominants; // 必经节点集合 + block_set dominant_frontiers; // 支配边界 + bool reachable = false; // 是否可达 + Loop *loopbelong = nullptr; // 所属循环 + int loopdepth = 0; // 循环深度 + +protected: + explicit BasicBlock(Function *parent, const std::string &name = ""); + +public: + static bool classof(const Value *value) { + return value->getKind() == kBasicBlock; + } + +public: + int getNumInstructions() const { return instructions.size(); } + int getNumArguments() const { return arguments.size(); } + int getNumPredecessors() const { return predecessors.size(); } + int getNumSuccessors() const { return successors.size(); } + Function *getParent() const { return parent; } + inst_list &getInstructions() { return instructions; } + auto getArguments() const { return make_range(arguments); } + const block_list &getPredecessors() const { return predecessors; } + block_list &getPredecessors() { return predecessors; } + const block_list &getSuccessors() const { return successors; } + block_list &getSuccessors() { return successors; } + iterator begin() { return instructions.begin(); } + iterator end() { return instructions.end(); } + iterator terminator() { return std::prev(end()); } + Argument *createArgument(Type *type, const std::string &name = "") { + auto arg = new Argument(type, this, arguments.size(), name); + assert(arg); + arguments.emplace_back(arg); + return arguments.back().get(); + }; + + // 控制流分析相关 + BasicBlock *getIdom() const { return idom; } + void setIdom(BasicBlock *dom) { idom = dom; } + const block_list &getSdoms() const { return sdoms; } + void addSdom(BasicBlock *bb) { sdoms.push_back(bb); } + void clearSdoms() { sdoms.clear(); } + const block_set &getDominants() const { return dominants; } + void addDominant(BasicBlock *bb) { dominants.insert(bb); } + void setDominants(const block_set &doms) { dominants = doms; } + const block_set &getDominantFrontiers() const { return dominant_frontiers; } + void setDominantFrontiers(const block_set &df) { dominant_frontiers = df; } + bool isReachable() const { return reachable; } + void setReachable(bool r) { reachable = r; } + + // 循环分析相关 + Loop *getLoop() const { return loopbelong; } + void setLoop(Loop *loop) { loopbelong = loop; } + int getLoopDepth() const { return loopdepth; } + void setLoopDepth(int depth) { loopdepth = depth; } + + void addPredecessor(BasicBlock *bb) { + if (std::find(predecessors.begin(), predecessors.end(), bb) == predecessors.end()) + predecessors.push_back(bb); + } + + void addSuccessor(BasicBlock *bb) { + if (std::find(successors.begin(), successors.end(), bb) == successors.end()) + successors.push_back(bb); + } + + void removePredecessor(BasicBlock *bb) { + auto it = std::find(predecessors.begin(), predecessors.end(), bb); + if (it != predecessors.end()) + predecessors.erase(it); + } + + void removeSuccessor(BasicBlock *bb) { + auto it = std::find(successors.begin(), successors.end(), bb); + if (it != successors.end()) + successors.erase(it); + } + + // 获取支配树中所有子节点 + block_list getChildren() { + std::queue q; + block_list children; + for (auto sdom : sdoms) { + q.push(sdom); + children.push_back(sdom); + } + while (!q.empty()) { + auto block = q.front(); + q.pop(); + for (auto sdom : block->sdoms) { + q.push(sdom); + children.push_back(sdom); + } + } + return children; + } + +public: + void print(std::ostream &os) const override; +}; // class BasicBlock + +//! User is the abstract base type of `Value` types which use other `Value` as +//! operands. Currently, there are two kinds of `User`s, `Instruction` and +//! `GlobalValue`. +class User : public Value { +protected: + std::vector operands; + +protected: + User(Kind kind, Type *type, const std::string &name = "") + : Value(kind, type, name), operands() {} + +public: + using use_iterator = std::vector::const_iterator; + struct operand_iterator : public std::vector::const_iterator { + using Base = std::vector::const_iterator; + operand_iterator(const Base &iter) : Base(iter) {} + using value_type = Value *; + value_type operator->() { return Base::operator*().getValue(); } + value_type operator*() { return Base::operator*().getValue(); } + }; + +public: + int getNumOperands() const { return operands.size(); } + operand_iterator operand_begin() const { return operands.begin(); } + operand_iterator operand_end() const { return operands.end(); } + auto getOperands() const { + return make_range(operand_begin(), operand_end()); + } + Value *getOperand(int index) const { return operands[index].getValue(); } + void addOperand(Value *value) { + operands.emplace_back(operands.size(), this, value); + value->addUse(&operands.back()); + } + template void addOperands(const ContainerT &operands) { + for (auto value : operands) + addOperand(value); + } + void replaceOperand(int index, Value *value); + void setOperand(int index, Value *value); +}; // class User + + +class GetSubArrayInst; +/** + * 左值 具有地址的对象 + */ +class LVal : public User { + friend class GetSubArrayInst; + + protected: + LVal *fatherLVal{}; ///< 父左值 + std::list> childrenLVals; ///< 子左值 + GetSubArrayInst *defineInst{}; /// 定义该左值的GetSubArray指令 + + protected: + LVal() = default; + + public: + virtual ~LVal() = default; + virtual auto getLValDims() const -> std::vector = 0; ///< 获取左值的维度 + virtual auto getLValNumDims() const -> unsigned = 0; ///< 获取左值的维度数量 + + public: + auto getFatherLVal() const -> LVal * { return fatherLVal; } ///< 获取父左值 + auto getChildrenLVals() const -> const std::list> & { + return childrenLVals; + } ///< 获取子左值列表 + auto getAncestorLVal() const -> LVal * { + auto curLVal = const_cast(this); + while (curLVal->getFatherLVal() != nullptr) { + curLVal = curLVal->getFatherLVal(); + } + return curLVal; + } ///< 获取祖先左值 + void setFatherLVal(LVal *father) { fatherLVal = father; } ///< 设置父左值 + void setDefineInst(GetSubArrayInst *inst) { defineInst = inst; } ///< 设置定义指令 + void addChild(LVal *child) { childrenLVals.emplace_back(child); } ///< 添加子左值 + void removeChild(LVal *child) { + auto iter = std::find_if(childrenLVals.begin(), childrenLVals.end(), + [child](const std::unique_ptr &ptr) { return ptr.get() == child; }); + childrenLVals.erase(iter); + } ///< 移除子左值 + auto getDefineInst() const -> GetSubArrayInst * { return defineInst; } ///< 获取定义指令 +}; + +/*! + * Base of all concrete instruction types. + */ +class Instruction : public User { +public: + // 指令种类定义已移至Value::Kind + +protected: + BasicBlock *parent; + +protected: + Instruction(Kind kind, Type *type, BasicBlock *parent = nullptr, + const std::string &name = ""); + +public: + static bool classof(const Value *value) { + return value->getKind() >= kFirstInst and value->getKind() <= kLastInst; + } + +public: + Kind getKind() const { return kind; } + BasicBlock *getParent() const { return parent; } + Function *getFunction() const { return parent->getParent(); } + void setParent(BasicBlock *bb) { parent = bb; } + + bool isBinary() const { + static constexpr uint64_t BinaryOpMask = + (kAdd | kSub | kMul | kDiv | kRem | kAnd | kOr) | + (kICmpEQ | kICmpNE | kICmpLT | kICmpGT | kICmpLE | kICmpGE) | + (kFAdd | kFSub | kFMul | kFDiv | kFRem) | + (kFCmpEQ | kFCmpNE | kFCmpLT | kFCmpGT | kFCmpLE | kFCmpGE); + return kind & BinaryOpMask; + } + bool isUnary() const { + static constexpr uint64_t UnaryOpMask = + kNeg | kNot | kFNeg | kFNot | kFtoI | kItoF | kBitFtoI | kBitItoF; + return kind & UnaryOpMask; + } + bool isMemory() const { + static constexpr uint64_t MemoryOpMask = + kAlloca | kLoad | kStore | kLa | kMemset | kGetSubArray; + return kind & MemoryOpMask; + } + bool isTerminator() const { + static constexpr uint64_t TerminatorOpMask = kCondBr | kBr | kReturn; + return kind & TerminatorOpMask; + } + bool isCmp() const { + static constexpr uint64_t CmpOpMask = + (kICmpEQ | kICmpNE | kICmpLT | kICmpGT | kICmpLE | kICmpGE) | + (kFCmpEQ | kFCmpNE | kFCmpLT | kFCmpGT | kFCmpLE | kFCmpGE); + return kind & CmpOpMask; + } + bool isBranch() const { + static constexpr uint64_t BranchOpMask = kBr | kCondBr; + return kind & BranchOpMask; + } + bool isCommutative() const { + static constexpr uint64_t CommutativeOpMask = + kAdd | kMul | kICmpEQ | kICmpNE | kFAdd | kFMul | kFCmpEQ | kFCmpNE | kAnd | kOr; + return kind & CommutativeOpMask; + } + bool isUnconditional() const { return kind == kBr; } + bool isConditional() const { return kind == kCondBr; } + bool isPhi() const { return kind == kPhi; } + bool isAlloca() const { return kind == kAlloca; } + bool isLoad() const { return kind == kLoad; } + bool isStore() const { return kind == kStore; } + bool isLa() const { return kind == kLa; } + bool isMemset() const { return kind == kMemset; } + bool isGetSubArray() const { return kind == kGetSubArray; } + bool isCall() const { return kind == kCall; } + bool isReturn() const { return kind == kReturn; } +}; // class Instruction + +class Function; +//! Function call. +class CallInst : public Instruction { + friend class IRBuilder; + +protected: + CallInst(Function *callee, const std::vector &args = {}, + BasicBlock *parent = nullptr, const std::string &name = ""); + +public: + static bool classof(const Value *value) { return value->getKind() == kCall; } + +public: + Function *getCallee() const; + auto getArguments() const { + return make_range(std::next(operand_begin()), operand_end()); + } + +public: + void print(std::ostream &os) const override; +}; // class CallInst + +//! Unary instruction, includes '!', '-' and type conversion. +class UnaryInst : public Instruction { + friend class IRBuilder; + +protected: + UnaryInst(Kind kind, Type *type, Value *operand, BasicBlock *parent = nullptr, + const std::string &name = "") + : Instruction(kind, type, parent, name) { + addOperand(operand); + } + +public: + static bool classof(const Value *value) { + return Instruction::classof(value) and + static_cast(value)->isUnary(); + } + +public: + Value *getOperand() const { return User::getOperand(0); } + +public: + void print(std::ostream &os) const override; +}; // class UnaryInst + +//! Binary instruction, e.g., arithmatic, relation, logic, etc. +class BinaryInst : public Instruction { + friend class IRBuilder; + +protected: + BinaryInst(Kind kind, Type *type, Value *lhs, Value *rhs, BasicBlock *parent, + const std::string &name = "") + : Instruction(kind, type, parent, name) { + addOperand(lhs); + addOperand(rhs); + } + +public: + static bool classof(const Value *value) { + return Instruction::classof(value) and + static_cast(value)->isBinary(); + } + +public: + Value *getLhs() const { return getOperand(0); } + Value *getRhs() const { return getOperand(1); } + +public: + void print(std::ostream &os) const override; +}; // class BinaryInst + +//! The return statement +class ReturnInst : public Instruction { + friend class IRBuilder; + +protected: + ReturnInst(Value *value = nullptr, BasicBlock *parent = nullptr) + : Instruction(kReturn, Type::getVoidType(), parent, "") { + if (value) + addOperand(value); + } + +public: + static bool classof(const Value *value) { + return value->getKind() == kReturn; + } + +public: + bool hasReturnValue() const { return not operands.empty(); } + Value *getReturnValue() const { + return hasReturnValue() ? getOperand(0) : nullptr; + } + +public: + void print(std::ostream &os) const override; +}; // class ReturnInst + +//! Unconditional branch +class UncondBrInst : public Instruction { + friend class IRBuilder; + +protected: + UncondBrInst(BasicBlock *block, std::vector args, + BasicBlock *parent = nullptr) + : Instruction(kBr, Type::getVoidType(), parent, "") { + // assert(block->getNumArguments() == args.size()); + addOperand(block); + addOperands(args); + } + +public: + static bool classof(const Value *value) { return value->getKind() == kBr; } + +public: + BasicBlock *getBlock() const { return dyncast(getOperand(0)); } + auto getArguments() const { + return make_range(std::next(operand_begin()), operand_end()); + } + +public: + void print(std::ostream &os) const override; +}; // class UncondBrInst + +//! Conditional branch +class CondBrInst : public Instruction { + friend class IRBuilder; + +protected: + CondBrInst(Value *condition, BasicBlock *thenBlock, BasicBlock *elseBlock, + const std::vector &thenArgs, + const std::vector &elseArgs, BasicBlock *parent = nullptr) + : Instruction(kCondBr, Type::getVoidType(), parent, "") { + // assert(thenBlock->getNumArguments() == thenArgs.size() and + // elseBlock->getNumArguments() == elseArgs.size()); + addOperand(condition); + addOperand(thenBlock); + addOperand(elseBlock); + addOperands(thenArgs); + addOperands(elseArgs); + } + +public: + static bool classof(const Value *value) { + return value->getKind() == kCondBr; + } + +public: + Value *getCondition() const { return getOperand(0); } + BasicBlock *getThenBlock() const { + return dyncast(getOperand(1)); + } + BasicBlock *getElseBlock() const { + return dyncast(getOperand(2)); + } + auto getThenArguments() const { + auto begin = std::next(operand_begin(), 3); + auto end = std::next(begin, getThenBlock()->getNumArguments()); + return make_range(begin, end); + } + auto getElseArguments() const { + auto begin = + std::next(operand_begin(), 3 + getThenBlock()->getNumArguments()); + auto end = operand_end(); + return make_range(begin, end); + } + +public: + void print(std::ostream &os) const override; +}; // class CondBrInst + +//! Allocate memory for stack variables, used for non-global variable declartion +class AllocaInst : public Instruction { + friend class IRBuilder; + +protected: + AllocaInst(Type *type, const std::vector &dims = {}, + BasicBlock *parent = nullptr, const std::string &name = "") + : Instruction(kAlloca, type, parent, name) { + addOperands(dims); + } + +public: + static bool classof(const Value *value) { + return value->getKind() == kAlloca; + } + +public: + int getNumDims() const { return getNumOperands(); } + auto getDims() const { return getOperands(); } + Value *getDim(int index) { return getOperand(index); } + +public: + void print(std::ostream &os) const override; +}; // class AllocaInst + +//! Load a value from memory address specified by a pointer value +class LoadInst : public Instruction { + friend class IRBuilder; + +protected: + LoadInst(Value *pointer, const std::vector &indices = {}, + BasicBlock *parent = nullptr, const std::string &name = "") + : Instruction(kLoad, pointer->getType()->as()->getBaseType(), + parent, name) { + addOperand(pointer); + addOperands(indices); + } + +public: + static bool classof(const Value *value) { return value->getKind() == kLoad; } + +public: + int getNumIndices() const { return getNumOperands() - 1; } + Value *getPointer() const { return getOperand(0); } + auto getIndices() const { + return make_range(std::next(operand_begin()), operand_end()); + } + Value *getIndex(int index) const { return getOperand(index + 1); } + +public: + void print(std::ostream &os) const override; +}; // class LoadInst + +//! Store a value to memory address specified by a pointer value +class StoreInst : public Instruction { + friend class IRBuilder; + +protected: + StoreInst(Value *value, Value *pointer, + const std::vector &indices = {}, + BasicBlock *parent = nullptr, const std::string &name = "") + : Instruction(kStore, Type::getVoidType(), parent, name) { + addOperand(value); + addOperand(pointer); + addOperands(indices); + } + +public: + static bool classof(const Value *value) { return value->getKind() == kStore; } + +public: + int getNumIndices() const { return getNumOperands() - 2; } + Value *getValue() const { return getOperand(0); } + Value *getPointer() const { return getOperand(1); } + auto getIndices() const { + return make_range(std::next(operand_begin(), 2), operand_end()); + } + Value *getIndex(int index) const { return getOperand(index + 2); } + +public: + void print(std::ostream &os) const override; +}; // class StoreInst + +//! Get address instruction +class LaInst : public Instruction { + friend class IRBuilder; + +protected: + LaInst(Value *pointer, const std::vector &indices = {}, + BasicBlock *parent = nullptr, const std::string &name = "") + : Instruction(kLa, pointer->getType(), parent, name) { + assert(pointer); + addOperand(pointer); + addOperands(indices); + } + +public: + static bool classof(const Value *value) { return value->getKind() == kLa; } + +public: + int getNumIndices() const { return getNumOperands() - 1; } + Value *getPointer() const { return getOperand(0); } + auto getIndices() const { + return make_range(std::next(operand_begin()), operand_end()); + } + Value *getIndex(int index) const { return getOperand(index + 1); } + +public: + void print(std::ostream &os) const override; +}; + +//! Memset instruction +class MemsetInst : public Instruction { + friend class IRBuilder; + +protected: + MemsetInst(Value *pointer, Value *begin, Value *size, Value *value, + BasicBlock *parent = nullptr, const std::string &name = "") + : Instruction(kMemset, Type::getVoidType(), parent, name) { + addOperand(pointer); + addOperand(begin); + addOperand(size); + addOperand(value); + } + +public: + static bool classof(const Value *value) { return value->getKind() == kMemset; } + +public: + Value *getPointer() const { return getOperand(0); } + Value *getBegin() const { return getOperand(1); } + Value *getSize() const { return getOperand(2); } + Value *getValue() const { return getOperand(3); } + +public: + void print(std::ostream &os) const override; +}; + +//! Get subarray instruction +class GetSubArrayInst : public Instruction { + friend class IRBuilder; + +protected: + GetSubArrayInst(Value *fatherArray, Value *childArray, + const std::vector &indices, + BasicBlock *parent = nullptr, const std::string &name = "") + : Instruction(kGetSubArray, Type::getVoidType(), parent, name) { + addOperand(fatherArray); + addOperand(childArray); + addOperands(indices); + } + +public: + static bool classof(const Value *value) { + return value->getKind() == kGetSubArray; + } + +public: + Value *getFatherArray() const { return getOperand(0); } + Value *getChildArray() const { return getOperand(1); } + int getNumIndices() const { return getNumOperands() - 2; } + auto getIndices() const { + return make_range(std::next(operand_begin(), 2), operand_end()); + } + Value *getIndex(int index) const { return getOperand(index + 2); } + +public: + void print(std::ostream &os) const override; +}; + +//! Phi instruction for SSA form +class PhiInst : public Instruction { + friend class IRBuilder; + +protected: + Value *map_val; // 旧的映射关系 + + PhiInst(Type *type, Value *lhs, const std::vector &rhs, + Value *mval, BasicBlock *parent, const std::string &name = "") + : Instruction(kPhi, type, parent, name), map_val(mval) { + addOperand(lhs); + addOperands(rhs); + } + +public: + static bool classof(const Value *value) { return value->getKind() == kPhi; } + +public: + Value *getMapVal() const { return map_val; } + Value *getPointer() const { return getOperand(0); } + auto getValues() const { + return make_range(std::next(operand_begin()), operand_end()); + } + Value *getValue(unsigned index) const { return getOperand(index + 1); } + +public: + void print(std::ostream &os) const override; +}; + +class Loop { +public: + using block_list = std::vector; + using block_set = std::unordered_set; + using Loop_list = std::vector; + +protected: + Function *parent; // 所属函数 + block_list blocksInLoop; // 循环内的基本块 + BasicBlock *preheaderBlock = nullptr; // 前驱块 + BasicBlock *headerBlock = nullptr; // 循环头 + block_list latchBlock; // 回边块 + block_set exitingBlocks; // 退出块 + block_set exitBlocks; // 退出目标块 + Loop *parentloop = nullptr; // 父循环 + Loop_list subLoops; // 子循环 + size_t loopID; // 循环ID + unsigned loopDepth; // 循环深度 + + Instruction *indCondVar = nullptr; // 循环条件变量 + Instruction::Kind IcmpKind; // 比较类型 + Value *indEnd = nullptr; // 循环结束值 + AllocaInst *IndPhi = nullptr; // 循环变量 + + ConstantValue *indBegin = nullptr; // 循环起始值 + ConstantValue *indStep = nullptr; // 循环步长 + + int StepType = 0; // 循环步长类型 + bool parallelable = false; // 是否可并行 + +public: + explicit Loop(BasicBlock *header, const std::string &name = "") + : headerBlock(header) { + blocksInLoop.push_back(header); + } + + void setloopID() { + static unsigned loopCount = 0; + loopCount = loopCount + 1; + loopID = loopCount; + } + + BasicBlock *getHeader() const { return headerBlock; } + BasicBlock *getPreheaderBlock() const { return preheaderBlock; } + block_list &getLatchBlocks() { return latchBlock; } + block_set &getExitingBlocks() { return exitingBlocks; } + block_set &getExitBlocks() { return exitBlocks; } + Loop *getParentLoop() const { return parentloop; } + void setParentLoop(Loop *parent) { parentloop = parent; } + void addBasicBlock(BasicBlock *bb) { blocksInLoop.push_back(bb); } + void addSubLoop(Loop *loop) { subLoops.push_back(loop); } + void setLoopDepth(unsigned depth) { loopDepth = depth; } + block_list &getBasicBlocks() { return blocksInLoop; } + Loop_list &getSubLoops() { return subLoops; } + unsigned getLoopDepth() const { return loopDepth; } + + bool contains(BasicBlock *bb) const { + return std::find(blocksInLoop.begin(), blocksInLoop.end(), bb) != blocksInLoop.end(); + } + + void addExitingBlock(BasicBlock *bb) { exitingBlocks.insert(bb); } + void addExitBlock(BasicBlock *bb) { exitBlocks.insert(bb); } + void addLatchBlock(BasicBlock *bb) { latchBlock.push_back(bb); } + void setPreheaderBlock(BasicBlock *bb) { preheaderBlock = bb; } + + void setIndexCondInstr(Instruction *instr) { indCondVar = instr; } + void setIcmpKind(Instruction::Kind kind) { IcmpKind = kind; } + Instruction::Kind getIcmpKind() const { return IcmpKind; } + + void setIndEnd(Value *value) { indEnd = value; } + void setIndPhi(AllocaInst *phi) { IndPhi = phi; } + Value *getIndEnd() const { return indEnd; } + AllocaInst *getIndPhi() const { return IndPhi; } + Instruction *getIndCondVar() const { return indCondVar; } + + void setParallelable(bool flag) { parallelable = flag; } + bool isParallelable() const { return parallelable; } +}; + +class Module; +//! Function definition +class Function : public Value { + friend class Module; + +protected: + Function(Module *parent, Type *type, const std::string &name) + : Value(kFunction, type, name), parent(parent), variableID(0), blockID(0), blocks() { + blocks.emplace_back(new BasicBlock(this, "entry")); + } + +public: + static bool classof(const Value *value) { + return value->getKind() == kFunction; + } + +public: + using block_list = std::list>; + using Loop_list = std::list>; + +protected: + Module *parent; + int variableID; + int blockID; + block_list blocks; + /*是放在module中还是新建分析器呢?*/ + Loop_list loops; // 循环列表 + Loop_list topLoops; // 顶层循环 + std::unordered_map basicblock2Loop; // 基本块到循环的映射 + + // 数据流分析相关 + std::unordered_map value2AllocBlocks; + std::unordered_map> value2DefBlocks; + std::unordered_map> value2UseBlocks; + +public: + Type *getReturnType() const { + return getType()->as()->getReturnType(); + } + auto getParamTypes() const { + return getType()->as()->getParamTypes(); + } + auto getBasicBlocks() const { return make_range(blocks); } + BasicBlock *getEntryBlock() const { return blocks.front().get(); } + BasicBlock *addBasicBlock(const std::string &name = "") { + blocks.emplace_back(new BasicBlock(this, name)); + return blocks.back().get(); + } + void removeBasicBlock(BasicBlock *block) { + blocks.remove_if([&](std::unique_ptr &b) -> bool { + return block == b.get(); + }); + } + int allocateVariableID() { return variableID++; } + int allocateblockID() { return blockID++; } + + // 循环分析 + void addLoop(Loop *loop) { loops.emplace_back(loop); } + void addTopLoop(Loop *loop) { topLoops.emplace_back(loop); } + Loop_list &getLoops() { return loops; } + Loop_list &getTopLoops() { return topLoops; } + Loop *getLoopOfBasicBlock(BasicBlock *bb) { + return basicblock2Loop.count(bb) ? basicblock2Loop[bb] : nullptr; + } + void addBBToLoop(BasicBlock *bb, Loop *loop) { basicblock2Loop[bb] = loop; } + + // 数据流分析 + void addValue2AllocBlocks(Value *value, BasicBlock *block) { + value2AllocBlocks[value] = block; + } + BasicBlock *getAllocBlockByValue(Value *value) { + return value2AllocBlocks.count(value) ? value2AllocBlocks[value] : nullptr; + } + void addValue2DefBlocks(Value *value, BasicBlock *block) { + ++value2DefBlocks[value][block]; + } + void addValue2UseBlocks(Value *value, BasicBlock *block) { + ++value2UseBlocks[value][block]; + } + +public: + void print(std::ostream &os) const override; +}; // class Function + +//! Global value declared at file scope +class GlobalValue : public User { + friend class Module; + +protected: + Module *parent; + std::vector initValues; // 初始值列表 + bool isConst; + +protected: + GlobalValue(Module *parent, Type *type, const std::string &name, + const std::vector &dims = {}, + const std::vector &initValues = {}, + bool isConst = false) + : User(kGlobal, type, name), parent(parent), + initValues(initValues), isConst(isConst) { + assert(type->isPointer()); + addOperands(dims); + } + +public: + static bool classof(const Value *value) { + return value->getKind() == kGlobal; + } + +public: + const std::vector& getInitValues() const { return initValues; } + int getNumDims() const { return getNumOperands(); } + Value *getDim(int index) { return getOperand(index); } + bool isConstant() const { return isConst; } + +public: + void print(std::ostream &os) const override{}; +}; // class GlobalValue + +//! IR unit for representing a SysY compile unit +class Module { +protected: + std::vector> children; + std::map functions; + std::map globals; + std::map externalFunctions; // 外部函数声明 + +public: + Module() = default; + +public: + Function *createFunction(const std::string &name, Type *type) { + if (functions.count(name)) + return nullptr; + auto func = new Function(this, type, name); + assert(func); + children.emplace_back(func); + functions.emplace(name, func); + return func; + }; + + Function *createExternalFunction(const std::string &name, Type *type) { + if (externalFunctions.count(name)) + return nullptr; + auto func = new Function(this, type, name); + assert(func); + children.emplace_back(func); + externalFunctions.emplace(name, func); + return func; + } + + GlobalValue *createGlobalValue(const std::string &name, Type *type, + const std::vector &dims = {}, + const std::vector &initValues = {}, + bool isConst = false) { + if (globals.count(name)) + return nullptr; + auto global = new GlobalValue(this, type, name, dims, initValues, isConst); + assert(global); + children.emplace_back(global); + globals.emplace(name, global); + return global; + } + + Function *getFunction(const std::string &name) const { + auto result = functions.find(name); + if (result == functions.end()) + return nullptr; + return result->second; + } + + Function *getExternalFunction(const std::string &name) const { + auto result = externalFunctions.find(name); + if (result == externalFunctions.end()) + return nullptr; + return result->second; + } + + GlobalValue *getGlobalValue(const std::string &name) const { + auto result = globals.find(name); + if (result == globals.end()) + return nullptr; + return result->second; + } + + std::map *getFunctions() { return &functions; } + std::map *getGlobalValues() { return &globals; } + std::map *getExternalFunctions() { return &externalFunctions; } + +public: + void print(std::ostream &os) const; +}; // class Module + +/*! + * @} + */ +inline std::ostream &operator<<(std::ostream &os, const Type &type) { + type.print(os); + return os; +} + +inline std::ostream &operator<<(std::ostream &os, const Value &value) { + value.print(os); + return os; +} + +} // namespace sysy diff --git a/src/IRBuilder.h b/src/include/IRBuilder.h similarity index 91% rename from src/IRBuilder.h rename to src/include/IRBuilder.h index 60cb092..c6d12da 100644 --- a/src/IRBuilder.h +++ b/src/include/IRBuilder.h @@ -8,14 +8,23 @@ namespace sysy { class IRBuilder { private: - BasicBlock *block; - BasicBlock::iterator position; + unsigned labelIndex; ///< 基本块标签编号 + unsigned tmpIndex; ///< 临时变量编号 + + BasicBlock *block; ///< 当前基本块 + BasicBlock::iterator position; ///< 当前基本块指令列表位置的迭代器 + + std::vector trueBlocks; ///< true分支基本块列表 + std::vector falseBlocks; ///< false分支基本块列表 + + std::vector breakBlocks; ///< break目标块列表 + std::vector continueBlocks; ///< continue目标块列表 public: - IRBuilder() = default; - IRBuilder(BasicBlock *block) : block(block), position(block->end()) {} + IRBuilder() : labelIndex(0), tmpIndex(0), block(nullptr) {} + explicit IRBuilder(BasicBlock *block) : labelIndex(0), tmpIndex(0), block(block), position(block->end()) {} IRBuilder(BasicBlock *block, BasicBlock::iterator position) - : block(block), position(position) {} + : labelIndex(0), tmpIndex(0), block(block), position(position) {} public: BasicBlock *getBasicBlock() const { return block; } @@ -60,7 +69,7 @@ public: name); } UnaryInst *createIToFInst(Value *operand, const std::string &name = "") { - return createUnaryInst(Instruction::kIToF, Type::getFloatType(), operand, + return createUnaryInst(Instruction::kItoF, Type::getFloatType(), operand, name); } BinaryInst *createBinaryInst(Instruction::Kind kind, Type *type, Value *lhs, diff --git a/src/LLVMIRGenerator.h b/src/include/LLVMIRGenerator.h similarity index 100% rename from src/LLVMIRGenerator.h rename to src/include/LLVMIRGenerator.h diff --git a/src/LLVMIRGenerator_1.h b/src/include/LLVMIRGenerator_1.h similarity index 100% rename from src/LLVMIRGenerator_1.h rename to src/include/LLVMIRGenerator_1.h diff --git a/src/SysYFormatter.h b/src/include/SysYFormatter.h similarity index 100% rename from src/SysYFormatter.h rename to src/include/SysYFormatter.h diff --git a/src/include/SysYIRAnalyser.h b/src/include/SysYIRAnalyser.h new file mode 100644 index 0000000..e69de29 diff --git a/src/SysYIRGenerator.h b/src/include/SysYIRGenerator.h similarity index 90% rename from src/SysYIRGenerator.h rename to src/include/SysYIRGenerator.h index 3c89ce0..638986a 100644 --- a/src/SysYIRGenerator.h +++ b/src/include/SysYIRGenerator.h @@ -134,15 +134,6 @@ public: std::any visitLOrExp(SysYParser::LOrExpContext *ctx) override; std::any visitConstExp(SysYParser::ConstExpContext *ctx) override; -private: - std::any visitConstGlobalDecl(SysYParser::ConstDeclContext *ctx, Type* type); - std::any visitVarGlobalDecl(SysYParser::VarDeclContext *ctx, Type* type); - std::any visitConstLocalDecl(SysYParser::ConstDeclContext *ctx, Type* type); - std::any visitVarLocalDecl(SysYParser::VarDeclContext *ctx, Type* type); - Type *getArithmeticResultType(Type *lhs, Type *rhs) { - assert(lhs->isIntOrFloat() and rhs->isIntOrFloat()); - return lhs == rhs ? lhs : Type::getFloatType(); - } }; // class SysYIRGenerator diff --git a/src/range.h b/src/include/range.h similarity index 100% rename from src/range.h rename to src/include/range.h From 30f89bba23defec8f53b577458af45f738fec1f9 Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Sat, 21 Jun 2025 12:53:41 +0800 Subject: [PATCH 04/13] =?UTF-8?q?=E6=9B=B4=E6=96=B0IR=E7=BB=93=E6=9E=84?= =?UTF-8?q?=EF=BC=8C=E9=87=8D=E5=86=99IRBuilder?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/IR.cpp | 963 +++++++++++++++++++-------- src/include/IR.h | 1368 +++++++++++++++++++++++++-------------- src/include/IRBuilder.h | 477 ++++++++------ 3 files changed, 1869 insertions(+), 939 deletions(-) diff --git a/src/IR.cpp b/src/IR.cpp index 1564415..ebd61c1 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -1,310 +1,719 @@ -#pragma once - #include "IR.h" +#include #include #include +#include +#include #include -#include #include +#include "IRBuilder.h" +/** + * @file IR.cpp + * + * @brief 定义IR相关类型与操作的源文件 + */ namespace sysy { -class IRBuilder { -private: - unsigned labelIndex; ///< 基本块标签编号 - unsigned tmpIndex; ///< 临时变量编号 +//===----------------------------------------------------------------------===// +// Types +//===----------------------------------------------------------------------===// - BasicBlock *block; ///< 当前基本块 - BasicBlock::iterator position; ///< 当前基本块指令列表位置的迭代器 +auto Type::getIntType() -> Type * { + static Type intType(kInt); + return &intType; +} - std::vector trueBlocks; ///< true分支基本块列表 - std::vector falseBlocks; ///< false分支基本块列表 +auto Type::getFloatType() -> Type * { + static Type floatType(kFloat); + return &floatType; +} - std::vector breakBlocks; ///< break目标块列表 - std::vector continueBlocks; ///< continue目标块列表 +auto Type::getVoidType() -> Type * { + static Type voidType(kVoid); + return &voidType; +} -public: - IRBuilder() : labelIndex(0), tmpIndex(0), block(nullptr) {} - explicit IRBuilder(BasicBlock *block) : labelIndex(0), tmpIndex(0), block(block), position(block->end()) {} - IRBuilder(BasicBlock *block, BasicBlock::iterator position) - : labelIndex(0), tmpIndex(0), block(block), position(position) {} +auto Type::getLabelType() -> Type * { + static Type labelType(kLabel); + return &labelType; +} -public: - unsigned getLabelIndex() { return labelIndex++; } - unsigned getTmpIndex() { return tmpIndex++; } - - BasicBlock *getBasicBlock() const { return block; } - BasicBlock::iterator getPosition() const { return position; } - - void setPosition(BasicBlock *block, BasicBlock::iterator position) { - this->block = block; - this->position = position; +auto Type::getPointerType(Type *baseType) -> Type * { + // forward to PointerType + return PointerType::get(baseType); +} + +auto Type::getFunctionType(Type *returnType, const std::vector ¶mTypes) -> Type * { + // forward to FunctionType + return FunctionType::get(returnType, paramTypes); +} + +auto Type::getSize() const -> unsigned { + switch (kind) { + case kInt: + case kFloat: + return 4; + case kLabel: + case kPointer: + case kFunction: + return 8; + case kVoid: + return 0; } - void setPosition(BasicBlock::iterator position) { this->position = position; } + return 0; +} - // 控制流管理函数 - BasicBlock *getBreakBlock() const { return breakBlocks.back(); } - BasicBlock *popBreakBlock() { - auto result = breakBlocks.back(); - breakBlocks.pop_back(); - return result; +PointerType* PointerType::get(Type *baseType) { + static std::map> pointerTypes; + auto iter = pointerTypes.find(baseType); + if (iter != pointerTypes.end()) { + return iter->second.get(); } - BasicBlock *getContinueBlock() const { return continueBlocks.back(); } - BasicBlock *popContinueBlock() { - auto result = continueBlocks.back(); - continueBlocks.pop_back(); - return result; - } - BasicBlock *getTrueBlock() const { return trueBlocks.back(); } - BasicBlock *getFalseBlock() const { return falseBlocks.back(); } - BasicBlock *popTrueBlock() { - auto result = trueBlocks.back(); - trueBlocks.pop_back(); - return result; - } - BasicBlock *popFalseBlock() { - auto result = falseBlocks.back(); - falseBlocks.pop_back(); - return result; - } - void pushBreakBlock(BasicBlock *block) { breakBlocks.push_back(block); } - void pushContinueBlock(BasicBlock *block) { continueBlocks.push_back(block); } - void pushTrueBlock(BasicBlock *block) { trueBlocks.push_back(block); } - void pushFalseBlock(BasicBlock *block) { falseBlocks.push_back(block); } + auto type = new PointerType(baseType); + assert(type); + auto result = pointerTypes.emplace(baseType, type); + return result.first->second.get(); +} -public: - // 指令创建函数 - Instruction *insertInst(Instruction *inst) { - assert(inst); - block->getInstructions().emplace(position, inst); - return inst; +FunctionType*FunctionType::get(Type *returnType, const std::vector ¶mTypes) { + static std::set> functionTypes; + auto iter = + std::find_if(functionTypes.begin(), functionTypes.end(), [&](const std::unique_ptr &type) -> bool { + if (returnType != type->getReturnType() || + paramTypes.size() != static_cast(type->getParamTypes().size())) { + return false; + } + return std::equal(paramTypes.begin(), paramTypes.end(), type->getParamTypes().begin()); + }); + if (iter != functionTypes.end()) { + return iter->get(); + } + auto type = new FunctionType(returnType, paramTypes); + assert(type); + auto result = functionTypes.emplace(type); + return result.first->get(); +} + +void Value::replaceAllUsesWith(Value *value) { + for (auto &use : uses) { + use->getUser()->setOperand(use->getIndex(), value); + } + uses.clear(); +} + +ConstantValue* ConstantValue::get(int value) { + static std::map> intConstants; + auto iter = intConstants.find(value); + if (iter != intConstants.end()) { + return iter->second.get(); + } + auto inst = new ConstantValue(value); + assert(inst); + auto result = intConstants.emplace(value, inst); + return result.first->second.get(); +} + +ConstantValue* ConstantValue::get(float value) { + static std::map> floatConstants; + auto iter = floatConstants.find(value); + if (iter != floatConstants.end()) { + return iter->second.get(); + } + auto inst = new ConstantValue(value); + assert(inst); + auto result = floatConstants.emplace(value, inst); + return result.first->second.get(); +} + +auto Function::getCalleesWithNoExternalAndSelf() -> std::set { + std::set result; + for (auto callee : callees) { + if (parent->getExternalFunctions().count(callee->getName()) == 0U && callee != this) { + result.insert(callee); + } + } + return result; +} + +Function * Function::clone(const std::string &suffix) const { + std::stringstream ss; + std::map oldNewBlockMap; + IRBuilder builder; + auto newFunction = new Function(parent, type, name); + newFunction->getEntryBlock()->setName(blocks.front()->getName()); + oldNewBlockMap.emplace(blocks.front().get(), newFunction->getEntryBlock()); + auto oldBlockListIter = std::next(blocks.begin()); + while (oldBlockListIter != blocks.end()) { + auto newBlock = newFunction->addBasicBlock(oldBlockListIter->get()->getName()); + oldNewBlockMap.emplace(oldBlockListIter->get(), newBlock); + oldBlockListIter++; } - UnaryInst *createUnaryInst(Instruction::Kind kind, Type *type, Value *operand, - const std::string &name = "") { - auto inst = new UnaryInst(kind, type, operand, block, name); - return static_cast(insertInst(inst)); + for (const auto &oldNewBlockItem : oldNewBlockMap) { + auto oldBlock = oldNewBlockItem.first; + auto newBlock = oldNewBlockItem.second; + for (const auto &oldPred : oldBlock->getPredecessors()) { + newBlock->addPredecessor(oldNewBlockMap.at(oldPred)); + } + for (const auto &oldSucc : oldBlock->getSuccessors()) { + newBlock->addSuccessor(oldNewBlockMap.at(oldSucc)); + } } - UnaryInst *createNegInst(Value *operand, const std::string &name = "") { - return createUnaryInst(Instruction::kNeg, Type::getIntType(), operand, name); + std::map oldNewValueMap; + std::map isAddedToCreate; + std::map isCreated; + std::queue toCreate; + + for (const auto &oldBlock : blocks) { + for (const auto &inst : oldBlock->getInstructions()) { + isAddedToCreate.emplace(inst.get(), false); + isCreated.emplace(inst.get(), false); + } + } + for (const auto &oldBlock : blocks) { + for (const auto &inst : oldBlock->getInstructions()) { + for (const auto &valueUse : inst->getOperands()) { + auto value = valueUse->getValue(); + if (oldNewValueMap.find(value) == oldNewValueMap.end()) { + auto oldAllocInst = dynamic_cast(value); + if (oldAllocInst != nullptr) { + std::vector dims; + for (const auto &dim : oldAllocInst->getDims()) { + dims.emplace_back(dim->getValue()); + } + ss << oldAllocInst->getName() << suffix; + auto newAllocInst = + new AllocaInst(oldAllocInst->getType(), dims, oldNewBlockMap.at(oldAllocInst->getParent()), ss.str()); + ss.str(""); + oldNewValueMap.emplace(oldAllocInst, newAllocInst); + if (isAddedToCreate.find(oldAllocInst) == isAddedToCreate.end()) { + isAddedToCreate.emplace(oldAllocInst, true); + } else { + isAddedToCreate.at(oldAllocInst) = true; + } + if (isCreated.find(oldAllocInst) == isCreated.end()) { + isCreated.emplace(oldAllocInst, true); + } else { + isCreated.at(oldAllocInst) = true; + } + } + } + } + if (inst->getKind() == Instruction::kAlloca) { + if (oldNewValueMap.find(inst.get()) == oldNewValueMap.end()) { + auto oldAllocInst = dynamic_cast(inst.get()); + std::vector dims; + for (const auto &dim : oldAllocInst->getDims()) { + dims.emplace_back(dim->getValue()); + } + ss << oldAllocInst->getName() << suffix; + auto newAllocInst = + new AllocaInst(oldAllocInst->getType(), dims, oldNewBlockMap.at(oldAllocInst->getParent()), ss.str()); + ss.str(""); + oldNewValueMap.emplace(oldAllocInst, newAllocInst); + if (isAddedToCreate.find(oldAllocInst) == isAddedToCreate.end()) { + isAddedToCreate.emplace(oldAllocInst, true); + } else { + isAddedToCreate.at(oldAllocInst) = true; + } + if (isCreated.find(oldAllocInst) == isCreated.end()) { + isCreated.emplace(oldAllocInst, true); + } else { + isCreated.at(oldAllocInst) = true; + } + } + } + } + } + for (const auto &oldBlock : blocks) { + for (const auto &inst : oldBlock->getInstructions()) { + for (const auto &valueUse : inst->getOperands()) { + auto value = valueUse->getValue(); + if (oldNewValueMap.find(value) == oldNewValueMap.end()) { + auto globalValue = dynamic_cast(value); + auto constVariable = dynamic_cast(value); + auto constantValue = dynamic_cast(value); + auto functionValue = dynamic_cast(value); + if (globalValue != nullptr || constantValue != nullptr || constVariable != nullptr || + functionValue != nullptr) { + if (functionValue == this) { + oldNewValueMap.emplace(value, newFunction); + } else { + oldNewValueMap.emplace(value, value); + } + isCreated.emplace(value, true); + isAddedToCreate.emplace(value, true); + } + } + } + } + } + for (const auto &oldBlock : blocks) { + for (const auto &inst : oldBlock->getInstructions()) { + if (inst->getKind() != Instruction::kAlloca) { + bool isReady = true; + for (const auto &use : inst->getOperands()) { + auto value = use->getValue(); + if (dynamic_cast(value) == nullptr && !isCreated.at(value)) { + isReady = false; + break; + } + } + if (isReady) { + toCreate.push(inst.get()); + isAddedToCreate.at(inst.get()) = true; + } + } + } } - UnaryInst *createNotInst(Value *operand, const std::string &name = "") { - return createUnaryInst(Instruction::kNot, Type::getIntType(), operand, name); - } + while (!toCreate.empty()) { + auto inst = dynamic_cast(toCreate.front()); + toCreate.pop(); - UnaryInst *createFtoIInst(Value *operand, const std::string &name = "") { - return createUnaryInst(Instruction::kFtoI, Type::getIntType(), operand, name); - } - - UnaryInst *createBitFtoIInst(Value *operand, const std::string &name = "") { - return createUnaryInst(Instruction::kBitFtoI, Type::getIntType(), operand, name); - } - - UnaryInst *createFNegInst(Value *operand, const std::string &name = "") { - return createUnaryInst(Instruction::kFNeg, Type::getFloatType(), operand, name); - } - - UnaryInst *createFNotInst(Value *operand, const std::string &name = "") { - return createUnaryInst(Instruction::kFNot, Type::getIntType(), operand, name); - } - - UnaryInst *createIToFInst(Value *operand, const std::string &name = "") { - return createUnaryInst(Instruction::kItoF, Type::getFloatType(), operand, name); - } - - UnaryInst *createBitItoFInst(Value *operand, const std::string &name = "") { - return createUnaryInst(Instruction::kBitItoF, Type::getFloatType(), operand, name); - } - - BinaryInst *createBinaryInst(Instruction::Kind kind, Type *type, Value *lhs, - Value *rhs, const std::string &name = "") { - auto inst = new BinaryInst(kind, type, lhs, rhs, block, name); - return static_cast(insertInst(inst)); - } - - BinaryInst *createAddInst(Value *lhs, Value *rhs, const std::string &name = "") { - return createBinaryInst(Instruction::kAdd, Type::getIntType(), lhs, rhs, name); - } - - BinaryInst *createSubInst(Value *lhs, Value *rhs, const std::string &name = "") { - return createBinaryInst(Instruction::kSub, Type::getIntType(), lhs, rhs, name); - } - - BinaryInst *createMulInst(Value *lhs, Value *rhs, const std::string &name = "") { - return createBinaryInst(Instruction::kMul, Type::getIntType(), lhs, rhs, name); - } - - BinaryInst *createDivInst(Value *lhs, Value *rhs, const std::string &name = "") { - return createBinaryInst(Instruction::kDiv, Type::getIntType(), lhs, rhs, name); - } - - BinaryInst *createRemInst(Value *lhs, Value *rhs, const std::string &name = "") { - return createBinaryInst(Instruction::kRem, Type::getIntType(), lhs, rhs, name); - } - - BinaryInst *createICmpEQInst(Value *lhs, Value *rhs, const std::string &name = "") { - return createBinaryInst(Instruction::kICmpEQ, Type::getIntType(), lhs, rhs, name); - } - - BinaryInst *createICmpNEInst(Value *lhs, Value *rhs, const std::string &name = "") { - return createBinaryInst(Instruction::kICmpNE, Type::getIntType(), lhs, rhs, name); - } - - BinaryInst *createICmpLTInst(Value *lhs, Value *rhs, const std::string &name = "") { - return createBinaryInst(Instruction::kICmpLT, Type::getIntType(), lhs, rhs, name); - } - - BinaryInst *createICmpLEInst(Value *lhs, Value *rhs, const std::string &name = "") { - return createBinaryInst(Instruction::kICmpLE, Type::getIntType(), lhs, rhs, name); - } - - BinaryInst *createICmpGTInst(Value *lhs, Value *rhs, const std::string &name = "") { - return createBinaryInst(Instruction::kICmpGT, Type::getIntType(), lhs, rhs, name); - } - - BinaryInst *createICmpGEInst(Value *lhs, Value *rhs, const std::string &name = "") { - return createBinaryInst(Instruction::kICmpGE, Type::getIntType(), lhs, rhs, name); - } - - BinaryInst *createFAddInst(Value *lhs, Value *rhs, const std::string &name = "") { - return createBinaryInst(Instruction::kFAdd, Type::getFloatType(), lhs, rhs, name); - } - - BinaryInst *createFSubInst(Value *lhs, Value *rhs, const std::string &name = "") { - return createBinaryInst(Instruction::kFSub, Type::getFloatType(), lhs, rhs, name); - } - - BinaryInst *createFMulInst(Value *lhs, Value *rhs, const std::string &name = "") { - return createBinaryInst(Instruction::kFMul, Type::getFloatType(), lhs, rhs, name); - } - - BinaryInst *createFDivInst(Value *lhs, Value *rhs, const std::string &name = "") { - return createBinaryInst(Instruction::kFDiv, Type::getFloatType(), lhs, rhs, name); - } - - BinaryInst *createFCmpEQInst(Value *lhs, Value *rhs, const std::string &name = "") { - return createBinaryInst(Instruction::kFCmpEQ, Type::getIntType(), lhs, rhs, name); - } - - BinaryInst *createFCmpNEInst(Value *lhs, Value *rhs, const std::string &name = "") { - return createBinaryInst(Instruction::kFCmpNE, Type::getIntType(), lhs, rhs, name); - } - - BinaryInst *createFCmpLTInst(Value *lhs, Value *rhs, const std::string &name = "") { - return createBinaryInst(Instruction::kFCmpLT, Type::getIntType(), lhs, rhs, name); - } - - BinaryInst *createFCmpLEInst(Value *lhs, Value *rhs, const std::string &name = "") { - return createBinaryInst(Instruction::kFCmpLE, Type::getIntType(), lhs, rhs, name); - } - - BinaryInst *createFCmpGTInst(Value *lhs, Value *rhs, const std::string &name = "") { - return createBinaryInst(Instruction::kFCmpGT, Type::getIntType(), lhs, rhs, name); - } - - BinaryInst *createFCmpGEInst(Value *lhs, Value *rhs, const std::string &name = "") { - return createBinaryInst(Instruction::kFCmpGE, Type::getIntType(), lhs, rhs, name); - } - - BinaryInst *createAndInst(Value *lhs, Value *rhs, const std::string &name = "") { - return createBinaryInst(Instruction::kAnd, Type::getIntType(), lhs, rhs, name); - } - - BinaryInst *createOrInst(Value *lhs, Value *rhs, const std::string &name = "") { - return createBinaryInst(Instruction::kOr, Type::getIntType(), lhs, rhs, name); - } - - CallInst *createCallInst(Function *callee, const std::vector &args, - const std::string &name = "") { - auto inst = new CallInst(callee, args, block, name); - return static_cast(insertInst(inst)); - } - - ReturnInst *createReturnInst(Value *value = nullptr) { - auto inst = new ReturnInst(value, block); - return static_cast(insertInst(inst)); - } - - UncondBrInst *createUncondBrInst(BasicBlock *thenBlock, - const std::vector &args) { - auto inst = new UncondBrInst(thenBlock, args, block); - return static_cast(insertInst(inst)); - } - - CondBrInst *createCondBrInst(Value *condition, BasicBlock *thenBlock, - BasicBlock *elseBlock, - const std::vector &thenArgs, - const std::vector &elseArgs) { - auto inst = new CondBrInst(condition, thenBlock, elseBlock, thenArgs, elseArgs, block); - return static_cast(insertInst(inst)); - } - - AllocaInst *createAllocaInst(Type *type, const std::vector &dims = {}, - const std::string &name = "") { - auto inst = new AllocaInst(type, dims, block, name); - return static_cast(insertInst(inst)); - } - - AllocaInst *createAllocaInstWithoutInsert(Type *type, - const std::vector &dims = {}, - BasicBlock *parent = nullptr, - const std::string &name = "") { - return new AllocaInst(type, dims, parent, name); - } - - LoadInst *createLoadInst(Value *pointer, const std::vector &indices = {}, - const std::string &name = "") { - auto inst = new LoadInst(pointer, indices, block, name); - return static_cast(insertInst(inst)); - } - - LaInst *createLaInst(Value *pointer, const std::vector &indices = {}, - const std::string &name = "") { - auto inst = new LaInst(pointer, indices, block, name); - return static_cast(insertInst(inst)); - } - - GetSubArrayInst *createGetSubArray(LVal *fatherArray, - const std::vector &indices, - const std::string &name = "") { - assert(fatherArray->getLValNumDims() > indices.size()); - std::vector subDims; - auto dims = fatherArray->getLValDims(); - auto iter = std::next(dims.begin(), indices.size()); - while (iter != dims.end()) { - subDims.emplace_back(*iter); - iter++; + bool isReady = true; + for (const auto &valueUse : inst->getOperands()) { + auto value = dynamic_cast(valueUse->getValue()); + if (value != nullptr && !isCreated.at(value)) { + isReady = false; + break; + } } - auto fatherArrayValue = dynamic_cast(fatherArray); - AllocaInst * childArray = new AllocaInst(fatherArrayValue->getType(), subDims, block); - auto inst = new GetSubArrayInst(fatherArray, childArray, indices, block ,name); - return static_cast(insertInst(inst)); + if (!isReady) { + toCreate.push(inst); + continue; + } + isCreated.at(inst) = true; + switch (inst->getKind()) { + case Instruction::kAdd: + case Instruction::kSub: + case Instruction::kMul: + case Instruction::kDiv: + case Instruction::kRem: + case Instruction::kICmpEQ: + case Instruction::kICmpNE: + case Instruction::kICmpLT: + case Instruction::kICmpGT: + case Instruction::kICmpLE: + case Instruction::kICmpGE: + case Instruction::kAnd: + case Instruction::kOr: + case Instruction::kFAdd: + case Instruction::kFSub: + case Instruction::kFMul: + case Instruction::kFDiv: + case Instruction::kFCmpEQ: + case Instruction::kFCmpNE: + case Instruction::kFCmpLT: + case Instruction::kFCmpGT: + case Instruction::kFCmpLE: + case Instruction::kFCmpGE: { + auto oldBinaryInst = dynamic_cast(inst); + auto lhs = oldBinaryInst->getLhs(); + auto rhs = oldBinaryInst->getRhs(); + Value *newLhs; + Value *newRhs; + newLhs = oldNewValueMap[lhs]; + newRhs = oldNewValueMap[rhs]; + ss << oldBinaryInst->getName() << suffix; + auto newBinaryInst = new BinaryInst(oldBinaryInst->getKind(), oldBinaryInst->getType(), newLhs, newRhs, + oldNewBlockMap.at(oldBinaryInst->getParent()), ss.str()); + ss.str(""); + oldNewValueMap.emplace(oldBinaryInst, newBinaryInst); + break; + } + + case Instruction::kNeg: + case Instruction::kNot: + case Instruction::kFNeg: + case Instruction::kFNot: + case Instruction::kItoF: + case Instruction::kFtoI: { + auto oldUnaryInst = dynamic_cast(inst); + auto hs = oldUnaryInst->getOperand(); + Value *newHs; + newHs = oldNewValueMap.at(hs); + ss << oldUnaryInst->getName() << suffix; + auto newUnaryInst = new UnaryInst(oldUnaryInst->getKind(), oldUnaryInst->getType(), newHs, + oldNewBlockMap.at(oldUnaryInst->getParent()), ss.str()); + ss.str(""); + oldNewValueMap.emplace(oldUnaryInst, newUnaryInst); + break; + } + + case Instruction::kCall: { + auto oldCallInst = dynamic_cast(inst); + std::vector newArgumnts; + for (const auto &arg : oldCallInst->getArguments()) { + newArgumnts.emplace_back(oldNewValueMap.at(arg->getValue())); + } + + ss << oldCallInst->getName() << suffix; + CallInst *newCallInst; + newCallInst = + new CallInst(oldCallInst->getCallee(), newArgumnts, oldNewBlockMap.at(oldCallInst->getParent()), ss.str()); + ss.str(""); + // if (oldCallInst->getCallee() != this) { + // newCallInst = new CallInst(oldCallInst->getCallee(), newArgumnts, + // oldNewBlockMap.at(oldCallInst->getParent()), + // oldCallInst->getName()); + // } else { + // newCallInst = new CallInst(newFunction, newArgumnts, oldNewBlockMap.at(oldCallInst->getParent()), + // oldCallInst->getName()); + // } + + oldNewValueMap.emplace(oldCallInst, newCallInst); + break; + } + + case Instruction::kCondBr: { + auto oldCondBrInst = dynamic_cast(inst); + auto oldCond = oldCondBrInst->getCondition(); + Value *newCond; + newCond = oldNewValueMap.at(oldCond); + auto newCondBrInst = new CondBrInst(newCond, oldNewBlockMap.at(oldCondBrInst->getThenBlock()), + oldNewBlockMap.at(oldCondBrInst->getElseBlock()), {}, {}, + oldNewBlockMap.at(oldCondBrInst->getParent())); + oldNewValueMap.emplace(oldCondBrInst, newCondBrInst); + break; + } + + case Instruction::kBr: { + auto oldBrInst = dynamic_cast(inst); + auto newBrInst = + new UncondBrInst(oldNewBlockMap.at(oldBrInst->getBlock()), {}, oldNewBlockMap.at(oldBrInst->getParent())); + oldNewValueMap.emplace(oldBrInst, newBrInst); + break; + } + + case Instruction::kReturn: { + auto oldReturnInst = dynamic_cast(inst); + auto oldRval = oldReturnInst->getReturnValue(); + Value *newRval = nullptr; + if (oldRval != nullptr) { + newRval = oldNewValueMap.at(oldRval); + } + auto newReturnInst = + new ReturnInst(newRval, oldNewBlockMap.at(oldReturnInst->getParent()), oldReturnInst->getName()); + oldNewValueMap.emplace(oldReturnInst, newReturnInst); + break; + } + + case Instruction::kAlloca: { + assert(false); + } + + case Instruction::kLoad: { + auto oldLoadInst = dynamic_cast(inst); + auto oldPointer = oldLoadInst->getPointer(); + Value *newPointer; + newPointer = oldNewValueMap.at(oldPointer); + + std::vector newIndices; + for (const auto &index : oldLoadInst->getIndices()) { + newIndices.emplace_back(oldNewValueMap.at(index->getValue())); + } + ss << oldLoadInst->getName() << suffix; + auto newLoadInst = new LoadInst(newPointer, newIndices, oldNewBlockMap.at(oldLoadInst->getParent()), ss.str()); + ss.str(""); + oldNewValueMap.emplace(oldLoadInst, newLoadInst); + break; + } + + case Instruction::kStore: { + auto oldStoreInst = dynamic_cast(inst); + auto oldPointer = oldStoreInst->getPointer(); + auto oldValue = oldStoreInst->getValue(); + Value *newPointer; + Value *newValue; + std::vector newIndices; + newPointer = oldNewValueMap.at(oldPointer); + newValue = oldNewValueMap.at(oldValue); + for (const auto &index : oldStoreInst->getIndices()) { + newIndices.emplace_back(oldNewValueMap.at(index->getValue())); + } + auto newStoreInst = new StoreInst(newValue, newPointer, newIndices, + oldNewBlockMap.at(oldStoreInst->getParent()), oldStoreInst->getName()); + oldNewValueMap.emplace(oldStoreInst, newStoreInst); + break; + } + + case Instruction::kLa: { + auto oldLaInst = dynamic_cast(inst); + auto oldPointer = oldLaInst->getPointer(); + Value *newPointer; + std::vector newIndices; + newPointer = oldNewValueMap.at(oldPointer); + + for (const auto &index : oldLaInst->getIndices()) { + newIndices.emplace_back(oldNewValueMap.at(index->getValue())); + } + ss << oldLaInst->getName() << suffix; + auto newLaInst = new LaInst(newPointer, newIndices, oldNewBlockMap.at(oldLaInst->getParent()), ss.str()); + ss.str(""); + oldNewValueMap.emplace(oldLaInst, newLaInst); + break; + } + + case Instruction::kGetSubArray: { + auto oldGetSubArrayInst = dynamic_cast(inst); + auto oldFather = oldGetSubArrayInst->getFatherArray(); + auto oldChild = oldGetSubArrayInst->getChildArray(); + Value *newFather; + Value *newChild; + std::vector newIndices; + newFather = oldNewValueMap.at(oldFather); + newChild = oldNewValueMap.at(oldChild); + + for (const auto &index : oldGetSubArrayInst->getIndices()) { + newIndices.emplace_back(oldNewValueMap.at(index->getValue())); + } + ss << oldGetSubArrayInst->getName() << suffix; + auto newGetSubArrayInst = + new GetSubArrayInst(dynamic_cast(newFather), dynamic_cast(newChild), newIndices, + oldNewBlockMap.at(oldGetSubArrayInst->getParent()), ss.str()); + ss.str(""); + oldNewValueMap.emplace(oldGetSubArrayInst, newGetSubArrayInst); + break; + } + + case Instruction::kMemset: { + auto oldMemsetInst = dynamic_cast(inst); + auto oldPointer = oldMemsetInst->getPointer(); + auto oldValue = oldMemsetInst->getValue(); + Value *newPointer; + Value *newValue; + newPointer = oldNewValueMap.at(oldPointer); + newValue = oldNewValueMap.at(oldValue); + + auto newMemsetInst = new MemsetInst(newPointer, oldMemsetInst->getBegin(), oldMemsetInst->getSize(), newValue, + oldNewBlockMap.at(oldMemsetInst->getParent()), oldMemsetInst->getName()); + oldNewValueMap.emplace(oldMemsetInst, newMemsetInst); + break; + } + + case Instruction::kInvalid: + case Instruction::kPhi: { + break; + } + + default: + assert(false); + } + for (const auto &userUse : inst->getUses()) { + auto user = userUse->getUser(); + if (!isAddedToCreate.at(user)) { + toCreate.push(user); + isAddedToCreate.at(user) = true; + } + } } - MemsetInst *createMemsetInst(Value *pointer, Value *begin, Value *size, - Value *value, const std::string &name = "") { - auto inst = new MemsetInst(pointer, begin, size, value, block, name); - return static_cast(insertInst(inst)); + for (const auto &oldBlock : blocks) { + auto newBlock = oldNewBlockMap.at(oldBlock.get()); + builder.setPosition(newBlock, newBlock->end()); + for (const auto &inst : oldBlock->getInstructions()) { + builder.insertInst(dynamic_cast(oldNewValueMap.at(inst.get()))); + } } - StoreInst *createStoreInst(Value *value, Value *pointer, - const std::vector &indices = {}, - const std::string &name = "") { - auto inst = new StoreInst(value, pointer, indices, block, name); - return static_cast(insertInst(inst)); + for (const auto ¶m : blocks.front()->getArguments()) { + newFunction->getEntryBlock()->insertArgument(dynamic_cast(oldNewValueMap.at(param))); } - PhiInst *createPhiInst(Type *type, Value *lhs, BasicBlock *parent, - const std::string &name = "") { - auto predNum = parent->getNumPredecessors(); - std::vector rhs(predNum, lhs); - auto inst = new PhiInst(type, lhs, rhs, lhs, parent, name); - parent->getInstructions().emplace(parent->begin(), inst); - return inst; - } -}; + return newFunction; +} +/** + * @brief 设置操作数 + * + * @param [in] index 所要设置的操作数的位置 + * @param [in] value 所要设置成的value + * @return 无返回值 + */ +void User::setOperand(unsigned index, Value *value) { + assert(index < getNumOperands()); + operands[index]->setValue(value); + value->addUse(operands[index]); +} +/** + * @brief 替换操作数 + * + * @param [in] index 所要替换的操作数的位置 + * @param [in] value 所要替换成的value + * @return 无返回值 + */ +void User::replaceOperand(unsigned index, Value *value) { + assert(index < getNumOperands()); + auto &use = operands[index]; + use->getValue()->removeUse(use); + use->setValue(value); + value->addUse(use); +} -} // namespace sysy +CallInst::CallInst(Function *callee, const std::vector &args, BasicBlock *parent, const std::string &name) + : Instruction(kCall, callee->getReturnType(), parent, name) { + addOperand(callee); + for (auto arg : args) { + addOperand(arg); + } +} +/** + * @brief 获取被调用函数的指针 + * + * @return 被调用函数的指针 + */ +Function * CallInst::getCallee() const { return dynamic_cast(getOperand(0)); } + +/** + * @brief 获取变量指针 + * + * @param [in] name 变量名字 + * @return 变量指针 + */ +auto SymbolTable::getVariable(const std::string &name) const -> User * { + auto node = curNode; + while (node != nullptr) { + auto iter = node->varList.find(name); + if (iter != node->varList.end()) { + return iter->second; + } + node = node->pNode; + } + + return nullptr; +} +/** + * @brief 添加变量 + * + * @param [in] name 变量名字 + * @param [in] variable 变量指针 + * @return 变量指针 + */ +auto SymbolTable::addVariable(const std::string &name, User *variable) -> User * { + User *result = nullptr; + if (curNode != nullptr) { + std::stringstream ss; + auto iter = variableIndex.find(name); + if (iter != variableIndex.end()) { + ss << name << "(" << iter->second << ")"; + iter->second += 1; + } else { + variableIndex.emplace(name, 1); + ss << name << "(" << 0 << ")"; + } + + variable->setName(ss.str()); + curNode->varList.emplace(name, variable); + auto global = dynamic_cast(variable); + auto constvar = dynamic_cast(variable); + if (global != nullptr) { + globals.emplace_back(global); + } else if (constvar != nullptr) { + consts.emplace_back(constvar); + } + + result = variable; + } + + return result; +} +/** + * @brief 获取全局变量 + * + * @return 全局变量列表 + */ +auto SymbolTable::getGlobals() -> std::vector> & { return globals; } +/** + * @brief 获取常量 + * + * @return 常量列表 + */ +auto SymbolTable::getConsts() const -> const std::vector> & { return consts; } +/** + * @brief 进入新的作用域 + * + * @return 无返回值 + */ +void SymbolTable::enterNewScope() { + auto newNode = new SymbolTableNode; + nodeList.emplace_back(newNode); + if (curNode != nullptr) { + curNode->children.emplace_back(newNode); + } + newNode->pNode = curNode; + curNode = newNode; +} +/** + * @brief 进入全局作用域 + * + * @return 无返回值 + */ +void SymbolTable::enterGlobalScope() { curNode = nodeList.front().get(); } +/** + * @brief 离开作用域 + * + * @return 无返回值 + */ +void SymbolTable::leaveScope() { curNode = curNode->pNode; } +/** + * @brief 是否位于全局作用域 + * + * @return 布尔值 + */ +auto SymbolTable::isInGlobalScope() const -> bool { return curNode->pNode == nullptr; } + +/** + * @brief 判断是否为循环不变量 + * @param value: 要判断的value + * @return true: 是不变量 + * @return false: 不是 + */ +auto Loop::isSimpleLoopInvariant(Value *value) -> bool { + // auto constValue = dynamic_cast(value); + // if (constValue != nullptr) { + // return false; + // } + if (auto instr = dynamic_cast(value)) { + if (instr->isLoad()) { + auto loadinst = dynamic_cast(instr); + + auto loadvalue = dynamic_cast(loadinst->getOperand(0)); + if (loadvalue != nullptr) { + if (loadvalue->getParent() != nullptr) { + auto basicblock = loadvalue->getParent(); + return !this->isLoopContainsBasicBlock(basicblock); + } + } + auto globalvalue = dynamic_cast(loadinst->getOperand(0)); + if (globalvalue != nullptr) { + return true; + } + auto basicblock = instr->getParent(); + + return !this->isLoopContainsBasicBlock(basicblock); + } + auto basicblock = instr->getParent(); + return !this->isLoopContainsBasicBlock(basicblock); + } + return true; +} + +/** + * @brief 移动指令 + * + * @param [in] sourcePos 源指令列表位置 + * @param [in] targetPos 目的指令列表位置 + * @param [in] block 目标基本块 + * @return 无返回值 + */ +auto BasicBlock::moveInst(iterator sourcePos, iterator targetPos, BasicBlock *block) -> iterator { + auto inst = sourcePos->release(); + inst->setParent(block); + block->instructions.emplace(targetPos, inst); + return instructions.erase(sourcePos); +} + +} // namespace sysy diff --git a/src/include/IR.h b/src/include/IR.h index 1f3885b..f1759cd 100644 --- a/src/include/IR.h +++ b/src/include/IR.h @@ -340,41 +340,6 @@ class Function; class Loop; class BasicBlock; -/*! - * Arguments of `BasicBlock`s. - * - * SysY IR is an SSA language, however, it does not use PHI instructions as in - * LLVM IR. `Value`s from different predecessor blocks are passed explicitly as - * block arguments. This is also the approach used by MLIR. - * NOTE that `Function` does not own `Argument`s, function arguments are - * implemented as its entry block's arguments. - */ - -class Argument : public Value { -protected: - BasicBlock *block; - int index; - -public: - Argument(Type *type, BasicBlock *block, int index, - const std::string &name = ""); - -public: - static bool classof(const Value *value) { - return value->getKind() == kArgument; - } - -public: - BasicBlock *getParent() const { return block; } - int getIndex() const { return index; } - -public: - void print(std::ostream &os) const override; -}; - -class Instruction; -class Function; -class Loop; /*! * The container for `Instruction` sequence. * @@ -382,104 +347,131 @@ class Loop; * a terminator (branch or return). Besides, `BasicBlock` stores its arguments * and records its predecessor and successor `BasicBlock`s. */ -class BasicBlock : public Value { + + class BasicBlock : public Value { friend class Function; -public: + public: using inst_list = std::list>; using iterator = inst_list::iterator; - using arg_list = std::vector>; + using arg_list = std::vector; using block_list = std::vector; using block_set = std::unordered_set; -protected: - Function *parent; - inst_list instructions; - arg_list arguments; - block_list successors; - block_list predecessors; - BasicBlock *idom = nullptr; // 直接支配节点 - block_list sdoms; // 支配树后继 - block_set dominants; // 必经节点集合 - block_set dominant_frontiers; // 支配边界 - bool reachable = false; // 是否可达 - Loop *loopbelong = nullptr; // 所属循环 - int loopdepth = 0; // 循环深度 + protected: + Function *parent; ///< 从属的函数 + inst_list instructions; ///< 拥有的指令序列 + arg_list arguments; ///< 分配空间后的形式参数列表 + block_list successors; ///< 前驱列表 + block_list predecessors; ///< 后继列表 + BasicBlock *idom = nullptr; ///< 直接支配结点,即支配树前驱,唯一,默认nullptr + block_list sdoms; ///< 支配树后继,可以有多个 + block_set dominants; ///< 必经结点集合 + block_set dominant_frontiers; ///< 支配边界 + bool reachable = false; ///< 用于表示该节点是否可达,默认不可达 + Loop *loopbelong = nullptr; ///< 用来表示该块属于哪个循环,唯一,默认nullptr + int loopdepth = 0; /// < 用来表示其归属循环的深度,默认0 -protected: - explicit BasicBlock(Function *parent, const std::string &name = ""); + public: + explicit BasicBlock(Function *parent, const std::string &name = "") + : Value(Type::getLabelType(), name), parent(parent) {} -public: - static bool classof(const Value *value) { - return value->getKind() == kBasicBlock; - } + ~BasicBlock() override { + for (auto pre : predecessors) { + pre->removeSuccessor(this); + } -public: - int getNumInstructions() const { return instructions.size(); } - int getNumArguments() const { return arguments.size(); } - int getNumPredecessors() const { return predecessors.size(); } - int getNumSuccessors() const { return successors.size(); } - Function *getParent() const { return parent; } - inst_list &getInstructions() { return instructions; } - auto getArguments() const { return make_range(arguments); } - const block_list &getPredecessors() const { return predecessors; } - block_list &getPredecessors() { return predecessors; } - const block_list &getSuccessors() const { return successors; } - block_list &getSuccessors() { return successors; } - iterator begin() { return instructions.begin(); } - iterator end() { return instructions.end(); } - iterator terminator() { return std::prev(end()); } - Argument *createArgument(Type *type, const std::string &name = "") { - auto arg = new Argument(type, this, arguments.size(), name); - assert(arg); - arguments.emplace_back(arg); - return arguments.back().get(); - }; + for (auto suc : successors) { + suc->removePredecessor(this); + } + } ///< 基本块的析构函数,同时删除其前驱后继关系 - // 控制流分析相关 - BasicBlock *getIdom() const { return idom; } - void setIdom(BasicBlock *dom) { idom = dom; } - const block_list &getSdoms() const { return sdoms; } - void addSdom(BasicBlock *bb) { sdoms.push_back(bb); } + public: + auto getNumInstructions() const -> unsigned { return instructions.size(); } ///< 获取指令数量 + auto getNumArguments() const -> unsigned { return arguments.size(); } ///< 获取形式参数数量 + auto getNumPredecessors() const -> unsigned { return predecessors.size(); } ///< 获取前驱数量 + auto getNumSuccessors() const -> unsigned { return successors.size(); } ///< 获取后继数量 + auto getParent() const -> Function * { return parent; } ///< 获取父函数 + void setParent(Function *func) { parent = func; } ///< 设置父函数 + auto getInstructions() -> inst_list & { return instructions; } ///< 获取指令列表 + auto getArguments() -> arg_list & { return arguments; } ///< 获取分配空间后的形式参数列表 + auto getPredecessors() const -> const block_list & { return predecessors; } ///< 获取前驱列表 + auto getSuccessors() -> block_list & { return successors; } ///< 获取后继列表 + auto getDominants() -> block_set & { return dominants; } + auto getIdom() -> BasicBlock * { return idom; } + auto getSdoms() -> block_list & { return sdoms; } + auto getDFs() -> block_set & { return dominant_frontiers; } + auto begin() -> iterator { return instructions.begin(); } ///< 返回指向指令列表开头的迭代器 + auto end() -> iterator { return instructions.end(); } ///< 返回指向指令列表末尾的迭代器 + auto terminator() -> iterator { return std::prev(end()); } ///< 基本块最后的IR + void insertArgument(AllocaInst *inst) { arguments.push_back(inst); } ///< 插入分配空间后的形式参数 + void addPredecessor(BasicBlock *block) { + if (std::find(predecessors.begin(), predecessors.end(), block) == predecessors.end()) { + predecessors.push_back(block); + } + } ///< 添加前驱 + void addSuccessor(BasicBlock *block) { + if (std::find(successors.begin(), successors.end(), block) == successors.end()) { + successors.push_back(block); + } + } ///< 添加后继 + void addPredecessor(const block_list &blocks) { + for (auto block : blocks) { + addPredecessor(block); + } + } ///< 添加多个前驱 + void addSuccessor(const block_list &blocks) { + for (auto block : blocks) { + addSuccessor(block); + } + } ///< 添加多个后继 + void setIdom(BasicBlock *block) { idom = block; } + void addSdoms(BasicBlock *block) { sdoms.push_back(block); } void clearSdoms() { sdoms.clear(); } - const block_set &getDominants() const { return dominants; } - void addDominant(BasicBlock *bb) { dominants.insert(bb); } - void setDominants(const block_set &doms) { dominants = doms; } - const block_set &getDominantFrontiers() const { return dominant_frontiers; } - void setDominantFrontiers(const block_set &df) { dominant_frontiers = df; } - bool isReachable() const { return reachable; } - void setReachable(bool r) { reachable = r; } - - // 循环分析相关 - Loop *getLoop() const { return loopbelong; } - void setLoop(Loop *loop) { loopbelong = loop; } - int getLoopDepth() const { return loopdepth; } - void setLoopDepth(int depth) { loopdepth = depth; } - - void addPredecessor(BasicBlock *bb) { - if (std::find(predecessors.begin(), predecessors.end(), bb) == predecessors.end()) - predecessors.push_back(bb); + // 重载1,参数为 BasicBlock* + auto addDominants(BasicBlock *block) -> void { dominants.emplace(block); } + // 重载2,参数为 block_set + auto addDominants(const block_set &blocks) -> void { dominants.insert(blocks.begin(), blocks.end()); } + auto setDominants(BasicBlock *block) -> void { + dominants.clear(); + addDominants(block); } - - void addSuccessor(BasicBlock *bb) { - if (std::find(successors.begin(), successors.end(), bb) == successors.end()) - successors.push_back(bb); + auto setDominants(const block_set &doms) -> void { + dominants.clear(); + addDominants(doms); } - - void removePredecessor(BasicBlock *bb) { - auto it = std::find(predecessors.begin(), predecessors.end(), bb); - if (it != predecessors.end()) - predecessors.erase(it); + auto setDFs(const block_set &df) -> void { + dominant_frontiers.clear(); + for (auto elem : df) { + dominant_frontiers.emplace(elem); + } } - - void removeSuccessor(BasicBlock *bb) { - auto it = std::find(successors.begin(), successors.end(), bb); - if (it != successors.end()) - successors.erase(it); - } - - // 获取支配树中所有子节点 - block_list getChildren() { + void removePredecessor(BasicBlock *block) { + auto iter = std::find(predecessors.begin(), predecessors.end(), block); + if (iter != predecessors.end()) { + predecessors.erase(iter); + } else { + assert(false); + } + } ///< 删除前驱 + void removeSuccessor(BasicBlock *block) { + auto iter = std::find(successors.begin(), successors.end(), block); + if (iter != successors.end()) { + successors.erase(iter); + } else { + assert(false); + } + } ///< 删除后继 + void replacePredecessor(BasicBlock *oldBlock, BasicBlock *newBlock) { + for (auto &predecessor : predecessors) { + if (predecessor == oldBlock) { + predecessor = newBlock; + break; + } + } + } ///< 替换前驱 + // 获取支配树中该块的所有子节点,包括子节点的子节点等,迭代实现 + auto getChildren() -> block_list { std::queue q; block_list children; for (auto sdom : sdoms) { @@ -494,60 +486,73 @@ public: children.push_back(sdom); } } + return children; } -public: - void print(std::ostream &os) const override; -}; // class BasicBlock + auto setreachableTrue() -> void { reachable = true; } ///< 设置可达 + + auto setreachableFalse() -> void { reachable = false; } ///< 设置不可达 + + auto getreachable() -> bool { return reachable; } ///< 返回可达状态 + + static void conectBlocks(BasicBlock *prev, BasicBlock *next) { + prev->addSuccessor(next); + next->addPredecessor(prev); + } ///< 连接两个块,即设置两个基本块的前驱后继关系 + void setLoop(Loop *loop2set) { loopbelong = loop2set; } ///< 设置所属循环 + auto getLoop() { return loopbelong; } ///< 获得所属循环 + void setLoopDepth(int loopdepth2set) { loopdepth = loopdepth2set; } ///< 设置循环深度 + auto getLoopDepth() { return loopdepth; } ///< 获得其在循环的深度 + void removeInst(iterator pos) { instructions.erase(pos); } ///< 删除指令 + auto moveInst(iterator sourcePos, iterator targetPos, BasicBlock *block) -> iterator; ///< 移动指令 +}; //! User is the abstract base type of `Value` types which use other `Value` as //! operands. Currently, there are two kinds of `User`s, `Instruction` and //! `GlobalValue`. +// User是`Value`的派生类型,其使用其他的`Value`作为操作数 + + class User : public Value { -protected: - std::vector operands; + protected: + std::vector> operands; ///< 操作数/使用关系 -protected: - User(Kind kind, Type *type, const std::string &name = "") - : Value(kind, type, name), operands() {} + protected: + explicit User(Type *type, const std::string &name = "") : Value(type, name) {} -public: - using use_iterator = std::vector::const_iterator; - struct operand_iterator : public std::vector::const_iterator { - using Base = std::vector::const_iterator; - operand_iterator(const Base &iter) : Base(iter) {} - using value_type = Value *; - value_type operator->() { return Base::operator*().getValue(); } - value_type operator*() { return Base::operator*().getValue(); } - }; - -public: - int getNumOperands() const { return operands.size(); } - operand_iterator operand_begin() const { return operands.begin(); } - operand_iterator operand_end() const { return operands.end(); } - auto getOperands() const { - return make_range(operand_begin(), operand_end()); - } - Value *getOperand(int index) const { return operands[index].getValue(); } + public: + auto getNumOperands() const -> unsigned { return operands.size(); } ///< 获取操作数数量 + auto operand_begin() const { return operands.begin(); } ///< 返回操作数列表的开头迭代器 + auto operand_end() const { return operands.end(); } ///< 返回操作数列表的结尾迭代器 + auto getOperands() const { return make_range(operand_begin(), operand_end()); } ///< 获取操作数列表 + auto getOperand(unsigned index) const -> Value * { return operands[index]->getValue(); } ///< 获取位置为index的操作数 void addOperand(Value *value) { - operands.emplace_back(operands.size(), this, value); - value->addUse(&operands.back()); - } - template void addOperands(const ContainerT &operands) { - for (auto value : operands) + operands.emplace_back(std::make_shared(operands.size(), this, value)); + value->addUse(operands.back()); + } ///< 增加操作数 + void removeOperand(unsigned index) { + auto value = getOperand(index); + value->removeUse(operands[index]); + operands.erase(operands.begin() + index); + } ///< 移除操作数 + template + void addOperands(const ContainerT &newoperands) { + for (auto value : newoperands) { addOperand(value); - } - void replaceOperand(int index, Value *value); - void setOperand(int index, Value *value); -}; // class User + } + } ///< 增加多个操作数 + void replaceOperand(unsigned index, Value *value); ///< 替换操作数 + void setOperand(unsigned index, Value *value); ///< 设置操作数 +}; + class GetSubArrayInst; /** * 左值 具有地址的对象 */ -class LVal : public User { +class LVal { friend class GetSubArrayInst; protected: @@ -560,8 +565,8 @@ class LVal : public User { public: virtual ~LVal() = default; - virtual auto getLValDims() const -> std::vector = 0; ///< 获取左值的维度 - virtual auto getLValNumDims() const -> unsigned = 0; ///< 获取左值的维度数量 + virtual std::vector getLValDims() const = 0; ///< 获取左值的维度 + virtual unsigned getLValNumDims() const = 0; ///< 获取左值的维度数量 public: auto getFatherLVal() const -> LVal * { return fatherLVal; } ///< 获取父左值 @@ -590,23 +595,163 @@ class LVal : public User { * Base of all concrete instruction types. */ class Instruction : public User { -public: - // 指令种类定义已移至Value::Kind + public: + /// 指令标识码 + enum Kind : uint64_t { + kInvalid = 0x0UL, + // Binary + kAdd = 0x1UL << 0, + kSub = 0x1UL << 1, + kMul = 0x1UL << 2, + kDiv = 0x1UL << 3, + kRem = 0x1UL << 4, + kICmpEQ = 0x1UL << 5, + kICmpNE = 0x1UL << 6, + kICmpLT = 0x1UL << 7, + kICmpGT = 0x1UL << 8, + kICmpLE = 0x1UL << 9, + kICmpGE = 0x1UL << 10, + kFAdd = 0x1UL << 11, + kFSub = 0x1UL << 12, + kFMul = 0x1UL << 13, + kFDiv = 0x1UL << 14, + kFCmpEQ = 0x1UL << 15, + kFCmpNE = 0x1UL << 16, + kFCmpLT = 0x1UL << 17, + kFCmpGT = 0x1UL << 18, + kFCmpLE = 0x1UL << 19, + kFCmpGE = 0x1UL << 20, + kAnd = 0x1UL << 21, + kOr = 0x1UL << 22, + // Unary + kNeg = 0x1UL << 23, + kNot = 0x1UL << 24, + kFNeg = 0x1UL << 25, + kFNot = 0x1UL << 26, + kFtoI = 0x1UL << 27, + kItoF = 0x1UL << 28, + // call + kCall = 0x1UL << 29, + // terminator + kCondBr = 0x1UL << 30, + kBr = 0x1UL << 31, + kReturn = 0x1UL << 32, + // mem op + kAlloca = 0x1UL << 33, + kLoad = 0x1UL << 34, + kStore = 0x1UL << 35, + kLa = 0x1UL << 36, + kMemset = 0x1UL << 37, + kGetSubArray = 0x1UL << 38, + // constant + kConstant = 0x1UL << 37, + // phi + kPhi = 0x1UL << 39, + kBitItoF = 0x1UL << 40, + kBitFtoI = 0x1UL << 41 + }; protected: + Kind kind; BasicBlock *parent; protected: - Instruction(Kind kind, Type *type, BasicBlock *parent = nullptr, - const std::string &name = ""); + Instruction(Kind kind, Type *type, BasicBlock *parent = nullptr, const std::string &name = "") + : User(type, name), kind(kind), parent(parent) {} public: - static bool classof(const Value *value) { - return value->getKind() >= kFirstInst and value->getKind() <= kLastInst; - } public: Kind getKind() const { return kind; } + std::string getKindString() const{ + switch (kind) { + case kInvalid: + return "Invalid"; + case kAdd: + return "Add"; + case kSub: + return "Sub"; + case kMul: + return "Mul"; + case kDiv: + return "Div"; + case kRem: + return "Rem"; + case kICmpEQ: + return "ICmpEQ"; + case kICmpNE: + return "ICmpNE"; + case kICmpLT: + return "ICmpLT"; + case kICmpGT: + return "ICmpGT"; + case kICmpLE: + return "ICmpLE"; + case kICmpGE: + return "ICmpGE"; + case kFAdd: + return "FAdd"; + case kFSub: + return "FSub"; + case kFMul: + return "FMul"; + case kFDiv: + return "FDiv"; + case kFCmpEQ: + return "FCmpEQ"; + case kFCmpNE: + return "FCmpNE"; + case kFCmpLT: + return "FCmpLT"; + case kFCmpGT: + return "FCmpGT"; + case kFCmpLE: + return "FCmpLE"; + case kFCmpGE: + return "FCmpGE"; + case kAnd: + return "And"; + case kOr: + return "Or"; + case kNeg: + return "Neg"; + case kNot: + return "Not"; + case kFNeg: + return "FNeg"; + case kFNot: + return "FNot"; + case kFtoI: + return "FtoI"; + case kItoF: + return "IToF"; + case kCall: + return "Call"; + case kCondBr: + return "CondBr"; + case kBr: + return "Br"; + case kReturn: + return "Return"; + case kAlloca: + return "Alloca"; + case kLoad: + return "Load"; + case kStore: + return "Store"; + case kLa: + return "La"; + case kMemset: + return "Memset"; + case kPhi: + return "Phi"; + case kGetSubArray: + return "GetSubArray"; + default: + return "Unknown"; + } + } ///< 根据指令标识码获取字符串 + BasicBlock *getParent() const { return parent; } Function *getFunction() const { return parent->getParent(); } void setParent(BasicBlock *bb) { parent = bb; } @@ -615,7 +760,7 @@ public: static constexpr uint64_t BinaryOpMask = (kAdd | kSub | kMul | kDiv | kRem | kAnd | kOr) | (kICmpEQ | kICmpNE | kICmpLT | kICmpGT | kICmpLE | kICmpGE) | - (kFAdd | kFSub | kFMul | kFDiv | kFRem) | + (kFAdd | kFSub | kFMul | kFDiv) | (kFCmpEQ | kFCmpNE | kFCmpLT | kFCmpGT | kFCmpLE | kFCmpGE); return kind & BinaryOpMask; } @@ -626,7 +771,7 @@ public: } bool isMemory() const { static constexpr uint64_t MemoryOpMask = - kAlloca | kLoad | kStore | kLa | kMemset | kGetSubArray; + kAlloca | kLoad | kStore; return kind & MemoryOpMask; } bool isTerminator() const { @@ -659,19 +804,67 @@ public: bool isGetSubArray() const { return kind == kGetSubArray; } bool isCall() const { return kind == kCall; } bool isReturn() const { return kind == kReturn; } + bool isDefine() const { + static constexpr uint64_t DefineOpMask = kAlloca | kStore | kPhi; + return (kind & DefineOpMask) != 0U; + } }; // class Instruction class Function; //! Function call. + +class LaInst : public Instruction { + friend class Function; + friend class IRBuilder; + + protected: + explicit LaInst(Value *pointer, const std::vector &indices = {}, BasicBlock *parent = nullptr, + const std::string &name = "") + : Instruction(Kind::kLa, pointer->getType(), parent, name) { + assert(pointer); + addOperand(pointer); + addOperands(indices); + } + + public: + auto getNumIndices() const -> unsigned { return getNumOperands() - 1; } ///< 获取索引长度 + auto getPointer() const -> Value * { return getOperand(0); } ///< 获取目标变量的Value指针 + auto getIndices() const { return make_range(std::next(operand_begin()), operand_end()); } ///< 获取索引列表 + auto getIndex(unsigned index) const -> Value * { return getOperand(index + 1); } ///< 获取位置为index的索引分量 +}; + +class PhiInst : public Instruction { + friend class IRBuilder; + friend class Function; + friend class SysySSA; + + protected: + Value *map_val; // Phi的旧值 + + PhiInst(Type *type, Value *lhs, const std::vector &rhs, Value *mval, BasicBlock *parent, + const std::string &name = "") + : Instruction(Kind::kPhi, type, parent, name) { + map_val = mval; + addOperand(lhs); + addOperands(rhs); + } + + public: + Value * getMapVal() { return map_val; } + Value * getPointer() const { return getOperand(0); } + auto getValues() { return make_range(std::next(operand_begin()), operand_end()); } + Value * getValue(unsigned index) const { return getOperand(index + 1); } +}; + + class CallInst : public Instruction { + friend class Function; friend class IRBuilder; protected: CallInst(Function *callee, const std::vector &args = {}, BasicBlock *parent = nullptr, const std::string &name = ""); -public: - static bool classof(const Value *value) { return value->getKind() == kCall; } public: Function *getCallee() const; @@ -679,12 +872,11 @@ public: return make_range(std::next(operand_begin()), operand_end()); } -public: - void print(std::ostream &os) const override; }; // class CallInst //! Unary instruction, includes '!', '-' and type conversion. class UnaryInst : public Instruction { + friend class Function; friend class IRBuilder; protected: @@ -694,74 +886,110 @@ protected: addOperand(operand); } -public: - static bool classof(const Value *value) { - return Instruction::classof(value) and - static_cast(value)->isUnary(); - } public: Value *getOperand() const { return User::getOperand(0); } -public: - void print(std::ostream &os) const override; }; // class UnaryInst //! Binary instruction, e.g., arithmatic, relation, logic, etc. class BinaryInst : public Instruction { friend class IRBuilder; + friend class Function; -protected: - BinaryInst(Kind kind, Type *type, Value *lhs, Value *rhs, BasicBlock *parent, - const std::string &name = "") + protected: + BinaryInst(Kind kind, Type *type, Value *lhs, Value *rhs, BasicBlock *parent, const std::string &name = "") : Instruction(kind, type, parent, name) { addOperand(lhs); addOperand(rhs); } -public: - static bool classof(const Value *value) { - return Instruction::classof(value) and - static_cast(value)->isBinary(); - } - public: Value *getLhs() const { return getOperand(0); } Value *getRhs() const { return getOperand(1); } - -public: - void print(std::ostream &os) const override; + template + auto eval(T lhs, T rhs) -> T { + switch (getKind()) { + case kAdd: + return lhs + rhs; + case kSub: + return lhs - rhs; + case kMul: + return lhs * rhs; + case kDiv: + return lhs / rhs; + case kRem: + if constexpr (std::is_floating_point::value) { + throw std::runtime_error("Remainder operation not supported for floating point types."); + } else { + return lhs % rhs; + } + case kICmpEQ: + return lhs == rhs; + case kICmpNE: + return lhs != rhs; + case kICmpLT: + return lhs < rhs; + case kICmpGT: + return lhs > rhs; + case kICmpLE: + return lhs <= rhs; + case kICmpGE: + return lhs >= rhs; + case kFAdd: + return lhs + rhs; + case kFSub: + return lhs - rhs; + case kFMul: + return lhs * rhs; + case kFDiv: + return lhs / rhs; + case kFCmpEQ: + return lhs == rhs; + case kFCmpNE: + return lhs != rhs; + case kFCmpLT: + return lhs < rhs; + case kFCmpGT: + return lhs > rhs; + case kFCmpLE: + return lhs <= rhs; + case kFCmpGE: + return lhs >= rhs; + case kAnd: + return lhs && rhs; + case kOr: + return lhs || rhs; + default: + assert(false); + } + } ///< 根据指令类型进行二元计算,eval template模板实现 }; // class BinaryInst //! The return statement class ReturnInst : public Instruction { friend class IRBuilder; + friend class Function; -protected: - ReturnInst(Value *value = nullptr, BasicBlock *parent = nullptr) - : Instruction(kReturn, Type::getVoidType(), parent, "") { - if (value) + protected: + explicit ReturnInst(Value *value = nullptr, BasicBlock *parent = nullptr, const std::string &name = "") + : Instruction(kReturn, Type::getVoidType(), parent, name) { + if (value != nullptr) { addOperand(value); + } } -public: - static bool classof(const Value *value) { - return value->getKind() == kReturn; - } - -public: + public: bool hasReturnValue() const { return not operands.empty(); } Value *getReturnValue() const { return hasReturnValue() ? getOperand(0) : nullptr; } - -public: - void print(std::ostream &os) const override; -}; // class ReturnInst +}; //! Unconditional branch class UncondBrInst : public Instruction { friend class IRBuilder; + friend class Function; protected: UncondBrInst(BasicBlock *block, std::vector args, @@ -772,23 +1000,19 @@ protected: addOperands(args); } -public: - static bool classof(const Value *value) { return value->getKind() == kBr; } - public: BasicBlock *getBlock() const { return dyncast(getOperand(0)); } auto getArguments() const { return make_range(std::next(operand_begin()), operand_end()); } -public: - void print(std::ostream &os) const override; }; // class UncondBrInst //! Conditional branch class CondBrInst : public Instruction { friend class IRBuilder; - + friend class Function; + protected: CondBrInst(Value *condition, BasicBlock *thenBlock, BasicBlock *elseBlock, const std::vector &thenArgs, @@ -802,12 +1026,6 @@ protected: addOperands(thenArgs); addOperands(elseArgs); } - -public: - static bool classof(const Value *value) { - return value->getKind() == kCondBr; - } - public: Value *getCondition() const { return getOperand(0); } BasicBlock *getThenBlock() const { @@ -828,14 +1046,12 @@ public: return make_range(begin, end); } -public: - void print(std::ostream &os) const override; }; // class CondBrInst //! Allocate memory for stack variables, used for non-global variable declartion -class AllocaInst : public Instruction { +class AllocaInst : public Instruction , public LVal { friend class IRBuilder; - + friend class Function; protected: AllocaInst(Type *type, const std::vector &dims = {}, BasicBlock *parent = nullptr, const std::string &name = "") @@ -844,22 +1060,59 @@ protected: } public: - static bool classof(const Value *value) { - return value->getKind() == kAlloca; - } - -public: + std::vector getLValDims() const override { + std::vector dims; + for (const auto &dim : getOperands()) { + dims.emplace_back(dim->getValue()); + } + return dims; + } ///< 获取作为左值的维度数组 + unsigned getLValNumDims() const override { return getNumOperands(); } + int getNumDims() const { return getNumOperands(); } auto getDims() const { return getOperands(); } Value *getDim(int index) { return getOperand(index); } -public: - void print(std::ostream &os) const override; }; // class AllocaInst + +class GetSubArrayInst : public Instruction { + friend class IRBuilder; + friend class Function; + + public: + GetSubArrayInst(LVal *fatherArray, LVal *childArray, const std::vector &indices, + BasicBlock *parent = nullptr, const std::string &name = "") + : Instruction(Kind::kGetSubArray, Type::getVoidType(), parent, name) { + auto predicate = [childArray](const std::unique_ptr &child) -> bool { return child.get() == childArray; }; + if (std::find_if(fatherArray->childrenLVals.begin(), fatherArray->childrenLVals.end(), predicate) == + fatherArray->childrenLVals.end()) { + fatherArray->childrenLVals.emplace_back(childArray); + } + childArray->fatherLVal = fatherArray; + childArray->defineInst = this; + auto fatherArrayValue = dynamic_cast(fatherArray); + auto childArrayValue = dynamic_cast(childArray); + assert(fatherArrayValue); + assert(childArrayValue); + addOperand(fatherArrayValue); + addOperand(childArrayValue); + addOperands(indices); + } + + public: + auto getFatherArray() const -> Value * { return getOperand(0); } ///< 获取父数组 + auto getChildArray() const -> Value * { return getOperand(1); } ///< 获取子数组 + auto getFatherLVal() const -> LVal * { return dynamic_cast(getOperand(0)); } ///< 获取父左值 + auto getChildLVal() const -> LVal * { return dynamic_cast(getOperand(1)); } ///< 获取子左值 + auto getIndices() const { return make_range(std::next(operand_begin(), 2), operand_end()); } ///< 获取索引 + auto getNumIndices() const -> unsigned { return getNumOperands() - 2; } ///< 获取索引数量 +}; + //! Load a value from memory address specified by a pointer value class LoadInst : public Instruction { friend class IRBuilder; + friend class Function; protected: LoadInst(Value *pointer, const std::vector &indices = {}, @@ -870,9 +1123,6 @@ protected: addOperands(indices); } -public: - static bool classof(const Value *value) { return value->getKind() == kLoad; } - public: int getNumIndices() const { return getNumOperands() - 1; } Value *getPointer() const { return getOperand(0); } @@ -880,14 +1130,28 @@ public: return make_range(std::next(operand_begin()), operand_end()); } Value *getIndex(int index) const { return getOperand(index + 1); } + std::list getAncestorIndices() const { + std::list indices; + for (const auto &index : getIndices()) { + indices.emplace_back(index->getValue()); + } + auto curPointer = dynamic_cast(getPointer()); + while (curPointer->getFatherLVal() != nullptr) { + auto inserter = std::next(indices.begin()); + for (const auto &index : curPointer->getDefineInst()->getIndices()) { + indices.insert(inserter, index->getValue()); + } + curPointer = curPointer->getFatherLVal(); + } -public: - void print(std::ostream &os) const override; + return indices; + } ///< 获取相对于祖先数组的索引列表 }; // class LoadInst //! Store a value to memory address specified by a pointer value class StoreInst : public Instruction { friend class IRBuilder; + friend class Function; protected: StoreInst(Value *value, Value *pointer, @@ -898,9 +1162,6 @@ protected: addOperand(pointer); addOperands(indices); } - -public: - static bool classof(const Value *value) { return value->getKind() == kStore; } public: int getNumIndices() const { return getNumOperands() - 2; } @@ -910,42 +1171,29 @@ public: return make_range(std::next(operand_begin(), 2), operand_end()); } Value *getIndex(int index) const { return getOperand(index + 2); } + std::list getAncestorIndices() const { + std::list indices; + for (const auto &index : getIndices()) { + indices.emplace_back(index->getValue()); + } + auto curPointer = dynamic_cast(getPointer()); + while (curPointer->getFatherLVal() != nullptr) { + auto inserter = std::next(indices.begin()); + for (const auto &index : curPointer->getDefineInst()->getIndices()) { + indices.insert(inserter, index->getValue()); + } + curPointer = curPointer->getFatherLVal(); + } + + return indices; + } ///< 获取相对于祖先数组的索引列表 -public: - void print(std::ostream &os) const override; }; // class StoreInst -//! Get address instruction -class LaInst : public Instruction { - friend class IRBuilder; - -protected: - LaInst(Value *pointer, const std::vector &indices = {}, - BasicBlock *parent = nullptr, const std::string &name = "") - : Instruction(kLa, pointer->getType(), parent, name) { - assert(pointer); - addOperand(pointer); - addOperands(indices); - } - -public: - static bool classof(const Value *value) { return value->getKind() == kLa; } - -public: - int getNumIndices() const { return getNumOperands() - 1; } - Value *getPointer() const { return getOperand(0); } - auto getIndices() const { - return make_range(std::next(operand_begin()), operand_end()); - } - Value *getIndex(int index) const { return getOperand(index + 1); } - -public: - void print(std::ostream &os) const override; -}; - //! Memset instruction class MemsetInst : public Instruction { friend class IRBuilder; + friend class Function; protected: MemsetInst(Value *pointer, Value *begin, Value *size, Value *value, @@ -957,80 +1205,15 @@ protected: addOperand(value); } -public: - static bool classof(const Value *value) { return value->getKind() == kMemset; } - public: Value *getPointer() const { return getOperand(0); } Value *getBegin() const { return getOperand(1); } Value *getSize() const { return getOperand(2); } Value *getValue() const { return getOperand(3); } -public: - void print(std::ostream &os) const override; -}; - -//! Get subarray instruction -class GetSubArrayInst : public Instruction { - friend class IRBuilder; - -protected: - GetSubArrayInst(Value *fatherArray, Value *childArray, - const std::vector &indices, - BasicBlock *parent = nullptr, const std::string &name = "") - : Instruction(kGetSubArray, Type::getVoidType(), parent, name) { - addOperand(fatherArray); - addOperand(childArray); - addOperands(indices); - } - -public: - static bool classof(const Value *value) { - return value->getKind() == kGetSubArray; - } - -public: - Value *getFatherArray() const { return getOperand(0); } - Value *getChildArray() const { return getOperand(1); } - int getNumIndices() const { return getNumOperands() - 2; } - auto getIndices() const { - return make_range(std::next(operand_begin(), 2), operand_end()); - } - Value *getIndex(int index) const { return getOperand(index + 2); } - -public: - void print(std::ostream &os) const override; -}; - -//! Phi instruction for SSA form -class PhiInst : public Instruction { - friend class IRBuilder; - -protected: - Value *map_val; // 旧的映射关系 - - PhiInst(Type *type, Value *lhs, const std::vector &rhs, - Value *mval, BasicBlock *parent, const std::string &name = "") - : Instruction(kPhi, type, parent, name), map_val(mval) { - addOperand(lhs); - addOperands(rhs); - } - -public: - static bool classof(const Value *value) { return value->getKind() == kPhi; } - -public: - Value *getMapVal() const { return map_val; } - Value *getPointer() const { return getOperand(0); } - auto getValues() const { - return make_range(std::next(operand_begin()), operand_end()); - } - Value *getValue(unsigned index) const { return getOperand(index + 1); } - -public: - void print(std::ostream &os) const override; }; +// 循环类 class Loop { public: using block_list = std::vector; @@ -1057,7 +1240,9 @@ protected: ConstantValue *indBegin = nullptr; // 循环起始值 ConstantValue *indStep = nullptr; // 循环步长 - + + std::set GlobalValuechange; // 循环内改变的全局变量 + int StepType = 0; // 循环步长类型 bool parallelable = false; // 是否可并行 @@ -1072,7 +1257,14 @@ public: loopCount = loopCount + 1; loopID = loopCount; } - + auto getindBegin() { return indBegin; } ///< 获得循环开始值 + auto getindStep() { return indStep; } ///< 获得循环步长 + auto setindBegin(ConstantValue *indBegin2set) { indBegin = indBegin2set; } ///< 设置循环开始值 + auto setindStep(ConstantValue *indStep2set) { indStep = indStep2set; } ///< 设置循环步长 + auto setStepType(int StepType2Set) { StepType = StepType2Set; } ///< 设置循环变量规则 + auto getStepType() { return StepType; } ///< 获得循环变量规则 + auto getLoopID() -> size_t { return loopID; } + BasicBlock *getHeader() const { return headerBlock; } BasicBlock *getPreheaderBlock() const { return preheaderBlock; } block_list &getLatchBlocks() { return latchBlock; } @@ -1087,9 +1279,9 @@ public: Loop_list &getSubLoops() { return subLoops; } unsigned getLoopDepth() const { return loopDepth; } - bool contains(BasicBlock *bb) const { + bool isLoopContainsBasicBlock(BasicBlock *bb) const { return std::find(blocksInLoop.begin(), blocksInLoop.end(), bb) != blocksInLoop.end(); - } + } ///< 判断输入块是否在该循环内 void addExitingBlock(BasicBlock *bb) { exitingBlocks.insert(bb); } void addExitBlock(BasicBlock *bb) { exitBlocks.insert(bb); } @@ -1100,12 +1292,21 @@ public: void setIcmpKind(Instruction::Kind kind) { IcmpKind = kind; } Instruction::Kind getIcmpKind() const { return IcmpKind; } + bool isSimpleLoopInvariant(Value *value) ; ///< 判断是否为简单循环不变量,若其在loop中,则不是。 + void setIndEnd(Value *value) { indEnd = value; } void setIndPhi(AllocaInst *phi) { IndPhi = phi; } Value *getIndEnd() const { return indEnd; } AllocaInst *getIndPhi() const { return IndPhi; } Instruction *getIndCondVar() const { return indCondVar; } + auto addGlobalValuechange(GlobalValue *globalvaluechange2add) { + GlobalValuechange.insert(globalvaluechange2add); + } ///<添加在循环中改变的全局变量 + auto getGlobalValuechange() -> std::set & { + return GlobalValuechange; + } ///<获得在循环中改变的所有全局变量 + void setParallelable(bool flag) { parallelable = flag; } bool isParallelable() const { return parallelable; } }; @@ -1116,204 +1317,409 @@ class Function : public Value { friend class Module; protected: - Function(Module *parent, Type *type, const std::string &name) - : Value(kFunction, type, name), parent(parent), variableID(0), blockID(0), blocks() { - blocks.emplace_back(new BasicBlock(this, "entry")); - } - -public: - static bool classof(const Value *value) { - return value->getKind() == kFunction; + Function(Module *parent, Type *type, const std::string &name) : Value(type, name), parent(parent) { + blocks.emplace_back(new BasicBlock(this)); } public: using block_list = std::list>; using Loop_list = std::list>; + // 函数优化属性标识符 + enum FunctionAttribute : uint64_t { + PlaceHolder = 0x0UL, + Pure = 0x1UL << 0, + SelfRecursive = 0x1UL << 1, + SideEffect = 0x1UL << 2, + NoPureCauseMemRead = 0x1UL << 3 + }; + protected: - Module *parent; - int variableID; - int blockID; - block_list blocks; - /*是放在module中还是新建分析器呢?*/ - Loop_list loops; // 循环列表 - Loop_list topLoops; // 顶层循环 - std::unordered_map basicblock2Loop; // 基本块到循环的映射 + Module *parent; ///< 函数的父模块 + block_list blocks; ///< 函数包含的基本块列表 + Loop_list loops; ///< 函数包含的循环列表 + Loop_list topLoops; ///< 函数所包含的顶层循环; + std::list> indirectAllocas; ///< 函数中mem2reg引入的间接分配的内存 - // 数据流分析相关 - std::unordered_map value2AllocBlocks; - std::unordered_map> value2DefBlocks; - std::unordered_map> value2UseBlocks; + FunctionAttribute attribute = PlaceHolder; ///< 函数属性 + std::set callees; ///< 函数调用的函数集合 -public: - Type *getReturnType() const { - return getType()->as()->getReturnType(); + std::unordered_map basicblock2Loop; + std::unordered_map value2AllocBlocks; ///< value -- alloc block mapping + std::unordered_map> + value2DefBlocks; //< value -- define blocks mapping + std::unordered_map> value2UseBlocks; //< value -- use blocks mapping + + public: + static auto getcloneIndex() -> unsigned { + static unsigned cloneIndex = 0; + cloneIndex += 1; + return cloneIndex - 1; } - auto getParamTypes() const { - return getType()->as()->getParamTypes(); - } - auto getBasicBlocks() const { return make_range(blocks); } - BasicBlock *getEntryBlock() const { return blocks.front().get(); } - BasicBlock *addBasicBlock(const std::string &name = "") { + auto clone(const std::string &suffix = "_" + std::to_string(getcloneIndex()) + "@") const + -> Function *; ///< 复制函数 + auto getCallees() -> const std::set & { return callees; } + auto addCallee(Function *callee) -> void { callees.insert(callee); } + auto removeCallee(Function *callee) -> void { callees.erase(callee); } + auto clearCallees() -> void { callees.clear(); } + auto getCalleesWithNoExternalAndSelf() -> std::set; + auto getAttribute() const -> FunctionAttribute { return attribute; } ///< 获取函数属性 + auto setAttribute(FunctionAttribute attr) -> void { + attribute = static_cast(attribute | attr); + } ///< 设置函数属性 + auto clearAttribute() -> void { attribute = PlaceHolder; } ///< 清楚所有函数属性,只保留PlaceHolder + auto getLoopOfBasicBlock(BasicBlock *bb) -> Loop * { + return basicblock2Loop.count(bb) != 0 ? basicblock2Loop[bb] : nullptr; + } ///< 获得块所在循环 + auto getLoopDepthByBlock(BasicBlock *basicblock2Check) { + if (getLoopOfBasicBlock(basicblock2Check) != nullptr) { + auto loop = getLoopOfBasicBlock(basicblock2Check); + return loop->getLoopDepth(); + } + return static_cast(0); + } ///< 通过块,获得其所在循环深度 + auto addBBToLoop(BasicBlock *bb, Loop *LoopToadd) { basicblock2Loop[bb] = LoopToadd; } ///< 添加块与循环的映射 + auto getBBToLoopRef() -> std::unordered_map & { + return basicblock2Loop; + } ///< 获得块-循环映射表 + // auto getNewLoopPtr(BasicBlock *header) -> Loop * { return new Loop(header); } + auto getReturnType() const -> Type * { return getType()->as()->getReturnType(); } ///< 获取返回值类型 + auto getParamTypes() const { return getType()->as()->getParamTypes(); } ///< 获取形式参数类型列表 + auto getBasicBlocks() { return make_range(blocks); } ///< 获取基本块列表 + auto getBasicBlocks_NoRange() -> block_list & { return blocks; } + auto getEntryBlock() -> BasicBlock * { return blocks.front().get(); } ///< 获取入口块 + auto removeBasicBlock(BasicBlock *blockToRemove) { + auto is_same_ptr = [blockToRemove](const std::unique_ptr &ptr) { return ptr.get() == blockToRemove; }; + blocks.remove_if(is_same_ptr); + // blocks.erase(std::remove_if(blocks.begin(), blocks.end(), is_same_ptr), blocks.end()); + } ///< 将该块从function的blocks中删除 + // auto getBasicBlocks_NoRange() -> block_list & { return blocks; } + auto addBasicBlock(const std::string &name = "") -> BasicBlock * { blocks.emplace_back(new BasicBlock(this, name)); return blocks.back().get(); - } - void removeBasicBlock(BasicBlock *block) { - blocks.remove_if([&](std::unique_ptr &b) -> bool { - return block == b.get(); - }); - } - int allocateVariableID() { return variableID++; } - int allocateblockID() { return blockID++; } - - // 循环分析 - void addLoop(Loop *loop) { loops.emplace_back(loop); } - void addTopLoop(Loop *loop) { topLoops.emplace_back(loop); } - Loop_list &getLoops() { return loops; } - Loop_list &getTopLoops() { return topLoops; } - Loop *getLoopOfBasicBlock(BasicBlock *bb) { - return basicblock2Loop.count(bb) ? basicblock2Loop[bb] : nullptr; - } - void addBBToLoop(BasicBlock *bb, Loop *loop) { basicblock2Loop[bb] = loop; } - - // 数据流分析 - void addValue2AllocBlocks(Value *value, BasicBlock *block) { + } ///< 添加新的基本块 + auto addBasicBlock(BasicBlock *block) -> BasicBlock * { + blocks.emplace_back(block); + return block; + } ///< 添加基本块到blocks中 + auto addBasicBlockFront(BasicBlock *block) -> BasicBlock * { + blocks.emplace_front(block); + return block; + } // 从前端插入新的基本块 + /** value -- alloc blocks mapping */ + auto addValue2AllocBlocks(Value *value, BasicBlock *block) -> void { value2AllocBlocks[value] = block; - } - BasicBlock *getAllocBlockByValue(Value *value) { - return value2AllocBlocks.count(value) ? value2AllocBlocks[value] : nullptr; - } - void addValue2DefBlocks(Value *value, BasicBlock *block) { + } ///< 添加value -- alloc block mapping + auto getAllocBlockByValue(Value *value) -> BasicBlock * { + if (value2AllocBlocks.count(value) > 0) { + return value2AllocBlocks[value]; + } + return nullptr; + } ///< 通过value获取alloc block + auto getValue2AllocBlocks() -> std::unordered_map & { + return value2AllocBlocks; + } ///< 获取所有value -- alloc block mappings + auto removeValue2AllocBlock(Value *value) -> void { + value2AllocBlocks.erase(value); + } ///< 删除value -- alloc block mapping + /** value -- define blocks mapping */ + auto addValue2DefBlocks(Value *value, BasicBlock *block) -> void { ++value2DefBlocks[value][block]; - } - void addValue2UseBlocks(Value *value, BasicBlock *block) { + } ///< 添加value -- define block mapping + // keep in mind that the return is not a reference. + auto getDefBlocksByValue(Value *value) -> std::unordered_set { + std::unordered_set blocks; + if (value2DefBlocks.count(value) > 0) { + for (const auto &pair : value2DefBlocks[value]) { + blocks.insert(pair.first); + } + } + return blocks; + } ///< 通过value获取define blocks + auto getValue2DefBlocks() -> std::unordered_map> & { + return value2DefBlocks; + } ///< 获取所有value -- define blocks mappings + auto removeValue2DefBlock(Value *value, BasicBlock *block) -> bool { + bool changed = false; + if (--value2DefBlocks[value][block] == 0) { + value2DefBlocks[value].erase(block); + if (value2DefBlocks[value].empty()) { + value2DefBlocks.erase(value); + changed = true; + } + } + return changed; + } ///< 删除value -- define block mapping + auto getValuesOfDefBlock() -> std::unordered_set { + std::unordered_set values; + for (const auto &pair : value2DefBlocks) { + values.insert(pair.first); + } + return values; + } ///< 获取所有定义过的value + /** value -- use blocks mapping */ + auto addValue2UseBlocks(Value *value, BasicBlock *block) -> void { ++value2UseBlocks[value][block]; - } + } ///< 添加value -- use block mapping + // keep in mind that the return is not a reference. + auto getUseBlocksByValue(Value *value) -> std::unordered_set { + std::unordered_set blocks; + if (value2UseBlocks.count(value) > 0) { + for (const auto &pair : value2UseBlocks[value]) { + blocks.insert(pair.first); + } + } + return blocks; + } ///< 通过value获取use blocks + auto getValue2UseBlocks() -> std::unordered_map> & { + return value2UseBlocks; + } ///< 获取所有value -- use blocks mappings + auto removeValue2UseBlock(Value *value, BasicBlock *block) -> bool { + bool changed = false; + if (--value2UseBlocks[value][block] == 0) { + value2UseBlocks[value].erase(block); + if (value2UseBlocks[value].empty()) { + value2UseBlocks.erase(value); + changed = true; + } + } + return changed; + } ///< 删除value -- use block mapping + auto addIndirectAlloca(AllocaInst *alloca) { indirectAllocas.emplace_back(alloca); } ///< 添加间接分配 + auto getIndirectAllocas() -> std::list> & { + return indirectAllocas; + } ///< 获取间接分配列表 + + /** loop -- begin */ + + void addLoop(Loop *loop) { loops.emplace_back(loop); } ///< 添加循环(非顶层) + void addTopLoop(Loop *loop) { topLoops.emplace_back(loop); } ///< 添加顶层循环 + auto getLoops() -> Loop_list & { return loops; } ///< 获得循环(非顶层) + auto getTopLoops() -> Loop_list & { return topLoops; } ///< 获得顶层循环 + /** loop -- end */ -public: - void print(std::ostream &os) const override; }; // class Function //! Global value declared at file scope -class GlobalValue : public User { +class GlobalValue : public User, public LVal { friend class Module; protected: - Module *parent; - std::vector initValues; // 初始值列表 - bool isConst; + Module *parent; ///< 父模块 + unsigned numDims; ///< 维度数量 + ValueCounter initValues; ///< 初值 protected: GlobalValue(Module *parent, Type *type, const std::string &name, const std::vector &dims = {}, - const std::vector &initValues = {}, - bool isConst = false) - : User(kGlobal, type, name), parent(parent), - initValues(initValues), isConst(isConst) { + ValueCounter init = {}) + : User(type, name), parent(parent) { assert(type->isPointer()); addOperands(dims); + numDims = dims.size(); + if (init.size() == 0) { + unsigned num = 1; + for (unsigned i = 0; i < numDims; i++) { + num *= dynamic_cast(dims[i])->getInt(); + } + if (dynamic_cast(type)->getBaseType() == Type::getFloatType()) { + init.push_back(ConstantValue::get(0.0F), num); + } else { + init.push_back(ConstantValue::get(0), num); + } + } + initValues = init; + } + +public: + unsigned getLValNumDims() const override { return numDims; } ///< 获取作为左值的维度数量 + std::vector getLValDims() const override { + std::vector dims; + for (const auto &dim : getOperands()) { + dims.emplace_back(dim->getValue()); + } + + return dims; + } ///< 获取作为左值的维度列表 + + unsigned getNumDims() const { return numDims; } ///< 获取维度数量 + Value* getDim(unsigned index) const { return getOperand(index); } ///< 获取位置为index的维度 + auto getDims() const { return getOperands(); } ///< 获取维度列表 + Value* getByIndex(unsigned index) const { + return initValues.getValue(index); + } ///< 通过一维偏移量index获取初始值 + Value * getByIndices(const std::vector &indices) const { + int index = 0; + for (size_t i = 0; i < indices.size(); i++) { + index = dynamic_cast(getDim(i))->getInt() * index + + dynamic_cast(indices[i])->getInt(); + } + return getByIndex(index); + } ///< 通过多维索引indices获取初始值 + const ValueCounter &getInitValues() const { return initValues; } +}; // class GlobalValue + + +class ConstantVariable : public User, public LVal { + friend class Module; + + protected: + Module *parent; ///< 父模块 + unsigned numDims; ///< 维度数量 + ValueCounter initValues; ///< 值 + + protected: + ConstantVariable(Module *parent, Type *type, const std::string &name, const ValueCounter &init, + const std::vector &dims = {}) + : User(type, name), parent(parent) { + assert(type->isPointer()); + numDims = dims.size(); + initValues = init; + addOperands(dims); } -public: - static bool classof(const Value *value) { - return value->getKind() == kGlobal; - } + public: + unsigned getLValNumDims() const override { return numDims; } ///< 获取作为左值的维度数量 + std::vector getLValDims() const override { + std::vector dims; + for (const auto &dim : getOperands()) { + dims.emplace_back(dim->getValue()); + } -public: - const std::vector& getInitValues() const { return initValues; } - int getNumDims() const { return getNumOperands(); } - Value *getDim(int index) { return getOperand(index); } - bool isConstant() const { return isConst; } + return dims; + } ///< 获取作为左值的维度列表 + Value* getByIndex(unsigned index) const { return initValues.getValue(index); } ///< 通过一维位置index获取值 + Value* getByIndices(const std::vector &indices) const { + int index = 0; + for (size_t i = 0; i < indices.size(); i++) { + index = dynamic_cast(getDim(i))->getInt() * index + + dynamic_cast(indices[i])->getInt(); + } -public: - void print(std::ostream &os) const override{}; -}; // class GlobalValue + return getByIndex(index); + } ///< 通过多维索引indices获取初始值 + unsigned getNumDims() const { return numDims; } ///< 获取维度数量 + Value* getDim(unsigned index) const { return getOperand(index); } ///< 获取位置为index的维度 + auto getDims() const { return getOperands(); } ///< 获取维度列表 + const ValueCounter& getInitValues() const { return initValues; } ///< 获取初始值 +}; + +using SymbolTableNode = struct SymbolTableNode { + struct SymbolTableNode *pNode; ///< 父节点 + std::vector children; ///< 子节点列表 + std::map varList; ///< 变量列表 +}; + + +class SymbolTable { + private: + SymbolTableNode *curNode{}; ///< 当前所在的作用域(符号表节点) + std::map variableIndex; ///< 变量命名索引表 + std::vector> globals; ///< 全局变量列表 + std::vector> consts; ///< 常量列表 + std::vector> nodeList; ///< 符号表节点列表 + + public: + SymbolTable() = default; + + auto getVariable(const std::string &name) const -> User *; ///< 根据名字name以及当前作用域获取变量 + auto addVariable(const std::string &name, User *variable) -> User *; ///< 添加变量 + auto getGlobals() -> std::vector> &; ///< 获取全局变量列表 + auto getConsts() const -> const std::vector> &; ///< 获取常量列表 + void enterNewScope(); ///< 进入新的作用域 + void leaveScope(); ///< 离开作用域 + auto isInGlobalScope() const -> bool; ///< 是否位于全局作用域 + void enterGlobalScope(); ///< 进入全局作用域 + auto isCurNodeNull() -> bool { return curNode == nullptr; } +}; //! IR unit for representing a SysY compile unit class Module { -protected: - std::vector> children; - std::map functions; - std::map globals; - std::map externalFunctions; // 外部函数声明 + protected: + std::map> externalFunctions; ///< 外部函数表 + std::map> functions; ///< 函数表 + SymbolTable variableTable; ///< 符号表 -public: + public: Module() = default; -public: - Function *createFunction(const std::string &name, Type *type) { - if (functions.count(name)) + public: + auto createFunction(const std::string &name, Type *type) -> Function * { + auto result = functions.try_emplace(name, new Function(this, type, name)); + if (!result.second) { return nullptr; - auto func = new Function(this, type, name); - assert(func); - children.emplace_back(func); - functions.emplace(name, func); - return func; - }; - - Function *createExternalFunction(const std::string &name, Type *type) { - if (externalFunctions.count(name)) + } + return result.first->second.get(); + } ///< 创建函数 + auto createExternalFunction(const std::string &name, Type *type) -> Function * { + auto result = externalFunctions.try_emplace(name, new Function(this, type, name)); + if (!result.second) { return nullptr; - auto func = new Function(this, type, name); - assert(func); - children.emplace_back(func); - externalFunctions.emplace(name, func); - return func; - } - - GlobalValue *createGlobalValue(const std::string &name, Type *type, - const std::vector &dims = {}, - const std::vector &initValues = {}, - bool isConst = false) { - if (globals.count(name)) + } + return result.first->second.get(); + } ///< 创建外部函数 + auto createGlobalValue(const std::string &name, Type *type, const std::vector &dims = {}, + const ValueCounter &init = {}) -> GlobalValue * { + bool isFinished = variableTable.isCurNodeNull(); + if (isFinished) { + variableTable.enterGlobalScope(); + } + auto result = variableTable.addVariable(name, new GlobalValue(this, type, name, dims, init)); + if (isFinished) { + variableTable.leaveScope(); + } + if (result == nullptr) { return nullptr; - auto global = new GlobalValue(this, type, name, dims, initValues, isConst); - assert(global); - children.emplace_back(global); - globals.emplace(name, global); - return global; - } - - Function *getFunction(const std::string &name) const { + } + return dynamic_cast(result); + } ///< 创建全局变量 + auto createConstVar(const std::string &name, Type *type, const ValueCounter &init, + const std::vector &dims = {}) -> ConstantVariable * { + auto result = variableTable.addVariable(name, new ConstantVariable(this, type, name, init, dims)); + if (result == nullptr) { + return nullptr; + } + return dynamic_cast(result); + } ///< 创建常量 + void addVariable(const std::string &name, AllocaInst *variable) { + variableTable.addVariable(name, variable); + } ///< 添加变量 + auto getVariable(const std::string &name) -> User * { + return variableTable.getVariable(name); + } ///< 根据名字name和当前作用域获取变量 + auto getFunction(const std::string &name) const -> Function * { auto result = functions.find(name); - if (result == functions.end()) + if (result == functions.end()) { return nullptr; - return result->second; - } - - Function *getExternalFunction(const std::string &name) const { + } + return result->second.get(); + } ///< 获取函数 + auto getExternalFunction(const std::string &name) const -> Function * { auto result = externalFunctions.find(name); - if (result == externalFunctions.end()) + if (result == functions.end()) { return nullptr; - return result->second; - } - - GlobalValue *getGlobalValue(const std::string &name) const { - auto result = globals.find(name); - if (result == globals.end()) - return nullptr; - return result->second; - } + } + return result->second.get(); + } ///< 获取外部函数 + auto getFunctions() -> std::map> & { return functions; } ///< 获取函数列表 + auto getExternalFunctions() const -> const std::map> & { + return externalFunctions; + } ///< 获取外部函数列表 + auto getGlobals() -> std::vector> & { + return variableTable.getGlobals(); + } ///< 获取全局变量列表 + auto getConsts() const -> const std::vector> & { + return variableTable.getConsts(); + } ///< 获取常量列表 + void enterNewScope() { variableTable.enterNewScope(); } ///< 进入新的作用域 - std::map *getFunctions() { return &functions; } - std::map *getGlobalValues() { return &globals; } - std::map *getExternalFunctions() { return &externalFunctions; } + void leaveScope() { variableTable.leaveScope(); } ///< 离开作用域 -public: - void print(std::ostream &os) const; -}; // class Module + auto isInGlobalArea() const -> bool { return variableTable.isInGlobalScope(); } ///< 是否位于全局作用域 +}; /*! * @} */ -inline std::ostream &operator<<(std::ostream &os, const Type &type) { - type.print(os); - return os; -} - -inline std::ostream &operator<<(std::ostream &os, const Value &value) { - value.print(os); - return os; -} } // namespace sysy diff --git a/src/include/IRBuilder.h b/src/include/IRBuilder.h index c6d12da..7335ea8 100644 --- a/src/include/IRBuilder.h +++ b/src/include/IRBuilder.h @@ -1,13 +1,25 @@ #pragma once -#include "IR.h" #include -#include +#include +#include +#include +#include +#include "IR.h" +/** + * @file IRBuilder.h + * + * @brief 定义IR构建器的头文件 + */ namespace sysy { +/** + * @brief 中间IR的构建器 + * + */ class IRBuilder { -private: + private: unsigned labelIndex; ///< 基本块标签编号 unsigned tmpIndex; ///< 临时变量编号 @@ -20,222 +32,325 @@ private: std::vector breakBlocks; ///< break目标块列表 std::vector continueBlocks; ///< continue目标块列表 -public: + public: IRBuilder() : labelIndex(0), tmpIndex(0), block(nullptr) {} explicit IRBuilder(BasicBlock *block) : labelIndex(0), tmpIndex(0), block(block), position(block->end()) {} IRBuilder(BasicBlock *block, BasicBlock::iterator position) : labelIndex(0), tmpIndex(0), block(block), position(position) {} -public: - BasicBlock *getBasicBlock() const { return block; } - BasicBlock::iterator getPosition() const { return position; } + public: + auto getLabelIndex() -> unsigned { + labelIndex += 1; + return labelIndex - 1; + } ///< 获取基本块标签编号 + auto getTmpIndex() -> unsigned { + tmpIndex += 1; + return tmpIndex - 1; + } ///< 获取临时变量编号 + auto getBasicBlock() const -> BasicBlock * { return block; } ///< 获取当前基本块 + auto getBreakBlock() const -> BasicBlock * { return breakBlocks.back(); } ///< 获取break目标块 + auto popBreakBlock() -> BasicBlock * { + auto result = breakBlocks.back(); + breakBlocks.pop_back(); + return result; + } ///< 弹出break目标块 + auto getContinueBlock() const -> BasicBlock * { return continueBlocks.back(); } ///< 获取continue目标块 + auto popContinueBlock() -> BasicBlock * { + auto result = continueBlocks.back(); + continueBlocks.pop_back(); + return result; + } ///< 弹出continue目标块 + + auto getTrueBlock() const -> BasicBlock * { return trueBlocks.back(); } ///< 获取true分支基本块 + auto getFalseBlock() const -> BasicBlock * { return falseBlocks.back(); } ///< 获取false分支基本块 + auto popTrueBlock() -> BasicBlock * { + auto result = trueBlocks.back(); + trueBlocks.pop_back(); + return result; + } ///< 弹出true分支基本块 + auto popFalseBlock() -> BasicBlock * { + auto result = falseBlocks.back(); + falseBlocks.pop_back(); + return result; + } ///< 弹出false分支基本块 + auto getPosition() const -> BasicBlock::iterator { return position; } ///< 获取当前基本块指令列表位置的迭代器 void setPosition(BasicBlock *block, BasicBlock::iterator position) { this->block = block; this->position = position; - } - void setPosition(BasicBlock::iterator position) { this->position = position; } + } ///< 设置基本块和基本块指令列表位置的迭代器 + void setPosition(BasicBlock::iterator position) { + this->position = position; + } ///< 设置当前基本块指令列表位置的迭代器 + void pushBreakBlock(BasicBlock *block) { breakBlocks.push_back(block); } ///< 压入break目标基本块 + void pushContinueBlock(BasicBlock *block) { continueBlocks.push_back(block); } ///< 压入continue目标基本块 + void pushTrueBlock(BasicBlock *block) { trueBlocks.push_back(block); } ///< 压入true分支基本块 + void pushFalseBlock(BasicBlock *block) { falseBlocks.push_back(block); } ///< 压入false分支基本块 -public: - CallInst *createCallInst(Function *callee, - const std::vector &args = {}, - const std::string &name = "") { - auto inst = new CallInst(callee, args, block, name); + public: + auto insertInst(Instruction *inst) -> Instruction * { assert(inst); block->getInstructions().emplace(position, inst); return inst; - } - UnaryInst *createUnaryInst(Instruction::Kind kind, Type *type, Value *operand, - const std::string &name = "") { + } ///< 插入指令 + auto createUnaryInst(Instruction::Kind kind, Type *type, Value *operand, const std::string &name = "") + -> UnaryInst * { + std::string newName; + if (name.empty()) { + std::stringstream ss; + ss << "%" << tmpIndex; + newName = ss.str(); + tmpIndex++; + } else { + newName = name; + } - auto inst = new UnaryInst(kind, type, operand, block, name); + auto inst = new UnaryInst(kind, type, operand, block, newName); assert(inst); block->getInstructions().emplace(position, inst); return inst; - } - UnaryInst *createNegInst(Value *operand, const std::string &name = "") { - return createUnaryInst(Instruction::kNeg, Type::getIntType(), operand, - name); - } - UnaryInst *createNotInst(Value *operand, const std::string &name = "") { - return createUnaryInst(Instruction::kNot, Type::getIntType(), operand, - name); - } - UnaryInst *createFtoIInst(Value *operand, const std::string &name = "") { - return createUnaryInst(Instruction::kFtoI, Type::getIntType(), operand, - name); - } - UnaryInst *createFNegInst(Value *operand, const std::string &name = "") { - return createUnaryInst(Instruction::kFNeg, Type::getFloatType(), operand, - name); - } - UnaryInst *createIToFInst(Value *operand, const std::string &name = "") { - return createUnaryInst(Instruction::kItoF, Type::getFloatType(), operand, - name); - } - BinaryInst *createBinaryInst(Instruction::Kind kind, Type *type, Value *lhs, - Value *rhs, const std::string &name = "") { - auto inst = new BinaryInst(kind, type, lhs, rhs, block, name); + } ///< 创建一元指令 + auto createNegInst(Value *operand, const std::string &name = "") -> UnaryInst * { + return createUnaryInst(Instruction::kNeg, Type::getIntType(), operand, name); + } ///< 创建取反指令 + auto createNotInst(Value *operand, const std::string &name = "") -> UnaryInst * { + return createUnaryInst(Instruction::kNot, Type::getIntType(), operand, name); + } ///< 创建取非指令 + auto createFtoIInst(Value *operand, const std::string &name = "") -> UnaryInst * { + return createUnaryInst(Instruction::kFtoI, Type::getIntType(), operand, name); + } ///< 创建浮点转整型指令 + auto createBitFtoIInst(Value *operand, const std::string &name = "") -> UnaryInst * { + return createUnaryInst(Instruction::kBitFtoI, Type::getIntType(), operand, name); + } ///< 创建按位浮点转整型指令 + auto createFNegInst(Value *operand, const std::string &name = "") -> UnaryInst * { + return createUnaryInst(Instruction::kFNeg, Type::getFloatType(), operand, name); + } ///< 创建浮点取反指令 + auto createFNotInst(Value *operand, const std::string &name = "") -> UnaryInst * { + return createUnaryInst(Instruction::kFNot, Type::getIntType(), operand, name); + } ///< 创建浮点取非指令 + auto createIToFInst(Value *operand, const std::string &name = "") -> UnaryInst * { + return createUnaryInst(Instruction::kItoF, Type::getFloatType(), operand, name); + } ///< 创建整型转浮点指令 + auto createBitItoFInst(Value *operand, const std::string &name = "") -> UnaryInst * { + return createUnaryInst(Instruction::kBitItoF, Type::getFloatType(), operand, name); + } ///< 创建按位整型转浮点指令 + auto createBinaryInst(Instruction::Kind kind, Type *type, Value *lhs, Value *rhs, const std::string &name = "") + -> BinaryInst * { + std::string newName; + if (name.empty()) { + std::stringstream ss; + ss << "%" << tmpIndex; + newName = ss.str(); + tmpIndex++; + } else { + newName = name; + } + + auto inst = new BinaryInst(kind, type, lhs, rhs, block, newName); assert(inst); block->getInstructions().emplace(position, inst); return inst; - } - BinaryInst *createAddInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kAdd, Type::getIntType(), lhs, rhs, - name); - } - BinaryInst *createSubInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kSub, Type::getIntType(), lhs, rhs, - name); - } - BinaryInst *createMulInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kMul, Type::getIntType(), lhs, rhs, - name); - } - BinaryInst *createDivInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kDiv, Type::getIntType(), lhs, rhs, - name); - } - BinaryInst *createRemInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kRem, Type::getIntType(), lhs, rhs, - name); - } - BinaryInst *createICmpEQInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kICmpEQ, Type::getIntType(), lhs, rhs, - name); - } - BinaryInst *createICmpNEInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kICmpNE, Type::getIntType(), lhs, rhs, - name); - } - BinaryInst *createICmpLTInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kICmpLT, Type::getIntType(), lhs, rhs, - name); - } - BinaryInst *createICmpLEInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kICmpLE, Type::getIntType(), lhs, rhs, - name); - } - BinaryInst *createICmpGTInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kICmpGT, Type::getIntType(), lhs, rhs, - name); - } - BinaryInst *createICmpGEInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kICmpGE, Type::getIntType(), lhs, rhs, - name); - } - BinaryInst *createFAddInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kFAdd, Type::getFloatType(), lhs, rhs, - name); - } - BinaryInst *createFSubInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kFSub, Type::getFloatType(), lhs, rhs, - name); - } - BinaryInst *createFMulInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kFMul, Type::getFloatType(), lhs, rhs, - name); - } - BinaryInst *createFDivInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kFDiv, Type::getFloatType(), lhs, rhs, - name); - } - BinaryInst *createFRemInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kFRem, Type::getFloatType(), lhs, rhs, - name); - } - BinaryInst *createFCmpEQInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kFCmpEQ, Type::getFloatType(), lhs, - rhs, name); - } - BinaryInst *createFCmpNEInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kFCmpNE, Type::getFloatType(), lhs, - rhs, name); - } - BinaryInst *createFCmpLTInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kFCmpLT, Type::getFloatType(), lhs, - rhs, name); - } - BinaryInst *createFCmpLEInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kFCmpLE, Type::getFloatType(), lhs, - rhs, name); - } - BinaryInst *createFCmpGTInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kFCmpGT, Type::getFloatType(), lhs, - rhs, name); - } - BinaryInst *createFCmpGEInst(Value *lhs, Value *rhs, - const std::string &name = "") { - return createBinaryInst(Instruction::kFCmpGE, Type::getFloatType(), lhs, - rhs, name); - } - ReturnInst *createReturnInst(Value *value = nullptr) { - auto inst = new ReturnInst(value); + } ///< 创建二元指令 + auto createAddInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + return createBinaryInst(Instruction::kAdd, Type::getIntType(), lhs, rhs, name); + } ///< 创建加法指令 + auto createSubInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + return createBinaryInst(Instruction::kSub, Type::getIntType(), lhs, rhs, name); + } ///< 创建减法指令 + auto createMulInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + return createBinaryInst(Instruction::kMul, Type::getIntType(), lhs, rhs, name); + } ///< 创建乘法指令 + auto createDivInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + return createBinaryInst(Instruction::kDiv, Type::getIntType(), lhs, rhs, name); + } ///< 创建除法指令 + auto createRemInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + return createBinaryInst(Instruction::kRem, Type::getIntType(), lhs, rhs, name); + } ///< 创建取余指令 + auto createICmpEQInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + return createBinaryInst(Instruction::kICmpEQ, Type::getIntType(), lhs, rhs, name); + } ///< 创建相等设置指令 + auto createICmpNEInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + return createBinaryInst(Instruction::kICmpNE, Type::getIntType(), lhs, rhs, name); + } ///< 创建不相等设置指令 + auto createICmpLTInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + return createBinaryInst(Instruction::kICmpLT, Type::getIntType(), lhs, rhs, name); + } ///< 创建小于设置指令 + auto createICmpLEInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + return createBinaryInst(Instruction::kICmpLE, Type::getIntType(), lhs, rhs, name); + } ///< 创建小于等于设置指令 + auto createICmpGTInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + return createBinaryInst(Instruction::kICmpGT, Type::getIntType(), lhs, rhs, name); + } ///< 创建大于设置指令 + auto createICmpGEInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + return createBinaryInst(Instruction::kICmpGE, Type::getIntType(), lhs, rhs, name); + } ///< 创建大于等于设置指令 + auto createFAddInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + return createBinaryInst(Instruction::kFAdd, Type::getFloatType(), lhs, rhs, name); + } ///< 创建浮点加法指令 + auto createFSubInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + return createBinaryInst(Instruction::kFSub, Type::getFloatType(), lhs, rhs, name); + } ///< 创建浮点减法指令 + auto createFMulInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + return createBinaryInst(Instruction::kFMul, Type::getFloatType(), lhs, rhs, name); + } ///< 创建浮点乘法指令 + auto createFDivInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + return createBinaryInst(Instruction::kFDiv, Type::getFloatType(), lhs, rhs, name); + } ///< 创建浮点除法指令 + auto createFCmpEQInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + return createBinaryInst(Instruction::kFCmpEQ, Type::getIntType(), lhs, rhs, name); + } ///< 创建浮点相等设置指令 + auto createFCmpNEInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + return createBinaryInst(Instruction::kFCmpNE, Type::getIntType(), lhs, rhs, name); + } ///< 创建浮点不相等设置指令 + auto createFCmpLTInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + return createBinaryInst(Instruction::kFCmpLT, Type::getIntType(), lhs, rhs, name); + } ///< 创建浮点小于设置指令 + auto createFCmpLEInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + return createBinaryInst(Instruction::kFCmpLE, Type::getIntType(), lhs, rhs, name); + } ///< 创建浮点小于等于设置指令 + auto createFCmpGTInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + return createBinaryInst(Instruction::kFCmpGT, Type::getIntType(), lhs, rhs, name); + } ///< 创建浮点大于设置指令 + auto createFCmpGEInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + return createBinaryInst(Instruction::kFCmpGE, Type::getIntType(), lhs, rhs, name); + } ///< 创建浮点相大于等于设置指令 + auto createAndInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + return createBinaryInst(Instruction::kAnd, Type::getIntType(), lhs, rhs, name); + } ///< 创建按位且指令 + auto createOrInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + return createBinaryInst(Instruction::kOr, Type::getIntType(), lhs, rhs, name); + } ///< 创建按位或指令 + auto createCallInst(Function *callee, const std::vector &args, const std::string &name = "") -> CallInst * { + std::string newName; + if (name.empty() && callee->getReturnType() != Type::getVoidType()) { + std::stringstream ss; + ss << "%" << tmpIndex; + newName = ss.str(); + tmpIndex++; + } else { + newName = name; + } + + auto inst = new CallInst(callee, args, block, newName); assert(inst); block->getInstructions().emplace(position, inst); return inst; - } - UncondBrInst *createUncondBrInst(BasicBlock *block, - std::vector args) { - auto inst = new UncondBrInst(block, args, block); + } ///< 创建Call指令 + auto createReturnInst(Value *value = nullptr, const std::string &name = "") -> ReturnInst * { + auto inst = new ReturnInst(value, block, name); assert(inst); block->getInstructions().emplace(position, inst); return inst; - } - CondBrInst *createCondBrInst(Value *condition, BasicBlock *thenBlock, - BasicBlock *elseBlock, - const std::vector &thenArgs, - const std::vector &elseArgs) { - auto inst = new CondBrInst(condition, thenBlock, elseBlock, thenArgs, - elseArgs, block); + } ///< 创建return指令 + auto createUncondBrInst(BasicBlock *thenBlock, const std::vector &args) -> UncondBrInst * { + auto inst = new UncondBrInst(thenBlock, args, block); assert(inst); block->getInstructions().emplace(position, inst); return inst; - } - AllocaInst *createAllocaInst(Type *type, - const std::vector &dims = {}, - const std::string &name = "") { + } ///< 创建无条件指令 + auto createCondBrInst(Value *condition, BasicBlock *thenBlock, BasicBlock *elseBlock, + const std::vector &thenArgs, const std::vector &elseArgs) -> CondBrInst * { + auto inst = new CondBrInst(condition, thenBlock, elseBlock, thenArgs, elseArgs, block); + assert(inst); + block->getInstructions().emplace(position, inst); + return inst; + } ///< 创建条件跳转指令 + auto createAllocaInst(Type *type, const std::vector &dims = {}, const std::string &name = "") + -> AllocaInst * { auto inst = new AllocaInst(type, dims, block, name); assert(inst); block->getInstructions().emplace(position, inst); return inst; - } - LoadInst *createLoadInst(Value *pointer, - const std::vector &indices = {}, - const std::string &name = "") { - auto inst = new LoadInst(pointer, indices, block, name); + } ///< 创建分配指令 + auto createAllocaInstWithoutInsert(Type *type, const std::vector &dims = {}, BasicBlock *parent = nullptr, + const std::string &name = "") -> AllocaInst * { + auto inst = new AllocaInst(type, dims, parent, name); + assert(inst); + return inst; + } ///< 创建不插入指令列表的分配指令 + auto createLoadInst(Value *pointer, const std::vector &indices = {}, const std::string &name = "") + -> LoadInst * { + std::string newName; + if (name.empty()) { + std::stringstream ss; + ss << "%" << tmpIndex; + newName = ss.str(); + tmpIndex++; + } else { + newName = name; + } + + auto inst = new LoadInst(pointer, indices, block, newName); assert(inst); block->getInstructions().emplace(position, inst); return inst; - } - StoreInst *createStoreInst(Value *value, Value *pointer, - const std::vector &indices = {}, - const std::string &name = "") { + } ///< 创建load指令 + auto createLaInst(Value *pointer, const std::vector &indices = {}, const std::string &name = "") + -> LaInst * { + std::string newName; + if (name.empty()) { + std::stringstream ss; + ss << "%" << tmpIndex; + newName = ss.str(); + tmpIndex++; + } else { + newName = name; + } + + auto inst = new LaInst(pointer, indices, block, newName); + assert(inst); + block->getInstructions().emplace(position, inst); + return inst; + } ///< 创建la指令 + auto createGetSubArray(LVal *fatherArray, const std::vector &indices, const std::string &name = "") + -> GetSubArrayInst * { + assert(fatherArray->getLValNumDims() > indices.size()); + std::vector subDims; + auto dims = fatherArray->getLValDims(); + auto iter = std::next(dims.begin(), indices.size()); + while (iter != dims.end()) { + subDims.emplace_back(*iter); + iter++; + } + + std::string childArrayName; + std::stringstream ss; + ss << "A" + << "%" << tmpIndex; + childArrayName = ss.str(); + tmpIndex++; + + auto fatherArrayValue = dynamic_cast(fatherArray); + auto childArray = new AllocaInst(fatherArrayValue->getType(), subDims, block, childArrayName); + auto inst = new GetSubArrayInst(fatherArray, childArray, indices, block, name); + assert(inst); + block->getInstructions().emplace(position, inst); + return inst; + } ///< 创建获取部分数组指令 + auto createMemsetInst(Value *pointer, Value *begin, Value *size, Value *value, const std::string &name = "") + -> MemsetInst * { + auto inst = new MemsetInst(pointer, begin, size, value, block, name); + assert(inst); + block->getInstructions().emplace(position, inst); + return inst; + } ///< 创建memset指令 + auto createStoreInst(Value *value, Value *pointer, const std::vector &indices = {}, + const std::string &name = "") -> StoreInst * { auto inst = new StoreInst(value, pointer, indices, block, name); assert(inst); block->getInstructions().emplace(position, inst); return inst; - } + } ///< 创建store指令 + auto createPhiInst(Type *type, Value *lhs, BasicBlock *parent, const std::string &name = "") -> PhiInst * { + auto predNum = parent->getNumPredecessors(); + std::vector rhs; + for (size_t i = 0; i < predNum; i++) { + rhs.push_back(lhs); + } + auto inst = new PhiInst(type, lhs, rhs, lhs, parent, name); + assert(inst); + parent->getInstructions().emplace(parent->begin(), inst); + return inst; + } ///< 创建Phi指令 }; -} // namespace sysy \ No newline at end of file +} // namespace sysy From c1583e447dbd04bdfce0d1381c070ef45a44984d Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Sat, 21 Jun 2025 13:44:51 +0800 Subject: [PATCH 05/13] =?UTF-8?q?=E6=9B=B4=E6=94=B9g4=E6=96=87=E4=BB=B6?= =?UTF-8?q?=EF=BC=8C=E4=BC=98=E5=8C=96IR=E7=94=9F=E6=88=90=E6=B5=81?= =?UTF-8?q?=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/SysY.g4 | 13 +-- src/include/SysYIRGenerator.h | 158 +++++++++++++++++----------------- 2 files changed, 85 insertions(+), 86 deletions(-) diff --git a/src/SysY.g4 b/src/SysY.g4 index b3ed583..a9e4208 100644 --- a/src/SysY.g4 +++ b/src/SysY.g4 @@ -101,7 +101,10 @@ BLOCKCOMMENT: '/*' .*? '*/' -> skip; // CompUnit: (CompUnit)? (decl |funcDef); -compUnit: (decl |funcDef)+; +compUnit: (globalDecl |funcDef)+; + +globalDecl: constDecl # globalConstDecl + | varDecl # globalVarDecl; decl: constDecl | varDecl; @@ -111,16 +114,16 @@ bType: INT | FLOAT; constDef: Ident (LBRACK constExp RBRACK)* ASSIGN constInitVal; -constInitVal: constExp - | LBRACE (constInitVal (COMMA constInitVal)*)? RBRACE; +constInitVal: constExp # constScalarInitValue + | LBRACE (constInitVal (COMMA constInitVal)*)? RBRACE # constArrayInitValue; varDecl: bType varDef (COMMA varDef)* SEMICOLON; varDef: Ident (LBRACK constExp RBRACK)* | Ident (LBRACK constExp RBRACK)* ASSIGN initVal; -initVal: exp - | LBRACE (initVal (COMMA initVal)*)? RBRACE; +initVal: exp # scalarInitValue + | LBRACE (initVal (COMMA initVal)*)? RBRACE # arrayInitValue; funcType: VOID | INT | FLOAT; diff --git a/src/include/SysYIRGenerator.h b/src/include/SysYIRGenerator.h index 638986a..b4ead27 100644 --- a/src/include/SysYIRGenerator.h +++ b/src/include/SysYIRGenerator.h @@ -9,130 +9,126 @@ namespace sysy { -class SymbolTable{ -private: - enum Kind - { - kModule, - kFunction, - kBlock, - }; - - std::forward_list>> Scopes; - -public: - struct ModuleScope { - SymbolTable& tables_ref; - ModuleScope(SymbolTable& tables) : tables_ref(tables) { - tables.enter(kModule); - } - ~ModuleScope() { tables_ref.exit(); } - }; - struct FunctionScope { - SymbolTable& tables_ref; - FunctionScope(SymbolTable& tables) : tables_ref(tables) { - tables.enter(kFunction); - } - ~FunctionScope() { tables_ref.exit(); } - }; - struct BlockScope { - SymbolTable& tables_ref; - BlockScope(SymbolTable& tables) : tables_ref(tables) { - tables.enter(kBlock); - } - ~BlockScope() { tables_ref.exit(); } - }; - - SymbolTable() = default; - - bool isModuleScope() const { return Scopes.front().first == kModule; } - bool isFunctionScope() const { return Scopes.front().first == kFunction; } - bool isBlockScope() const { return Scopes.front().first == kBlock; } - Value *lookup(const std::string &name) const { - for (auto &scope : Scopes) { - auto iter = scope.second.find(name); - if (iter != scope.second.end()) - return iter->second; - } - return nullptr; - } - auto insert(const std::string &name, Value *value) { - assert(not Scopes.empty()); - return Scopes.front().second.emplace(name, value); - } -private: - void enter(Kind kind) { - Scopes.emplace_front(); - Scopes.front().first = kind; - } - void exit() { - Scopes.pop_front(); - } - -}; class SysYIRGenerator : public SysYBaseVisitor { + +// @brief 用于存储数组值的树结构 +// 多位数组本质上是一维数组的嵌套可以用树来表示。 +class ArrayValueTree { +private: + Value *value = nullptr; /// 该节点存储的value + std::vector> children; /// 子节点列表 + +public: + ArrayValueTree() = default; + +public: + auto getValue() const -> Value * { return value; } + auto getChildren() const + -> const std::vector> & { + return children; + } + + void setValue(Value *newValue) { value = newValue; } + void addChild(ArrayValueTree *newChild) { children.emplace_back(newChild); } + void addChildren(const std::vector &newChildren) { + for (const auto &child : newChildren) { + children.emplace_back(child); + } + } +}; + + +class Utils { +public: + // transform a tree of ArrayValueTree to a ValueCounter + static void tree2Array(Type *type, ArrayValueTree *root, + const std::vector &dims, unsigned numDims, + ValueCounter &result, IRBuilder *builder); + static void + createExternalFunction(const std::vector ¶mTypes, + const std::vector ¶mNames, + const std::vector> ¶mDims, + Type *returnType, const std::string &funcName, + Module *pModule, IRBuilder *pBuilder); + + static void initExternalFunction(Module *pModule, IRBuilder *pBuilder); +}; + private: std::unique_ptr module; IRBuilder builder; - SymbolTable symbols_table; - - int trueBlockNum = 0, falseBlockNum = 0; - - int d = 0, n = 0; - vector path; - bool isalloca; - AllocaInst *current_alloca; - GlobalValue *current_global; - Type* current_type; - int numdims = 0; public: SysYIRGenerator() = default; public: Module *get() const { return module.get(); } - + IRBuilder *getBuilder(){ return &builder; } public: std::any visitCompUnit(SysYParser::CompUnitContext *ctx) override; - std::any visitDecl(SysYParser::DeclContext *ctx) override; + + std::any visitGlobalConstDecl(SysYParser::GlobalConstDeclContext *ctx) override; + std::any visitGlobalVarDecl(SysYParser::GlobalVarDeclContext *ctx) override; + + std::any visitDecl(SysYParser::DeclContext *ctx) override ; std::any visitConstDecl(SysYParser::ConstDeclContext *ctx) override; + std::any visitVarDecl(SysYParser::VarDeclContext *ctx) override; + std::any visitBType(SysYParser::BTypeContext *ctx) override; - std::any visitConstDef(SysYParser::ConstDefContext *ctx) override; - std::any visitConstInitVal(SysYParser::ConstInitValContext *ctx) override; + + // std::any visitConstDef(SysYParser::ConstDefContext *ctx) override; + // std::any visitVarDef(SysYParser::VarDefContext *ctx) override; + + std::any visitScalarInitValue(SysYParser::ScalarInitValueContext *ctx) override; + std::any visitArrayInitValue(SysYParser::ArrayInitValueContext *ctx) override; + + std::any visitConstScalarInitValue(SysYParser::ConstScalarInitValueContext *ctx) override; + std::any visitConstArrayInitValue(SysYParser::ConstArrayInitValueContext *ctx) override; + + // std::any visitConstInitVal(SysYParser::ConstInitValContext *ctx) override; std::any visitFuncType(SysYParser::FuncTypeContext* ctx) override; std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override; - std::any visitVarDecl(SysYParser::VarDeclContext *ctx) override; - std::any visitVarDef(SysYParser::VarDefContext *ctx) override; - std::any visitInitVal(SysYParser::InitValContext *ctx) override; + // std::any visitInitVal(SysYParser::InitValContext *ctx) override; // std::any visitFuncFParam(SysYParser::FuncFParamContext *ctx) override; // std::any visitFuncFParams(SysYParser::FuncFParamsContext *ctx) override; + std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override; // std::any visitStmt(SysYParser::StmtContext *ctx) override; std::any visitAssignStmt(SysYParser::AssignStmtContext *ctx) override; + // std::any visitExpStmt(SysYParser::ExpStmtContext *ctx) override; + // std::any visitBlkStmt(SysYParser::BlkStmtContext *ctx) override; std::any visitIfStmt(SysYParser::IfStmtContext *ctx) override; std::any visitWhileStmt(SysYParser::WhileStmtContext *ctx) override; std::any visitBreakStmt(SysYParser::BreakStmtContext *ctx) override; std::any visitContinueStmt(SysYParser::ContinueStmtContext *ctx) override; std::any visitReturnStmt(SysYParser::ReturnStmtContext *ctx) override; + // std::any visitExp(SysYParser::ExpContext *ctx) override; + // std::any visitCond(SysYParser::CondContext *ctx) override; + std::any visitLValue(SysYParser::LValueContext *ctx) override; std::any visitPrimExp(SysYParser::PrimExpContext *ctx) override; + // std::any visitParenExp(SysYParser::ParenExpContext *ctx) override; std::any visitNumber(SysYParser::NumberContext *ctx) override; // std::any visitString(SysYParser::StringContext *ctx) override; + std::any visitCall(SysYParser::CallContext *ctx) override; + // std::any visitUnaryExp(SysYParser::UnaryExpContext *ctx) override; // std::any visitUnaryOp(SysYParser::UnaryOpContext *ctx) override; + std::any visitUnExp(SysYParser::UnExpContext *ctx) override; - // std::any visitFuncRParams(SysYParser::FuncRParamsContext *ctx) override; + std::any visitFuncRParams(SysYParser::FuncRParamsContext *ctx) override; std::any visitMulExp(SysYParser::MulExpContext *ctx) override; std::any visitAddExp(SysYParser::AddExpContext *ctx) override; std::any visitRelExp(SysYParser::RelExpContext *ctx) override; std::any visitEqExp(SysYParser::EqExpContext *ctx) override; std::any visitLAndExp(SysYParser::LAndExpContext *ctx) override; std::any visitLOrExp(SysYParser::LOrExpContext *ctx) override; - std::any visitConstExp(SysYParser::ConstExpContext *ctx) override; + + // std::any visitConstExp(SysYParser::ConstExpContext *ctx) override; }; // class SysYIRGenerator From 2b038e671b3beaeaad833b7bbe2d5b59d8d48df5 Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Sat, 21 Jun 2025 14:33:22 +0800 Subject: [PATCH 06/13] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dbug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/include/IR.h | 27 ++++++--------------------- src/include/SysYIRGenerator.h | 2 +- 2 files changed, 7 insertions(+), 22 deletions(-) diff --git a/src/include/IR.h b/src/include/IR.h index f1759cd..60db35a 100644 --- a/src/include/IR.h +++ b/src/include/IR.h @@ -188,24 +188,6 @@ class Use { void setValue(Value *newValue) { value = newValue; } ///< 将被使用的值设置为newValue }; -template -inline std::enable_if_t, bool> -isa(const Value *value) { - return T::classof(value); -} - -template -inline std::enable_if_t, T *> -dyncast(Value *value) { - return isa(value) ? static_cast(value) : nullptr; -} - -template -inline std::enable_if_t, const T *> -dyncast(const Value *value) { - return isa(value) ? static_cast(value) : nullptr; -} - //! The base class of all value types class Value { @@ -1001,7 +983,7 @@ protected: } public: - BasicBlock *getBlock() const { return dyncast(getOperand(0)); } + BasicBlock *getBlock() const { return dynamic_cast(getOperand(0)); } auto getArguments() const { return make_range(std::next(operand_begin()), operand_end()); } @@ -1029,10 +1011,10 @@ protected: public: Value *getCondition() const { return getOperand(0); } BasicBlock *getThenBlock() const { - return dyncast(getOperand(1)); + return dynamic_cast(getOperand(1)); } BasicBlock *getElseBlock() const { - return dyncast(getOperand(2)); + return dynamic_cast(getOperand(2)); } auto getThenArguments() const { auto begin = std::next(operand_begin(), 3); @@ -1213,6 +1195,8 @@ public: }; +class GlobalValue; + // 循环类 class Loop { public: @@ -1658,6 +1642,7 @@ class Module { } return result.first->second.get(); } ///< 创建外部函数 + ///< 变量创建伴随着符号表的更新 auto createGlobalValue(const std::string &name, Type *type, const std::vector &dims = {}, const ValueCounter &init = {}) -> GlobalValue * { bool isFinished = variableTable.isCurNodeNull(); diff --git a/src/include/SysYIRGenerator.h b/src/include/SysYIRGenerator.h index b4ead27..a5f5a91 100644 --- a/src/include/SysYIRGenerator.h +++ b/src/include/SysYIRGenerator.h @@ -71,7 +71,7 @@ public: std::any visitGlobalConstDecl(SysYParser::GlobalConstDeclContext *ctx) override; std::any visitGlobalVarDecl(SysYParser::GlobalVarDeclContext *ctx) override; - std::any visitDecl(SysYParser::DeclContext *ctx) override ; + // std::any visitDecl(SysYParser::DeclContext *ctx) override; std::any visitConstDecl(SysYParser::ConstDeclContext *ctx) override; std::any visitVarDecl(SysYParser::VarDeclContext *ctx) override; From 8109d4423228e7ec85d38e02f5b7a318cdbcd389 Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Sat, 21 Jun 2025 14:33:56 +0800 Subject: [PATCH 07/13] =?UTF-8?q?=E5=B7=A5=E5=85=B7=E7=B1=BB=E6=96=B9?= =?UTF-8?q?=E6=B3=95=E9=83=A8=E5=88=86=E5=AE=9E=E7=8E=B0=EF=BC=8C=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0=E9=83=A8=E5=88=86IR=E7=94=9F=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/SysYIRGenerator.cpp | 894 +++++----------------------------------- 1 file changed, 112 insertions(+), 782 deletions(-) diff --git a/src/SysYIRGenerator.cpp b/src/SysYIRGenerator.cpp index 5a7ccc2..c70869f 100644 --- a/src/SysYIRGenerator.cpp +++ b/src/SysYIRGenerator.cpp @@ -6,15 +6,20 @@ #include "IR.h" #include #include +#include +#include +#include +#include using namespace std; #include "SysYIRGenerator.h" + namespace sysy { /* * @brief: visit compUnit * @details: - * compUnit: (decl | funcDef)+; + * compUnit: (globalDecl | funcDef)+; */ std::any SysYIRGenerator::visitCompUnit(SysYParser::CompUnitContext *ctx) { // create the IR module @@ -22,816 +27,141 @@ std::any SysYIRGenerator::visitCompUnit(SysYParser::CompUnitContext *ctx) { assert(pModule); module.reset(pModule); - SymbolTable::ModuleScope scope(symbols_table); + // SymbolTable::ModuleScope scope(symbols_table); - // 待添加运行时库函数getint等 - // generates globals and functions - auto type_i32 = Type::getIntType(); - auto type_f32 = Type::getFloatType(); - auto type_void = Type::getVoidType(); - auto type_i32p = Type::getPointerType(type_i32); - auto type_f32p = Type::getPointerType(type_f32); - - //runtime library - module->createFunction("getint", Type::getFunctionType(type_i32, {})); - module->createFunction("getch", Type::getFunctionType(type_i32, {})); - module->createFunction("getfloat", Type::getFunctionType(type_f32, {})); - symbols_table.insert("getint", module->getFunction("getint")); - symbols_table.insert("getch", module->getFunction("getch")); - symbols_table.insert("getfloat", module->getFunction("getfloat")); - - module->createFunction("getarray", Type::getFunctionType(type_i32, {type_i32p})); - module->createFunction("getfarray", Type::getFunctionType(type_i32, {type_f32p})); - symbols_table.insert("getarray", module->getFunction("getarray")); - symbols_table.insert("getfarray", module->getFunction("getfarray")); + Utils::initExternalFunction(pModule, &builder); - module->createFunction("putint", Type::getFunctionType(type_void, {type_i32})); - module->createFunction("putch", Type::getFunctionType(type_void, {type_i32})); - module->createFunction("putfloat", Type::getFunctionType(type_void, {type_f32})); - symbols_table.insert("putint", module->getFunction("putint")); - symbols_table.insert("putch", module->getFunction("putch")); - symbols_table.insert("putfloat", module->getFunction("putfloat")); - - module->createFunction("putarray", Type::getFunctionType(type_void, {type_i32, type_i32p})); - module->createFunction("putfarray", Type::getFunctionType(type_void, {type_i32, type_f32p})); - symbols_table.insert("putarray", module->getFunction("putarray")); - symbols_table.insert("putfarray", module->getFunction("putfarray")); - - module->createFunction("putf", Type::getFunctionType(type_void, {})); - symbols_table.insert("putf", module->getFunction("putf")); - - module->createFunction("starttime", Type::getFunctionType(type_void, {type_i32})); - module->createFunction("stoptime", Type::getFunctionType(type_void, {type_i32})); - symbols_table.insert("starttime", module->getFunction("starttime")); - symbols_table.insert("stoptime", module->getFunction("stoptime")); - - // visit all decls and funcDefs - for(auto decl:ctx->decl()){ - decl->accept(this); - } - for(auto funcDef:ctx->funcDef()){ - builder = IRBuilder(); - printf("entry funcDef\n"); - funcDef->accept(this); - } - // return the IR module ? + pModule->enterNewScope(); + visitChildren(ctx); + pModule->leaveScope(); return pModule; } -/* - * @brief: visit decl - * @details: - * decl: constDecl | varDecl; - * constDecl: CONST bType constDef (COMMA constDef)* SEMI; - * varDecl: bType varDef (COMMA varDef)* SEMI; - * constDecl and varDecl shares similar syntax structure - * we consider them together? not sure - */ -std::any SysYIRGenerator::visitDecl(SysYParser::DeclContext *ctx) { - if(ctx->constDecl()) - return visitConstDecl(ctx->constDecl()); - else if(ctx->varDecl()) - return visitVarDecl(ctx->varDecl()); - std::cerr << "error unkown decl" << ctx->getText() << std::endl; - return std::any(); -} +std::any SysYIRGenerator::visitGlobalConstDecl(SysYParser::GlobalConstDeclContext *ctx){ + auto constDecl = ctx->constDecl(); + Type* type = std::any_cast(visitBType(constDecl->bType())); + for (const auto &constDef : constDecl->constDef()) { + std::vector dims = {}; + std::string name = constDef->Ident()->getText(); + auto constExps = constDef->constExp(); + if (!constExps.empty()) { + for (const auto &constExp : constExps) { + dims.push_back(std::any_cast(visitConstExp(constExp))); + } + } -/* - * @brief: visit constdecl - * @details: - * constDecl: CONST bType constDef (COMMA constDef)* SEMI; - * constDef: Ident (LBRACK constExp RBRACK)* (ASSIGN constInitVal)?; - */ -std::any SysYIRGenerator::visitConstDecl(SysYParser::ConstDeclContext *ctx) { - cout << "visitconstDecl" << endl; - current_type = any_cast(ctx->bType()->accept(this)); - for(auto constDef:ctx->constDef()){ - constDef->accept(this); + ArrayValueTree* root = std::any_cast(constDef->constInitVal()->accept(this)); + ValueCounter values; + Utils::tree2Array(type, root, dims, dims.size(), values, &builder); + delete root; + // 创建全局常量变量,并更新符号表 + module->createConstVar(name, Type::getPointerType(type), values, dims); } return std::any(); } -/* - * @brief: visit btype - * @details: - * bType: INT | FLOAT; - */ -std::any SysYIRGenerator::visitBType(SysYParser::BTypeContext *ctx) { - if(ctx->INT()) - return Type::getPointerType(Type::getIntType()); - else if(ctx->FLOAT()) - return Type::getPointerType(Type::getFloatType()); - std::cerr << "error: unknown type" << ctx->getText() << std::endl; + +std::any SysYIRGenerator::visitGlobalVarDecl(SysYParser::GlobalVarDeclContext *ctx) { + auto varDecl = ctx->varDecl(); + Type* type = std::any_cast(visitBType(varDecl->bType())); + for (const auto &varDef : varDecl->varDef()) { + std::vector dims = {}; + std::string name = varDef->Ident()->getText(); + auto constExps = varDef->constExp(); + if (!constExps.empty()) { + for (const auto &constExp : constExps) { + dims.push_back(std::any_cast(visitConstExp(constExp))); + } + } + + ArrayValueTree* root = std::any_cast(varDef->initVal()->accept(this)); + ValueCounter values; + Utils::tree2Array(type, root, dims, dims.size(), values, &builder); + delete root; + // 创建全局变量,并更新符号表 + module->createGlobalValue(name, Type::getPointerType(type), dims, values); + } return std::any(); } -/* - * @brief: visit constDef - * @details: - * constDef: Ident (LBRACK constExp RBRACK)* (ASSIGN constInitVal)?; - * constInitVal: constExp | LBRACE (constInitVal (COMMA constInitVal)*)? RBRACE; - */ -std::any SysYIRGenerator::visitConstDef(SysYParser::ConstDefContext *ctx){ - auto name = ctx->Ident()->getText(); - Type* type = current_type; - Value* init = ctx->constInitVal() ? any_cast(ctx->constInitVal()->accept(this)) : (Value *)nullptr; - if (ctx->constExp().empty()){ - //scalar - if(init){ - if(symbols_table.isModuleScope()){ - assert(init->isConstant() && "global must be initialized by constant"); - Value* global = module->createGlobalValue(name, type, {}, init); - symbols_table.insert(name, global); - cout << "add module const " << name ; - if(init){ - cout << " inited by " ; - init->print(cout); +void Utils::tree2Array(Type *type, ArrayValueTree *root, + const std::vector &dims, unsigned numDims, + ValueCounter &result, IRBuilder *builder) { + auto value = root->getValue(); + auto &children = root->getChildren(); + if (value != nullptr) { + if (type == value->getType()) { + result.push_back(value); + } else { + if (type == Type::getFloatType()) { + auto constValue = dynamic_cast(value); + if (constValue != nullptr) { + result.push_back( + ConstantValue::get(static_cast(constValue->getInt()))); + } else { + result.push_back(builder->createIToFInst(value)); } - cout << '\n'; - } - else{ - Value* alloca = builder.createAllocaInst(type, {}, name); - Value* store = builder.createStoreInst(init, alloca); - symbols_table.insert(name, alloca); - cout << "add local const " << name ; - if(init){ - cout << " inited by " ; - init->print(cout); + } else { + auto constValue = dynamic_cast(value); + if (constValue != nullptr) { + result.push_back( + ConstantValue::get(static_cast(constValue->getFloat()))); + } else { + result.push_back(builder->createFtoIInst(value)); } - cout << '\n'; } } - else{ - assert(false && "const without initialization"); - } - } - else{ - //array - std::cerr << "array constDef not implemented yet" << std::endl; - } - printf("visitConstDef %s\n",name.c_str()); - return std::any(); -} - - -/* - * @brief: visit constInitVal - * @details: - * constInitVal: constExp - * | LBRACE (constInitVal (COMMA constInitVal)*)? RBRACE; - */ -std::any SysYIRGenerator::visitConstInitVal(SysYParser::ConstInitValContext *ctx){ - Value* initvalue; - if(ctx->constExp()) - initvalue = any_cast(ctx->constExp()->accept(this)); - else{ - //还未实现数组初始化等功能待验证 - std::cerr << "array initvalue not implemented yet" << std::endl; - // auto numConstInitVals = ctx->constInitVal().size(); - // vector initvalues; - // for(int i = 0; i < numConstInitVals; i++) - // initvalues.push_back(any_cast(ctx->constInitVal(i)->accept(this))); - // initvalue = ConstantValue::getArray(initvalues); - } - return initvalue; -} - -/* - * @brief: visit function type - * @details: - * funcType: VOID | INT | FLOAT; - */ -std::any SysYIRGenerator::visitFuncType(SysYParser::FuncTypeContext* ctx){ - if(ctx->INT()) - return Type::getIntType(); - else if(ctx->FLOAT()) - return Type::getFloatType(); - else if(ctx->VOID()) - return Type::getVoidType(); - std::cerr << "invalid function type: " << ctx->getText() << std::endl; - return std::any(); -} - -/* - * @brief: visit function define - * @details: - * funcDef: funcType Ident LPAREN funcFParams? RPAREN blockStmt; - * funcFParams: funcFParam (COMMA funcFParam)*; - * funcFParam: bType Ident (LBRACK RBRACK (LBRACK exp RBRACK)*)?; - * entry -> next -> others -> exit - * entry: allocas, br - * next: retval, params, br - * other: blockStmt init block - * exit: load retval, ret - */ -std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext* ctx){ - - auto funcName = ctx->Ident()->getText(); - auto returnType = any_cast(ctx->funcType()->accept(this)); - auto func = module->getFunction(funcName); - - cout << "func: "; - returnType->print(cout); - cout << ' '<< funcName.c_str() << endl; - - vector paramTypes; - vector paramNames; - - if(ctx->funcFParams()){ - for(auto funcParam:ctx->funcFParams()->funcFParam()){ - Type* paramType = any_cast(funcParam->bType()->accept(this)); - paramTypes.push_back(paramType); - paramNames.push_back(funcParam->Ident()->getText()); - } - } - - - auto funcType = FunctionType::get(returnType, paramTypes); - auto function = module->createFunction(funcName, funcType); - - SymbolTable::FunctionScope scope(symbols_table); - BasicBlock* entryblock = function->getEntryBlock(); - for(size_t i = 0; i < paramTypes.size(); i++) - entryblock->createArgument(paramTypes[i], paramNames[i]); - - for(auto& arg: entryblock->getArguments()) - symbols_table.insert(arg->getName(), (Value *)arg.get()); - - cout << "setposition entryblock" << endl; - builder.setPosition(entryblock, entryblock->end()); - - ctx->blockStmt()->accept(this); - - return std::any(); -} - -/* - * @brief: visit varDecl - * @details: - * varDecl: bType varDef (COMMA varDef)* SEMI; - */ -std::any SysYIRGenerator::visitVarDecl(SysYParser::VarDeclContext *ctx){ - cout << "visitVarDecl" << endl; - current_type = any_cast(ctx->bType()->accept(this)); - for(auto varDef:ctx->varDef()){ - varDef->accept(this); - } - return std::any(); -} - -/* - * @brief: visit varDef - * @details: - * varDef: Ident (LBRACK constExp RBRACK)* (ASSIGN initVal)?; - */ -std::any SysYIRGenerator::visitVarDef(SysYParser::VarDefContext *ctx){ - const std::string name = ctx->Ident()->getText(); - Type* type = current_type; - Value* init = ctx->initVal() ? any_cast(ctx->initVal()->accept(this)) : nullptr; - // const std::vector dims = {}; - - cout << "vardef: "; - current_type->print(cout); - cout << ' ' << name << endl; - - if(ctx->constExp().empty()){ - //scalar - if(symbols_table.isModuleScope()){ - - if(init) - assert(init->isConstant() && "global must be initialized by constant"); - Value* global = module->createGlobalValue(name, type, {}, init); - symbols_table.insert(name, global); - cout << "add module var " << name ; - // if(init){ - // cout << " inited by " ; - // init->print(cout); - // } - cout << '\n'; - } - else{ - - Value* alloca = builder.createAllocaInst(type, {}, name); - cout << "creatalloca" << endl; - alloca->print(cout); - Value* store = (StoreInst *)nullptr; - if(init != nullptr) - store = builder.createStoreInst(alloca, init, {}, name); - symbols_table.insert(name, alloca); - // alloca->setName(name); - cout << "add local var " ; - alloca->print(cout); - // if(init){ - // cout << " inited by " ; - // init->print(cout); - // } - cout << '\n'; - } - } - else{ - //array - std::cerr << "array varDef not implemented yet" << std::endl; - } - return std::any(); -} - -/* - * @brief: visit initVal - * @details: - * initVal: exp | LBRACE (initVal (COMMA initVal)*)? RBRACE; - */ -std::any SysYIRGenerator::visitInitVal(SysYParser::InitValContext *ctx) { - Value* initvalue = nullptr; - if(ctx->exp()) - initvalue = any_cast(ctx->exp()->accept(this)); - else{ - //还未实现数组初始化等功能待验证 - std::cerr << "array initvalue not implemented yet" << std::endl; - // auto numConstInitVals = ctx->constInitVal().size(); - // vector initvalues; - // for(int i = 0; i < numConstInitVals; i++) - // initvalues.push_back(any_cast(ctx->constInitVal(i)->accept(this))); - // initvalue = ConstantValue::getArray(initvalues); - } - return initvalue; -} - -// std::any SysYIRGenerator::visitFuncFParams(SysYParser::FuncFParamsContext* ctx){ -// return visitChildren(ctx); -// } -// std::any SysYIRGenerator::visitFuncFParam(SysYParser::FuncFParamContext *ctx) { -// return visitChildren(ctx); -// } - -/* - * @brief: visit blockStmt - * @details: - * blockStmt: LBRACE blockItem* RBRACE; - * blockItem: decl | stmt; - */ -std::any SysYIRGenerator::visitBlockStmt(SysYParser::BlockStmtContext* ctx){ - SymbolTable::BlockScope scope(symbols_table); - for (auto item : ctx->blockItem()){ - item->accept(this); - // if(builder.getBasicBlock()->isTerminal()){ - // break; - // } - } - return std::any(); -} - -/* - * @brief: visit ifstmt - * @details: - * ifStmt: IF LPAREN cond RPAREN stmt (ELSE stmt)?; - */ -std::any SysYIRGenerator::visitIfStmt(SysYParser::IfStmtContext *ctx) { - auto condition = any_cast(ctx->cond()->accept(this)); - - auto thenBlock = builder.getBasicBlock()->getParent()->addBasicBlock("then"); - auto elseBlock = builder.getBasicBlock()->getParent()->addBasicBlock("else"); - - auto condbr = builder.createCondBrInst(condition,thenBlock,elseBlock,{},{}); - - builder.setPosition(thenBlock, thenBlock->end()); - ctx->stmt(0)->accept(this); - - if(ctx->ELSE()){ - builder.setPosition(elseBlock, elseBlock->end()); - ctx->stmt(1)->accept(this); - } - //无条件跳转到下一个基本块 - // builder.createUncondBrInst(builder.getBasicBlock()->getParent()->addBasicBlock("next"),{}); - return std::any(); -} - -/* - * @brief: visit whilestmt - * @details: - * whileStmt: WHILE LPAREN cond RPAREN stmt; - */ - - std::any SysYIRGenerator::visitWhileStmt(SysYParser::WhileStmtContext* ctx) { - //需要解决一个函数多个循环的命名问题 - auto header = builder.getBasicBlock()->getParent()->addBasicBlock("header"); - auto body = builder.getBasicBlock()->getParent()->addBasicBlock("body"); - auto exit = builder.getBasicBlock()->getParent()->addBasicBlock("exit"); - - SymbolTable::BlockScope scope(symbols_table); - - { // visit header block - builder.setPosition(header, header->end()); - auto cond = any_cast(ctx->cond()->accept(this)); - auto condbr = builder.createCondBrInst(cond, body, exit, {}, {}); + return; } - { // visit body block - builder.setPosition(body, body->end()); - ctx->stmt()->accept(this); - auto uncondbr = builder.createUncondBrInst(header, {}); - } - - // visit exit block - builder.setPosition(exit, exit->end()); - //无条件跳转到下一个基本块以及一些参数传递 - - return std::any(); -} - -/* - * @brief: visit breakstmt - * @details: - * breakStmt: BREAK SEMICOLON; - */ -std::any SysYIRGenerator::visitBreakStmt(SysYParser::BreakStmtContext* ctx) { - //如何获取break所在body对应的header块 - return std::any(); -} -/* - * @brief Visit ReturnStmt - * returnStmt: RETURN exp? SEMICOLON; - */ -std::any SysYIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext* ctx) { - cout << "visitReturnStmt" << endl; - // auto value = ctx->exp() ? any_cast_Value(visit(ctx->exp())) : nullptr; - Value* value = ctx->exp() ? any_cast(ctx->exp()->accept(this)) : nullptr; - - const auto func = builder.getBasicBlock()->getParent(); - - assert(func && "ret stmt block parent err!"); - - // 匹配 返回值类型 与 函数定义类型 - if (func->getReturnType()->isVoid()) { - if (ctx->exp()) - assert(false && "the returned value is not matching the function"); - - auto ret = builder.createReturnInst(); - return std::any(); - } - - assert(ctx->exp() && "the returned value is not matching the function"); - - auto ret = builder.createReturnInst(value); - - //需要增加无条件跳转吗 - return std::any(); -} - -/* - * @brief: visit continuestmt - * @details: - * continueStmt: CONTINUE SEMICOLON; - */ -std::any SysYIRGenerator::visitContinueStmt(SysYParser::ContinueStmtContext* ctx) { - //如何获取continue所在body对应的header块 - return std::any(); -} - - -/* - * @brief visit assign stmt - * @details: - * assignStmt: lValue ASSIGN exp SEMICOLON - */ -std::any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext* ctx) { - cout << "visitassignstme :\n"; - auto lvalue = any_cast(ctx->lValue()->accept(this)); - cout << "getlval" << endl;//lvalue->print(cout);cout << ')'; - auto rvalue = any_cast(ctx->exp()->accept(this)); - //可能要考虑类型转换例如int a = 1.0 - cout << "getrval" << endl;//rvalue->print(cout);cout << ")\n"; - builder.createStoreInst(rvalue, lvalue, {}, {}); - return std::any(); -} - -/* - * @brief: visit lValue - * @details: - * lValue: Ident (LBRACK exp RBRACK)*; - */ -std::any SysYIRGenerator::visitLValue(SysYParser::LValueContext* ctx) { - cout << "visitLValue" << endl; - auto name = ctx->Ident()->getText(); - Value* value = symbols_table.lookup(name); - - assert(value && "lvalue not found"); - - if(ctx->exp().size() == 0){ - //scalar - cout << "lvalue: " << name << endl; - return value; - } - else{ - //array - std::cerr << "array lvalue not implemented yet" << std::endl; - } - std::cerr << "error lvalue" << ctx->getText() << std::endl; - return std::any(); -} - -std::any SysYIRGenerator::visitPrimExp(SysYParser::PrimExpContext *ctx){ - cout << "visitPrimExp" << endl; - return visitChildren(ctx); -} -// std::any SysYIRGenerator::visitExp(SysYParser::ExpContext* ctx) { -// cout << "visitExp" << endl; -// return ctx->addExp()->accept(this); -// } - -std::any SysYIRGenerator::visitNumber(SysYParser::NumberContext *ctx) { - cout << "visitNumber" << endl; - Value* res = nullptr; - if (auto iLiteral = ctx->ILITERAL()) { - /* 基数 (8, 10, 16) */ - const auto text = iLiteral->getText(); - int base = 10; - if (text.find("0x") == 0 || text.find("0X") == 0) { - base = 16; - } else if (text.find("0b") == 0 || text.find("0B") == 0) { - base = 2; - } else if (text.find("0") == 0) { - base = 8; - } - res = ConstantValue::get(Type::getIntType() ,(int)std::stol(text, 0, base)); - } else if (auto fLiteral = ctx->FLITERAL()) { - const auto text = fLiteral->getText(); - res = ConstantValue::get(Type::getFloatType(), (float)std::stof(text)); - } - cout << "number: "; - res->print(cout); - cout << endl; - - return res; -} - -/* - * @brief: visit call - * @details: - * call: Ident LPAREN funcRParams? RPAREN; - */ -std::any SysYIRGenerator::visitCall(SysYParser::CallContext* ctx) { - cout << "visitCall" << endl; - auto funcName = ctx->Ident()->getText(); - auto func = module->getFunction(funcName); - assert(func && "function not found"); - - //需要做类型检查和转换 - std::vector args; - if(ctx->funcRParams()){ - for(auto exp:ctx->funcRParams()->exp()){ - args.push_back(any_cast(exp->accept(this))); - } - } - Value* call = builder.createCallInst(func, args); - return call; -} - -/* - * @brief: visit unexp - * @details: - * unExp: unaryOp unaryExp - */ -std::any SysYIRGenerator::visitUnExp(SysYParser::UnExpContext *ctx) { - cout << "visitUnExp" << endl; - Value* res = nullptr; - auto op = ctx->unaryOp()->getText(); - auto exp = any_cast(ctx->unaryExp()->accept(this)); - if(ctx->unaryOp()->ADD()){ - res = exp; - } - else if(ctx->unaryOp()->SUB()){ - res = builder.createNegInst(exp, exp->getName()); - } - else if(ctx->unaryOp()->NOT()){ - //not将非零值转换为0,零值转换为1 - res = builder.createNotInst(exp, exp->getName()); - } - return res; -} - -/* - * @brief: visit mulexp - * @details: - * mulExp: unaryExp ((MUL | DIV | MOD) unaryExp)* - */ -std::any SysYIRGenerator::visitMulExp(SysYParser::MulExpContext *ctx) { - cout << "visitMulExp" << endl; - Value* res = nullptr; - cout << "mulExplhsin\n"; - Value* lhs = any_cast(ctx->unaryExp(0)->accept(this)); - cout << "mulExplhsout\n"; - if(ctx->unaryExp().size() == 1){ - cout << "unaryExp().size() = 1\n"; - res = lhs; - } - else{ - cout << "unaryExp().size() > 1\n"; - for(size_t i = 1; i < ctx->unaryExp().size(); i++){ - Value* rhs = any_cast(ctx->unaryExp(i)->accept(this)); - auto opNode = dynamic_cast(ctx->children[2 * i - 1]); - - if(opNode->getText() == "*"){ - res = builder.createMulInst(lhs, rhs, lhs->getName() + "*" + rhs->getName()); - } - else if(opNode->getText() == "/"){ - res = builder.createDivInst(lhs, rhs, lhs->getName() + "/" + rhs->getName()); - } - else if(opNode->getText() == "%"){ - std::cerr << "mod not implemented yet" << std::endl; - // res = builder.createModInst(lhs, rhs, lhs->getName() + "%" + rhs->getName()); - } - } - } - return res; -} - -/* - * @brief: visit addexp - * @details: - * addExp: mulExp ((ADD | SUB) mulExp)* - */ -std::any SysYIRGenerator::visitAddExp(SysYParser::AddExpContext *ctx) { - cout << "visitAddExp" << endl; - Value* res = nullptr; - Value* lhs = any_cast(ctx->mulExp(0)->accept(this)); - if(ctx->mulExp().size() == 1){ - cout << "ctx->mulExp().size() = 1\n"; - res = lhs; - } - else{ - for(size_t i = 1; i < ctx->mulExp().size(); i++){ - cout << "i = " << i << "\n"; - Value* rhs = any_cast(ctx->mulExp(i)->accept(this)); - auto opNode = dynamic_cast(ctx->children[2 * i - 1]); - - if(opNode->getText() == "+"){ - res = builder.createAddInst(lhs, rhs, lhs->getName() + "+" + rhs->getName()); - } - else if(opNode->getText() == "-"){ - res = builder.createSubInst(lhs, rhs, lhs->getName() + "-" + rhs->getName()); + auto beforeSize = result.size(); + for (const auto &child : children) { + int begin = result.size(); + int newNumDims = 0; + for (unsigned i = 0; i < numDims - 1; i++) { + auto dim = dynamic_cast(*(dims.rbegin() + i))->getInt(); + if (begin % dim == 0) { + newNumDims += 1; + begin /= dim; + } else { + break; } } - lhs = res; + tree2Array(type, child.get(), dims, newNumDims, result, builder); + } + auto afterSize = result.size(); + + int blockSize = 1; + for (unsigned i = 0; i < numDims; i++) { + blockSize *= dynamic_cast(*(dims.rbegin() + i))->getInt(); } - return res; -} - -/* - * @brief: visit relexp - * @details: - * relExp: addExp ((LT | GT | LE | GE) addExp)* - */ -std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) { - cout << "visitRelExp" << endl; - Value* res = nullptr; - Value* lhs = any_cast(ctx->addExp(0)->accept(this)); - if(ctx->addExp().size() == 1){ - res = lhs; - } - else{ - for(size_t i = 1; i < ctx->addExp().size(); i++){ - Value* rhs = any_cast(ctx->addExp(i)->accept(this)); - auto opNode = dynamic_cast(ctx->children[2 * i - 1]); - if(lhs->getType() != rhs->getType()){ - std::cerr << "type mismatch:type check not implemented" << std::endl; - } - Type* type = lhs->getType(); - if(opNode->getText() == "<"){ - if(type->isInt()) - res = builder.createICmpLTInst(lhs, rhs, lhs->getName() + "<" + rhs->getName()); - else if(type->isFloat()) - res = builder.createFCmpLTInst(lhs, rhs, lhs->getName() + "<" + rhs->getName()); - } - else if(opNode->getText() == ">"){ - if(type->isInt()) - res = builder.createICmpGTInst(lhs, rhs, lhs->getName() + ">" + rhs->getName()); - else if(type->isFloat()) - res = builder.createFCmpGTInst(lhs, rhs, lhs->getName() + ">" + rhs->getName()); - } - else if(opNode->getText() == "<="){ - if(type->isInt()) - res = builder.createICmpLEInst(lhs, rhs, lhs->getName() + "<=" + rhs->getName()); - else if(type->isFloat()) - res = builder.createFCmpLEInst(lhs, rhs, lhs->getName() + "<=" + rhs->getName()); - } - else if(opNode->getText() == ">="){ - if(type->isInt()) - res = builder.createICmpGEInst(lhs, rhs, lhs->getName() + ">=" + rhs->getName()); - else if(type->isFloat()) - res = builder.createFCmpGEInst(lhs, rhs, lhs->getName() + ">=" + rhs->getName()); - } + int num = blockSize - afterSize + beforeSize; + if (num > 0) { + if (type == Type::getFloatType()) { + result.push_back(ConstantValue::get(0.0F), num); + } else { + result.push_back(ConstantValue::get(0), num); } } - return res; } -/* - * @brief: visit eqexp - * @details: - * eqExp: relExp ((EQ | NEQ) relExp)* - */ -std::any SysYIRGenerator::visitEqExp(SysYParser::EqExpContext* ctx) { - cout << "visitEqExp" << endl; - Value* res = nullptr; - Value* lhs = any_cast(ctx->relExp(0)->accept(this)); - if(ctx->relExp().size() == 1){ - res = lhs; - } - else{ - for(size_t i = 1; i < ctx->relExp().size(); i++){ - Value* rhs = any_cast(ctx->relExp(i)->accept(this)); - auto opNode = dynamic_cast(ctx->children[2 * i - 1]); - if(lhs->getType() != rhs->getType()){ - std::cerr << "type mismatch:type check not implemented" << std::endl; - } - Type* type = lhs->getType(); - if(opNode->getText() == "=="){ - if(type->isInt()) - res = builder.createICmpEQInst(lhs, rhs, lhs->getName() + "==" + rhs->getName()); - else if(type->isFloat()) - res = builder.createFCmpEQInst(lhs, rhs, lhs->getName() + "==" + rhs->getName()); - } - else if(opNode->getText() == "!="){ - if(type->isInt()) - res = builder.createICmpNEInst(lhs, rhs, lhs->getName() + "!=" + rhs->getName()); - else if(type->isFloat()) - res = builder.createFCmpNEInst(lhs, rhs, lhs->getName() + "!=" + rhs->getName()); - } - } - } - return res; -} +void Utils::createExternalFunction( + const std::vector ¶mTypes, + const std::vector ¶mNames, + const std::vector> ¶mDims, Type *returnType, + const std::string &funcName, Module *pModule, IRBuilder *pBuilder) { + auto funcType = Type::getFunctionType(returnType, paramTypes); + auto function = pModule->createExternalFunction(funcName, funcType); + auto entry = function->getEntryBlock(); + pBuilder->setPosition(entry, entry->end()); -/* - * @brief: visit lAndexp - * @details: - * lAndExp: eqExp (AND eqExp)* - */ -std::any SysYIRGenerator::visitLAndExp(SysYParser::LAndExpContext* ctx) { - cout << "visitLAndExp" << endl; - auto currentBlock = builder.getBasicBlock(); - // auto trueBlock = currentBlock->getParent()->addBasicBlock("trueland" + std::to_string(++trueBlockNum)); - auto falseBlock = currentBlock->getParent()->addBasicBlock("falseland" + std::to_string(++falseBlockNum)); - Value* value = any_cast(ctx->eqExp(0)->accept(this)); - auto trueBlock = currentBlock; - for(size_t i = 1; i < ctx->eqExp().size(); i++){ - trueBlock = trueBlock->getParent()->addBasicBlock("trueland" + std::to_string(++trueBlockNum)); - builder.createCondBrInst(value, currentBlock, falseBlock, {}, {}); - builder.setPosition(trueBlock, trueBlock->end()); - value = any_cast(ctx->eqExp(i)->accept(this)); + for (size_t i = 0; i < paramTypes.size(); ++i) { + auto alloca = pBuilder->createAllocaInst( + Type::getPointerType(paramTypes[i]), paramDims[i], paramNames[i]); + entry->insertArgument(alloca); + // pModule->addVariable(paramNames[i], alloca); } - //结构trueblk条件跳转到falseblk - // - trueBlock1->trueBlock2->trueBlock3->...->trueBlockn->nextblk - //entry-| - // -falseBlock->nextblk - //需要在最后一个trueblock的末尾加上无条件跳转到下一个基本块的指令 - // builder.createCondBrInst(value, trueBlock, falseBlock, {}, {}); - return std::any(); -} - -/* - * @brief: visit lOrexp - * @details: - * lOrExp: lAndExp (OR lAndExp)* - */ -std::any SysYIRGenerator::visitLOrExp(SysYParser::LOrExpContext* ctx) { - cout << "visitLOrExp" << endl; - auto currentBlock = builder.getBasicBlock(); - auto trueBlock = currentBlock->getParent()->addBasicBlock("trueland" + std::to_string(++trueBlockNum)); - Value* value = any_cast(ctx->lAndExp(0)->accept(this)); - auto falseBlock = currentBlock; - for(size_t i = 1; i < ctx->lAndExp().size(); i++){ - falseBlock = currentBlock->getParent()->addBasicBlock("falseland" + std::to_string(++falseBlockNum)); - builder.createCondBrInst(value, trueBlock, falseBlock, {}, {}); - builder.setPosition(falseBlock, falseBlock->end()); - value = any_cast(ctx->lAndExp(i)->accept(this)); - } - //结构trueblk条件跳转到falseblk - // - falseBlock1->falseBlock2->falseBlock3->...->falseBlockn->nextblk - //entry-| - // -trueBlock->nextblk - //需要在最后一个falseblock的末尾加上无条件跳转到下一个基本块的指令 - // builder.createCondBrInst(value, trueBlock, falseBlock, {}, {}); - return std::any(); -} - -/* - * @brief: visit constexp - * @details: - * constExp: addExp; - */ -std::any SysYIRGenerator::visitConstExp(SysYParser::ConstExpContext* ctx) { - cout << "visitConstExp" << endl; - ConstantValue* res = nullptr; - Value* value = any_cast(ctx->addExp()->accept(this)); - if(isa(value)){ - res = dyncast(value); - } - else{ - std::cerr << "error constexp" << ctx->getText() << std::endl; - } - return res; } } // namespace sysy \ No newline at end of file From ba5f2a0620a8bc6543032ccd914adfba36230a55 Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Sat, 21 Jun 2025 15:40:00 +0800 Subject: [PATCH 08/13] =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=A0=BC=E5=BC=8F?= =?UTF-8?q?=E5=8C=96=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/ASTPrinter.cpp | 355 --------------------------------------- src/include/ASTPrinter.h | 59 ------- src/sysyc.cpp | 6 - 3 files changed, 420 deletions(-) delete mode 100644 src/ASTPrinter.cpp delete mode 100644 src/include/ASTPrinter.h diff --git a/src/ASTPrinter.cpp b/src/ASTPrinter.cpp deleted file mode 100644 index a29383e..0000000 --- a/src/ASTPrinter.cpp +++ /dev/null @@ -1,355 +0,0 @@ -#include -#include -using namespace std; -#include "ASTPrinter.h" -#include "SysYParser.h" - - -any ASTPrinter::visitCompUnit(SysYParser::CompUnitContext *ctx) { - if(ctx->decl().empty() && ctx->funcDef().empty()) - return nullptr; - for (auto dcl : ctx->decl()) {dcl->accept(this);cout << '\n';}cout << '\n'; - for (auto func : ctx->funcDef()) {func->accept(this);cout << "\n";} - return nullptr; -} - -// std::any ASTPrinter::visitBType(SysYParser::BTypeContext *ctx); -// std::any ASTPrinter::visitDecl(SysYParser::DeclContext *ctx); - -std::any ASTPrinter::visitConstDecl(SysYParser::ConstDeclContext *ctx) { - cout << getIndent() << ctx->CONST()->getText() << ' ' << ctx->bType()->getText() << ' '; - auto numConstDefs = ctx->constDef().size(); - ctx->constDef(0)->accept(this); - for (int i = 1; i < numConstDefs; ++i) { - cout << ctx->COMMA(i - 1)->getText() << ' '; - ctx->constDef(i)->accept(this); - } - cout << ctx->SEMICOLON()->getText() << '\n'; - return nullptr; -} - -std::any ASTPrinter::visitConstDef(SysYParser::ConstDefContext *ctx) { - cout << ctx->Ident()->getText(); - auto numConstExps = ctx->constExp().size(); - for (int i = 0; i < numConstExps; ++i) { - cout << ctx->LBRACK(i)->getText(); - ctx->constExp(i)->accept(this); - cout << ctx->RBRACK(i)->getText(); - } - cout << ' ' << ctx->ASSIGN()->getText() << ' '; - ctx->constInitVal()->accept(this); - return nullptr; -} - -// std::any ASTPrinter::visitConstInitVal(SysYParser::ConstInitValContext *ctx); - -std::any ASTPrinter::visitVarDecl(SysYParser::VarDeclContext *ctx){ - cout << getIndent() << ctx->bType()->getText() << ' '; - auto numVarDefs = ctx->varDef().size(); - ctx->varDef(0)->accept(this); - for (int i = 1; i < numVarDefs; ++i) { - cout << ", "; - ctx->varDef(i)->accept(this); - } - cout << ctx->SEMICOLON()->getText() << '\n'; - return nullptr; -} - -std::any ASTPrinter::visitVarDef(SysYParser::VarDefContext *ctx){ - cout << ctx->Ident()->getText(); - auto numConstExps = ctx->constExp().size(); - for (int i = 0; i < numConstExps; ++i) { - cout << ctx->LBRACK(i)->getText(); - ctx->constExp(i)->accept(this); - cout << ctx->RBRACK(i)->getText(); - } - if (ctx->initVal()) { - cout << ' ' << ctx->ASSIGN()->getText() << ' '; - ctx->initVal()->accept(this); - } - return nullptr; -} - -std::any ASTPrinter::visitInitVal(SysYParser::InitValContext *ctx){ - if (ctx->exp()) { - ctx->exp()->accept(this); - } else { - cout << ctx->LBRACE()->getText(); - auto numInitVals = ctx->initVal().size(); - ctx->initVal(0)->accept(this); - for (int i = 1; i < numInitVals; ++i) { - cout << ctx->COMMA(i - 1)->getText() << ' '; - ctx->initVal(i)->accept(this); - } - cout << ctx->RBRACE()->getText(); - } - return nullptr; -} - -std::any ASTPrinter::visitFuncDef(SysYParser::FuncDefContext *ctx){ - cout << getIndent() << ctx->funcType()->getText() << ' ' << ctx->Ident()->getText(); - cout << ctx->LPAREN()->getText(); - if (ctx->funcFParams()) ctx->funcFParams()->accept(this); - if(ctx->RPAREN()) - cout << ctx->RPAREN()->getText(); - else - cout << ""; - ctx->blockStmt()->accept(this); - return nullptr; -} - -// std::any ASTPrinter::visitFuncType(SysYParser::FuncTypeContext *ctx); - -std::any ASTPrinter::visitFuncFParams(SysYParser::FuncFParamsContext *ctx){ - auto numFuncFParams = ctx->funcFParam().size(); - ctx->funcFParam(0)->accept(this); - for (int i = 1; i < numFuncFParams; ++i) { - cout << ctx->COMMA(i - 1)->getText() << ' '; - ctx->funcFParam(i)->accept(this); - } - return nullptr; -} - -std::any ASTPrinter::visitFuncFParam(SysYParser::FuncFParamContext *ctx){ - cout << ctx->bType()->getText() << ' ' << ctx->Ident()->getText(); - if (!ctx->exp().empty()) { - cout << "[]"; - for (auto exp : ctx->exp()) { - cout << '['; - exp->accept(this); - cout << ']'; - } - } - return nullptr; -} - -std::any ASTPrinter::visitBlockStmt(SysYParser::BlockStmtContext *ctx){ - cout << ctx->LBRACE()->getText() << endl; - indentLevel++; - for (auto item : ctx->blockItem()) item->accept(this); - indentLevel--; - cout << getIndent() << ctx->RBRACE()->getText() << endl; - return nullptr; -} -// std::any ASTPrinter::visitBlockItem(SysYParser::BlockItemContext *ctx); - -std::any ASTPrinter::visitAssignStmt(SysYParser::AssignStmtContext *ctx){ - cout << getIndent(); - ctx->lValue()->accept(this); - cout << ' ' << ctx->ASSIGN()->getText() << ' '; - ctx->exp()->accept(this); - cout << ctx->SEMICOLON()->getText() << '\n'; - return nullptr; -} - -std::any ASTPrinter::visitExpStmt(SysYParser::ExpStmtContext *ctx){ - cout << getIndent(); - if (ctx->exp()) { - ctx->exp()->accept(this); - } - cout << ctx->SEMICOLON()->getText() << '\n'; - return nullptr; -} - -std::any ASTPrinter::visitIfStmt(SysYParser::IfStmtContext *ctx){ - cout << getIndent() << ctx->IF()->getText() << ' ' << ctx->LPAREN()->getText(); - ctx->cond()->accept(this); - cout << ctx->RPAREN()->getText() << ' '; - //格式化有问题 - if(ctx->stmt(0)) { - ctx->stmt(0)->accept(this); - } - else { - cout << '{' << endl; - indentLevel++; - ctx->stmt(0)->accept(this); - indentLevel--; - cout << getIndent() << '}' << endl; - } - if (ctx->ELSE()) { - cout << getIndent() << ctx->ELSE()->getText() << ' '; - ctx->stmt(1)->accept(this); - } - return nullptr; -} - -std::any ASTPrinter::visitWhileStmt(SysYParser::WhileStmtContext *ctx){ - cout << getIndent() << ctx->WHILE()->getText() << ' ' << ctx->LPAREN()->getText(); - ctx->cond()->accept(this); - cout << ctx->RPAREN()->getText() << ' '; - ctx->stmt()->accept(this); - return nullptr; -} - -std::any ASTPrinter::visitBreakStmt(SysYParser::BreakStmtContext *ctx){ - cout << getIndent() << ctx->BREAK()->getText() << ctx->SEMICOLON()->getText() << '\n'; - return nullptr; -} - -std::any ASTPrinter::visitContinueStmt(SysYParser::ContinueStmtContext *ctx){ - cout << getIndent() << ctx->CONTINUE()->getText() << ctx->SEMICOLON()->getText() << '\n'; - return nullptr; -} - -std::any ASTPrinter::visitReturnStmt(SysYParser::ReturnStmtContext *ctx){ - cout << getIndent() << ctx->RETURN()->getText() << ' '; - if (ctx->exp()) { - ctx->exp()->accept(this); - } - cout << ctx->SEMICOLON()->getText() << '\n'; - return nullptr; -} - -// std::any ASTPrinter::visitExp(SysYParser::ExpContext *ctx); -// std::any ASTPrinter::visitCond(SysYParser::CondContext *ctx); -std::any ASTPrinter::visitLValue(SysYParser::LValueContext *ctx){ - cout << ctx->Ident()->getText(); - for (auto exp : ctx->exp()) { - cout << "["; - exp->accept(this); - cout << "]"; - } - return nullptr; -} -// std::any ASTPrinter::visitPrimaryExp(SysYParser::PrimaryExpContext *ctx); -std::any ASTPrinter::visitParenExp(SysYParser::ParenExpContext *ctx){ - cout << ctx->LPAREN()->getText(); - ctx->exp()->accept(this); - cout << ctx->RPAREN()->getText(); - return nullptr; -} - -std::any ASTPrinter::visitNumber(SysYParser::NumberContext *ctx) { - if(ctx->ILITERAL())cout << ctx->ILITERAL()->getText(); - if(ctx->FLITERAL())cout << ctx->FLITERAL()->getText(); - return nullptr; -} - -std::any ASTPrinter::visitString(SysYParser::StringContext *ctx) { - cout << ctx->STRING()->getText(); - return nullptr; -} -// std::any visitUnaryExp(SysYParser::UnaryExpContext *ctx); -// std::any ASTPrinter::visitUnaryOp(SysYParser::UnaryOpContext *ctx); -std::any ASTPrinter::visitCall(SysYParser::CallContext *ctx){ - cout << ctx->Ident()->getText() << ctx->LPAREN()->getText(); - if(ctx->funcRParams()) - ctx->funcRParams()->accept(this); - cout << ctx->RPAREN()->getText(); - return nullptr; -} - -any ASTPrinter::visitFuncRParams(SysYParser::FuncRParamsContext *ctx) { - if (ctx->exp().empty()) - return nullptr; - auto numParams = ctx->exp().size(); - ctx->exp(0)->accept(this); - for (int i = 1; i < numParams; ++i) { - cout << ctx->COMMA(i - 1)->getText() << ' '; - ctx->exp(i)->accept(this); - } - return nullptr; -} - -std::any ASTPrinter::visitMulExp(SysYParser::MulExpContext *ctx){ - auto unaryExps = ctx->unaryExp(); - if (unaryExps.size() == 1) { - unaryExps[0]->accept(this); - } else { - for (size_t i = 0; i < unaryExps.size() - 1; ++i) { - auto opNode = dynamic_cast(ctx->children[2 * i + 1]); - if (opNode) { - unaryExps[i]->accept(this); - cout << " " << opNode->getText() << " "; - } - } - unaryExps.back()->accept(this); - } - return nullptr; -} - -std::any ASTPrinter::visitAddExp(SysYParser::AddExpContext *ctx){ - auto mulExps = ctx->mulExp(); - if (mulExps.size() == 1) { - mulExps[0]->accept(this); - } else { - for (size_t i = 0; i < mulExps.size() - 1; ++i) { - auto opNode = dynamic_cast(ctx->children[2 * i + 1]); - if (opNode) { - mulExps[i]->accept(this); - cout << " " << opNode->getText() << " "; - } - } - mulExps.back()->accept(this); - } - return nullptr; -} -// 以下表达式待补全形式同addexp mulexp -std::any ASTPrinter::visitRelExp(SysYParser::RelExpContext *ctx){ - auto relExps = ctx->addExp(); - if (relExps.size() == 1) { - relExps[0]->accept(this); - } else { - for (size_t i = 0; i < relExps.size() - 1; ++i) { - auto opNode = dynamic_cast(ctx->children[2 * i + 1]); - if (opNode) { - relExps[i]->accept(this); - cout << " " << opNode->getText() << " "; - } - } - relExps.back()->accept(this); - } - return nullptr; -} -std::any ASTPrinter::visitEqExp(SysYParser::EqExpContext *ctx){ - auto eqExps = ctx->relExp(); - if (eqExps.size() == 1) { - eqExps[0]->accept(this); - } else { - for (size_t i = 0; i < eqExps.size() - 1; ++i) { - auto opNode = dynamic_cast(ctx->children[2 * i + 1]); - if (opNode) { - eqExps[i]->accept(this); - cout << " " << opNode->getText() << " "; - } - } - eqExps.back()->accept(this); - } - return nullptr; -} -std::any ASTPrinter::visitLAndExp(SysYParser::LAndExpContext *ctx){ - auto lAndExps = ctx->eqExp(); - if (lAndExps.size() == 1) { - lAndExps[0]->accept(this); - } else { - for (size_t i = 0; i < lAndExps.size() - 1; ++i) { - auto opNode = dynamic_cast(ctx->children[2 * i + 1]); - if (opNode) { - lAndExps[i]->accept(this); - cout << " " << opNode->getText() << " "; - } - } - lAndExps.back()->accept(this); - } - return nullptr; -} -std::any ASTPrinter::visitLOrExp(SysYParser::LOrExpContext *ctx){ - auto lOrExps = ctx->lAndExp(); - if (lOrExps.size() == 1) { - lOrExps[0]->accept(this); - } else { - for (size_t i = 0; i < lOrExps.size() - 1; ++i) { - auto opNode = dynamic_cast(ctx->children[2 * i + 1]); - if (opNode) { - lOrExps[i]->accept(this); - cout << " " << opNode->getText() << " "; - } - } - lOrExps.back()->accept(this); - } - return nullptr; -} -std::any ASTPrinter::visitConstExp(SysYParser::ConstExpContext *ctx){ - ctx->addExp()->accept(this); - return nullptr; -} \ No newline at end of file diff --git a/src/include/ASTPrinter.h b/src/include/ASTPrinter.h deleted file mode 100644 index 31b4863..0000000 --- a/src/include/ASTPrinter.h +++ /dev/null @@ -1,59 +0,0 @@ -#pragma once - -#include "SysYBaseVisitor.h" -#include "SysYParser.h" - -class ASTPrinter : public SysYBaseVisitor { -private: - int indentLevel = 0; - - std::string getIndent() { - return std::string(indentLevel * 4, ' '); - } -public: - std::any visitCompUnit(SysYParser::CompUnitContext *ctx) override; - // std::any visitBType(SysYParser::BTypeContext *ctx) override; - // std::any visitDecl(SysYParser::DeclContext *ctx) override; - std::any visitConstDecl(SysYParser::ConstDeclContext *ctx) override; - std::any visitConstDef(SysYParser::ConstDefContext *ctx) override; - // std::any visitConstInitVal(SysYParser::ConstInitValContext *ctx) override; - std::any visitVarDecl(SysYParser::VarDeclContext *ctx) override; - std::any visitVarDef(SysYParser::VarDefContext *ctx) override; - std::any visitInitVal(SysYParser::InitValContext *ctx) override; - std::any visitFuncDef(SysYParser::FuncDefContext *ctx) override; - // std::any visitFuncType(SysYParser::FuncTypeContext *ctx) override; - std::any visitFuncFParams(SysYParser::FuncFParamsContext *ctx) override; - std::any visitFuncFParam(SysYParser::FuncFParamContext *ctx) override; - std::any visitBlockStmt(SysYParser::BlockStmtContext *ctx) override; - - // std::any visitBlockItem(SysYParser::BlockItemContext *ctx) override; - // std::any visitStmt(SysYParser::StmtContext *ctx) override; - - std::any visitAssignStmt(SysYParser::AssignStmtContext *ctx) override; - std::any visitExpStmt(SysYParser::ExpStmtContext *ctx) override; - std::any visitIfStmt(SysYParser::IfStmtContext *ctx) override; - std::any visitWhileStmt(SysYParser::WhileStmtContext *ctx) override; - std::any visitBreakStmt(SysYParser::BreakStmtContext *ctx) override; - std::any visitContinueStmt(SysYParser::ContinueStmtContext *ctx) override; - std::any visitReturnStmt(SysYParser::ReturnStmtContext *ctx) override; - - // std::any visitExp(SysYParser::ExpContext *ctx) override; - // std::any visitCond(SysYParser::CondContext *ctx) override; - std::any visitLValue(SysYParser::LValueContext *ctx) override; - // std::any visitPrimaryExp(SysYParser::PrimaryExpContext *ctx) override; - std::any visitParenExp(SysYParser::ParenExpContext *ctx) override; - std::any visitNumber(SysYParser::NumberContext *ctx) override; - std::any visitString(SysYParser::StringContext *ctx) override; - // std::any visitUnaryExp(SysYParser::UnaryExpContext *ctx) override; - std::any visitCall(SysYParser::CallContext *ctx) override; - // std::any visitUnExpOp(SysYParser::UnExpContext *ctx) override; - // std::any visitUnaryOp(SysYParser::UnaryOpContext *ctx) override; - std::any visitFuncRParams(SysYParser::FuncRParamsContext *ctx) override; - std::any visitMulExp(SysYParser::MulExpContext *ctx) override; - std::any visitAddExp(SysYParser::AddExpContext *ctx) override; - std::any visitRelExp(SysYParser::RelExpContext *ctx) override; - std::any visitEqExp(SysYParser::EqExpContext *ctx) override; - std::any visitLAndExp(SysYParser::LAndExpContext *ctx) override; - std::any visitLOrExp(SysYParser::LOrExpContext *ctx) override; - std::any visitConstExp(SysYParser::ConstExpContext *ctx) override; -}; diff --git a/src/sysyc.cpp b/src/sysyc.cpp index b509fab..1328f55 100644 --- a/src/sysyc.cpp +++ b/src/sysyc.cpp @@ -70,12 +70,6 @@ int main(int argc, char **argv) { return EXIT_SUCCESS; } - // pretty format the input file - if (argFormat) { - ASTPrinter printer; - printer.visitCompUnit(moduleAST); - return EXIT_SUCCESS; - } // visit AST to generate IR From 3ed1c7fecdf8bfad15eda1ecf78e977a85b62428 Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Sat, 21 Jun 2025 16:39:13 +0800 Subject: [PATCH 09/13] =?UTF-8?q?=E6=9B=B4=E6=94=B9=E5=89=8D=E7=BD=AE?= =?UTF-8?q?=E5=A3=B0=E6=98=8E=EF=BC=8CIR=E7=94=9F=E6=88=90=E6=9B=B4?= =?UTF-8?q?=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/SysYIRGenerator.cpp | 348 ++++++++++++++++++++++++++++++++++++ src/include/IR.h | 386 ++++++++++++++++++++-------------------- src/include/IRBuilder.h | 129 +++++++------- 3 files changed, 602 insertions(+), 261 deletions(-) diff --git a/src/SysYIRGenerator.cpp b/src/SysYIRGenerator.cpp index c70869f..d23bf28 100644 --- a/src/SysYIRGenerator.cpp +++ b/src/SysYIRGenerator.cpp @@ -83,6 +83,354 @@ std::any SysYIRGenerator::visitGlobalVarDecl(SysYParser::GlobalVarDeclContext *c return std::any(); } +std::any SysYIRGenerator::visitConstDecl(SysYParser::ConstDeclContext *ctx){ + Type* type = std::any_cast(visitBType(ctx->bType())); + for (const auto constDef : ctx->constDef()) { + std::vector dims = {}; + std::string name = constDef->Ident()->getText(); + auto constExps = constDef->constExp(); + if (!constExps.empty()) { + for (const auto constExp : constExps) { + dims.push_back(std::any_cast(visitConstExp(constExp))); + } + } + + ArrayValueTree* root = std::any_cast(constDef->constInitVal()->accept(this)); + ValueCounter values; + Utils::tree2Array(type, root, dims, dims.size(), values, &builder); + delete root; + + module->createConstVar(name, Type::getPointerType(type), values, dims); + } + return 0; +} + +std::any SysYIRGenerator::visitVarDecl(SysYParser::VarDeclContext *ctx) { + Type* type = std::any_cast(visitBType(ctx->bType())); + for (const auto varDef : ctx->varDef()) { + std::vector dims = {}; + std::string name = varDef->Ident()->getText(); + auto constExps = varDef->constExp(); + if (!constExps.empty()) { + for (const auto &constExp : constExps) { + dims.push_back(std::any_cast(visitConstExp(constExp))); + } + } + + AllocaInst* alloca = + builder.createAllocaInst(Type::getPointerType(type), dims, name); + + if (varDef->initVal() != nullptr) { + ValueCounter values; + // 这里的varDef->initVal()可能是ScalarInitValue或ArrayInitValue + ArrayValueTree* root = std::any_cast(varDef->initVal()->accept(this)); + Utils::tree2Array(type, root, dims, dims.size(), values, &builder); + delete root; + if (dims.empty()) { + builder.createStoreInst(values.getValue(0), alloca); + } else { + // 对于多维数组,使用memset初始化 + // 计算每个维度的大小 + // 这里的values.getNumbers()返回的是每个维度的大小 + // 这里的values.getValues()返回的是每个维度对应的值 + // 例如:对于一个二维数组,values.getNumbers()可能是[3, 4],表示3行4列 + // values.getValues()可能是[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] + // 对于每个维度,使用memset将对应的值填充到数组中 + // 这里的alloca是一个指向数组的指针 + std::vector & counterNumbers = values.getNumbers(); + std::vector & counterValues = values.getValues(); + unsigned begin = 0; + for (size_t i = 0; i < counterNumbers.size(); i++) { + + builder.createMemsetInst( + alloca, ConstantValue::get(static_cast(begin)), + ConstantValue::get(static_cast(counterNumbers[i])), + counterValues[i]); + begin += counterNumbers[i]; + } + } + } + module->addVariable(name, alloca); + } + return std::any(); +} + +std::any SysYIRGenerator::visitBType(SysYParser::BTypeContext *ctx) { + return ctx->INT() != nullptr ? Type::getIntType() : Type::getFloatType(); +} + +std::any SysYIRGenerator::visitScalarInitValue(SysYParser::ScalarInitValueContext *ctx) { + AllocaInst* alloca = std::any_cast(visitExp(ctx->exp())); + ArrayValueTree* result = new ArrayValueTree(); + result->setValue(alloca); + return result; +} + +std::any SysYIRGenerator::visitArrayInitValue(SysYParser::ArrayInitValueContext *ctx) { + std::vector children; + for (const auto &initVal : ctx->initVal()) + children.push_back(std::any_cast(initVal->accept(this))); + ArrayValueTree* result = new ArrayValueTree(); + result->addChildren(children); + return result; +} + +std::any SysYIRGenerator::visitConstScalarInitValue(SysYParser::ConstScalarInitValueContext *ctx) { + AllocaInst* alloca = std::any_cast(visitConstExp(ctx->constExp())); + ArrayValueTree* result = new ArrayValueTree(); + result->setValue(alloca); + return result; +} + +std::any SysYIRGenerator::visitConstArrayInitValue(SysYParser::ConstArrayInitValueContext *ctx) { + std::vector children; + for (const auto &constInitVal : ctx->constInitVal()) + children.push_back(std::any_cast(constInitVal->accept(this))); + ArrayValueTree* result = new ArrayValueTree(); + result->addChildren(children); + return result; +} + +std::any SysYIRGenerator::visitFuncType(SysYParser::FuncTypeContext *ctx) { + if (ctx->INT() != nullptr) + return Type::getIntType(); + if (ctx->FLOAT() != nullptr) + return Type::getFloatType(); + return Type::getVoidType(); +} + +std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext *ctx){ + // 更新作用域 + module->enterNewScope(); + + auto name = ctx->Ident()->getText(); + std::vector paramTypes; + std::vector paramNames; + std::vector> paramDims; + + if (ctx->funcFParams() != nullptr) { + auto params = ctx->funcFParams()->funcFParam(); + for (const auto ¶m : params) { + paramTypes.push_back(std::any_cast(visitBType(param->bType()))); + paramNames.push_back(param->Ident()->getText()); + std::vector dims = {}; + if (param->exp() != nullptr) { + dims.push_back(ConstantValue::get(-1)); // 第一个维度不确定 + for (const auto &exp : param->exp()) { + dims.push_back(std::any_cast(visitExp(exp))); + } + } + paramDims.emplace_back(dims); + } + } + + Type *returnType = std::any_cast(visitFuncType(ctx->funcType())); + FunctionType* funcType = Type::getFunctionType(returnType, paramTypes); + Function* function = module->createFunction(name, funcType); + BasicBlock* entry = function->getEntryBlock(); + builder.setPosition(entry, entry->end()); + + for (size_t i = 0; i < paramTypes.size(); ++i) { + AllocaInst* alloca = builder.createAllocaInst(Type::getPointerType(paramTypes[i]), + paramDims[i], paramNames[i]); + entry->insertArgument(alloca); + module->addVariable(paramNames[i], alloca); + } + + for (auto item : ctx->blockStmt()->blockItem()) { + visitBlockItem(item); + } + + module->leaveScope(); + + return std::any; +} + +std::any SysYIRGenerator::visitBlockStmt(SysYParser::BlockStmtContext *ctx) { + module->enterNewScope(); + for (auto item : ctx->blockItem()) + visitBlockItem(item); + module->leaveScope(); + return 0; +} + +std::any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) { + auto lVal = ctx->lValue(); + std::string name = lVal->Ident()->getText(); + std::vector dims; + for (const auto &exp : lVal->exp()) { + dims.push_back(std::any_cast(visitExp(exp))); + } + + User* variable = module->getVariable(name); + Value* value = std::any_cast(visitExp(ctx->exp())); + PointerType* variableType =dynamic_cast(variable->getType())->getBaseType(); + + // 左值右值类型不同处理 + if (variableType != value->getType()) { + ConstantValue * constValue = dynamic_cast(value); + if (constValue != nullptr) { + if (variableType == Type::getFloatType()) { + value = ConstantValue::get(static_cast(constValue->getInt())); + } else { + value = ConstantValue::get(static_cast(constValue->getFloat())); + } + } else { + if (variableType == Type::getFloatType()) { + value = builder.createIToFInst(value); + } else { + value = builder.createFtoIInst(value); + } + } + } + builder.createStoreInst(value, variable, dims, variable->getName()); + + return std::any(); +} + + +std::any SysYIRGenerator::visitIfStmt(SysYParser::IfStmtContext *ctx) { + // labels string stream + + std::stringstream labelstring; + Function * function = builder.getBasicBlock()->getParent(); + + BasicBlock* thenBlock = new BasicBlock(function); + BasicBlock* exitBlock = new BasicBlock(function); + + if (ctx->stmt().size() > 1) { + BasicBlock* elseBlock = new BasicBlock(function); + + builder.pushTrueBlock(thenBlock); + builder.pushFalseBlock(elseBlock); + // 访问条件表达式 + visitCond(ctx->cond()); + builder.popTrueBlock(); + builder.popFalseBlock(); + + labelstring << "then.L" << builder.getLabelIndex(); + thenBlock->setName(labelstring.str()); + labelstring.str(""); + function->addBasicBlock(thenBlock); + builder.setPosition(thenBlock, thenBlock->end()); + + auto block = dynamic_cast(ctx->stmt(0)); + // 如果是块语句,直接访问 + // 否则访问语句 + if (block != nullptr) { + visitBlockStmt(block); + } else { + module->enterNewScope(); + ctx->stmt(0)->accept(this); + module->leaveScope(); + } + builder.createUncondBrInst(exitBlock, {}); + BasicBlock::conectBlocks(builder.getBasicBlock(), exitBlock); + + labelstring << "else.L" << builder.getLabelIndex(); + elseBlock->setName(labelstring.str()); + labelstring.str(""); + function->addBasicBlock(elseBlock); + builder.setPosition(elseBlock, elseBlock->end()); + + block = dynamic_cast(ctx->stmt(1)); + if (block != nullptr) { + visitBlockStmt(block); + } else { + module->enterNewScope(); + ctx->stmt(1)->accept(this); + module->leaveScope(); + } + BasicBlock::conectBlocks(builder.getBasicBlock(), exitBlock); + + labelstring << "exit.L" << builder.getLabelIndex(); + exitBlock->setName(labelstring.str()); + labelstring.str(""); + function->addBasicBlock(exitBlock); + builder.setPosition(exitBlock, exitBlock->end()); + + } else { + builder.pushTrueBlock(thenBlock); + builder.pushFalseBlock(exitBlock); + visitCond(ctx->cond()); + builder.popTrueBlock(); + builder.popFalseBlock(); + + labelstring << "then.L" << builder.getLabelIndex(); + thenBlock->setName(labelstring.str()); + labelstring.str(""); + function->addBasicBlock(thenBlock); + builder.setPosition(thenBlock, thenBlock->end()); + + auto block = dynamic_cast(ctx->stmt(0)); + if (block != nullptr) { + visitBlockStmt(block); + } else { + module->enterNewScope(); + ctx->stmt(0)->accept(this); + module->leaveScope(); + } + BasicBlock::conectBlocks(builder.getBasicBlock(), exitBlock); + + labelstring << "exit.L" << builder.getLabelIndex(); + exitBlock->setName(labelstring.str()); + labelstring.str(""); + function->addBasicBlock(exitBlock); + builder.setPosition(exitBlock, exitBlock->end()); + } + return std::any(); +} + +std::any SysYIRGenerator::visitWhileStmt(SysYParser::WhileStmtContext *ctx) { + + BasicBlock* curBlock = builder.getBasicBlock(); + Function* function = builder.getBasicBlock()->getParent(); + + std::stringstream labelstring; + labelstring << "head.L" << builder.getLabelIndex(); + BasicBlock *headBlock = function->addBasicBlock(labelstring.str()); + labelstring.str(""); + + BasicBlock::conectBlocks(curBlock, headBlock); + builder.setPosition(headBlock, headBlock->end()) + + function->addBasicBlock(condBlock); + builder.setPosition(condBlock, condBlock->end()); + + builder.pushTrueBlock(bodyBlock); + builder.pushFalseBlock(exitBlock); + + visitCond(ctx->cond()); + + builder.popTrueBlock(); + builder.popFalseBlock(); + + labelstring << "body.L" << builder.getLabelIndex(); + bodyBlock->setName(labelstring.str()); + labelstring.str(""); + function->addBasicBlock(bodyBlock); + builder.setPosition(bodyBlock, bodyBlock->end()); + + module->enterNewScope(); + for (auto item : ctx->blockStmt()->blockItem()) { + visitBlockItem(item); + } + module->leaveScope(); + + builder.createUncondBrInst(condBlock, {}); + + BasicBlock::conectBlocks(builder.getBasicBlock(), condBlock); + + labelstring << "exit.L" << builder.getLabelIndex(); + exitBlock->setName(labelstring.str()); + labelstring.str(""); + + function->addBasicBlock(exitBlock); + + builder.setPosition(exitBlock, exitBlock->end()); + + return std::any(); +} void Utils::tree2Array(Type *type, ArrayValueTree *root, const std::vector &dims, unsigned numDims, diff --git a/src/include/IR.h b/src/include/IR.h index 60db35a..1de2e23 100644 --- a/src/include/IR.h +++ b/src/include/IR.h @@ -58,23 +58,23 @@ class Type { virtual ~Type() = default; public: - static auto getIntType() -> Type *; ///< 返回表示Int类型的Type指针 - static auto getFloatType() -> Type *; ///< 返回表示Float类型的Type指针 - static auto getVoidType() -> Type *; ///< 返回表示Void类型的Type指针 - static auto getLabelType() -> Type *; ///< 返回表示Label类型的Type指针 - static auto getPointerType(Type *baseType) -> Type *; ///< 返回表示指向baseType类型的Pointer类型的Type指针 - static auto getFunctionType(Type *returnType, const std::vector ¶mTypes = {}) -> Type *; + static Type* getIntType(); ///< 返回表示Int类型的Type指针 + static Type* getFloatType(); ///< 返回表示Float类型的Type指针 + static Type* getVoidType(); ///< 返回表示Void类型的Type指针 + static Type* getLabelType(); ///< 返回表示Label类型的Type指针 + static Type* getPointerType(Type *baseType); ///< 返回表示指向baseType类型的Pointer类型的Type指针 + static Type* getFunctionType(Type *returnType, const std::vector ¶mTypes = {}); ///< 返回表示返回类型为returnType,形参类型列表为paramTypes的函数类型的Type指针 public: - auto getKind() const -> Kind { return kind; } ///< 返回Type对象代表原始标量类型 - auto isInt() const -> bool { return kind == kInt; } ///< 判定是否为Int类型 - auto isFloat() const -> bool { return kind == kFloat; } ///< 判定是否为Float类型 - auto isVoid() const -> bool { return kind == kVoid; } ///< 判定是否为Void类型 - auto isLabel() const -> bool { return kind == kLabel; } ///< 判定是否为Label类型 - auto isPointer() const -> bool { return kind == kPointer; } ///< 判定是否为Pointer类型 - auto isFunction() const -> bool { return kind == kFunction; } ///< 判定是否为Function类型 - auto getSize() const -> unsigned; ///< 返回类型所占的空间大小(字节) + Kind getKind() const { return kind; } ///< 返回Type对象代表原始标量类型 + bool isInt() const { return kind == kInt; } ///< 判定是否为Int类型 + bool isFloat() const { return kind == kFloat; } ///< 判定是否为Float类型 + bool isVoid() const { return kind == kVoid; } ///< 判定是否为Void类型 + bool isLabel() const { return kind == kLabel; } ///< 判定是否为Label类型 + bool isPointer() const { return kind == kPointer; } ///< 判定是否为Pointer类型 + bool isFunction() const { return kind == kFunction; } ///< 判定是否为Function类型 + unsigned getSize() const; ///< 返回类型所占的空间大小(字节) /// 尝试将一个变量转换为给定的Type及其派生类类型的变量 template auto as() const -> std::enable_if_t, T *> { @@ -90,10 +90,10 @@ class PointerType : public Type { explicit PointerType(Type *baseType) : Type(kPointer), baseType(baseType) {} public: - static auto get(Type *baseType) -> PointerType *; ///< 获取指向baseType的Pointer类型 + static PointerType* get(Type *baseType); ///< 获取指向baseType的Pointer类型 public: - auto getBaseType() const -> Type * { return baseType; } ///< 获取指向的类型 + Type* getBaseType() const { return baseType; } ///< 获取指向的类型 }; class FunctionType : public Type { @@ -107,12 +107,12 @@ class FunctionType : public Type { public: /// 获取返回值类型为returnType, 形参类型列表为paramTypes的Function类型 - static auto get(Type *returnType, const std::vector ¶mTypes = {}) -> FunctionType *; + static FunctionType* get(Type *returnType, const std::vector ¶mTypes = {}); public: - auto getReturnType() const -> Type * { return returnType; } ///< 获取返回值类信息 + Type* getReturnType() const { return returnType; } ///< 获取返回值类信息 auto getParamTypes() const { return make_range(paramTypes); } ///< 获取形参类型列表 - auto getNumParams() const -> unsigned { return paramTypes.size(); } ///< 获取形参数量 + unsigned getNumParams() const { return paramTypes.size(); } ///< 获取形参数量 }; /*! @@ -182,9 +182,9 @@ class Use { Use(unsigned index, User *user, Value *value) : index(index), user(user), value(value) {} public: - auto getIndex() const -> unsigned { return index; } ///< 返回value在User操作数中的位置 - auto getUser() const -> User * { return user; } ///< 返回使用者 - auto getValue() const -> Value * { return value; } ///< 返回被使用的值 + unsigned getIndex() const { return index; } ///< 返回value在User操作数中的位置 + User* getUser() const { return user; } ///< 返回使用者 + Value* getValue() const { return value; } ///< 返回被使用的值 void setValue(Value *newValue) { value = newValue; } ///< 将被使用的值设置为newValue }; @@ -202,19 +202,17 @@ class Value { public: void setName(const std::string &newName) { name = newName; } ///< 设置名字 - auto getName() const -> const std::string & { return name; } ///< 获取名字 - auto getType() const -> Type * { return type; } ///< 返回值的类型 - auto isInt() const -> bool { return type->isInt(); } ///< 判定是否为Int类型 - auto isFloat() const -> bool { return type->isFloat(); } ///< 判定是否为Float类型 - auto isPointer() const -> bool { return type->isPointer(); } ///< 判定是否为Pointer类型 - auto getUses() -> std::list> & { return uses; } ///< 获取使用关系列表 + const std::string& getName() const { return name; } ///< 获取名字 + Type* getType() const { return type; } ///< 返回值的类型 + bool isInt() const { return type->isInt(); } ///< 判定是否为Int类型 + bool isFloat() const { return type->isFloat(); } ///< 判定是否为Float类型 + bool isPointer() const { return type->isPointer(); } ///< 判定是否为Pointer类型 + std::list>& getUses() { return uses; } ///< 获取使用关系列表 void addUse(const std::shared_ptr &use) { uses.push_back(use); } ///< 添加使用关系 void replaceAllUsesWith(Value *value); ///< 将原来使用该value的使用者全变为使用给定参数value并修改相应use关系 void removeUse(const std::shared_ptr &use) { uses.remove(use); } ///< 删除使用关系use }; - - /** * ValueCounter 需要理解为一个Value *的计数器。 * 它的主要目的是为了节省存储空间和方便Memset指令的创建。 @@ -236,8 +234,8 @@ class ValueCounter { ValueCounter() = default; public: - auto size() const -> unsigned { return __size; } ///< 返回总的Value数量 - auto getValue(unsigned index) const -> Value * { + unsigned size() const { return __size; } ///< 返回总的Value数量 + Value* getValue(unsigned index) const { if (index >= __size) { return nullptr; } @@ -252,8 +250,8 @@ class ValueCounter { return nullptr; } ///< 根据位置index获取Value * - auto getValues() const -> const std::vector & { return __counterValues; } ///< 获取互异Value *列表 - auto getNumbers() const -> const std::vector & { return __counterNumbers; } ///< 获取Value *重复数量列表 + const std::vector& getValues() const { return __counterValues; } ///< 获取互异Value *列表 + const std::vector& getNumbers() const { return __counterNumbers; } ///< 获取Value *重复数量列表 void push_back(Value *value, unsigned num = 1) { if (__size != 0 && __counterValues.back() == value) { *(__counterNumbers.end() - 1) += num; @@ -293,20 +291,20 @@ class ConstantValue : public Value { : Value(Type::getFloatType(), name), fScalar(value) {} public: - static auto get(int value) -> ConstantValue *; ///< 获取一个int类型的ConstValue *,其值为value - static auto get(float value) -> ConstantValue *; ///< 获取一个float类型的ConstValue *,其值为value + static ConstantValue* get(int value); ///< 获取一个int类型的ConstValue *,其值为value + static ConstantValue* get(float value); ///< 获取一个float类型的ConstValue *,其值为value public: - auto getInt() const -> int { + int getInt() const { assert(isInt()); return iScalar; } ///< 返回int类型的值 - auto getFloat() const -> float { + float getFloat() const { assert(isFloat()); return fScalar; } ///< 返回float类型的值 template - auto getValue() const -> T { + T getValue() const { if (std::is_same::value && isInt()) { return getInt(); } @@ -369,23 +367,23 @@ class BasicBlock; } ///< 基本块的析构函数,同时删除其前驱后继关系 public: - auto getNumInstructions() const -> unsigned { return instructions.size(); } ///< 获取指令数量 - auto getNumArguments() const -> unsigned { return arguments.size(); } ///< 获取形式参数数量 - auto getNumPredecessors() const -> unsigned { return predecessors.size(); } ///< 获取前驱数量 - auto getNumSuccessors() const -> unsigned { return successors.size(); } ///< 获取后继数量 - auto getParent() const -> Function * { return parent; } ///< 获取父函数 + unsigned getNumInstructions() const { return instructions.size(); } ///< 获取指令数量 + unsigned getNumArguments() const { return arguments.size(); } ///< 获取形式参数数量 + unsigned getNumPredecessors() const { return predecessors.size(); } ///< 获取前驱数量 + unsigned getNumSuccessors() const { return successors.size(); } ///< 获取后继数量 + Function* getParent() const { return parent; } ///< 获取父函数 void setParent(Function *func) { parent = func; } ///< 设置父函数 - auto getInstructions() -> inst_list & { return instructions; } ///< 获取指令列表 - auto getArguments() -> arg_list & { return arguments; } ///< 获取分配空间后的形式参数列表 - auto getPredecessors() const -> const block_list & { return predecessors; } ///< 获取前驱列表 - auto getSuccessors() -> block_list & { return successors; } ///< 获取后继列表 - auto getDominants() -> block_set & { return dominants; } - auto getIdom() -> BasicBlock * { return idom; } - auto getSdoms() -> block_list & { return sdoms; } - auto getDFs() -> block_set & { return dominant_frontiers; } - auto begin() -> iterator { return instructions.begin(); } ///< 返回指向指令列表开头的迭代器 - auto end() -> iterator { return instructions.end(); } ///< 返回指向指令列表末尾的迭代器 - auto terminator() -> iterator { return std::prev(end()); } ///< 基本块最后的IR + inst_list& getInstructions() { return instructions; } ///< 获取指令列表 + arg_list& getArguments() { return arguments; } ///< 获取分配空间后的形式参数列表 + const block_list& getPredecessors() const { return predecessors; } ///< 获取前驱列表 + block_list& getSuccessors() { return successors; } ///< 获取后继列表 + block_set& getDominants() { return dominants; } + BasicBlock* getIdom() { return idom; } + block_list& getSdoms() { return sdoms; } + block_set& getDFs() { return dominant_frontiers; } + iterator begin() { return instructions.begin(); } ///< 返回指向指令列表开头的迭代器 + iterator end() { return instructions.end(); } ///< 返回指向指令列表末尾的迭代器 + iterator terminator() { return std::prev(end()); } ///< 基本块最后的IR void insertArgument(AllocaInst *inst) { arguments.push_back(inst); } ///< 插入分配空间后的形式参数 void addPredecessor(BasicBlock *block) { if (std::find(predecessors.begin(), predecessors.end(), block) == predecessors.end()) { @@ -411,18 +409,18 @@ class BasicBlock; void addSdoms(BasicBlock *block) { sdoms.push_back(block); } void clearSdoms() { sdoms.clear(); } // 重载1,参数为 BasicBlock* - auto addDominants(BasicBlock *block) -> void { dominants.emplace(block); } + void addDominants(BasicBlock *block) { dominants.emplace(block); } // 重载2,参数为 block_set - auto addDominants(const block_set &blocks) -> void { dominants.insert(blocks.begin(), blocks.end()); } - auto setDominants(BasicBlock *block) -> void { + void addDominants(const block_set &blocks) { dominants.insert(blocks.begin(), blocks.end()); } + void setDominants(BasicBlock *block) { dominants.clear(); addDominants(block); } - auto setDominants(const block_set &doms) -> void { + void setDominants(const block_set &doms) { dominants.clear(); addDominants(doms); } - auto setDFs(const block_set &df) -> void { + void setDFs(const block_set &df) { dominant_frontiers.clear(); for (auto elem : df) { dominant_frontiers.emplace(elem); @@ -453,7 +451,7 @@ class BasicBlock; } } ///< 替换前驱 // 获取支配树中该块的所有子节点,包括子节点的子节点等,迭代实现 - auto getChildren() -> block_list { + block_list getChildren() { std::queue q; block_list children; for (auto sdom : sdoms) { @@ -472,22 +470,20 @@ class BasicBlock; return children; } - auto setreachableTrue() -> void { reachable = true; } ///< 设置可达 - - auto setreachableFalse() -> void { reachable = false; } ///< 设置不可达 - - auto getreachable() -> bool { return reachable; } ///< 返回可达状态 + void setreachableTrue() { reachable = true; } ///< 设置可达 + void setreachableFalse() { reachable = false; } ///< 设置不可达 + bool getreachable() { return reachable; } ///< 返回可达状态 static void conectBlocks(BasicBlock *prev, BasicBlock *next) { prev->addSuccessor(next); next->addPredecessor(prev); } ///< 连接两个块,即设置两个基本块的前驱后继关系 void setLoop(Loop *loop2set) { loopbelong = loop2set; } ///< 设置所属循环 - auto getLoop() { return loopbelong; } ///< 获得所属循环 + Loop* getLoop() { return loopbelong; } ///< 获得所属循环 void setLoopDepth(int loopdepth2set) { loopdepth = loopdepth2set; } ///< 设置循环深度 - auto getLoopDepth() { return loopdepth; } ///< 获得其在循环的深度 + int getLoopDepth() { return loopdepth; } ///< 获得其在循环的深度 void removeInst(iterator pos) { instructions.erase(pos); } ///< 删除指令 - auto moveInst(iterator sourcePos, iterator targetPos, BasicBlock *block) -> iterator; ///< 移动指令 + iterator moveInst(iterator sourcePos, iterator targetPos, BasicBlock *block); ///< 移动指令 }; //! User is the abstract base type of `Value` types which use other `Value` as @@ -504,11 +500,11 @@ class User : public Value { explicit User(Type *type, const std::string &name = "") : Value(type, name) {} public: - auto getNumOperands() const -> unsigned { return operands.size(); } ///< 获取操作数数量 + unsigned getNumOperands() const { return operands.size(); } ///< 获取操作数数量 auto operand_begin() const { return operands.begin(); } ///< 返回操作数列表的开头迭代器 auto operand_end() const { return operands.end(); } ///< 返回操作数列表的结尾迭代器 auto getOperands() const { return make_range(operand_begin(), operand_end()); } ///< 获取操作数列表 - auto getOperand(unsigned index) const -> Value * { return operands[index]->getValue(); } ///< 获取位置为index的操作数 + Value* getOperand(unsigned index) const { return operands[index]->getValue(); } ///< 获取位置为index的操作数 void addOperand(Value *value) { operands.emplace_back(std::make_shared(operands.size(), this, value)); value->addUse(operands.back()); @@ -528,8 +524,6 @@ class User : public Value { void setOperand(unsigned index, Value *value); ///< 设置操作数 }; - - class GetSubArrayInst; /** * 左值 具有地址的对象 @@ -551,11 +545,11 @@ class LVal { virtual unsigned getLValNumDims() const = 0; ///< 获取左值的维度数量 public: - auto getFatherLVal() const -> LVal * { return fatherLVal; } ///< 获取父左值 - auto getChildrenLVals() const -> const std::list> & { + LVal* getFatherLVal() const { return fatherLVal; } ///< 获取父左值 + const std::list>& getChildrenLVals() const { return childrenLVals; } ///< 获取子左值列表 - auto getAncestorLVal() const -> LVal * { + LVal* getAncestorLVal() const { auto curLVal = const_cast(this); while (curLVal->getFatherLVal() != nullptr) { curLVal = curLVal->getFatherLVal(); @@ -570,7 +564,7 @@ class LVal { [child](const std::unique_ptr &ptr) { return ptr.get() == child; }); childrenLVals.erase(iter); } ///< 移除子左值 - auto getDefineInst() const -> GetSubArrayInst * { return defineInst; } ///< 获取定义指令 + GetSubArrayInst* getDefineInst() const { return defineInst; } ///< 获取定义指令 }; /*! @@ -734,8 +728,8 @@ public: } } ///< 根据指令标识码获取字符串 - BasicBlock *getParent() const { return parent; } - Function *getFunction() const { return parent->getParent(); } + BasicBlock* getParent() const { return parent; } + Function* getFunction() const { return parent->getParent(); } void setParent(BasicBlock *bb) { parent = bb; } bool isBinary() const { @@ -809,10 +803,10 @@ class LaInst : public Instruction { } public: - auto getNumIndices() const -> unsigned { return getNumOperands() - 1; } ///< 获取索引长度 - auto getPointer() const -> Value * { return getOperand(0); } ///< 获取目标变量的Value指针 + unsigned getNumIndices() const { return getNumOperands() - 1; } ///< 获取索引长度 + Value* getPointer() const { return getOperand(0); } ///< 获取目标变量的Value指针 auto getIndices() const { return make_range(std::next(operand_begin()), operand_end()); } ///< 获取索引列表 - auto getIndex(unsigned index) const -> Value * { return getOperand(index + 1); } ///< 获取位置为index的索引分量 + Value* getIndex(unsigned index) const { return getOperand(index + 1); } ///< 获取位置为index的索引分量 }; class PhiInst : public Instruction { @@ -832,10 +826,10 @@ class PhiInst : public Instruction { } public: - Value * getMapVal() { return map_val; } - Value * getPointer() const { return getOperand(0); } + Value* getMapVal() { return map_val; } + Value* getPointer() const { return getOperand(0); } auto getValues() { return make_range(std::next(operand_begin()), operand_end()); } - Value * getValue(unsigned index) const { return getOperand(index + 1); } + Value* getValue(unsigned index) const { return getOperand(index + 1); } }; @@ -849,7 +843,7 @@ protected: public: - Function *getCallee() const; + Function* getCallee() const; auto getArguments() const { return make_range(std::next(operand_begin()), operand_end()); } @@ -870,7 +864,7 @@ protected: public: - Value *getOperand() const { return User::getOperand(0); } + Value* getOperand() const { return User::getOperand(0); } }; // class UnaryInst @@ -887,10 +881,10 @@ class BinaryInst : public Instruction { } public: - Value *getLhs() const { return getOperand(0); } - Value *getRhs() const { return getOperand(1); } + Value* getLhs() const { return getOperand(0); } + Value* getRhs() const { return getOperand(1); } template - auto eval(T lhs, T rhs) -> T { + T eval(T lhs, T rhs) { switch (getKind()) { case kAdd: return lhs + rhs; @@ -963,7 +957,7 @@ class ReturnInst : public Instruction { public: bool hasReturnValue() const { return not operands.empty(); } - Value *getReturnValue() const { + Value* getReturnValue() const { return hasReturnValue() ? getOperand(0) : nullptr; } }; @@ -983,7 +977,7 @@ protected: } public: - BasicBlock *getBlock() const { return dynamic_cast(getOperand(0)); } + BasicBlock* getBlock() const { return dynamic_cast(getOperand(0)); } auto getArguments() const { return make_range(std::next(operand_begin()), operand_end()); } @@ -1009,11 +1003,11 @@ protected: addOperands(elseArgs); } public: - Value *getCondition() const { return getOperand(0); } - BasicBlock *getThenBlock() const { + Value* getCondition() const { return getOperand(0); } + BasicBlock* getThenBlock() const { return dynamic_cast(getOperand(1)); } - BasicBlock *getElseBlock() const { + BasicBlock* getElseBlock() const { return dynamic_cast(getOperand(2)); } auto getThenArguments() const { @@ -1053,7 +1047,7 @@ public: int getNumDims() const { return getNumOperands(); } auto getDims() const { return getOperands(); } - Value *getDim(int index) { return getOperand(index); } + Value* getDim(int index) { return getOperand(index); } }; // class AllocaInst @@ -1083,12 +1077,12 @@ class GetSubArrayInst : public Instruction { } public: - auto getFatherArray() const -> Value * { return getOperand(0); } ///< 获取父数组 - auto getChildArray() const -> Value * { return getOperand(1); } ///< 获取子数组 - auto getFatherLVal() const -> LVal * { return dynamic_cast(getOperand(0)); } ///< 获取父左值 - auto getChildLVal() const -> LVal * { return dynamic_cast(getOperand(1)); } ///< 获取子左值 + Value* getFatherArray() const { return getOperand(0); } ///< 获取父数组 + Value* getChildArray() const { return getOperand(1); } ///< 获取子数组 + LVal* getFatherLVal() const { return dynamic_cast(getOperand(0)); } ///< 获取父左值 + LVal* getChildLVal() const { return dynamic_cast(getOperand(1)); } ///< 获取子左值 auto getIndices() const { return make_range(std::next(operand_begin(), 2), operand_end()); } ///< 获取索引 - auto getNumIndices() const -> unsigned { return getNumOperands() - 2; } ///< 获取索引数量 + unsigned getNumIndices() const { return getNumOperands() - 2; } ///< 获取索引数量 }; //! Load a value from memory address specified by a pointer value @@ -1107,11 +1101,11 @@ protected: public: int getNumIndices() const { return getNumOperands() - 1; } - Value *getPointer() const { return getOperand(0); } + Value* getPointer() const { return getOperand(0); } auto getIndices() const { return make_range(std::next(operand_begin()), operand_end()); } - Value *getIndex(int index) const { return getOperand(index + 1); } + Value* getIndex(int index) const { return getOperand(index + 1); } std::list getAncestorIndices() const { std::list indices; for (const auto &index : getIndices()) { @@ -1147,12 +1141,12 @@ protected: public: int getNumIndices() const { return getNumOperands() - 2; } - Value *getValue() const { return getOperand(0); } - Value *getPointer() const { return getOperand(1); } + Value* getValue() const { return getOperand(0); } + Value* getPointer() const { return getOperand(1); } auto getIndices() const { return make_range(std::next(operand_begin(), 2), operand_end()); } - Value *getIndex(int index) const { return getOperand(index + 2); } + Value* getIndex(int index) const { return getOperand(index + 2); } std::list getAncestorIndices() const { std::list indices; for (const auto &index : getIndices()) { @@ -1178,6 +1172,13 @@ class MemsetInst : public Instruction { friend class Function; protected: + //! Create a memset instruction. + //! \param pointer The pointer to the memory location to be set. + //! \param begin The starting address of the memory region to be set. + //! \param size The size of the memory region to be set. + //! \param value The value to set the memory region to. + //! \param parent The parent basic block of this instruction. + //! \param name The name of this instruction. MemsetInst(Value *pointer, Value *begin, Value *size, Value *value, BasicBlock *parent = nullptr, const std::string &name = "") : Instruction(kMemset, Type::getVoidType(), parent, name) { @@ -1188,10 +1189,10 @@ protected: } public: - Value *getPointer() const { return getOperand(0); } - Value *getBegin() const { return getOperand(1); } - Value *getSize() const { return getOperand(2); } - Value *getValue() const { return getOperand(3); } + Value* getPointer() const { return getOperand(0); } + Value* getBegin() const { return getOperand(1); } + Value* getSize() const { return getOperand(2); } + Value* getValue() const { return getOperand(3); } }; @@ -1241,26 +1242,26 @@ public: loopCount = loopCount + 1; loopID = loopCount; } - auto getindBegin() { return indBegin; } ///< 获得循环开始值 - auto getindStep() { return indStep; } ///< 获得循环步长 - auto setindBegin(ConstantValue *indBegin2set) { indBegin = indBegin2set; } ///< 设置循环开始值 - auto setindStep(ConstantValue *indStep2set) { indStep = indStep2set; } ///< 设置循环步长 - auto setStepType(int StepType2Set) { StepType = StepType2Set; } ///< 设置循环变量规则 - auto getStepType() { return StepType; } ///< 获得循环变量规则 - auto getLoopID() -> size_t { return loopID; } + ConstantValue* getindBegin() { return indBegin; } ///< 获得循环开始值 + ConstantValue* getindStep() { return indStep; } ///< 获得循环步长 + void setindBegin(ConstantValue *indBegin2set) { indBegin = indBegin2set; } ///< 设置循环开始值 + void setindStep(ConstantValue *indStep2set) { indStep = indStep2set; } ///< 设置循环步长 + void setStepType(int StepType2Set) { StepType = StepType2Set; } ///< 设置循环变量规则 + int getStepType() { return StepType; } ///< 获得循环变量规则 + size_t getLoopID() { return loopID; } - BasicBlock *getHeader() const { return headerBlock; } - BasicBlock *getPreheaderBlock() const { return preheaderBlock; } - block_list &getLatchBlocks() { return latchBlock; } - block_set &getExitingBlocks() { return exitingBlocks; } - block_set &getExitBlocks() { return exitBlocks; } - Loop *getParentLoop() const { return parentloop; } + BasicBlock* getHeader() const { return headerBlock; } + BasicBlock* getPreheaderBlock() const { return preheaderBlock; } + block_list& getLatchBlocks() { return latchBlock; } + block_set& getExitingBlocks() { return exitingBlocks; } + block_set& getExitBlocks() { return exitBlocks; } + Loop* getParentLoop() const { return parentloop; } void setParentLoop(Loop *parent) { parentloop = parent; } void addBasicBlock(BasicBlock *bb) { blocksInLoop.push_back(bb); } void addSubLoop(Loop *loop) { subLoops.push_back(loop); } void setLoopDepth(unsigned depth) { loopDepth = depth; } - block_list &getBasicBlocks() { return blocksInLoop; } - Loop_list &getSubLoops() { return subLoops; } + block_list& getBasicBlocks() { return blocksInLoop; } + Loop_list& getSubLoops() { return subLoops; } unsigned getLoopDepth() const { return loopDepth; } bool isLoopContainsBasicBlock(BasicBlock *bb) const { @@ -1280,14 +1281,14 @@ public: void setIndEnd(Value *value) { indEnd = value; } void setIndPhi(AllocaInst *phi) { IndPhi = phi; } - Value *getIndEnd() const { return indEnd; } - AllocaInst *getIndPhi() const { return IndPhi; } - Instruction *getIndCondVar() const { return indCondVar; } + Value* getIndEnd() const { return indEnd; } + AllocaInst* getIndPhi() const { return IndPhi; } + Instruction* getIndCondVar() const { return indCondVar; } - auto addGlobalValuechange(GlobalValue *globalvaluechange2add) { + void addGlobalValuechange(GlobalValue *globalvaluechange2add) { GlobalValuechange.insert(globalvaluechange2add); } ///<添加在循环中改变的全局变量 - auto getGlobalValuechange() -> std::set & { + std::set& getGlobalValuechange() { return GlobalValuechange; } ///<获得在循环中改变的所有全局变量 @@ -1335,83 +1336,82 @@ protected: std::unordered_map> value2UseBlocks; //< value -- use blocks mapping public: - static auto getcloneIndex() -> unsigned { + static unsigned getcloneIndex() { static unsigned cloneIndex = 0; cloneIndex += 1; return cloneIndex - 1; } - auto clone(const std::string &suffix = "_" + std::to_string(getcloneIndex()) + "@") const - -> Function *; ///< 复制函数 - auto getCallees() -> const std::set & { return callees; } - auto addCallee(Function *callee) -> void { callees.insert(callee); } - auto removeCallee(Function *callee) -> void { callees.erase(callee); } - auto clearCallees() -> void { callees.clear(); } - auto getCalleesWithNoExternalAndSelf() -> std::set; - auto getAttribute() const -> FunctionAttribute { return attribute; } ///< 获取函数属性 - auto setAttribute(FunctionAttribute attr) -> void { + Function* clone(const std::string &suffix = "_" + std::to_string(getcloneIndex()) + "@") const; ///< 复制函数 + const std::set& getCallees() { return callees; } + void addCallee(Function *callee) { callees.insert(callee); } + void removeCallee(Function *callee) { callees.erase(callee); } + void clearCallees() { callees.clear(); } + std::set getCalleesWithNoExternalAndSelf(); + FunctionAttribute getAttribute() const { return attribute; } ///< 获取函数属性 + void setAttribute(FunctionAttribute attr) { attribute = static_cast(attribute | attr); } ///< 设置函数属性 - auto clearAttribute() -> void { attribute = PlaceHolder; } ///< 清楚所有函数属性,只保留PlaceHolder - auto getLoopOfBasicBlock(BasicBlock *bb) -> Loop * { + void clearAttribute() { attribute = PlaceHolder; } ///< 清楚所有函数属性,只保留PlaceHolder + Loop* getLoopOfBasicBlock(BasicBlock *bb) { return basicblock2Loop.count(bb) != 0 ? basicblock2Loop[bb] : nullptr; } ///< 获得块所在循环 - auto getLoopDepthByBlock(BasicBlock *basicblock2Check) { + unsigned getLoopDepthByBlock(BasicBlock *basicblock2Check) { if (getLoopOfBasicBlock(basicblock2Check) != nullptr) { auto loop = getLoopOfBasicBlock(basicblock2Check); return loop->getLoopDepth(); } return static_cast(0); } ///< 通过块,获得其所在循环深度 - auto addBBToLoop(BasicBlock *bb, Loop *LoopToadd) { basicblock2Loop[bb] = LoopToadd; } ///< 添加块与循环的映射 - auto getBBToLoopRef() -> std::unordered_map & { + void addBBToLoop(BasicBlock *bb, Loop *LoopToadd) { basicblock2Loop[bb] = LoopToadd; } ///< 添加块与循环的映射 + std::unordered_map& getBBToLoopRef() { return basicblock2Loop; } ///< 获得块-循环映射表 // auto getNewLoopPtr(BasicBlock *header) -> Loop * { return new Loop(header); } - auto getReturnType() const -> Type * { return getType()->as()->getReturnType(); } ///< 获取返回值类型 + Type* getReturnType() const { return getType()->as()->getReturnType(); } ///< 获取返回值类型 auto getParamTypes() const { return getType()->as()->getParamTypes(); } ///< 获取形式参数类型列表 auto getBasicBlocks() { return make_range(blocks); } ///< 获取基本块列表 - auto getBasicBlocks_NoRange() -> block_list & { return blocks; } - auto getEntryBlock() -> BasicBlock * { return blocks.front().get(); } ///< 获取入口块 - auto removeBasicBlock(BasicBlock *blockToRemove) { + block_list& getBasicBlocks_NoRange() { return blocks; } + BasicBlock* getEntryBlock() { return blocks.front().get(); } ///< 获取入口块 + void removeBasicBlock(BasicBlock *blockToRemove) { auto is_same_ptr = [blockToRemove](const std::unique_ptr &ptr) { return ptr.get() == blockToRemove; }; blocks.remove_if(is_same_ptr); // blocks.erase(std::remove_if(blocks.begin(), blocks.end(), is_same_ptr), blocks.end()); } ///< 将该块从function的blocks中删除 // auto getBasicBlocks_NoRange() -> block_list & { return blocks; } - auto addBasicBlock(const std::string &name = "") -> BasicBlock * { + BasicBlock* addBasicBlock(const std::string &name = "") { blocks.emplace_back(new BasicBlock(this, name)); return blocks.back().get(); } ///< 添加新的基本块 - auto addBasicBlock(BasicBlock *block) -> BasicBlock * { + BasicBlock* addBasicBlock(BasicBlock *block) { blocks.emplace_back(block); return block; } ///< 添加基本块到blocks中 - auto addBasicBlockFront(BasicBlock *block) -> BasicBlock * { + BasicBlock* addBasicBlockFront(BasicBlock *block) { blocks.emplace_front(block); return block; } // 从前端插入新的基本块 /** value -- alloc blocks mapping */ - auto addValue2AllocBlocks(Value *value, BasicBlock *block) -> void { + void addValue2AllocBlocks(Value *value, BasicBlock *block) { value2AllocBlocks[value] = block; } ///< 添加value -- alloc block mapping - auto getAllocBlockByValue(Value *value) -> BasicBlock * { + BasicBlock* getAllocBlockByValue(Value *value) { if (value2AllocBlocks.count(value) > 0) { return value2AllocBlocks[value]; } return nullptr; } ///< 通过value获取alloc block - auto getValue2AllocBlocks() -> std::unordered_map & { + std::unordered_map& getValue2AllocBlocks() { return value2AllocBlocks; } ///< 获取所有value -- alloc block mappings - auto removeValue2AllocBlock(Value *value) -> void { + void removeValue2AllocBlock(Value *value) { value2AllocBlocks.erase(value); } ///< 删除value -- alloc block mapping /** value -- define blocks mapping */ - auto addValue2DefBlocks(Value *value, BasicBlock *block) -> void { + void addValue2DefBlocks(Value *value, BasicBlock *block) { ++value2DefBlocks[value][block]; } ///< 添加value -- define block mapping // keep in mind that the return is not a reference. - auto getDefBlocksByValue(Value *value) -> std::unordered_set { + std::unordered_set getDefBlocksByValue(Value *value) { std::unordered_set blocks; if (value2DefBlocks.count(value) > 0) { for (const auto &pair : value2DefBlocks[value]) { @@ -1420,10 +1420,10 @@ protected: } return blocks; } ///< 通过value获取define blocks - auto getValue2DefBlocks() -> std::unordered_map> & { + std::unordered_map>& getValue2DefBlocks() { return value2DefBlocks; } ///< 获取所有value -- define blocks mappings - auto removeValue2DefBlock(Value *value, BasicBlock *block) -> bool { + bool removeValue2DefBlock(Value *value, BasicBlock *block) { bool changed = false; if (--value2DefBlocks[value][block] == 0) { value2DefBlocks[value].erase(block); @@ -1434,7 +1434,7 @@ protected: } return changed; } ///< 删除value -- define block mapping - auto getValuesOfDefBlock() -> std::unordered_set { + std::unordered_set getValuesOfDefBlock() { std::unordered_set values; for (const auto &pair : value2DefBlocks) { values.insert(pair.first); @@ -1442,11 +1442,11 @@ protected: return values; } ///< 获取所有定义过的value /** value -- use blocks mapping */ - auto addValue2UseBlocks(Value *value, BasicBlock *block) -> void { + void addValue2UseBlocks(Value *value, BasicBlock *block) { ++value2UseBlocks[value][block]; } ///< 添加value -- use block mapping // keep in mind that the return is not a reference. - auto getUseBlocksByValue(Value *value) -> std::unordered_set { + std::unordered_set getUseBlocksByValue(Value *value) { std::unordered_set blocks; if (value2UseBlocks.count(value) > 0) { for (const auto &pair : value2UseBlocks[value]) { @@ -1455,10 +1455,10 @@ protected: } return blocks; } ///< 通过value获取use blocks - auto getValue2UseBlocks() -> std::unordered_map> & { + std::unordered_map>& getValue2UseBlocks() { return value2UseBlocks; } ///< 获取所有value -- use blocks mappings - auto removeValue2UseBlock(Value *value, BasicBlock *block) -> bool { + bool removeValue2UseBlock(Value *value, BasicBlock *block) { bool changed = false; if (--value2UseBlocks[value][block] == 0) { value2UseBlocks[value].erase(block); @@ -1469,8 +1469,8 @@ protected: } return changed; } ///< 删除value -- use block mapping - auto addIndirectAlloca(AllocaInst *alloca) { indirectAllocas.emplace_back(alloca); } ///< 添加间接分配 - auto getIndirectAllocas() -> std::list> & { + void addIndirectAlloca(AllocaInst *alloca) { indirectAllocas.emplace_back(alloca); } ///< 添加间接分配 + std::list>& getIndirectAllocas() { return indirectAllocas; } ///< 获取间接分配列表 @@ -1478,8 +1478,8 @@ protected: void addLoop(Loop *loop) { loops.emplace_back(loop); } ///< 添加循环(非顶层) void addTopLoop(Loop *loop) { topLoops.emplace_back(loop); } ///< 添加顶层循环 - auto getLoops() -> Loop_list & { return loops; } ///< 获得循环(非顶层) - auto getTopLoops() -> Loop_list & { return topLoops; } ///< 获得顶层循环 + Loop_list& getLoops() { return loops; } ///< 获得循环(非顶层) + Loop_list& getTopLoops() { return topLoops; } ///< 获得顶层循环 /** loop -- end */ }; // class Function @@ -1532,7 +1532,7 @@ public: Value* getByIndex(unsigned index) const { return initValues.getValue(index); } ///< 通过一维偏移量index获取初始值 - Value * getByIndices(const std::vector &indices) const { + Value* getByIndices(const std::vector &indices) const { int index = 0; for (size_t i = 0; i < indices.size(); i++) { index = dynamic_cast(getDim(i))->getInt() * index + @@ -1540,7 +1540,7 @@ public: } return getByIndex(index); } ///< 通过多维索引indices获取初始值 - const ValueCounter &getInitValues() const { return initValues; } + const ValueCounter& getInitValues() const { return initValues; } }; // class GlobalValue @@ -1583,13 +1583,13 @@ class ConstantVariable : public User, public LVal { return getByIndex(index); } ///< 通过多维索引indices获取初始值 unsigned getNumDims() const { return numDims; } ///< 获取维度数量 - Value* getDim(unsigned index) const { return getOperand(index); } ///< 获取位置为index的维度 + Value* getDim(unsigned index) const { return getOperand(index); } ///< 获取位置为index的维度 auto getDims() const { return getOperands(); } ///< 获取维度列表 const ValueCounter& getInitValues() const { return initValues; } ///< 获取初始值 }; using SymbolTableNode = struct SymbolTableNode { - struct SymbolTableNode *pNode; ///< 父节点 + SymbolTableNode *pNode; ///< 父节点 std::vector children; ///< 子节点列表 std::map varList; ///< 变量列表 }; @@ -1606,15 +1606,15 @@ class SymbolTable { public: SymbolTable() = default; - auto getVariable(const std::string &name) const -> User *; ///< 根据名字name以及当前作用域获取变量 - auto addVariable(const std::string &name, User *variable) -> User *; ///< 添加变量 - auto getGlobals() -> std::vector> &; ///< 获取全局变量列表 - auto getConsts() const -> const std::vector> &; ///< 获取常量列表 + User* getVariable(const std::string &name) const; ///< 根据名字name以及当前作用域获取变量 + User* addVariable(const std::string &name, User *variable); ///< 添加变量 + std::vector>& getGlobals(); ///< 获取全局变量列表 + const std::vector>& getConsts() const; ///< 获取常量列表 void enterNewScope(); ///< 进入新的作用域 void leaveScope(); ///< 离开作用域 - auto isInGlobalScope() const -> bool; ///< 是否位于全局作用域 + bool isInGlobalScope() const; ///< 是否位于全局作用域 void enterGlobalScope(); ///< 进入全局作用域 - auto isCurNodeNull() -> bool { return curNode == nullptr; } + bool isCurNodeNull() { return curNode == nullptr; } }; //! IR unit for representing a SysY compile unit @@ -1628,14 +1628,14 @@ class Module { Module() = default; public: - auto createFunction(const std::string &name, Type *type) -> Function * { + Function* createFunction(const std::string &name, Type *type) { auto result = functions.try_emplace(name, new Function(this, type, name)); if (!result.second) { return nullptr; } return result.first->second.get(); } ///< 创建函数 - auto createExternalFunction(const std::string &name, Type *type) -> Function * { + Function* createExternalFunction(const std::string &name, Type *type) { auto result = externalFunctions.try_emplace(name, new Function(this, type, name)); if (!result.second) { return nullptr; @@ -1643,8 +1643,8 @@ class Module { return result.first->second.get(); } ///< 创建外部函数 ///< 变量创建伴随着符号表的更新 - auto createGlobalValue(const std::string &name, Type *type, const std::vector &dims = {}, - const ValueCounter &init = {}) -> GlobalValue * { + GlobalValue* createGlobalValue(const std::string &name, Type *type, const std::vector &dims = {}, + const ValueCounter &init = {}) { bool isFinished = variableTable.isCurNodeNull(); if (isFinished) { variableTable.enterGlobalScope(); @@ -1658,8 +1658,8 @@ class Module { } return dynamic_cast(result); } ///< 创建全局变量 - auto createConstVar(const std::string &name, Type *type, const ValueCounter &init, - const std::vector &dims = {}) -> ConstantVariable * { + ConstantVariable* createConstVar(const std::string &name, Type *type, const ValueCounter &init, + const std::vector &dims = {}) { auto result = variableTable.addVariable(name, new ConstantVariable(this, type, name, init, dims)); if (result == nullptr) { return nullptr; @@ -1669,38 +1669,38 @@ class Module { void addVariable(const std::string &name, AllocaInst *variable) { variableTable.addVariable(name, variable); } ///< 添加变量 - auto getVariable(const std::string &name) -> User * { + User* getVariable(const std::string &name) { return variableTable.getVariable(name); } ///< 根据名字name和当前作用域获取变量 - auto getFunction(const std::string &name) const -> Function * { + Function* getFunction(const std::string &name) const { auto result = functions.find(name); if (result == functions.end()) { return nullptr; } return result->second.get(); } ///< 获取函数 - auto getExternalFunction(const std::string &name) const -> Function * { + Function* getExternalFunction(const std::string &name) const { auto result = externalFunctions.find(name); if (result == functions.end()) { return nullptr; } return result->second.get(); } ///< 获取外部函数 - auto getFunctions() -> std::map> & { return functions; } ///< 获取函数列表 - auto getExternalFunctions() const -> const std::map> & { + std::map>& getFunctions() { return functions; } ///< 获取函数列表 + const std::map>& getExternalFunctions() const { return externalFunctions; } ///< 获取外部函数列表 - auto getGlobals() -> std::vector> & { + std::vector>& getGlobals() { return variableTable.getGlobals(); } ///< 获取全局变量列表 - auto getConsts() const -> const std::vector> & { + const std::vector>& getConsts() const { return variableTable.getConsts(); } ///< 获取常量列表 void enterNewScope() { variableTable.enterNewScope(); } ///< 进入新的作用域 void leaveScope() { variableTable.leaveScope(); } ///< 离开作用域 - auto isInGlobalArea() const -> bool { return variableTable.isInGlobalScope(); } ///< 是否位于全局作用域 + bool isInGlobalArea() const { return variableTable.isInGlobalScope(); } ///< 是否位于全局作用域 }; /*! diff --git a/src/include/IRBuilder.h b/src/include/IRBuilder.h index 7335ea8..9189cba 100644 --- a/src/include/IRBuilder.h +++ b/src/include/IRBuilder.h @@ -39,41 +39,41 @@ class IRBuilder { : labelIndex(0), tmpIndex(0), block(block), position(position) {} public: - auto getLabelIndex() -> unsigned { + unsigned getLabelIndex() { labelIndex += 1; return labelIndex - 1; } ///< 获取基本块标签编号 - auto getTmpIndex() -> unsigned { + unsigned getTmpIndex() { tmpIndex += 1; return tmpIndex - 1; } ///< 获取临时变量编号 - auto getBasicBlock() const -> BasicBlock * { return block; } ///< 获取当前基本块 - auto getBreakBlock() const -> BasicBlock * { return breakBlocks.back(); } ///< 获取break目标块 - auto popBreakBlock() -> BasicBlock * { + BasicBlock * getBasicBlock() const { return block; } ///< 获取当前基本块 + BasicBlock * getBreakBlock() const { return breakBlocks.back(); } ///< 获取break目标块 + BasicBlock * popBreakBlock() { auto result = breakBlocks.back(); breakBlocks.pop_back(); return result; } ///< 弹出break目标块 - auto getContinueBlock() const -> BasicBlock * { return continueBlocks.back(); } ///< 获取continue目标块 - auto popContinueBlock() -> BasicBlock * { + BasicBlock * getContinueBlock() const { return continueBlocks.back(); } ///< 获取continue目标块 + BasicBlock * popContinueBlock() { auto result = continueBlocks.back(); continueBlocks.pop_back(); return result; } ///< 弹出continue目标块 - auto getTrueBlock() const -> BasicBlock * { return trueBlocks.back(); } ///< 获取true分支基本块 - auto getFalseBlock() const -> BasicBlock * { return falseBlocks.back(); } ///< 获取false分支基本块 - auto popTrueBlock() -> BasicBlock * { + BasicBlock * getTrueBlock() const { return trueBlocks.back(); } ///< 获取true分支基本块 + BasicBlock * getFalseBlock() const { return falseBlocks.back(); } ///< 获取false分支基本块 + BasicBlock * popTrueBlock() { auto result = trueBlocks.back(); trueBlocks.pop_back(); return result; } ///< 弹出true分支基本块 - auto popFalseBlock() -> BasicBlock * { + BasicBlock * popFalseBlock() { auto result = falseBlocks.back(); falseBlocks.pop_back(); return result; } ///< 弹出false分支基本块 - auto getPosition() const -> BasicBlock::iterator { return position; } ///< 获取当前基本块指令列表位置的迭代器 + BasicBlock::iterator getPosition() const { return position; } ///< 获取当前基本块指令列表位置的迭代器 void setPosition(BasicBlock *block, BasicBlock::iterator position) { this->block = block; this->position = position; @@ -87,13 +87,12 @@ class IRBuilder { void pushFalseBlock(BasicBlock *block) { falseBlocks.push_back(block); } ///< 压入false分支基本块 public: - auto insertInst(Instruction *inst) -> Instruction * { + Instruction * insertInst(Instruction *inst) { assert(inst); block->getInstructions().emplace(position, inst); return inst; } ///< 插入指令 - auto createUnaryInst(Instruction::Kind kind, Type *type, Value *operand, const std::string &name = "") - -> UnaryInst * { + UnaryInst * createUnaryInst(Instruction::Kind kind, Type *type, Value *operand, const std::string &name = "") { std::string newName; if (name.empty()) { std::stringstream ss; @@ -109,32 +108,31 @@ class IRBuilder { block->getInstructions().emplace(position, inst); return inst; } ///< 创建一元指令 - auto createNegInst(Value *operand, const std::string &name = "") -> UnaryInst * { + UnaryInst * createNegInst(Value *operand, const std::string &name = "") { return createUnaryInst(Instruction::kNeg, Type::getIntType(), operand, name); } ///< 创建取反指令 - auto createNotInst(Value *operand, const std::string &name = "") -> UnaryInst * { + UnaryInst * createNotInst(Value *operand, const std::string &name = "") { return createUnaryInst(Instruction::kNot, Type::getIntType(), operand, name); } ///< 创建取非指令 - auto createFtoIInst(Value *operand, const std::string &name = "") -> UnaryInst * { + UnaryInst * createFtoIInst(Value *operand, const std::string &name = "") { return createUnaryInst(Instruction::kFtoI, Type::getIntType(), operand, name); } ///< 创建浮点转整型指令 - auto createBitFtoIInst(Value *operand, const std::string &name = "") -> UnaryInst * { + UnaryInst * createBitFtoIInst(Value *operand, const std::string &name = "") { return createUnaryInst(Instruction::kBitFtoI, Type::getIntType(), operand, name); } ///< 创建按位浮点转整型指令 - auto createFNegInst(Value *operand, const std::string &name = "") -> UnaryInst * { + UnaryInst * createFNegInst(Value *operand, const std::string &name = "") { return createUnaryInst(Instruction::kFNeg, Type::getFloatType(), operand, name); } ///< 创建浮点取反指令 - auto createFNotInst(Value *operand, const std::string &name = "") -> UnaryInst * { + UnaryInst * createFNotInst(Value *operand, const std::string &name = "") { return createUnaryInst(Instruction::kFNot, Type::getIntType(), operand, name); } ///< 创建浮点取非指令 - auto createIToFInst(Value *operand, const std::string &name = "") -> UnaryInst * { + UnaryInst * createIToFInst(Value *operand, const std::string &name = "") { return createUnaryInst(Instruction::kItoF, Type::getFloatType(), operand, name); } ///< 创建整型转浮点指令 - auto createBitItoFInst(Value *operand, const std::string &name = "") -> UnaryInst * { + UnaryInst * createBitItoFInst(Value *operand, const std::string &name = "") { return createUnaryInst(Instruction::kBitItoF, Type::getFloatType(), operand, name); } ///< 创建按位整型转浮点指令 - auto createBinaryInst(Instruction::Kind kind, Type *type, Value *lhs, Value *rhs, const std::string &name = "") - -> BinaryInst * { + BinaryInst * createBinaryInst(Instruction::Kind kind, Type *type, Value *lhs, Value *rhs, const std::string &name = "") { std::string newName; if (name.empty()) { std::stringstream ss; @@ -150,76 +148,76 @@ class IRBuilder { block->getInstructions().emplace(position, inst); return inst; } ///< 创建二元指令 - auto createAddInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + BinaryInst * createAddInst(Value *lhs, Value *rhs, const std::string &name = "") { return createBinaryInst(Instruction::kAdd, Type::getIntType(), lhs, rhs, name); } ///< 创建加法指令 - auto createSubInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + BinaryInst * createSubInst(Value *lhs, Value *rhs, const std::string &name = "") { return createBinaryInst(Instruction::kSub, Type::getIntType(), lhs, rhs, name); } ///< 创建减法指令 - auto createMulInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + BinaryInst * createMulInst(Value *lhs, Value *rhs, const std::string &name = "") { return createBinaryInst(Instruction::kMul, Type::getIntType(), lhs, rhs, name); } ///< 创建乘法指令 - auto createDivInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + BinaryInst * createDivInst(Value *lhs, Value *rhs, const std::string &name = "") { return createBinaryInst(Instruction::kDiv, Type::getIntType(), lhs, rhs, name); } ///< 创建除法指令 - auto createRemInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + BinaryInst * createRemInst(Value *lhs, Value *rhs, const std::string &name = "") { return createBinaryInst(Instruction::kRem, Type::getIntType(), lhs, rhs, name); } ///< 创建取余指令 - auto createICmpEQInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + BinaryInst * createICmpEQInst(Value *lhs, Value *rhs, const std::string &name = "") { return createBinaryInst(Instruction::kICmpEQ, Type::getIntType(), lhs, rhs, name); } ///< 创建相等设置指令 - auto createICmpNEInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + BinaryInst * createICmpNEInst(Value *lhs, Value *rhs, const std::string &name = "") { return createBinaryInst(Instruction::kICmpNE, Type::getIntType(), lhs, rhs, name); } ///< 创建不相等设置指令 - auto createICmpLTInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + BinaryInst * createICmpLTInst(Value *lhs, Value *rhs, const std::string &name = "") { return createBinaryInst(Instruction::kICmpLT, Type::getIntType(), lhs, rhs, name); } ///< 创建小于设置指令 - auto createICmpLEInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + BinaryInst * createICmpLEInst(Value *lhs, Value *rhs, const std::string &name = "") { return createBinaryInst(Instruction::kICmpLE, Type::getIntType(), lhs, rhs, name); } ///< 创建小于等于设置指令 - auto createICmpGTInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + BinaryInst * createICmpGTInst(Value *lhs, Value *rhs, const std::string &name = "") { return createBinaryInst(Instruction::kICmpGT, Type::getIntType(), lhs, rhs, name); } ///< 创建大于设置指令 - auto createICmpGEInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + BinaryInst * createICmpGEInst(Value *lhs, Value *rhs, const std::string &name = "") { return createBinaryInst(Instruction::kICmpGE, Type::getIntType(), lhs, rhs, name); } ///< 创建大于等于设置指令 - auto createFAddInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + BinaryInst * createFAddInst(Value *lhs, Value *rhs, const std::string &name = "") { return createBinaryInst(Instruction::kFAdd, Type::getFloatType(), lhs, rhs, name); } ///< 创建浮点加法指令 - auto createFSubInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + BinaryInst * createFSubInst(Value *lhs, Value *rhs, const std::string &name = "") { return createBinaryInst(Instruction::kFSub, Type::getFloatType(), lhs, rhs, name); } ///< 创建浮点减法指令 - auto createFMulInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + BinaryInst * createFMulInst(Value *lhs, Value *rhs, const std::string &name = "") { return createBinaryInst(Instruction::kFMul, Type::getFloatType(), lhs, rhs, name); } ///< 创建浮点乘法指令 - auto createFDivInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + BinaryInst * createFDivInst(Value *lhs, Value *rhs, const std::string &name = "") { return createBinaryInst(Instruction::kFDiv, Type::getFloatType(), lhs, rhs, name); } ///< 创建浮点除法指令 - auto createFCmpEQInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + BinaryInst * createFCmpEQInst(Value *lhs, Value *rhs, const std::string &name = "") { return createBinaryInst(Instruction::kFCmpEQ, Type::getIntType(), lhs, rhs, name); } ///< 创建浮点相等设置指令 - auto createFCmpNEInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + BinaryInst * createFCmpNEInst(Value *lhs, Value *rhs, const std::string &name = "") { return createBinaryInst(Instruction::kFCmpNE, Type::getIntType(), lhs, rhs, name); } ///< 创建浮点不相等设置指令 - auto createFCmpLTInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + BinaryInst * createFCmpLTInst(Value *lhs, Value *rhs, const std::string &name = "") { return createBinaryInst(Instruction::kFCmpLT, Type::getIntType(), lhs, rhs, name); } ///< 创建浮点小于设置指令 - auto createFCmpLEInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + BinaryInst * createFCmpLEInst(Value *lhs, Value *rhs, const std::string &name = "") { return createBinaryInst(Instruction::kFCmpLE, Type::getIntType(), lhs, rhs, name); } ///< 创建浮点小于等于设置指令 - auto createFCmpGTInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + BinaryInst * createFCmpGTInst(Value *lhs, Value *rhs, const std::string &name = "") { return createBinaryInst(Instruction::kFCmpGT, Type::getIntType(), lhs, rhs, name); } ///< 创建浮点大于设置指令 - auto createFCmpGEInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + BinaryInst * createFCmpGEInst(Value *lhs, Value *rhs, const std::string &name = "") { return createBinaryInst(Instruction::kFCmpGE, Type::getIntType(), lhs, rhs, name); } ///< 创建浮点相大于等于设置指令 - auto createAndInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + BinaryInst * createAndInst(Value *lhs, Value *rhs, const std::string &name = "") { return createBinaryInst(Instruction::kAnd, Type::getIntType(), lhs, rhs, name); } ///< 创建按位且指令 - auto createOrInst(Value *lhs, Value *rhs, const std::string &name = "") -> BinaryInst * { + BinaryInst * createOrInst(Value *lhs, Value *rhs, const std::string &name = "") { return createBinaryInst(Instruction::kOr, Type::getIntType(), lhs, rhs, name); } ///< 创建按位或指令 - auto createCallInst(Function *callee, const std::vector &args, const std::string &name = "") -> CallInst * { + CallInst * createCallInst(Function *callee, const std::vector &args, const std::string &name = "") { std::string newName; if (name.empty() && callee->getReturnType() != Type::getVoidType()) { std::stringstream ss; @@ -235,40 +233,38 @@ class IRBuilder { block->getInstructions().emplace(position, inst); return inst; } ///< 创建Call指令 - auto createReturnInst(Value *value = nullptr, const std::string &name = "") -> ReturnInst * { + ReturnInst * createReturnInst(Value *value = nullptr, const std::string &name = "") { auto inst = new ReturnInst(value, block, name); assert(inst); block->getInstructions().emplace(position, inst); return inst; } ///< 创建return指令 - auto createUncondBrInst(BasicBlock *thenBlock, const std::vector &args) -> UncondBrInst * { + UncondBrInst * createUncondBrInst(BasicBlock *thenBlock, const std::vector &args) { auto inst = new UncondBrInst(thenBlock, args, block); assert(inst); block->getInstructions().emplace(position, inst); return inst; } ///< 创建无条件指令 - auto createCondBrInst(Value *condition, BasicBlock *thenBlock, BasicBlock *elseBlock, - const std::vector &thenArgs, const std::vector &elseArgs) -> CondBrInst * { + CondBrInst * createCondBrInst(Value *condition, BasicBlock *thenBlock, BasicBlock *elseBlock, + const std::vector &thenArgs, const std::vector &elseArgs) { auto inst = new CondBrInst(condition, thenBlock, elseBlock, thenArgs, elseArgs, block); assert(inst); block->getInstructions().emplace(position, inst); return inst; } ///< 创建条件跳转指令 - auto createAllocaInst(Type *type, const std::vector &dims = {}, const std::string &name = "") - -> AllocaInst * { + AllocaInst * createAllocaInst(Type *type, const std::vector &dims = {}, const std::string &name = "") { auto inst = new AllocaInst(type, dims, block, name); assert(inst); block->getInstructions().emplace(position, inst); return inst; } ///< 创建分配指令 - auto createAllocaInstWithoutInsert(Type *type, const std::vector &dims = {}, BasicBlock *parent = nullptr, - const std::string &name = "") -> AllocaInst * { + AllocaInst * createAllocaInstWithoutInsert(Type *type, const std::vector &dims = {}, BasicBlock *parent = nullptr, + const std::string &name = "") { auto inst = new AllocaInst(type, dims, parent, name); assert(inst); return inst; } ///< 创建不插入指令列表的分配指令 - auto createLoadInst(Value *pointer, const std::vector &indices = {}, const std::string &name = "") - -> LoadInst * { + LoadInst * createLoadInst(Value *pointer, const std::vector &indices = {}, const std::string &name = "") { std::string newName; if (name.empty()) { std::stringstream ss; @@ -284,8 +280,7 @@ class IRBuilder { block->getInstructions().emplace(position, inst); return inst; } ///< 创建load指令 - auto createLaInst(Value *pointer, const std::vector &indices = {}, const std::string &name = "") - -> LaInst * { + LaInst * createLaInst(Value *pointer, const std::vector &indices = {}, const std::string &name = "") { std::string newName; if (name.empty()) { std::stringstream ss; @@ -301,8 +296,7 @@ class IRBuilder { block->getInstructions().emplace(position, inst); return inst; } ///< 创建la指令 - auto createGetSubArray(LVal *fatherArray, const std::vector &indices, const std::string &name = "") - -> GetSubArrayInst * { + GetSubArrayInst * createGetSubArray(LVal *fatherArray, const std::vector &indices, const std::string &name = "") { assert(fatherArray->getLValNumDims() > indices.size()); std::vector subDims; auto dims = fatherArray->getLValDims(); @@ -326,21 +320,20 @@ class IRBuilder { block->getInstructions().emplace(position, inst); return inst; } ///< 创建获取部分数组指令 - auto createMemsetInst(Value *pointer, Value *begin, Value *size, Value *value, const std::string &name = "") - -> MemsetInst * { + MemsetInst * createMemsetInst(Value *pointer, Value *begin, Value *size, Value *value, const std::string &name = "") { auto inst = new MemsetInst(pointer, begin, size, value, block, name); assert(inst); block->getInstructions().emplace(position, inst); return inst; } ///< 创建memset指令 - auto createStoreInst(Value *value, Value *pointer, const std::vector &indices = {}, - const std::string &name = "") -> StoreInst * { + StoreInst * createStoreInst(Value *value, Value *pointer, const std::vector &indices = {}, + const std::string &name = "") { auto inst = new StoreInst(value, pointer, indices, block, name); assert(inst); block->getInstructions().emplace(position, inst); return inst; } ///< 创建store指令 - auto createPhiInst(Type *type, Value *lhs, BasicBlock *parent, const std::string &name = "") -> PhiInst * { + PhiInst * createPhiInst(Type *type, Value *lhs, BasicBlock *parent, const std::string &name = "") { auto predNum = parent->getNumPredecessors(); std::vector rhs; for (size_t i = 0; i < predNum; i++) { From 0a04c816cf409db2abb154b9a7b950760387f788 Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Sat, 21 Jun 2025 18:06:29 +0800 Subject: [PATCH 10/13] =?UTF-8?q?=E6=9B=B4=E6=96=B0IR=EF=BC=8C.g4=E4=BF=AE?= =?UTF-8?q?=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/SysY.g4 | 8 +- src/SysYIRGenerator.cpp | 457 ++++++++++++++++++++++++++++++++-- src/include/IR.h | 1 + src/include/SysYIRGenerator.h | 1 + 4 files changed, 446 insertions(+), 21 deletions(-) diff --git a/src/SysY.g4 b/src/SysY.g4 index a9e4208..d614ec4 100644 --- a/src/SysY.g4 +++ b/src/SysY.g4 @@ -153,10 +153,10 @@ cond: lOrExp; lValue: Ident (LBRACK exp RBRACK)*; // 为了方便测试 primaryExp 可以是一个string -primaryExp: LPAREN exp RPAREN #parenExp - | lValue #lVal - | number #num - | string #str; +primaryExp: LPAREN exp RPAREN + | lValue + | number + | string; number: ILITERAL | FLITERAL; unaryExp: primaryExp #primExp diff --git a/src/SysYIRGenerator.cpp b/src/SysYIRGenerator.cpp index d23bf28..6abdeff 100644 --- a/src/SysYIRGenerator.cpp +++ b/src/SysYIRGenerator.cpp @@ -382,7 +382,8 @@ std::any SysYIRGenerator::visitIfStmt(SysYParser::IfStmtContext *ctx) { } std::any SysYIRGenerator::visitWhileStmt(SysYParser::WhileStmtContext *ctx) { - + // while structure: + // curblock -> headBlock -> bodyBlock -> exitBlock BasicBlock* curBlock = builder.getBasicBlock(); Function* function = builder.getBasicBlock()->getParent(); @@ -390,18 +391,16 @@ std::any SysYIRGenerator::visitWhileStmt(SysYParser::WhileStmtContext *ctx) { labelstring << "head.L" << builder.getLabelIndex(); BasicBlock *headBlock = function->addBasicBlock(labelstring.str()); labelstring.str(""); - BasicBlock::conectBlocks(curBlock, headBlock); - builder.setPosition(headBlock, headBlock->end()) + builder.setPosition(headBlock, headBlock->end()); - function->addBasicBlock(condBlock); - builder.setPosition(condBlock, condBlock->end()); + BasicBlock* bodyBlock = new BasicBlock(function); + BasicBlock* exitBlock = new BasicBlock(function); builder.pushTrueBlock(bodyBlock); builder.pushFalseBlock(exitBlock); - + // 访问条件表达式 visitCond(ctx->cond()); - builder.popTrueBlock(); builder.popFalseBlock(); @@ -411,27 +410,451 @@ std::any SysYIRGenerator::visitWhileStmt(SysYParser::WhileStmtContext *ctx) { function->addBasicBlock(bodyBlock); builder.setPosition(bodyBlock, bodyBlock->end()); - module->enterNewScope(); - for (auto item : ctx->blockStmt()->blockItem()) { - visitBlockItem(item); - } - module->leaveScope(); - - builder.createUncondBrInst(condBlock, {}); + builder.pushBreakBlock(exitBlock); + builder.pushContinueBlock(headBlock); - BasicBlock::conectBlocks(builder.getBasicBlock(), condBlock); + auto block = dynamic_cast(ctx->stmt()); + + if( block != nullptr) { + visitBlockStmt(block); + } else { + module->enterNewScope(); + ctx->stmt()->accept(this); + module->leaveScope(); + } + + builder.createUncondBrInst(headBlock, {}); + BasicBlock::conectBlocks(builder.getBasicBlock(), exitBlock); + builder.popBreakBlock(); + builder.popContinueBlock(); labelstring << "exit.L" << builder.getLabelIndex(); exitBlock->setName(labelstring.str()); labelstring.str(""); - function->addBasicBlock(exitBlock); - builder.setPosition(exitBlock, exitBlock->end()); return std::any(); } +std::any SysYIRGenerator::visitBreakStmt(SysYParser::BreakStmtContext *ctx) { + BasicBlock* breakBlock = builder.getBreakBlock(); + builder.pushBreakBlock(breakBlock); + BasicBlock::conectBlocks(builder.getBasicBlock(), breakBlock); + return std::any(); +} + +std::any SysYIRGenerator::visitContinueStmt(SysYParser::ContinueStmtContext *ctx) { + BasicBlock* continueBlock = builder.getContinueBlock(); + builder.createUncondBrInst(continueBlock, {}); + return std::any(); +} + +std::any SysYIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext *ctx) { + Value* returnValue = nullptr; + if (ctx->exp() != nullptr) { + returnValue = std::any_cast(visitExp(ctx->exp())); + } + + Type* funcType = builder.getBasicBlock()->getParent()->getType(); + if (funcType!= returnValue->getType() && returnValue != nullptr) { + ConstantValue * constValue = dynamic_cast(returnValue); + if (constValue != nullptr) { + if (funcType == Type::getFloatType()) { + returnValue = ConstantValue::get(static_cast(constValue->getInt())); + } else { + returnValue = ConstantValue::get(static_cast(constValue->getFloat())); + } + } else { + if (funcType == Type::getFloatType()) { + returnValue = builder.createIToFInst(returnValue); + } else { + returnValue = builder.createFtoIInst(returnValue); + } + } + } + builder.createRetInst(returnValue); + return std::any(); +} + + +std::any SysYIRGenerator::visitLVal(SysYParser::LValContext *ctx) { + std::string name = ctx->Ident()->getText(); + User* variable = module->getVariable(name); + + Value* value = nullptr; + std::vector dims; + for (const auto &exp : ctx->exp()) { + dims.push_back(std::any_cast(visitExp(exp))); + } + + if (variable == nullptr) { + throw std::runtime_error("Variable " + name + " not found."); + } + + bool indicesConstant = true; + for (const auto &index : indices) { + if (dynamic_cast(index) == nullptr) { + indicesConstant = false; + break; + } + } + + ConstantVariable* constVar = dynamic_cast(variable); + GlobalValue* globalVar = dynamic_cast(variable); + AllocaInst* localVar = dynamic_cast(variable); + if (constVar != nullptr && indicesConstant) { + // 如果是常量变量,且索引是常量,则直接获取子数组 + value = constVar->getByIndices(indices); + } else if (module->isInGlobalArea() && (globalVar != nullptr)) { + assert(indicesConstant); + value = globalVar->getByIndices(indices); + } else { + if ((globalVar != nullptr && globalVar->getNumDims() > indices.size()) || + (localVar != nullptr && localVar->getNumDims() > indices.size()) || + (constVar != nullptr && constVar->getNumDims() > indices.size())) { + // value = builder.createLaInst(variable, indices); + // 如果变量是全局变量或局部变量,且索引数量小于维度数量,则创建createGetSubArray获取子数组 + auto getArrayInst = + builder.createGetSubArray(dynamic_cast(variable), indices); + value = getArrayInst->getChildArray(); + } else { + value = builder.createLoadInst(variable, indices); + } + } + + return value; +} + +std::any SysYIRGenerator::visitPrimaryExp(SysYParser::PrimaryExpContext *ctx) { + if (ctx->exp() != nullptr) + return visitExp(ctx->exp()); + if (ctx->lVal() != nullptr) + return visitLVal(ctx->lVal()); + if (ctx->number() != nullptr) + return visitNumber(ctx->number()); + // if (ctx->string() != nullptr) { + // std::string str = ctx->string()->getText(); + // str = str.substr(1, str.size() - 2); // 去掉双引号 + // return ConstantValue::get(str); + // } + return visitNumber(ctx->number()); +} + +std::any SysYIRGenerator::visitNumber(SysYParser::NumberContext *ctx) { + if (ctx->ILITERAL() != nullptr) { + int value = std::stol(ctx->ILITERAL()->getText(), nullptr, 0); + return static_cast(ConstantValue::get(Type::getIntType(), value)); + } else if (ctx->FLITERAL() != nullptr) { + float value = std::stof(ctx->FLITERAL()->getText()); + return static_cast(ConstantValue::get(Type::getFloatType(), value)); + } + throw std::runtime_error("Unknown number type."); + return std::any(); // 不会到达这里 +} + +std::any SysYIRGenerator::visitCall(SysYParser::CallContext *ctx) { + std::string funcName = ctx->Ident()->getText(); + Function *function = module->getFunction(funcName); + if (function == nullptr) { + function = module->getExternalFunction(name); + if (function == nullptr) { + std::cout << "The function " << name << " no defined." << std::endl; + assert(function); + } + } + + std::vector args = {}; + if (name == "starttime" || name == "stoptime") { + // 如果是starttime或stoptime函数 + // TODO: 这里需要处理starttime和stoptime函数的参数 + // args.emplace_back() + } else { + if (ctx->funcRParams() != nullptr) { + args = std::any_cast>(visitFuncRParams(ctx->funcRParams())); + } + + auto params = function->getEntryBlock()->getArguments(); + for (size_t i = 0; i < args.size(); i++) { + // 参数类型转换 + if (params[i]->getType() != args[i]->getType() && + (params[i]->getNumDims() != 0 || + params[i]->getType()->as()->getBaseType() != args[i]->getType())) { + ConstantValue * constValue = dynamic_cast(args[i]); + if (constValue != nullptr) { + if (params[i]->getType() == Type::getPointerType(Type::getFloatType())) { + args[i] = ConstantValue::get(static_cast(constValue->getInt())); + } else { + args[i] = ConstantValue::get(static_cast(constValue->getFloat())); + } + } else { + if (params[i]->getType() == Type::getPointerType(Type::getFloatType())) { + args[i] = builder.createIToFInst(args[i]); + } else { + args[i] = builder.createFtoIInst(args[i]); + } + } + } + } + } + + return static_cast(builder.createCallInst(function, args)); +} + +std::any SysYIRGenerator::visitUnExp(SysYParser::UnExpContext *ctx) { + Value* value = std::any_cast(visitUnaryExp(ctx->unaryExp())); + Value* result = value; + if (ctx->unaryOp()->SUB() != nullptr) { + ConstantValue * constValue = dynamic_cast(value); + if (constValue != nullptr) { + if (constValue->isFloat()) { + result = ConstantValue::get(-constValue->getFloat()); + } else { + result = ConstantValue::get(-constValue->getInt()); + } + } else if (value != nullptr) { + if (value->getType() == Type::getIntType()) { + result = builder.createNegInst(value); + } else { + result = builder.createFNegInst(value); + } + } else { + std::cout << "UnExp: value is nullptr." << std::endl; + assert(false); + } + } else if (ctx->unaryOp()->NOT() != nullptr) { + auto constValue = dynamic_cast(value); + if (constValue != nullptr) { + if (constValue->isFloat()) { + result = + ConstantValue::get(1 - (constValue->getFloat() != 0.0F ? 1 : 0)); + } else { + result = ConstantValue::get(1 - (constValue->getInt() != 0 ? 1 : 0)); + } + } else if (value != nullptr) { + if (value->getType() == Type::getIntType()) { + result = builder.createNotInst(value); + } else { + result = builder.createFNotInst(value); + } + } else { + std::cout << "UnExp: value is nullptr." << std::endl; + assert(false); + } + } + return result; +} + +std::any SysYIRGenerator::visitFuncRParams(SysYParser::FuncRParamsContext *ctx) { + std::vector params; + for (const auto &exp : ctx->exp()) + params.push_back(std::any_cast(visitExp(exp))); + return params; +} + + +std::any SysYIRGenerator::visitMulExp(SysYParser::MulExpContext *ctx) { + auto result = std::any_cast(visitUnaryExp(ctx->unaryExp(0))); + + for (size_t i = 1; i < ctx->unaryExp().size(); i++) { + auto op = ctx->mulOp(i - 1); + Value* operand = std::any_cast(visitUnaryExp(ctx->unaryExp(i))); + + Type* resultType = result->getType(); + Type* operandType = operand->getType(); + + if (resultType == Type::getFloatType() || operandType == Type::getFloatType()) { + // 如果有一个操作数是浮点数,则将两个操作数都转换为浮点数 + if (operandType != Type::getFloatType()) { + ConstantValue * constValue = dynamic_cast(operand); + if (constValue != nullptr) + operand = ConstantValue::get(static_cast(constValue->getInt())); + else + operand = builder.createIToFInst(operand); + } else if (resultType != Type::getFloatType()) { + ConstantValue* constResult = dynamic_cast(result); + if (constResult != nullptr) + result = ConstantValue::get(static_cast(constResult->getInt())); + else + result = builder.createIToFInst(result); + } + + ConstantValue* constResult = dynamic_cast(result); + ConstantValue* constOperand = dynamic_cast(operand); + if (op->MUL() != nullptr) { + if ((constOperand != nullptr) && (constResult != nullptr)) { + result = ConstantValue::get(constResult->getFloat() * + constOperand->getFloat()); + } else { + result = builder.createFMulInst(result, operand); + } + } else if (op->DIV() != nullptr) { + if ((constOperand != nullptr) && (constResult != nullptr)) { + result = ConstantValue::get(constResult->getFloat() / + constOperand->getFloat()); + } else { + result = builder.createFDivInst(result, operand); + } + } else { + // float类型的取模操作不允许 + std::cout << "MulExp: float type mod operation is not allowed." << std::endl; + assert(false); + } + } else { + ConstantValue * constResult = dynamic_cast(result); + ConstantValue * constOperand = dynamic_cast(operand); + if (op->MUL() != nullptr) { + if ((constOperand != nullptr) && (constResult != nullptr)) + result = ConstantValue::get(constResult->getInt() * constOperand->getInt()); + else + result = builder.createMulInst(result, operand); + } else if (op->DIV() != nullptr) { + if ((constOperand != nullptr) && (constResult != nullptr)) + result = ConstantValue::get(constResult->getInt() / constOperand->getInt()); + else + result = builder.createDivInst(result, operand); + } else { + if ((constOperand != nullptr) && (constResult != nullptr)) + result = ConstantValue::get(constResult->getInt() % constOperand->getInt()); + else + result = builder.createRemInst(result, operand); + } + } + } + + return result; +} + + +std::any SysYIRGenerator::visitAddExp(SysYParser::AddExpContext *ctx) { + Value* result = std::any_cast(visitMulExp(ctx->mulExp(0))); + + for (size_t i = 1; i < ctx->mulExp().size(); i++) { + auto op = ctx->addOp(i - 1); + + Value* operand = std::any_cast(visitMulExp(ctx->mulExp(i))); + Type* resultType = result->getType(); + Type* operandType = operand->getType(); + + if (resultType == Type::getFloatType() || operandType == Type::getFloatType()) { + // 类型转换 + if (operandType != Type::getFloatType()) { + Value* constOperand = dynamic_cast(operand); + if (constOperand != nullptr) + operand = ConstantValue::get(static_cast(constOperand->getInt())); + else + operand = builder.createIToFInst(operand); + } else if (resultType != Type::getFloatType()) { + Value* constResult = dynamic_cast(result); + if (constResult != nullptr) + result = ConstantValue::get(static_cast(constResult->getInt())); + else + result = builder.createIToFInst(result); + } + + Value* constResult = dynamic_cast(result); + Value* constOperand = dynamic_cast(operand); + if (op->ADD() != nullptr) { + if ((constResult != nullptr) && (constOperand != nullptr)) + result = ConstantValue::get(constResult->getFloat() + constOperand->getFloat()); + else + result = builder.createFAddInst(result, operand); + } else { + if ((constResult != nullptr) && (constOperand != nullptr)) + result = ConstantValue::get(constResult->getFloat() - constOperand->getFloat()); + else + result = builder.createFSubInst(result, operand); + } + } else { + Value* constResult = dynamic_cast(result); + Value* constOperand = dynamic_cast(operand); + if (op->ADD() != nullptr) { + if ((constResult != nullptr) && (constOperand != nullptr)) + result = ConstantValue::get(constResult->getInt() + constOperand->getInt()); + else + result = builder.createAddInst(result, operand); + } else { + if ((constResult != nullptr) && (constOperand != nullptr)) + result = ConstantValue::get(constResult->getInt() - constOperand->getInt()); + else + result = builder.createSubInst(result, operand); + } + } + } + + return result; +} + +std:any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) { + Value* result = std::any_cast(visitAddExp(ctx->addExp(0))); + + for (size_t i = 1; i < ctx->addExp().size(); i++) { + auto op = ctx->relOp(i - 1); + Value* operand = std::any_cast(visitAddExp(ctx->addExp(i))); + + Type* resultType = result->getType(); + Type* operandType = operand->getType(); + + ConstantValue* constResult = dynamic_cast(result); + ConstantValue* constOperand = dynamic_cast(operand); + + // 常量比较 + if ((constResult != nullptr) && (constOperand != nullptr)) { + auto operand1 = constResult->isFloat() ? constResult->getFloat() + : constResult->getInt(); + auto operand2 = constOperand->isFloat() ? constOperand->getFloat() + : constOperand->getInt(); + + if (op->LT() != nullptr) result = ConstantValue::get(operand1 < operand2 ? 1 : 0); + else if (op->GT() != nullptr) result = ConstantValue::get(operand1 > operand2 ? 1 : 0); + else if (op->LE() != nullptr) result = ConstantValue::get(operand1 <= operand2 ? 1 : 0); + else if (op->GE() != nullptr) result = ConstantValue::get(operand1 >= operand2 ? 1 : 0); + else assert(false); + + } else { + Type* resultType = result->getType(); + Type* operandType = operand->getType(); + Type* floatType = Type::getFloatType(); + + // 浮点数处理 + if (resultType == floatType || operandType == floatType) { + if (resultType != floatType) { + if (constResult != nullptr) + result = ConstantValue::get(static_cast(constResult->getInt())); + else + result = builder.createIToFInst(result); + + } + if (operandType != floatType) { + if (constOperand != nullptr) + operand = ConstantValue::get(static_cast(constOperand->getInt())); + else + operand = builder.createIToFInst(operand); + + } + + if (op->LT() != nullptr) result = builder.createFCmpLTInst(result, operand); + else if (op->GT() != nullptr) result = builder.createFCmpGTInst(result, operand); + else if (op->LE() != nullptr) result = builder.createFCmpLEInst(result, operand); + else if (op->GE() != nullptr) result = builder.createFCmpGEInst(result, operand); + else assert(false); + + } else { + // 整数处理 + if (op->LT() != nullptr) result = builder.createICmpLTInst(result, operand); + else if (op->GT() != nullptr) result = builder.createICmpGTInst(result, operand); + else if (op->LE() != nullptr) result = builder.createICmpLEInst(result, operand); + else if (op->GE() != nullptr) result = builder.createICmpGEInst(result, operand); + else assert(false); + + } + } + } + + return result; +} + + void Utils::tree2Array(Type *type, ArrayValueTree *root, const std::vector &dims, unsigned numDims, ValueCounter &result, IRBuilder *builder) { diff --git a/src/include/IR.h b/src/include/IR.h index 1de2e23..3182a9a 100644 --- a/src/include/IR.h +++ b/src/include/IR.h @@ -1575,6 +1575,7 @@ class ConstantVariable : public User, public LVal { Value* getByIndex(unsigned index) const { return initValues.getValue(index); } ///< 通过一维位置index获取值 Value* getByIndices(const std::vector &indices) const { int index = 0; + // 计算偏移量 for (size_t i = 0; i < indices.size(); i++) { index = dynamic_cast(getDim(i))->getInt() * index + dynamic_cast(indices[i])->getInt(); diff --git a/src/include/SysYIRGenerator.h b/src/include/SysYIRGenerator.h index a5f5a91..e203016 100644 --- a/src/include/SysYIRGenerator.h +++ b/src/include/SysYIRGenerator.h @@ -120,6 +120,7 @@ public: // std::any visitUnaryOp(SysYParser::UnaryOpContext *ctx) override; std::any visitUnExp(SysYParser::UnExpContext *ctx) override; + std::any visitFuncRParams(SysYParser::FuncRParamsContext *ctx) override; std::any visitMulExp(SysYParser::MulExpContext *ctx) override; std::any visitAddExp(SysYParser::AddExpContext *ctx) override; From 73b382773a8207a2b168caa65c8eeff494768085 Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Sat, 21 Jun 2025 18:07:32 +0800 Subject: [PATCH 11/13] =?UTF-8?q?=E6=9A=82=E5=AD=98=E6=97=A7=E7=AC=A6?= =?UTF-8?q?=E5=8F=B7=E8=A1=A8=E7=BB=93=E6=9E=84=E5=AE=9A=E4=B9=89=EF=BC=8C?= =?UTF-8?q?TODO.md=E4=B8=AD=E6=B7=BB=E5=8A=A0=E7=9B=B8=E5=85=B3=E8=AF=B4?= =?UTF-8?q?=E6=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- TODO.md | 1 + olddef.h | 62 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+) create mode 100644 olddef.h diff --git a/TODO.md b/TODO.md index c386b0d..aae2f4e 100644 --- a/TODO.md +++ b/TODO.md @@ -8,6 +8,7 @@ - `IRBuilder`:构建指令和基本块的工具类(你们正在实现的部分) ### 2. **中端必要优化(最小集合)** +常量传播 | 优化阶段 | 关键作用 | 是否必须 | |-------------------|----------------------------------|----------| | `Mem2Reg` | 消除冗余内存访问,转换为SSA形式 | ✅ 核心 | diff --git a/olddef.h b/olddef.h new file mode 100644 index 0000000..c7b9522 --- /dev/null +++ b/olddef.h @@ -0,0 +1,62 @@ + +class SymbolTable{ +private: + enum Kind + { + kModule, + kFunction, + kBlock, + }; + + std::forward_list>> Scopes; + +public: + struct ModuleScope { + SymbolTable& tables_ref; + ModuleScope(SymbolTable& tables) : tables_ref(tables) { + tables.enter(kModule); + } + ~ModuleScope() { tables_ref.exit(); } + }; + struct FunctionScope { + SymbolTable& tables_ref; + FunctionScope(SymbolTable& tables) : tables_ref(tables) { + tables.enter(kFunction); + } + ~FunctionScope() { tables_ref.exit(); } + }; + struct BlockScope { + SymbolTable& tables_ref; + BlockScope(SymbolTable& tables) : tables_ref(tables) { + tables.enter(kBlock); + } + ~BlockScope() { tables_ref.exit(); } + }; + + SymbolTable() = default; + + bool isModuleScope() const { return Scopes.front().first == kModule; } + bool isFunctionScope() const { return Scopes.front().first == kFunction; } + bool isBlockScope() const { return Scopes.front().first == kBlock; } + Value *lookup(const std::string &name) const { + for (auto &scope : Scopes) { + auto iter = scope.second.find(name); + if (iter != scope.second.end()) + return iter->second; + } + return nullptr; + } + auto insert(const std::string &name, Value *value) { + assert(not Scopes.empty()); + return Scopes.front().second.emplace(name, value); + } +private: + void enter(Kind kind) { + Scopes.emplace_front(); + Scopes.front().first = kind; + } + void exit() { + Scopes.pop_front(); + } + +}; \ No newline at end of file From 4828c18f96e36c59090ec1c411dd62addaf635ac Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Sun, 22 Jun 2025 00:25:43 +0800 Subject: [PATCH 12/13] =?UTF-8?q?=E5=89=8D=E7=AB=AF=E5=9F=BA=E6=9C=AC?= =?UTF-8?q?=E6=9E=84=E5=BB=BA=E5=AE=8C=E6=AF=95=EF=BC=8Cbuild=E5=89=8D?= =?UTF-8?q?=E7=AB=AF=E9=83=A8=E5=88=86=E6=97=A0=E6=8A=A5=E9=94=99=EF=BC=8C?= =?UTF-8?q?argument=E7=B1=BB=E5=88=A0=E9=99=A4=E5=90=8E=E7=AB=AF=E6=8A=A5?= =?UTF-8?q?=E9=94=99=EF=BC=8CllvmIR=E8=BE=93=E5=87=BA=E5=BE=85=E5=AE=8C?= =?UTF-8?q?=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/CMakeLists.txt | 1 - src/SysY.g4 | 7 +- src/SysYIRGenerator.cpp | 306 ++++++++++++++++++++++++---------- src/include/SysYIRGenerator.h | 13 +- src/sysyc.cpp | 3 +- 5 files changed, 231 insertions(+), 99 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7e8572c..e8c78f6 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -13,7 +13,6 @@ target_link_libraries(SysYParser PUBLIC antlr4_shared) add_executable(sysyc sysyc.cpp - ASTPrinter.cpp IR.cpp SysYIRGenerator.cpp Backend.cpp diff --git a/src/SysY.g4 b/src/SysY.g4 index d614ec4..ad74a0c 100644 --- a/src/SysY.g4 +++ b/src/SysY.g4 @@ -159,9 +159,10 @@ primaryExp: LPAREN exp RPAREN | string; number: ILITERAL | FLITERAL; -unaryExp: primaryExp #primExp - | Ident LPAREN (funcRParams)? RPAREN #call - | unaryOp unaryExp #unExp; +call: Ident LPAREN (funcRParams)? RPAREN; +unaryExp: primaryExp + | call + | unaryOp unaryExp; unaryOp: ADD|SUB|NOT; funcRParams: exp (COMMA exp)*; diff --git a/src/SysYIRGenerator.cpp b/src/SysYIRGenerator.cpp index 6abdeff..7e4ed85 100644 --- a/src/SysYIRGenerator.cpp +++ b/src/SysYIRGenerator.cpp @@ -137,8 +137,8 @@ std::any SysYIRGenerator::visitVarDecl(SysYParser::VarDeclContext *ctx) { // values.getValues()可能是[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] // 对于每个维度,使用memset将对应的值填充到数组中 // 这里的alloca是一个指向数组的指针 - std::vector & counterNumbers = values.getNumbers(); - std::vector & counterValues = values.getValues(); + const std::vector & counterNumbers = values.getNumbers(); + const std::vector & counterValues = values.getValues(); unsigned begin = 0; for (size_t i = 0; i < counterNumbers.size(); i++) { @@ -160,9 +160,9 @@ std::any SysYIRGenerator::visitBType(SysYParser::BTypeContext *ctx) { } std::any SysYIRGenerator::visitScalarInitValue(SysYParser::ScalarInitValueContext *ctx) { - AllocaInst* alloca = std::any_cast(visitExp(ctx->exp())); + Value* value = std::any_cast(visitExp(ctx->exp())); ArrayValueTree* result = new ArrayValueTree(); - result->setValue(alloca); + result->setValue(value); return result; } @@ -176,9 +176,9 @@ std::any SysYIRGenerator::visitArrayInitValue(SysYParser::ArrayInitValueContext } std::any SysYIRGenerator::visitConstScalarInitValue(SysYParser::ConstScalarInitValueContext *ctx) { - AllocaInst* alloca = std::any_cast(visitConstExp(ctx->constExp())); + Value* value = std::any_cast(visitConstExp(ctx->constExp())); ArrayValueTree* result = new ArrayValueTree(); - result->setValue(alloca); + result->setValue(value); return result; } @@ -214,7 +214,7 @@ std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext *ctx){ paramTypes.push_back(std::any_cast(visitBType(param->bType()))); paramNames.push_back(param->Ident()->getText()); std::vector dims = {}; - if (param->exp() != nullptr) { + if (!param->LBRACK().empty()) { dims.push_back(ConstantValue::get(-1)); // 第一个维度不确定 for (const auto &exp : param->exp()) { dims.push_back(std::any_cast(visitExp(exp))); @@ -224,8 +224,8 @@ std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext *ctx){ } } - Type *returnType = std::any_cast(visitFuncType(ctx->funcType())); - FunctionType* funcType = Type::getFunctionType(returnType, paramTypes); + Type* returnType = std::any_cast(visitFuncType(ctx->funcType())); + Type* funcType = Type::getFunctionType(returnType, paramTypes); Function* function = module->createFunction(name, funcType); BasicBlock* entry = function->getEntryBlock(); builder.setPosition(entry, entry->end()); @@ -243,7 +243,7 @@ std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext *ctx){ module->leaveScope(); - return std::any; + return std::any(); } std::any SysYIRGenerator::visitBlockStmt(SysYParser::BlockStmtContext *ctx) { @@ -264,7 +264,7 @@ std::any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) { User* variable = module->getVariable(name); Value* value = std::any_cast(visitExp(ctx->exp())); - PointerType* variableType =dynamic_cast(variable->getType())->getBaseType(); + Type* variableType = dynamic_cast(variable->getType())->getBaseType(); // 左值右值类型不同处理 if (variableType != value->getType()) { @@ -473,12 +473,12 @@ std::any SysYIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext *ctx) { } } } - builder.createRetInst(returnValue); + builder.createReturnInst(returnValue); return std::any(); } -std::any SysYIRGenerator::visitLVal(SysYParser::LValContext *ctx) { +std::any SysYIRGenerator::visitLValue(SysYParser::LValueContext *ctx) { std::string name = ctx->Ident()->getText(); User* variable = module->getVariable(name); @@ -493,8 +493,8 @@ std::any SysYIRGenerator::visitLVal(SysYParser::LValContext *ctx) { } bool indicesConstant = true; - for (const auto &index : indices) { - if (dynamic_cast(index) == nullptr) { + for (const auto &dim : dims) { + if (dynamic_cast(dim) == nullptr) { indicesConstant = false; break; } @@ -505,21 +505,21 @@ std::any SysYIRGenerator::visitLVal(SysYParser::LValContext *ctx) { AllocaInst* localVar = dynamic_cast(variable); if (constVar != nullptr && indicesConstant) { // 如果是常量变量,且索引是常量,则直接获取子数组 - value = constVar->getByIndices(indices); + value = constVar->getByIndices(dims); } else if (module->isInGlobalArea() && (globalVar != nullptr)) { assert(indicesConstant); - value = globalVar->getByIndices(indices); + value = globalVar->getByIndices(dims); } else { - if ((globalVar != nullptr && globalVar->getNumDims() > indices.size()) || - (localVar != nullptr && localVar->getNumDims() > indices.size()) || - (constVar != nullptr && constVar->getNumDims() > indices.size())) { + if ((globalVar != nullptr && globalVar->getNumDims() > dims.size()) || + (localVar != nullptr && localVar->getNumDims() > dims.size()) || + (constVar != nullptr && constVar->getNumDims() > dims.size())) { // value = builder.createLaInst(variable, indices); // 如果变量是全局变量或局部变量,且索引数量小于维度数量,则创建createGetSubArray获取子数组 auto getArrayInst = - builder.createGetSubArray(dynamic_cast(variable), indices); + builder.createGetSubArray(dynamic_cast(variable), dims); value = getArrayInst->getChildArray(); } else { - value = builder.createLoadInst(variable, indices); + value = builder.createLoadInst(variable, dims); } } @@ -529,25 +529,23 @@ std::any SysYIRGenerator::visitLVal(SysYParser::LValContext *ctx) { std::any SysYIRGenerator::visitPrimaryExp(SysYParser::PrimaryExpContext *ctx) { if (ctx->exp() != nullptr) return visitExp(ctx->exp()); - if (ctx->lVal() != nullptr) - return visitLVal(ctx->lVal()); + if (ctx->lValue() != nullptr) + return visitLValue(ctx->lValue()); if (ctx->number() != nullptr) return visitNumber(ctx->number()); - // if (ctx->string() != nullptr) { - // std::string str = ctx->string()->getText(); - // str = str.substr(1, str.size() - 2); // 去掉双引号 - // return ConstantValue::get(str); - // } + if (ctx->string() != nullptr) { + cout << "String literal not supported in SysYIRGenerator." << endl; + } return visitNumber(ctx->number()); } std::any SysYIRGenerator::visitNumber(SysYParser::NumberContext *ctx) { if (ctx->ILITERAL() != nullptr) { int value = std::stol(ctx->ILITERAL()->getText(), nullptr, 0); - return static_cast(ConstantValue::get(Type::getIntType(), value)); + return static_cast(ConstantValue::get(value)); } else if (ctx->FLITERAL() != nullptr) { float value = std::stof(ctx->FLITERAL()->getText()); - return static_cast(ConstantValue::get(Type::getFloatType(), value)); + return static_cast(ConstantValue::get(value)); } throw std::runtime_error("Unknown number type."); return std::any(); // 不会到达这里 @@ -557,15 +555,15 @@ std::any SysYIRGenerator::visitCall(SysYParser::CallContext *ctx) { std::string funcName = ctx->Ident()->getText(); Function *function = module->getFunction(funcName); if (function == nullptr) { - function = module->getExternalFunction(name); + function = module->getExternalFunction(funcName); if (function == nullptr) { - std::cout << "The function " << name << " no defined." << std::endl; + std::cout << "The function " << funcName << " no defined." << std::endl; assert(function); } } std::vector args = {}; - if (name == "starttime" || name == "stoptime") { + if (funcName == "starttime" || funcName == "stoptime") { // 如果是starttime或stoptime函数 // TODO: 这里需要处理starttime和stoptime函数的参数 // args.emplace_back() @@ -601,7 +599,12 @@ std::any SysYIRGenerator::visitCall(SysYParser::CallContext *ctx) { return static_cast(builder.createCallInst(function, args)); } -std::any SysYIRGenerator::visitUnExp(SysYParser::UnExpContext *ctx) { +std::any SysYIRGenerator::visitUnaryExp(SysYParser::UnaryExpContext *ctx) { + if (ctx->primaryExp() != nullptr) + return visitPrimaryExp(ctx->primaryExp()); + if (ctx->call() != nullptr) + return visitCall(ctx->call()); + Value* value = std::any_cast(visitUnaryExp(ctx->unaryExp())); Value* result = value; if (ctx->unaryOp()->SUB() != nullptr) { @@ -654,24 +657,27 @@ std::any SysYIRGenerator::visitFuncRParams(SysYParser::FuncRParamsContext *ctx) std::any SysYIRGenerator::visitMulExp(SysYParser::MulExpContext *ctx) { - auto result = std::any_cast(visitUnaryExp(ctx->unaryExp(0))); - + Value * result = std::any_cast(visitUnaryExp(ctx->unaryExp(0))); + for (size_t i = 1; i < ctx->unaryExp().size(); i++) { - auto op = ctx->mulOp(i - 1); + auto opNode = dynamic_cast(ctx->children[2*i-1]); + int opType = opNode->getSymbol()->getType(); + Value* operand = std::any_cast(visitUnaryExp(ctx->unaryExp(i))); Type* resultType = result->getType(); Type* operandType = operand->getType(); + Type* floatType = Type::getFloatType(); - if (resultType == Type::getFloatType() || operandType == Type::getFloatType()) { + if (resultType == floatType || operandType == floatType) { // 如果有一个操作数是浮点数,则将两个操作数都转换为浮点数 - if (operandType != Type::getFloatType()) { + if (operandType != floatType) { ConstantValue * constValue = dynamic_cast(operand); if (constValue != nullptr) operand = ConstantValue::get(static_cast(constValue->getInt())); else operand = builder.createIToFInst(operand); - } else if (resultType != Type::getFloatType()) { + } else if (resultType != floatType) { ConstantValue* constResult = dynamic_cast(result); if (constResult != nullptr) result = ConstantValue::get(static_cast(constResult->getInt())); @@ -681,14 +687,14 @@ std::any SysYIRGenerator::visitMulExp(SysYParser::MulExpContext *ctx) { ConstantValue* constResult = dynamic_cast(result); ConstantValue* constOperand = dynamic_cast(operand); - if (op->MUL() != nullptr) { + if (opType == SysYParser::MUL) { if ((constOperand != nullptr) && (constResult != nullptr)) { result = ConstantValue::get(constResult->getFloat() * constOperand->getFloat()); } else { result = builder.createFMulInst(result, operand); } - } else if (op->DIV() != nullptr) { + } else if (opType == SysYParser::DIV) { if ((constOperand != nullptr) && (constResult != nullptr)) { result = ConstantValue::get(constResult->getFloat() / constOperand->getFloat()); @@ -703,12 +709,12 @@ std::any SysYIRGenerator::visitMulExp(SysYParser::MulExpContext *ctx) { } else { ConstantValue * constResult = dynamic_cast(result); ConstantValue * constOperand = dynamic_cast(operand); - if (op->MUL() != nullptr) { + if (opType == SysYParser::MUL) { if ((constOperand != nullptr) && (constResult != nullptr)) result = ConstantValue::get(constResult->getInt() * constOperand->getInt()); else result = builder.createMulInst(result, operand); - } else if (op->DIV() != nullptr) { + } else if (opType == SysYParser::DIV) { if ((constOperand != nullptr) && (constResult != nullptr)) result = ConstantValue::get(constResult->getInt() / constOperand->getInt()); else @@ -730,31 +736,33 @@ std::any SysYIRGenerator::visitAddExp(SysYParser::AddExpContext *ctx) { Value* result = std::any_cast(visitMulExp(ctx->mulExp(0))); for (size_t i = 1; i < ctx->mulExp().size(); i++) { - auto op = ctx->addOp(i - 1); + auto opNode = dynamic_cast(ctx->children[2*i-1]); + int opType = opNode->getSymbol()->getType(); Value* operand = std::any_cast(visitMulExp(ctx->mulExp(i))); Type* resultType = result->getType(); Type* operandType = operand->getType(); + Type* floatType = Type::getFloatType(); - if (resultType == Type::getFloatType() || operandType == Type::getFloatType()) { + if (resultType == floatType || operandType == floatType) { // 类型转换 - if (operandType != Type::getFloatType()) { - Value* constOperand = dynamic_cast(operand); + if (operandType != floatType) { + ConstantValue * constOperand = dynamic_cast(operand); if (constOperand != nullptr) operand = ConstantValue::get(static_cast(constOperand->getInt())); else operand = builder.createIToFInst(operand); - } else if (resultType != Type::getFloatType()) { - Value* constResult = dynamic_cast(result); + } else if (resultType != floatType) { + ConstantValue * constResult = dynamic_cast(result); if (constResult != nullptr) result = ConstantValue::get(static_cast(constResult->getInt())); else result = builder.createIToFInst(result); } - Value* constResult = dynamic_cast(result); - Value* constOperand = dynamic_cast(operand); - if (op->ADD() != nullptr) { + ConstantValue * constResult = dynamic_cast(result); + ConstantValue * constOperand = dynamic_cast(operand); + if (opType == SysYParser::ADD) { if ((constResult != nullptr) && (constOperand != nullptr)) result = ConstantValue::get(constResult->getFloat() + constOperand->getFloat()); else @@ -766,9 +774,9 @@ std::any SysYIRGenerator::visitAddExp(SysYParser::AddExpContext *ctx) { result = builder.createFSubInst(result, operand); } } else { - Value* constResult = dynamic_cast(result); - Value* constOperand = dynamic_cast(operand); - if (op->ADD() != nullptr) { + ConstantValue * constResult = dynamic_cast(result); + ConstantValue * constOperand = dynamic_cast(operand); + if (opType == SysYParser::ADD) { if ((constResult != nullptr) && (constOperand != nullptr)) result = ConstantValue::get(constResult->getInt() + constOperand->getInt()); else @@ -785,11 +793,13 @@ std::any SysYIRGenerator::visitAddExp(SysYParser::AddExpContext *ctx) { return result; } -std:any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) { +std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) { Value* result = std::any_cast(visitAddExp(ctx->addExp(0))); for (size_t i = 1; i < ctx->addExp().size(); i++) { - auto op = ctx->relOp(i - 1); + auto opNode = dynamic_cast(ctx->children[2*i-1]); + int opType = opNode->getSymbol()->getType(); + Value* operand = std::any_cast(visitAddExp(ctx->addExp(i))); Type* resultType = result->getType(); @@ -805,10 +815,10 @@ std:any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) { auto operand2 = constOperand->isFloat() ? constOperand->getFloat() : constOperand->getInt(); - if (op->LT() != nullptr) result = ConstantValue::get(operand1 < operand2 ? 1 : 0); - else if (op->GT() != nullptr) result = ConstantValue::get(operand1 > operand2 ? 1 : 0); - else if (op->LE() != nullptr) result = ConstantValue::get(operand1 <= operand2 ? 1 : 0); - else if (op->GE() != nullptr) result = ConstantValue::get(operand1 >= operand2 ? 1 : 0); + if (opType == SysYParser::LT) result = ConstantValue::get(operand1 < operand2 ? 1 : 0); + else if (opType == SysYParser::GT) result = ConstantValue::get(operand1 > operand2 ? 1 : 0); + else if (opType == SysYParser::LE) result = ConstantValue::get(operand1 <= operand2 ? 1 : 0); + else if (opType == SysYParser::GE) result = ConstantValue::get(operand1 >= operand2 ? 1 : 0); else assert(false); } else { @@ -833,18 +843,18 @@ std:any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) { } - if (op->LT() != nullptr) result = builder.createFCmpLTInst(result, operand); - else if (op->GT() != nullptr) result = builder.createFCmpGTInst(result, operand); - else if (op->LE() != nullptr) result = builder.createFCmpLEInst(result, operand); - else if (op->GE() != nullptr) result = builder.createFCmpGEInst(result, operand); + if (opType == SysYParser::LT) result = builder.createFCmpLTInst(result, operand); + else if (opType == SysYParser::GT) result = builder.createFCmpGTInst(result, operand); + else if (opType == SysYParser::LE) result = builder.createFCmpLEInst(result, operand); + else if (opType == SysYParser::GE) result = builder.createFCmpGEInst(result, operand); else assert(false); } else { // 整数处理 - if (op->LT() != nullptr) result = builder.createICmpLTInst(result, operand); - else if (op->GT() != nullptr) result = builder.createICmpGTInst(result, operand); - else if (op->LE() != nullptr) result = builder.createICmpLEInst(result, operand); - else if (op->GE() != nullptr) result = builder.createICmpGEInst(result, operand); + if (opType == SysYParser::LT) result = builder.createICmpLTInst(result, operand); + else if (opType == SysYParser::GT) result = builder.createICmpGTInst(result, operand); + else if (opType == SysYParser::LE) result = builder.createICmpLEInst(result, operand); + else if (opType == SysYParser::GE) result = builder.createICmpGEInst(result, operand); else assert(false); } @@ -855,31 +865,154 @@ std:any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) { } +std::any SysYIRGenerator::visitEqExp(SysYParser::EqExpContext *ctx) { + Value * result = std::any_cast(visitRelExp(ctx->relExp(0))); + + for (size_t i = 1; i < ctx->relExp().size(); i++) { + auto opNode = dynamic_cast(ctx->children[2*i-1]); + int opType = opNode->getSymbol()->getType(); + + Value * operand = std::any_cast(visitRelExp(ctx->relExp(i))); + + ConstantValue* constResult = dynamic_cast(result); + ConstantValue* constOperand = dynamic_cast(operand); + + if ((constResult != nullptr) && (constOperand != nullptr)) { + auto operand1 = constResult->isFloat() ? constResult->getFloat() + : constResult->getInt(); + auto operand2 = constOperand->isFloat() ? constOperand->getFloat() + : constOperand->getInt(); + + if (opType == SysYParser::EQ) result = ConstantValue::get(operand1 == operand2 ? 1 : 0); + else if (opType == SysYParser::NE) result = ConstantValue::get(operand1 != operand2 ? 1 : 0); + else assert(false); + + } else { + Type* resultType = result->getType(); + Type* operandType = operand->getType(); + Type* floatType = Type::getFloatType(); + + if (resultType == floatType || operandType == floatType) { + if (resultType != floatType) { + if (constResult != nullptr) + result = ConstantValue::get(static_cast(constResult->getInt())); + else + result = builder.createIToFInst(result); + } + if (operandType != floatType) { + if (constOperand != nullptr) + operand = ConstantValue::get(static_cast(constOperand->getInt())); + else + operand = builder.createIToFInst(operand); + } + + if (opType == SysYParser::EQ) result = builder.createFCmpEQInst(result, operand); + else if (opType == SysYParser::NE) result = builder.createFCmpNEInst(result, operand); + else assert(false); + + } else { + + if (opType == SysYParser::EQ) result = builder.createICmpEQInst(result, operand); + else if (opType == SysYParser::NE) result = builder.createICmpNEInst(result, operand); + else assert(false); + + } + } + } + + if (ctx->relExp().size() == 1) { + ConstantValue * constResult = dynamic_cast(result); + // 如果只有一个关系表达式,则将结果转换为0或1 + if (constResult != nullptr) { + if (constResult->isFloat()) + result = ConstantValue::get(constResult->getFloat() != 0.0F ? 1 : 0); + else + result = ConstantValue::get(constResult->getInt() != 0 ? 1 : 0); + } + } + + return result; +} + +std::any SysYIRGenerator::visitLAndExp(SysYParser::LAndExpContext *ctx){ + std::stringstream labelstring; + BasicBlock *curBlock = builder.getBasicBlock(); + Function *function = builder.getBasicBlock()->getParent(); + BasicBlock *trueBlock = builder.getTrueBlock(); + BasicBlock *falseBlock = builder.getFalseBlock(); + auto conds = ctx->eqExp(); + + for (size_t i = 0; i < conds.size() - 1; i++) { + + labelstring << "AND.L" << builder.getLabelIndex(); + BasicBlock *newtrueBlock = function->addBasicBlock(labelstring.str()); + labelstring.str(""); + + auto cond = std::any_cast(visitEqExp(ctx->eqExp(i))); + builder.createCondBrInst(cond, newtrueBlock, falseBlock, {}, {}); + + BasicBlock::conectBlocks(curBlock, newtrueBlock); + BasicBlock::conectBlocks(curBlock, falseBlock); + + curBlock = newtrueBlock; + builder.setPosition(curBlock, curBlock->end()); + } + + auto cond = std::any_cast(visitEqExp(conds.back())); + builder.createCondBrInst(cond, trueBlock, falseBlock, {}, {}); + + BasicBlock::conectBlocks(curBlock, trueBlock); + BasicBlock::conectBlocks(curBlock, falseBlock); + + return std::any(); +} + +auto SysYIRGenerator::visitLOrExp(SysYParser::LOrExpContext *ctx) -> std::any { + std::stringstream labelstring; + BasicBlock *curBlock = builder.getBasicBlock(); + Function *function = curBlock->getParent(); + auto conds = ctx->lAndExp(); + + for (size_t i = 0; i < conds.size() - 1; i++) { + labelstring << "OR.L" << builder.getLabelIndex(); + BasicBlock *newFalseBlock = function->addBasicBlock(labelstring.str()); + labelstring.str(""); + + builder.pushFalseBlock(newFalseBlock); + visitLAndExp(ctx->lAndExp(i)); + builder.popFalseBlock(); + + builder.setPosition(newFalseBlock, newFalseBlock->end()); + } + + visitLAndExp(conds.back()); + + return std::any(); +} + void Utils::tree2Array(Type *type, ArrayValueTree *root, const std::vector &dims, unsigned numDims, ValueCounter &result, IRBuilder *builder) { - auto value = root->getValue(); + Value* value = root->getValue(); auto &children = root->getChildren(); if (value != nullptr) { if (type == value->getType()) { result.push_back(value); } else { if (type == Type::getFloatType()) { - auto constValue = dynamic_cast(value); - if (constValue != nullptr) { - result.push_back( - ConstantValue::get(static_cast(constValue->getInt()))); - } else { + ConstantValue* constValue = dynamic_cast(value); + if (constValue != nullptr) + result.push_back(ConstantValue::get(static_cast(constValue->getInt()))); + else result.push_back(builder->createIToFInst(value)); - } + } else { - auto constValue = dynamic_cast(value); - if (constValue != nullptr) { - result.push_back( - ConstantValue::get(static_cast(constValue->getFloat()))); - } else { + ConstantValue* constValue = dynamic_cast(value); + if (constValue != nullptr) + result.push_back(ConstantValue::get(static_cast(constValue->getFloat()))); + else result.push_back(builder->createFtoIInst(value)); - } + } } return; @@ -909,11 +1042,10 @@ void Utils::tree2Array(Type *type, ArrayValueTree *root, int num = blockSize - afterSize + beforeSize; if (num > 0) { - if (type == Type::getFloatType()) { + if (type == Type::getFloatType()) result.push_back(ConstantValue::get(0.0F), num); - } else { + else result.push_back(ConstantValue::get(0), num); - } } } diff --git a/src/include/SysYIRGenerator.h b/src/include/SysYIRGenerator.h index e203016..445a856 100644 --- a/src/include/SysYIRGenerator.h +++ b/src/include/SysYIRGenerator.h @@ -10,8 +10,6 @@ namespace sysy { -class SysYIRGenerator : public SysYBaseVisitor { - // @brief 用于存储数组值的树结构 // 多位数组本质上是一维数组的嵌套可以用树来表示。 class ArrayValueTree { @@ -55,6 +53,8 @@ public: static void initExternalFunction(Module *pModule, IRBuilder *pBuilder); }; +class SysYIRGenerator : public SysYBaseVisitor { + private: std::unique_ptr module; IRBuilder builder; @@ -108,7 +108,8 @@ public: // std::any visitCond(SysYParser::CondContext *ctx) override; std::any visitLValue(SysYParser::LValueContext *ctx) override; - std::any visitPrimExp(SysYParser::PrimExpContext *ctx) override; + + std::any visitPrimaryExp(SysYParser::PrimaryExpContext *ctx) override; // std::any visitParenExp(SysYParser::ParenExpContext *ctx) override; std::any visitNumber(SysYParser::NumberContext *ctx) override; @@ -116,11 +117,11 @@ public: std::any visitCall(SysYParser::CallContext *ctx) override; - // std::any visitUnaryExp(SysYParser::UnaryExpContext *ctx) override; + std::any visitUnaryExp(SysYParser::UnaryExpContext *ctx) override; // std::any visitUnaryOp(SysYParser::UnaryOpContext *ctx) override; - std::any visitUnExp(SysYParser::UnExpContext *ctx) override; - + // std::any visitUnExp(SysYParser::UnExpContext *ctx) override; + std::any visitFuncRParams(SysYParser::FuncRParamsContext *ctx) override; std::any visitMulExp(SysYParser::MulExpContext *ctx) override; std::any visitAddExp(SysYParser::AddExpContext *ctx) override; diff --git a/src/sysyc.cpp b/src/sysyc.cpp index 1328f55..40116e1 100644 --- a/src/sysyc.cpp +++ b/src/sysyc.cpp @@ -6,7 +6,6 @@ using namespace std; #include "SysYLexer.h" #include "SysYParser.h" using namespace antlr4; -#include "ASTPrinter.h" #include "Backend.h" #include "SysYIRGenerator.h" #include "LLVMIRGenerator.h" @@ -77,7 +76,7 @@ int main(int argc, char **argv) { SysYIRGenerator generator; generator.visitCompUnit(moduleAST); auto moduleIR = generator.get(); - moduleIR->print(cout); + // moduleIR->print(cout); return EXIT_SUCCESS; } else if (argStopAfter == "llvmir") { LLVMIRGenerator llvmirGenerator; From d90330af3f06ece00d0658a9f3ae866c0f7516a2 Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Sun, 22 Jun 2025 14:14:02 +0800 Subject: [PATCH 13/13] add Utils::initExternalFunction --- src/CMakeLists.txt | 4 +- src/SysYIRGenerator.cpp | 109 ++++++++++++++++++++++++++++++++++++++++ src/sysyc.cpp | 17 ++++--- 3 files changed, 120 insertions(+), 10 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index e8c78f6..6bfd3e4 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -15,9 +15,9 @@ add_executable(sysyc sysyc.cpp IR.cpp SysYIRGenerator.cpp - Backend.cpp + # Backend.cpp # LLVMIRGenerator.cpp - LLVMIRGenerator_1.cpp + # LLVMIRGenerator_1.cpp ) target_include_directories(sysyc PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/include) target_compile_options(sysyc PRIVATE -frtti) diff --git a/src/SysYIRGenerator.cpp b/src/SysYIRGenerator.cpp index 7e4ed85..1718aba 100644 --- a/src/SysYIRGenerator.cpp +++ b/src/SysYIRGenerator.cpp @@ -1067,4 +1067,113 @@ void Utils::createExternalFunction( } } +void Utils::initExternalFunction(Module *pModule, IRBuilder *pBuilder) { + std::vector paramTypes; + std::vector paramNames; + std::vector> paramDims; + Type *returnType; + std::string funcName; + + returnType = Type::getIntType(); + funcName = "getint"; + Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, + funcName, pModule, pBuilder); + funcName = "getch"; + Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, + funcName, pModule, pBuilder); + paramTypes.push_back(Type::getIntType()); + paramNames.emplace_back("x"); + paramDims.push_back(std::vector{ConstantValue::get(-1)}); + funcName = "getarray"; + Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, + funcName, pModule, pBuilder); + + returnType = Type::getFloatType(); + paramTypes.clear(); + paramNames.clear(); + paramDims.clear(); + funcName = "getfloat"; + Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, + funcName, pModule, pBuilder); + + returnType = Type::getIntType(); + paramTypes.push_back(Type::getFloatType()); + paramNames.emplace_back("x"); + paramDims.push_back(std::vector{ConstantValue::get(-1)}); + funcName = "getfarray"; + Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, + funcName, pModule, pBuilder); + + returnType = Type::getVoidType(); + + paramTypes.clear(); + paramTypes.push_back(Type::getIntType()); + paramDims.clear(); + paramDims.emplace_back(); + funcName = "putint"; + Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, + funcName, pModule, pBuilder); + + funcName = "putch"; + Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, + funcName, pModule, pBuilder); + + paramTypes.clear(); + paramTypes.push_back(Type::getIntType()); + paramTypes.push_back(Type::getIntType()); + paramDims.clear(); + paramDims.emplace_back(); + paramDims.push_back(std::vector{ConstantValue::get(-1)}); + paramNames.clear(); + paramNames.emplace_back("n"); + paramNames.emplace_back("a"); + funcName = "putarray"; + Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, + funcName, pModule, pBuilder); + + paramTypes.clear(); + paramTypes.push_back(Type::getFloatType()); + paramDims.clear(); + paramDims.emplace_back(); + paramNames.clear(); + paramNames.emplace_back("a"); + funcName = "putfloat"; + Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, + funcName, pModule, pBuilder); + + paramTypes.clear(); + paramTypes.push_back(Type::getIntType()); + paramTypes.push_back(Type::getFloatType()); + paramDims.clear(); + paramDims.emplace_back(); + paramDims.push_back(std::vector{ConstantValue::get(-1)}); + paramNames.clear(); + paramNames.emplace_back("n"); + paramNames.emplace_back("a"); + funcName = "putfarray"; + Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, + funcName, pModule, pBuilder); + + paramTypes.clear(); + paramTypes.push_back(Type::getIntType()); + paramDims.clear(); + paramDims.emplace_back(); + paramNames.clear(); + paramNames.emplace_back("__LINE__"); + funcName = "starttime"; + Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, + funcName, pModule, pBuilder); + + paramTypes.clear(); + paramTypes.push_back(Type::getIntType()); + paramDims.clear(); + paramDims.emplace_back(); + paramNames.clear(); + paramNames.emplace_back("__LINE__"); + funcName = "stoptime"; + Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, + funcName, pModule, pBuilder); + +} + } // namespace sysy \ No newline at end of file diff --git a/src/sysyc.cpp b/src/sysyc.cpp index 40116e1..0f01099 100644 --- a/src/sysyc.cpp +++ b/src/sysyc.cpp @@ -6,9 +6,9 @@ using namespace std; #include "SysYLexer.h" #include "SysYParser.h" using namespace antlr4; -#include "Backend.h" +// #include "Backend.h" #include "SysYIRGenerator.h" -#include "LLVMIRGenerator.h" +// #include "LLVMIRGenerator.h" using namespace sysy; static string argStopAfter; @@ -78,12 +78,13 @@ int main(int argc, char **argv) { auto moduleIR = generator.get(); // moduleIR->print(cout); return EXIT_SUCCESS; - } else if (argStopAfter == "llvmir") { - LLVMIRGenerator llvmirGenerator; - llvmirGenerator.generateIR(moduleAST); // 使用公共接口生成 IR - cout << llvmirGenerator.getIR(); - return EXIT_SUCCESS; - } + } + // else if (argStopAfter == "llvmir") { + // LLVMIRGenerator llvmirGenerator; + // llvmirGenerator.generateIR(moduleAST); // 使用公共接口生成 IR + // cout << llvmirGenerator.getIR(); + // return EXIT_SUCCESS; + // } // // generate assembly // CodeGen codegen(moduleIR);