// SysYIRGenerator.cpp // TODO:类型转换及其检查 // TODO:sysy库函数处理 // TODO:数组处理 // TODO:对while、continue、break的测试 #include "IR.h" #include #include #include #include #include #include #include "SysYIRGenerator.h" using namespace std; namespace sysy { std::pair calculate_signed_magic(int d) { if (d == 0) throw std::runtime_error("Division by zero"); if (d == 1 || d == -1) return {0, 0}; // Not used by strength reduction int k = 0; unsigned int ad = (d > 0) ? d : -d; unsigned int temp = ad; while (temp > 0) { temp >>= 1; k++; } if ((ad & (ad - 1)) == 0) { // if power of 2 k--; } unsigned __int128 m_val = 1; m_val <<= (32 + k - 1); unsigned __int128 m_prime = m_val / ad; long long m = m_prime + 1; return {m, k}; } // 清除因函数调用而失效的表达式缓存(保守策略) void SysYIRGenerator::invalidateExpressionsOnCall() { availableBinaryExpressions.clear(); availableUnaryExpressions.clear(); availableLoads.clear(); availableGEPs.clear(); } // 在进入新的基本块时清空所有表达式缓存 void SysYIRGenerator::enterNewBasicBlock() { availableBinaryExpressions.clear(); availableUnaryExpressions.clear(); availableLoads.clear(); availableGEPs.clear(); } // 清除因变量赋值而失效的表达式缓存 // @param storedAddress: store 指令的目标地址 (例如 AllocaInst* 或 GEPInst*) void SysYIRGenerator::invalidateExpressionsOnStore(Value *storedAddress) { // 遍历二元表达式缓存,移除受影响的条目 // 创建一个临时列表来存储要移除的键,避免在迭代时修改容器 std::vector binaryKeysToRemove; for (const auto &pair : availableBinaryExpressions) { // 检查左操作数 // 如果左操作数是 LoadInst,并且它从 storedAddress 加载 if (auto loadInst = dynamic_cast(pair.first.left)) { if (loadInst->getPointer() == storedAddress) { binaryKeysToRemove.push_back(pair.first); continue; // 这个表达式已标记为移除,跳到下一个 } } // 如果左操作数本身就是被存储的地址 (例如,将一个地址值直接作为操作数,虽然不常见) if (pair.first.left == storedAddress) { binaryKeysToRemove.push_back(pair.first); continue; } // 检查右操作数,逻辑同左操作数 if (auto loadInst = dynamic_cast(pair.first.right)) { if (loadInst->getPointer() == storedAddress) { binaryKeysToRemove.push_back(pair.first); continue; } } if (pair.first.right == storedAddress) { binaryKeysToRemove.push_back(pair.first); continue; } } // 实际移除条目 for (const auto &key : binaryKeysToRemove) { availableBinaryExpressions.erase(key); } // 遍历一元表达式缓存,移除受影响的条目 std::vector unaryKeysToRemove; for (const auto &pair : availableUnaryExpressions) { // 检查操作数 if (auto loadInst = dynamic_cast(pair.first.operand)) { if (loadInst->getPointer() == storedAddress) { unaryKeysToRemove.push_back(pair.first); continue; } } if (pair.first.operand == storedAddress) { unaryKeysToRemove.push_back(pair.first); continue; } } // 实际移除条目 for (const auto &key : unaryKeysToRemove) { availableUnaryExpressions.erase(key); } availableLoads.erase(storedAddress); std::vector gepKeysToRemove; for (const auto &pair : availableGEPs) { // 检查 GEP 的基指针是否受存储影响 if (auto loadInst = dynamic_cast(pair.first.basePointer)) { if (loadInst->getPointer() == storedAddress) { gepKeysToRemove.push_back(pair.first); continue; // 标记此GEP为移除,跳过后续检查 } } // 如果基指针本身就是存储的目标地址 (不常见,但可能) if (pair.first.basePointer == storedAddress) { gepKeysToRemove.push_back(pair.first); continue; } // 检查 GEP 的每个索引是否受存储影响 for (const auto &indexVal : pair.first.indices) { if (auto loadInst = dynamic_cast(indexVal)) { if (loadInst->getPointer() == storedAddress) { gepKeysToRemove.push_back(pair.first); break; // 标记此GEP为移除,并跳出内部循环 } } // 如果索引本身就是存储的目标地址 if (indexVal == storedAddress) { gepKeysToRemove.push_back(pair.first); break; } } } // 实际移除条目 for (const auto &key : gepKeysToRemove) { availableGEPs.erase(key); } } // std::vector BinaryValueStack; ///< 用于存储value的栈 // std::vector BinaryOpStack; ///< 用于存储二元表达式的操作符栈 // // 约定操作符: // // 1: 'ADD', 2: 'SUB', 3: 'MUL', 4: 'DIV', 5: '%', 6: 'PLUS', 7: 'NEG', 8: 'NOT' // enum BinaryOp { // ADD = 1, // SUB = 2, // MUL = 3, // DIV = 4, // MOD = 5, // PLUS = 6, // NEG = 7, // NOT = 8 // }; Value *SysYIRGenerator::promoteType(Value *value, Type *targetType) { //如果是常量则直接返回相应的值 if (targetType == nullptr) { return value; // 如果值为空,那就不需要转换 } ConstantInteger* constInt = dynamic_cast(value); ConstantFloating *constFloat = dynamic_cast(value); if (constInt) { if (targetType->isFloat()) { return ConstantFloating::get(static_cast(constInt->getInt())); } return constInt; // 如果目标类型是int,直接返回原值 } else if (constFloat) { if (targetType->isInt()) { return ConstantInteger::get(static_cast(constFloat->getFloat())); } return constFloat; // 如果目标类型是float,直接返回原值 } if (value->getType()->isInt() && targetType->isFloat()) { return builder.createItoFInst(value); } else if (value->getType()->isFloat() && targetType->isInt()) { return builder.createFtoIInst(value); } // 如果类型已经匹配,直接返回原值 return value; } bool SysYIRGenerator::isRightAssociative(int op) { return (op == BinaryOp::PLUS || op == BinaryOp::NEG || op == BinaryOp::NOT); } void SysYIRGenerator::compute() { // 先将中缀表达式转换为后缀表达式 BinaryRPNStack.clear(); BinaryOpStack.clear(); int begin = BinaryExpStack.size() - BinaryExpLenStack.back(), end = BinaryExpStack.size(); for (int i = begin; i < end; i++) { auto item = BinaryExpStack[i]; if (std::holds_alternative(item)) { // 如果是操作数 (Value*),直接推入后缀表达式栈 BinaryRPNStack.push_back(item); // 直接 push_back item (ValueOrOperator类型) } else { // 如果是操作符 int currentOp = std::get(item); if (currentOp == LPAREN) { // 左括号直接入栈 BinaryOpStack.push_back(currentOp); } else if (currentOp == RPAREN) { // 右括号:将操作符栈中的操作符弹出并添加到后缀表达式栈,直到遇到左括号 while (!BinaryOpStack.empty() && BinaryOpStack.back() != LPAREN) { BinaryRPNStack.push_back(BinaryOpStack.back()); // 直接 push_back int BinaryOpStack.pop_back(); } if (!BinaryOpStack.empty() && BinaryOpStack.back() == LPAREN) { BinaryOpStack.pop_back(); // 弹出左括号,但不添加到后缀表达式栈 } else { // 错误:不匹配的右括号 std::cerr << "Error: Mismatched parentheses in expression." << std::endl; return; } } else { // 普通操作符 while (!BinaryOpStack.empty() && BinaryOpStack.back() != LPAREN) { int stackTopOp = BinaryOpStack.back(); // 如果当前操作符优先级低于栈顶操作符优先级 // 或者 (当前操作符优先级等于栈顶操作符优先级 并且 栈顶操作符是左结合) if (getOperatorPrecedence(currentOp) < getOperatorPrecedence(stackTopOp) || (getOperatorPrecedence(currentOp) == getOperatorPrecedence(stackTopOp) && !isRightAssociative(stackTopOp))) { BinaryRPNStack.push_back(stackTopOp); BinaryOpStack.pop_back(); } else { break; // 否则当前操作符入栈 } } BinaryOpStack.push_back(currentOp); // 当前操作符入栈 } } } // 遍历结束后,将操作符栈中剩余的所有操作符弹出并添加到后缀表达式栈 while (!BinaryOpStack.empty()) { if (BinaryOpStack.back() == LPAREN) { // 错误:不匹配的左括号 std::cerr << "Error: Mismatched parentheses in expression (unclosed parenthesis)." << std::endl; return; } BinaryRPNStack.push_back(BinaryOpStack.back()); // 直接 push_back int BinaryOpStack.pop_back(); } // 弹出BinaryExpStack的表达式 int count = end - begin; for (int i = 0; i < count; i++) { BinaryExpStack.pop_back(); } if (!BinaryExpLenStack.empty()) { BinaryExpLenStack.back() -= count; } // 计算后缀表达式 // 每次计算前清空操作数栈 BinaryValueStack.clear(); // 遍历后缀表达式栈 Type *commonType = nullptr; for(const auto &item : BinaryRPNStack) { if (std::holds_alternative(item)) { // 如果是操作数 (Value*) 检测他的类型 Value *value = std::get(item); if (commonType == nullptr) { commonType = value->getType(); } else if (value->getType() != commonType && value->getType()->isFloat()) { // 如果当前值的类型与commonType不同且是float类型,则提升为float commonType = Type::getFloatType(); break; } } else { continue; } } for (const auto &item : BinaryRPNStack) { if (std::holds_alternative(item)) { // 如果是操作数 (Value*),直接推入操作数栈 BinaryValueStack.push_back(std::get(item)); } else { // 如果是操作符 int op = std::get(item); Value *resultValue = nullptr; Value *lhs = nullptr; Value *rhs = nullptr; Value *operand = nullptr; switch (op) { case BinaryOp::ADD: case BinaryOp::SUB: case BinaryOp::MUL: case BinaryOp::DIV: case BinaryOp::MOD: { // 二元操作符需要两个操作数 if (BinaryValueStack.size() < 2) { std::cerr << "Error: Not enough operands for binary operation: " << op << std::endl; return; // 或者抛出异常 } rhs = BinaryValueStack.back(); BinaryValueStack.pop_back(); lhs = BinaryValueStack.back(); BinaryValueStack.pop_back(); // 类型转换 lhs = promoteType(lhs, commonType); rhs = promoteType(rhs, commonType); // 尝试常量折叠 ConstantValue *lhsConst = dynamic_cast(lhs); ConstantValue *rhsConst = dynamic_cast(rhs); if (lhsConst && rhsConst) { // 如果都是常量,直接计算结果 if (commonType == Type::getIntType()) { int lhsVal = lhsConst->getInt(); int rhsVal = rhsConst->getInt(); switch (op) { case BinaryOp::ADD: resultValue = ConstantInteger::get(lhsVal + rhsVal); break; case BinaryOp::SUB: resultValue = ConstantInteger::get(lhsVal - rhsVal); break; case BinaryOp::MUL: resultValue = ConstantInteger::get(lhsVal * rhsVal); break; case BinaryOp::DIV: if (rhsVal == 0) { std::cerr << "Error: Division by zero." << std::endl; return; } resultValue = sysy::ConstantInteger::get(lhsVal / rhsVal); break; case BinaryOp::MOD: if (rhsVal == 0) { std::cerr << "Error: Modulo by zero." << std::endl; return; } resultValue = sysy::ConstantInteger::get(lhsVal % rhsVal); break; default: std::cerr << "Error: Unknown binary operator for constants: " << op << std::endl; return; } } else if (commonType == Type::getFloatType()) { float lhsVal = lhsConst->getFloat(); float rhsVal = rhsConst->getFloat(); switch (op) { case BinaryOp::ADD: resultValue = ConstantFloating::get(lhsVal + rhsVal); break; case BinaryOp::SUB: resultValue = ConstantFloating::get(lhsVal - rhsVal); break; case BinaryOp::MUL: resultValue = ConstantFloating::get(lhsVal * rhsVal); break; case BinaryOp::DIV: if (rhsVal == 0.0f) { std::cerr << "Error: Division by zero." << std::endl; return; } resultValue = sysy::ConstantFloating::get(lhsVal / rhsVal); break; case BinaryOp::MOD: std::cerr << "Error: Modulo operator not supported for float types." << std::endl; return; default: std::cerr << "Error: Unknown binary operator for float constants: " << op << std::endl; return; } } else { std::cerr << "Error: Unsupported type for binary constant operation." << std::endl; return; } } else { // 否则,创建相应的IR指令 ExpKey currentExpKey(static_cast(op), lhs, rhs); auto it = availableBinaryExpressions.find(currentExpKey); if (it != availableBinaryExpressions.end()) { // 在缓存中找到,重用结果 resultValue = it->second; } else { if (commonType == Type::getIntType()) { switch (op) { case BinaryOp::ADD: resultValue = builder.createAddInst(lhs, rhs); break; case BinaryOp::SUB: resultValue = builder.createSubInst(lhs, rhs); break; case BinaryOp::MUL: resultValue = builder.createMulInst(lhs, rhs); break; case BinaryOp::DIV: resultValue = builder.createDivInst(lhs, rhs); break; case BinaryOp::MOD: resultValue = builder.createRemInst(lhs, rhs); break; } } else if (commonType == Type::getFloatType()) { switch (op) { case BinaryOp::ADD: resultValue = builder.createFAddInst(lhs, rhs); break; case BinaryOp::SUB: resultValue = builder.createFSubInst(lhs, rhs); break; case BinaryOp::MUL: resultValue = builder.createFMulInst(lhs, rhs); break; case BinaryOp::DIV: resultValue = builder.createFDivInst(lhs, rhs); break; case BinaryOp::MOD: std::cerr << "Error: Modulo operator not supported for float types." << std::endl; return; } } else { std::cerr << "Error: Unsupported type for binary instruction." << std::endl; return; } // 将新创建的指令结果添加到缓存 availableBinaryExpressions[currentExpKey] = resultValue; } } break; } case BinaryOp::PLUS: case BinaryOp::NEG: case BinaryOp::NOT: { // 一元操作符需要一个操作数 if (BinaryValueStack.empty()) { std::cerr << "Error: Not enough operands for unary operation: " << op << std::endl; return; } operand = BinaryValueStack.back(); BinaryValueStack.pop_back(); operand = promoteType(operand, commonType); // 尝试常量折叠 ConstantInteger *constInt = dynamic_cast(operand); ConstantFloating *constFloat = dynamic_cast(operand); if (constInt || constFloat) { // 如果是常量,直接计算结果 switch (op) { case BinaryOp::PLUS: resultValue = operand; break; case BinaryOp::NEG: { if (constInt) { resultValue = constInt->getNeg(); } else if (constFloat) { resultValue = constFloat->getNeg(); } else { std::cerr << "Error: Negation not supported for constant operand type." << std::endl; return; } break; } case BinaryOp::NOT: if (constInt) { resultValue = sysy::ConstantInteger::get(constInt->getInt() == 0 ? 1 : 0); } else if (constFloat) { resultValue = sysy::ConstantInteger::get(constFloat->getFloat() == 0.0f ? 1 : 0); } else { std::cerr << "Error: Logical NOT not supported for constant operand type." << std::endl; return; } break; default: std::cerr << "Error: Unknown unary operator for constants: " << op << std::endl; return; } } else { // 否则,创建相应的IR指令 (在这里应用CSE) UnExpKey currentUnExpKey(static_cast(op), operand); auto it = availableUnaryExpressions.find(currentUnExpKey); if (it != availableUnaryExpressions.end()) { // 在缓存中找到,重用结果 resultValue = it->second; } else { switch (op) { case BinaryOp::PLUS: resultValue = operand; // 一元加指令通常直接返回操作数 break; case BinaryOp::NEG: { if (commonType == sysy::Type::getIntType()) { resultValue = builder.createNegInst(operand); } else if (commonType == sysy::Type::getFloatType()) { resultValue = builder.createFNegInst(operand); } else { std::cerr << "Error: Negation not supported for operand type." << std::endl; return; } break; } case BinaryOp::NOT: // 逻辑非 if (commonType == sysy::Type::getIntType()) { resultValue = builder.createNotInst(operand); } else if (commonType == sysy::Type::getFloatType()) { resultValue = builder.createFNotInst(operand); } else { std::cerr << "Error: Logical NOT not supported for operand type." << std::endl; return; } break; default: std::cerr << "Error: Unknown unary operator for instructions: " << op << std::endl; return; } // 将新创建的指令结果添加到缓存 availableUnaryExpressions[currentUnExpKey] = resultValue; } } break; } default: std::cerr << "Error: Unknown operator " << op << " encountered in RPN stack." << std::endl; return; } // 将计算结果或指令结果推入操作数栈 if (resultValue) { BinaryValueStack.push_back(resultValue); } else { std::cerr << "Error: Result value is null after processing operator " << op << "!" << std::endl; return; } } } // 后缀表达式处理完毕,操作数栈的栈顶就是最终结果 if (BinaryValueStack.empty()) { std::cerr << "Error: No values left in BinaryValueStack after processing RPN." << std::endl; return; } if (BinaryValueStack.size() > 1) { std::cerr << "Warning: Multiple values left in BinaryValueStack after processing RPN. Expression might be malformed." << std::endl; } BinaryRPNStack.clear(); // 清空后缀表达式栈 BinaryOpStack.clear(); // 清空操作符栈 return; } Value* SysYIRGenerator::computeExp(SysYParser::ExpContext *ctx, Type* targetType){ if (ctx->addExp() == nullptr) { assert(false && "ExpContext should have an addExp child!"); } BinaryExpLenStack.push_back(0); // 进入新的层次时Push 0 visitAddExp(ctx->addExp()); // if(targetType == nullptr) { // targetType = Type::getIntType(); // 默认目标类型为int // } compute(); // 最后一个Value应该是最终结果 Value* result = BinaryValueStack.back(); BinaryValueStack.pop_back(); // 移除结果值 result = promoteType(result, targetType); // 确保结果类型符合目标类型 // 检查当前层次的操作符数量 int ExpLen = BinaryExpLenStack.back(); BinaryExpLenStack.pop_back(); // 离开层次时将该层次 if (ExpLen > 0) { std::cerr << "Warning: There are still " << ExpLen << " binary val or op left unprocessed in this level!" << std::endl; return nullptr; } return result; } Value* SysYIRGenerator::computeAddExp(SysYParser::AddExpContext *ctx, Type* targetType){ // 根据AddExpContext中的操作符和操作数计算加法表达式 // 这里假设AddExpContext已经被正确填充 if (ctx->mulExp().size() == 0) { assert(false && "AddExpContext should have a mulExp child!"); } BinaryExpLenStack.push_back(0); // 进入新的层次时Push 0 visitMulExp(ctx->mulExp(0)); // BinaryValueStack.push_back(result); for (int i = 1; i < ctx->mulExp().size(); i++) { auto opNode = dynamic_cast(ctx->children[2*i-1]); int opType = opNode->getSymbol()->getType(); switch(opType) { case SysYParser::ADD: BinaryExpStack.push_back(BinaryOp::ADD); BinaryExpLenStack.back()++; break; case SysYParser::SUB: BinaryExpStack.push_back(BinaryOp::SUB); BinaryExpLenStack.back()++; break; default: assert(false && "Unexpected operator in AddExp."); } // BinaryExpStack.push_back(opType); visitMulExp(ctx->mulExp(i)); // BinaryValueStack.push_back(operand); } // if(targetType == nullptr) { // targetType = Type::getIntType(); // 默认目标类型为int // } // 根据后缀表达式的逻辑计算 compute(); // 最后一个Value应该是最终结果 Value* result = BinaryValueStack.back(); BinaryValueStack.pop_back(); // 移除最后一个值,因为它已经被计算 result = promoteType(result, targetType); // 确保结果类型符合目标类型 int ExpLen = BinaryExpLenStack.back(); BinaryExpLenStack.pop_back(); // 离开层次时将该层次 if (ExpLen > 0) { std::cerr << "Warning: There are still " << ExpLen << " binary val or op left unprocessed in this level!" << std::endl; return nullptr; } return result; } Type* SysYIRGenerator::buildArrayType(Type* baseType, const std::vector& dims){ Type* currentType = baseType; // 从最内层维度开始构建 ArrayType // 例如对于 int arr[2][3],先处理 [3],再处理 [2] // 注意:SysY 的 dims 是从最外层到最内层,所以我们需要反向迭代 // 或者调整逻辑,使得从内到外构建 ArrayType // 假设 dims 列表是 [dim1, dim2, dim3...] (例如 [2, 3] for int[2][3]) // 我们需要从最内层维度开始向外构建 ArrayType for (int i = dims.size() - 1; i >= 0; --i) { // 维度大小必须是常量,否则无法构建 ArrayType ConstantInteger* constDim = dynamic_cast(dims[i]); if (constDim == nullptr) { // 如果维度不是常量,可能需要特殊处理,例如将其视为指针 // 对于函数参数 int arr[] 这种,第一个维度可以为未知 // 在这里,我们假设所有声明的数组维度都是常量 assert(false && "Array dimension must be a constant integer!"); return nullptr; } unsigned dimSize = constDim->getInt(); currentType = Type::getArrayType(currentType, dimSize); } return currentType; } // @brief: 获取 GEP 指令的地址 // @param basePointer: GEP 的基指针,已经过适当的加载/处理,类型为 LLVM IR 中的指针类型。 // 例如,对于局部数组,它是 AllocaInst;对于参数数组,它是 LoadInst 的结果。 // @param indices: 已经包含了所有必要的偏移索引 (包括可能的初始 0 索引,由 visitLValue 准备)。 // @return: 计算得到的地址值 (也是一个指针类型) Value* SysYIRGenerator::getGEPAddressInst(Value* basePointer, const std::vector& indices) { // 检查 basePointer 是否为指针类型 assert(basePointer->getType()->isPointer() && "Base pointer must be a pointer type!"); // `indices` 向量现在由调用方(如 visitLValue, visitVarDecl, visitAssignStmt)负责完整准备, // 包括是否需要添加初始的 `0` 索引。 // 所以这里直接将其传递给 `builder.createGetElementPtrInst`。 GEPKey key = {basePointer, indices}; // 尝试从缓存中查找 auto it = availableGEPs.find(key); if (it != availableGEPs.end()) { return it->second; // 缓存命中,返回已有的 GEPInst* } // 缓存未命中,创建新的 GEPInst Value* gepInst = builder.createGetElementPtrInst(basePointer, indices); // 假设 builder 提供了 createGEPInst 方法 availableGEPs[key] = gepInst; // 将新的 GEPInst* 加入缓存 return gepInst; } /* * @brief: visit compUnit * @details: * compUnit: (globalDecl | funcDef)+; */ std::any SysYIRGenerator::visitCompUnit(SysYParser::CompUnitContext *ctx) { // create the IR module auto pModule = new Module(); assert(pModule); module.reset(pModule); // SymbolTable::ModuleScope scope(symbols_table); Utils::initExternalFunction(pModule, &builder); pModule->enterNewScope(); visitChildren(ctx); pModule->leaveScope(); Utils::modify_timefuncname(pModule); return pModule; } std::any SysYIRGenerator::visitGlobalConstDecl(SysYParser::GlobalConstDeclContext *ctx){ auto constDecl = ctx->constDecl(); Type* type = std::any_cast(visitBType(constDecl->bType())); for (const auto &constDef : constDecl->constDef()) { std::vector dims = {}; std::string name = constDef->Ident()->getText(); auto constExps = constDef->constExp(); if (!constExps.empty()) { for (const auto &constExp : constExps) { dims.push_back(std::any_cast(visitConstExp(constExp))); } } ArrayValueTree* root = std::any_cast(constDef->constInitVal()->accept(this)); ValueCounter values; Utils::tree2Array(type, root, dims, dims.size(), values, &builder); delete root; // 创建全局常量变量,并更新符号表 Type* variableType = type; if (!dims.empty()) { // 如果有维度,说明是数组 variableType = buildArrayType(type, dims); // 构建完整的 ArrayType } module->createConstVar(name, Type::getPointerType(variableType), values); } return std::any(); } std::any SysYIRGenerator::visitGlobalVarDecl(SysYParser::GlobalVarDeclContext *ctx) { auto varDecl = ctx->varDecl(); Type* type = std::any_cast(visitBType(varDecl->bType())); for (const auto &varDef : varDecl->varDef()) { std::vector dims = {}; std::string name = varDef->Ident()->getText(); auto constExps = varDef->constExp(); if (!constExps.empty()) { for (const auto &constExp : constExps) { dims.push_back(std::any_cast(visitConstExp(constExp))); } } ValueCounter values = {}; if (varDef->initVal() != nullptr) { ArrayValueTree* root = std::any_cast(varDef->initVal()->accept(this)); Utils::tree2Array(type, root, dims, dims.size(), values, &builder); delete root; } // 创建全局变量,并更新符号表 Type* variableType = type; if (!dims.empty()) { // 如果有维度,说明是数组 variableType = buildArrayType(type, dims); // 构建完整的 ArrayType } module->createGlobalValue(name, Type::getPointerType(variableType), values); } return std::any(); } std::any SysYIRGenerator::visitConstDecl(SysYParser::ConstDeclContext *ctx) { Type *type = std::any_cast(visitBType(ctx->bType())); for (const auto constDef : ctx->constDef()) { std::vector dims = {}; std::string name = constDef->Ident()->getText(); auto constExps = constDef->constExp(); if (!constExps.empty()) { for (const auto constExp : constExps) { dims.push_back(std::any_cast(visitConstExp(constExp))); } } Type *variableType = type; if (!dims.empty()) { variableType = buildArrayType(type, dims); // 构建完整的 ArrayType } // 显式地为局部常量在栈上分配空间 // alloca 的类型将是指针指向常量类型,例如 `int*` 或 `int[2][3]*` // 将alloca全部集中到entry中 auto entry = builder.getBasicBlock()->getParent()->getEntryBlock(); auto it = builder.getPosition(); auto nowblk = builder.getBasicBlock(); builder.setPosition(entry, entry->terminator()); AllocaInst *alloca = builder.createAllocaInst(Type::getPointerType(variableType), name); builder.setPosition(nowblk, it); ArrayValueTree *root = std::any_cast(constDef->constInitVal()->accept(this)); ValueCounter values; Utils::tree2Array(type, root, dims, dims.size(), values, &builder); delete root; // 根据维度信息进行 store 初始化 if (dims.empty()) { // 标量常量初始化 // 局部常量必须有初始值,且通常是单个值 if (!values.getValues().empty()) { builder.createStoreInst(values.getValue(0), alloca); } else { // 错误处理:局部标量常量缺少初始化值 // 或者可以考虑默认初始化为0,但这通常不符合常量的语义 assert(false && "Local scalar constant must have an initialization value!"); return std::any(); // 直接返回,避免继续执行 } } else { // 数组常量初始化 const std::vector &counterValues = values.getValues(); const std::vector &counterNumbers = values.getNumbers(); int numElements = 1; std::vector dimSizes; for (Value *dimVal : dims) { if (ConstantInteger *constInt = dynamic_cast(dimVal)) { int dimSize = constInt->getInt(); numElements *= dimSize; dimSizes.push_back(dimSize); } // TODO else 错误处理:数组维度必须是常量(对于静态分配) else { assert(false && "Array dimension must be a constant integer!"); return std::any(); // 直接返回,避免继续执行 } } unsigned int elementSizeInBytes = type->getSize(); unsigned int totalSizeInBytes = numElements * elementSizeInBytes; // 检查是否所有初始化值都是零 bool allValuesAreZero = false; if (counterValues.empty()) { // 如果没有提供初始化值,通常视为全零初始化 allValuesAreZero = true; } else { allValuesAreZero = true; for (Value *val : counterValues) { if (ConstantInteger *constInt = dynamic_cast(val)) { if (constInt->getInt() != 0) { allValuesAreZero = false; break; } } else { // 如果不是常量整数,则不能确定是零 allValuesAreZero = false; break; } } } if (allValuesAreZero) { builder.createMemsetInst(alloca, ConstantInteger::get(0), ConstantInteger::get(totalSizeInBytes), ConstantInteger::get(0)); } else { int linearIndexOffset = 0; // 用于追踪当前处理的线性索引的偏移量 for (int k = 0; k < counterValues.size(); ++k) { // 当前 Value 的值和重复次数 Value *currentValue = counterValues[k]; unsigned currentRepeatNum = counterNumbers[k]; // 检查是否是0,并且重复次数足够大(例如 >16),才用 memset if (ConstantInteger *constInt = dynamic_cast(currentValue)) { if (constInt->getInt() == 0 && currentRepeatNum >= 16) { // 阈值可调整(如16、32等) // 计算 memset 的起始地址(基于当前线性偏移量) std::vector memsetStartIndices; int tempLinearIndex = linearIndexOffset; // 将线性索引转换为多维索引 for (int dimIdx = dimSizes.size() - 1; dimIdx >= 0; --dimIdx) { memsetStartIndices.insert(memsetStartIndices.begin(), ConstantInteger::get(static_cast(tempLinearIndex % dimSizes[dimIdx]))); tempLinearIndex /= dimSizes[dimIdx]; } // 构造 GEP 计算 memset 的起始地址 std::vector gepIndicesForMemset; gepIndicesForMemset.push_back(ConstantInteger::get(0)); // 跳过 alloca 类型 gepIndicesForMemset.insert(gepIndicesForMemset.end(), memsetStartIndices.begin(), memsetStartIndices.end()); Value *memsetPtr = builder.createGetElementPtrInst(alloca, gepIndicesForMemset); // 计算 memset 的字节数 = 元素个数 × 元素大小 Type *elementType = type;; uint64_t elementSize = elementType->getSize(); Value *size = ConstantInteger::get(currentRepeatNum * elementSize); // 生成 memset 指令(假设你的 IRBuilder 有 createMemset 方法) builder.createMemsetInst(memsetPtr, ConstantInteger::get(0), size, ConstantInteger::get(0)); // 跳过这些已处理的0 linearIndexOffset += currentRepeatNum; continue; // 直接进入下一次循环 } } for (unsigned i = 0; i < currentRepeatNum; ++i) { // 对于非零值,生成对应的 store 指令 std::vector currentIndices; int tempLinearIndex = linearIndexOffset + i; // 使用偏移量和当前重复次数内的索引 // 将线性索引转换为多维索引 for (int dimIdx = dimSizes.size() - 1; dimIdx >= 0; --dimIdx) { currentIndices.insert(currentIndices.begin(), ConstantInteger::get(static_cast(tempLinearIndex % dimSizes[dimIdx]))); tempLinearIndex /= dimSizes[dimIdx]; } // 对于局部数组,alloca 本身就是 GEP 的基指针。 // GEP 的第一个索引必须是 0,用于“步过”整个数组。 std::vector gepIndicesForInit; gepIndicesForInit.push_back(ConstantInteger::get(0)); gepIndicesForInit.insert(gepIndicesForInit.end(), currentIndices.begin(), currentIndices.end()); // 计算元素的地址 Value *elementAddress = getGEPAddressInst(alloca, gepIndicesForInit); // 生成 store 指令 builder.createStoreInst(currentValue, elementAddress); } // 更新线性索引偏移量,以便下一次迭代从正确的位置开始 linearIndexOffset += currentRepeatNum; } } } // 更新符号表,将常量名称与 AllocaInst 关联起来 module->addVariable(name, alloca); } return std::any(); } std::any SysYIRGenerator::visitVarDecl(SysYParser::VarDeclContext *ctx) { Type* type = std::any_cast(visitBType(ctx->bType())); for (const auto varDef : ctx->varDef()) { std::vector dims = {}; std::string name = varDef->Ident()->getText(); auto constExps = varDef->constExp(); if (!constExps.empty()) { for (const auto &constExp : constExps) { dims.push_back(std::any_cast(visitConstExp(constExp))); } } Type* variableType = type; if (!dims.empty()) { // 如果有维度,说明是数组 variableType = buildArrayType(type, dims); // 构建完整的 ArrayType } // 对于数组,alloca 的类型将是指针指向数组类型,例如 `int[2][3]*` // 对于标量,alloca 的类型将是指针指向标量类型,例如 `int*` auto entry = builder.getBasicBlock()->getParent()->getEntryBlock(); auto it = builder.getPosition(); auto nowblk = builder.getBasicBlock(); builder.setPosition(entry, entry->terminator()); AllocaInst *alloca = builder.createAllocaInst(Type::getPointerType(variableType), name); builder.setPosition(nowblk, it); if (varDef->initVal() != nullptr) { ValueCounter values; ArrayValueTree* root = std::any_cast(varDef->initVal()->accept(this)); Utils::tree2Array(type, root, dims, dims.size(), values, &builder); delete root; if (dims.empty()) { // 标量变量初始化 builder.createStoreInst(values.getValue(0), alloca); } else { // 数组变量初始化 const std::vector &counterValues = values.getValues(); const std::vector &counterNumbers = values.getNumbers(); int numElements = 1; std::vector dimSizes; for (Value *dimVal : dims) { if (ConstantInteger *constInt = dynamic_cast(dimVal)) { int dimSize = constInt->getInt(); numElements *= dimSize; dimSizes.push_back(dimSize); } // TODO else 错误处理:数组维度必须是常量(对于静态分配) } unsigned int elementSizeInBytes = type->getSize(); unsigned int totalSizeInBytes = numElements * elementSizeInBytes; bool allValuesAreZero = false; if (counterValues.empty()) { allValuesAreZero = true; } else { allValuesAreZero = true; for (Value *val : counterValues){ if (ConstantInteger *constInt = dynamic_cast(val)){ if (constInt->getInt() != 0){ allValuesAreZero = false; break; } } else{ allValuesAreZero = false; break; } } } if (allValuesAreZero) { builder.createMemsetInst( alloca, ConstantInteger::get(0), ConstantInteger::get(totalSizeInBytes), ConstantInteger::get(0)); } else { int linearIndexOffset = 0; // 用于追踪当前处理的线性索引的偏移量 for (int k = 0; k < counterValues.size(); ++k) { // 当前 Value 的值和重复次数 Value *currentValue = counterValues[k]; unsigned currentRepeatNum = counterNumbers[k]; // 检查是否是0,并且重复次数足够大(例如 >16),才用 memset if (ConstantInteger *constInt = dynamic_cast(currentValue)) { if (constInt->getInt() == 0 && currentRepeatNum >= 16) { // 阈值可调整(如16、32等) // 计算 memset 的起始地址(基于当前线性偏移量) std::vector memsetStartIndices; int tempLinearIndex = linearIndexOffset; // 将线性索引转换为多维索引 for (int dimIdx = dimSizes.size() - 1; dimIdx >= 0; --dimIdx) { memsetStartIndices.insert(memsetStartIndices.begin(), ConstantInteger::get(static_cast(tempLinearIndex % dimSizes[dimIdx]))); tempLinearIndex /= dimSizes[dimIdx]; } // 构造 GEP 计算 memset 的起始地址 std::vector gepIndicesForMemset; gepIndicesForMemset.push_back(ConstantInteger::get(0)); // 跳过 alloca 类型 gepIndicesForMemset.insert(gepIndicesForMemset.end(), memsetStartIndices.begin(), memsetStartIndices.end()); Value *memsetPtr = builder.createGetElementPtrInst(alloca, gepIndicesForMemset); // 计算 memset 的字节数 = 元素个数 × 元素大小 Type *elementType = type; ; uint64_t elementSize = elementType->getSize(); Value *size = ConstantInteger::get(currentRepeatNum * elementSize); // 生成 memset 指令(假设你的 IRBuilder 有 createMemset 方法) builder.createMemsetInst(memsetPtr, ConstantInteger::get(0), size, ConstantInteger::get(0)); // 跳过这些已处理的0 linearIndexOffset += currentRepeatNum; continue; // 直接进入下一次循环 } } for (unsigned i = 0; i < currentRepeatNum; ++i) { std::vector currentIndices; int tempLinearIndex = linearIndexOffset + i; // 使用偏移量和当前重复次数内的索引 // 将线性索引转换为多维索引 for (int dimIdx = dimSizes.size() - 1; dimIdx >= 0; --dimIdx) { currentIndices.insert(currentIndices.begin(), ConstantInteger::get(static_cast(tempLinearIndex % dimSizes[dimIdx]))); tempLinearIndex /= dimSizes[dimIdx]; } // 对于局部数组,alloca 本身就是 GEP 的基指针。 // GEP 的第一个索引必须是 0,用于“步过”整个数组。 std::vector gepIndicesForInit; gepIndicesForInit.push_back(ConstantInteger::get(0)); gepIndicesForInit.insert(gepIndicesForInit.end(), currentIndices.begin(), currentIndices.end()); // 计算元素的地址 Value *elementAddress = getGEPAddressInst(alloca, gepIndicesForInit); // 生成 store 指令 builder.createStoreInst(currentValue, elementAddress); } // 更新线性索引偏移量,以便下一次迭代从正确的位置开始 linearIndexOffset += currentRepeatNum; } } } } else { // 如果没有显式初始化值,默认对数组进行零初始化 if (!dims.empty()) { // 只有数组才需要默认的零初始化 int numElements = 1; for (Value *dimVal : dims) { if (ConstantInteger *constInt = dynamic_cast(dimVal)) { numElements *= constInt->getInt(); } } unsigned int elementSizeInBytes = type->getSize(); unsigned int totalSizeInBytes = numElements * elementSizeInBytes; builder.createMemsetInst( alloca, ConstantInteger::get(0), ConstantInteger::get(totalSizeInBytes), ConstantInteger::get(0) ); } } module->addVariable(name, alloca); } return std::any(); } std::any SysYIRGenerator::visitBType(SysYParser::BTypeContext *ctx) { return ctx->INT() != nullptr ? Type::getIntType() : Type::getFloatType(); } std::any SysYIRGenerator::visitScalarInitValue(SysYParser::ScalarInitValueContext *ctx) { // Value* value = std::any_cast(visitExp(ctx->exp())); Value* value = computeExp(ctx->exp()); ArrayValueTree* result = new ArrayValueTree(); result->setValue(value); return result; } std::any SysYIRGenerator::visitArrayInitValue(SysYParser::ArrayInitValueContext *ctx) { std::vector children; for (const auto &initVal : ctx->initVal()) children.push_back(std::any_cast(initVal->accept(this))); ArrayValueTree* result = new ArrayValueTree(); result->addChildren(children); return result; } std::any SysYIRGenerator::visitConstScalarInitValue(SysYParser::ConstScalarInitValueContext *ctx) { Value* value = std::any_cast(visitConstExp(ctx->constExp())); ArrayValueTree* result = new ArrayValueTree(); result->setValue(value); return result; } std::any SysYIRGenerator::visitConstExp(SysYParser::ConstExpContext *ctx){ return computeAddExp(ctx->addExp()); } std::any SysYIRGenerator::visitConstArrayInitValue(SysYParser::ConstArrayInitValueContext *ctx) { std::vector children; for (const auto &constInitVal : ctx->constInitVal()) children.push_back(std::any_cast(constInitVal->accept(this))); ArrayValueTree* result = new ArrayValueTree(); result->addChildren(children); return result; } std::any SysYIRGenerator::visitFuncType(SysYParser::FuncTypeContext *ctx) { if (ctx->INT() != nullptr) return Type::getIntType(); if (ctx->FLOAT() != nullptr) return Type::getFloatType(); return Type::getVoidType(); } std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext *ctx){ // 更新作用域 module->enterNewScope(); // 清除CSE缓存 enterNewBasicBlock(); auto name = ctx->Ident()->getText(); std::vector paramActualTypes; std::vector paramNames; std::vector> paramDims; if (ctx->funcFParams() != nullptr) { auto params = ctx->funcFParams()->funcFParam(); for (const auto ¶m : params) { Type* baseBType = std::any_cast(visitBType(param->bType())); std::string paramName = param->Ident()->getText(); // 用于收集当前参数的维度信息(如果它是数组) std::vector currentParamDims; if (!param->LBRACK().empty()) { // 如果参数声明中有方括号,说明是数组 // SysY 数组参数的第一个维度可以是未知的(例如 int arr[] 或 int arr[][10]) // 这里的 ConstantInteger::get(-1) 表示未知维度,但对于 LLVM 类型构建,我们主要关注已知维度 currentParamDims.push_back(ConstantInteger::get(-1)); // 标记第一个维度为未知 for (const auto &exp : param->exp()) { // 访问表达式以获取维度大小,这些维度必须是常量 Value* dimVal = computeExp(exp); // 确保维度是常量整数,否则 buildArrayType 会断言失败 assert(dynamic_cast(dimVal) && "Array dimension in parameter must be a constant integer!"); currentParamDims.push_back(dimVal); } } // 根据解析出的信息,确定参数在 LLVM IR 中的实际类型 Type* actualParamType; if (currentParamDims.empty()) { // 情况1:标量参数 (e.g., int x) actualParamType = baseBType; // 实际类型就是基本类型 } else { // 情况2&3:数组参数 (e.g., int arr[] 或 int arr[][10]) // 数组参数在函数传递时会退化为指针。 // 这个指针指向的类型是除第一维外,由后续维度构成的数组类型。 // 从 currentParamDims 中移除第一个标记未知维度的 -1 std::vector fixedDimsForTypeBuilding; if (currentParamDims.size() > 1) { // 如果有固定维度 (e.g., int arr[][10]) // 复制除第一个 -1 之外的所有维度 fixedDimsForTypeBuilding.assign(currentParamDims.begin() + 1, currentParamDims.end()); } Type* pointedToArrayType = baseBType; // 从基本类型开始构建 // 从最内层维度向外层构建数组类型 // buildArrayType 期望 dims 是从最外层到最内层,但它内部反向迭代,所以这里直接传入 // 例如,对于 int arr[][10],fixedDimsForTypeBuilding 包含 [10],构建出 [10 x i32] if (!fixedDimsForTypeBuilding.empty()) { pointedToArrayType = buildArrayType(baseBType, fixedDimsForTypeBuilding); } // 实际参数类型是指向这个构建好的数组类型的指针 actualParamType = Type::getPointerType(pointedToArrayType); // e.g., i32* 或 [10 x i32]* } paramActualTypes.push_back(actualParamType); // 存储参数的实际 LLVM IR 类型 paramNames.push_back(paramName); // 存储参数名称 } } Type* returnType = std::any_cast(visitFuncType(ctx->funcType())); Type* funcType = Type::getFunctionType(returnType, paramActualTypes); Function* function = module->createFunction(name, funcType); BasicBlock* entry = function->getEntryBlock(); builder.setPosition(entry, entry->end()); for(int i = 0; i < paramActualTypes.size(); ++i) { Argument* arg = new Argument(paramActualTypes[i], function, i, paramNames[i]); function->insertArgument(arg); } // 先将所有参数名字注册到符号表中,确保alloca不会使用相同的名字 for (int i = 0; i < paramNames.size(); ++i) { // 预先注册参数名字,这样addVariable就会使用不同的后缀 module->registerParameterName(paramNames[i]); } auto funcArgs = function->getArguments(); std::vector allocas; for (int i = 0; i < paramActualTypes.size(); ++i) { // 使用函数特定的前缀来确保参数alloca名字唯一 std::string allocaName = name + "_param_" + paramNames[i]; AllocaInst *alloca = builder.createAllocaInst(Type::getPointerType(paramActualTypes[i]), allocaName); // 直接设置唯一名字,不依赖addVariable的命名逻辑 alloca->setName(allocaName); allocas.push_back(alloca); // 直接添加到符号表,使用原参数名作为查找键 module->addVariableDirectly(paramNames[i], alloca); } for(int i = 0; i < paramActualTypes.size(); ++i) { Value *argValue = funcArgs[i]; builder.createStoreInst(argValue, allocas[i]); } // 在处理函数体之前,创建一个新的基本块作为函数体的实际入口 // 这样 entryBB 就可以在完成初始化后跳转到这里 BasicBlock* funcBodyEntry = function->addBasicBlock("funcBodyEntry_" + name); // 从 entryBB 无条件跳转到 funcBodyEntry builder.createUncondBrInst(funcBodyEntry); BasicBlock::conectBlocks(entry, funcBodyEntry); // 连接 entryBB 和 funcBodyEntry builder.setPosition(funcBodyEntry,funcBodyEntry->end()); // 将插入点设置到 funcBodyEntry for (auto item : ctx->blockStmt()->blockItem()) { visitBlockItem(item); } // 如果函数没有显式的返回语句,且返回类型不是 void,则需要添加一个默认的返回值 ReturnInst* retinst = nullptr; retinst = dynamic_cast(builder.getBasicBlock()->terminator()->get()); if (!retinst) { if (returnType->isVoid()) { builder.createReturnInst(); } else if (returnType->isInt()) { builder.createReturnInst(ConstantInteger::get(0)); // 默认返回 0 } else if (returnType->isFloat()) { builder.createReturnInst(ConstantFloating::get(0.0f)); // 默认返回 0.0f } else { assert(false && "Function with no explicit return and non-void type should return a value."); } } module->leaveScope(); return std::any(); } std::any SysYIRGenerator::visitBlockStmt(SysYParser::BlockStmtContext *ctx) { module->enterNewScope(); for (auto item : ctx->blockItem()) visitBlockItem(item); module->leaveScope(); return 0; } std::any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) { auto lVal = ctx->lValue(); std::string name = lVal->Ident()->getText(); Value* LValue = nullptr; Value* variable = module->getVariable(name); // 左值 vector indices; if (lVal->exp().size() > 0) { // 如果有下标,访问表达式获取下标值 for (auto &exp : lVal->exp()) { Value* indexValue = std::any_cast(computeExp(exp)); indices.push_back(indexValue); } } if (indices.empty()) { // variable 本身就是指向标量的指针 (e.g., int* %a) if (dynamic_cast(variable) || dynamic_cast(variable)) { LValue = variable; } // 标量变量的类型推断 Type* LType = builder.getIndexedType(variable->getType(), indices); Value* RValue = computeExp(ctx->exp(), LType); // 右值计算 Type* RType = RValue->getType(); // TODO:computeExp处理了类型转换,可以考虑删除判断逻辑 if (LType != RType) { ConstantValue *constValue = dynamic_cast(RValue); if (constValue != nullptr) { if (LType == Type::getFloatType()) { if(dynamic_cast(constValue)) { // 如果是整型常量,转换为浮点型 RValue = ConstantFloating::get(static_cast(constValue->getInt())); } else if (dynamic_cast(constValue)) { // 如果是浮点型常量,直接使用 RValue = ConstantFloating::get(static_cast(constValue->getFloat())); } } else { // 假设如果不是浮点型,就是整型 if(dynamic_cast(constValue)) { // 如果是浮点型常量,转换为整型 RValue = ConstantInteger::get(static_cast(constValue->getFloat())); } else if (dynamic_cast(constValue)) { // 如果是整型常量,直接使用 RValue = ConstantInteger::get(static_cast(constValue->getInt())); } } } else { if (LType == Type::getFloatType() && RType != Type::getFloatType()) { RValue = builder.createItoFInst(RValue); } else if (LType != Type::getFloatType() && RType == Type::getFloatType()) { RValue = builder.createFtoIInst(RValue); } // 如果两者都是同一类型,就不需要转换 } } builder.createStoreInst(RValue, LValue); } else { // 对于数组或多维数组的左值处理 // 需要获取 GEP 地址 Value* gepBasePointer = nullptr; std::vector gepIndices; if (AllocaInst *alloc = dynamic_cast(variable)) { Type* allocatedType = alloc->getType()->as()->getBaseType(); if (allocatedType->isPointer()) { // 尝试从缓存中获取 builder.createLoadInst(alloc) 的结果 auto it = availableLoads.find(alloc); if (it != availableLoads.end()) { gepBasePointer = it->second; // 缓存命中,重用 } else { gepBasePointer = builder.createLoadInst(alloc); // 缓存未命中,创建新的 LoadInst availableLoads[alloc] = gepBasePointer; // 将结果加入缓存 } // --- CSE 结束 --- // gepBasePointer = builder.createLoadInst(alloc); gepIndices = indices; } else { gepBasePointer = alloc; gepIndices.push_back(ConstantInteger::get(0)); gepIndices.insert(gepIndices.end(), indices.begin(), indices.end()); } } else if (GlobalValue *glob = dynamic_cast(variable)) { // 情况 B: 全局变量 (GlobalValue) gepBasePointer = glob; gepIndices.push_back(ConstantInteger::get(0)); gepIndices.insert(gepIndices.end(), indices.begin(), indices.end()); } else if (ConstantVariable *constV = dynamic_cast(variable)) { gepBasePointer = constV; gepIndices.push_back(ConstantInteger::get(0)); gepIndices.insert(gepIndices.end(), indices.begin(), indices.end()); } // 左值为地址 LValue = getGEPAddressInst(gepBasePointer, gepIndices); // 数组变量的类型推断,使用gepIndices和gepBasePointer的类型 Type* LType = builder.getIndexedType(gepBasePointer->getType(), gepIndices); Value* RValue = computeExp(ctx->exp(), LType); // 右值计算 Type* RType = RValue->getType(); // TODO:computeExp处理了类型转换,可以考虑删除判断逻辑 if (LType != RType) { ConstantValue *constValue = dynamic_cast(RValue); if (constValue != nullptr) { if (LType == Type::getFloatType()) { if(dynamic_cast(constValue)) { // 如果是整型常量,转换为浮点型 RValue = ConstantFloating::get(static_cast(constValue->getInt())); } else if (dynamic_cast(constValue)) { // 如果是浮点型常量,直接使用 RValue = ConstantFloating::get(static_cast(constValue->getFloat())); } } else { // 假设如果不是浮点型,就是整型 if(dynamic_cast(constValue)) { // 如果是浮点型常量,转换为整型 RValue = ConstantInteger::get(static_cast(constValue->getFloat())); } else if (dynamic_cast(constValue)) { // 如果是整型常量,直接使用 RValue = ConstantInteger::get(static_cast(constValue->getInt())); } } } else { if (LType == Type::getFloatType() && RType != Type::getFloatType()) { RValue = builder.createItoFInst(RValue); } else if (LType != Type::getFloatType() && RType == Type::getFloatType()) { RValue = builder.createFtoIInst(RValue); } // 如果两者都是同一类型,就不需要转换 } } builder.createStoreInst(RValue, LValue); } invalidateExpressionsOnStore(LValue); return std::any(); } std::any SysYIRGenerator::visitExpStmt(SysYParser::ExpStmtContext *ctx) { // 访问表达式 if (ctx->exp() != nullptr) { computeExp(ctx->exp()); } return std::any(); } std::any SysYIRGenerator::visitIfStmt(SysYParser::IfStmtContext *ctx) { // labels string stream std::stringstream labelstring; Function * function = builder.getBasicBlock()->getParent(); BasicBlock* thenBlock = new BasicBlock(function); BasicBlock* exitBlock = new BasicBlock(function); if (ctx->stmt().size() > 1) { BasicBlock* elseBlock = new BasicBlock(function); builder.pushTrueBlock(thenBlock); builder.pushFalseBlock(elseBlock); // 访问条件表达式 visitCond(ctx->cond()); builder.popTrueBlock(); builder.popFalseBlock(); labelstring << "if_then.L" << builder.getLabelIndex(); thenBlock->setName(labelstring.str()); labelstring.str(""); function->addBasicBlock(thenBlock); builder.setPosition(thenBlock, thenBlock->end()); // CSE清除缓存 enterNewBasicBlock(); auto block = dynamic_cast(ctx->stmt(0)); // 如果是块语句,直接访问 // 否则访问语句 if (block != nullptr) { visitBlockStmt(block); } else { module->enterNewScope(); ctx->stmt(0)->accept(this); module->leaveScope(); } builder.createUncondBrInst(exitBlock); BasicBlock::conectBlocks(builder.getBasicBlock(), exitBlock); labelstring << "if_else.L" << builder.getLabelIndex(); elseBlock->setName(labelstring.str()); labelstring.str(""); function->addBasicBlock(elseBlock); builder.setPosition(elseBlock, elseBlock->end()); // CSE清除缓存 enterNewBasicBlock(); block = dynamic_cast(ctx->stmt(1)); if (block != nullptr) { visitBlockStmt(block); } else { module->enterNewScope(); ctx->stmt(1)->accept(this); module->leaveScope(); } builder.createUncondBrInst(exitBlock); BasicBlock::conectBlocks(builder.getBasicBlock(), exitBlock); labelstring << "if_exit.L" << builder.getLabelIndex(); exitBlock->setName(labelstring.str()); labelstring.str(""); function->addBasicBlock(exitBlock); builder.setPosition(exitBlock, exitBlock->end()); // CSE清除缓存 enterNewBasicBlock(); } else { builder.pushTrueBlock(thenBlock); builder.pushFalseBlock(exitBlock); visitCond(ctx->cond()); builder.popTrueBlock(); builder.popFalseBlock(); labelstring << "if_then.L" << builder.getLabelIndex(); thenBlock->setName(labelstring.str()); labelstring.str(""); function->addBasicBlock(thenBlock); builder.setPosition(thenBlock, thenBlock->end()); // CSE清除缓存 enterNewBasicBlock(); auto block = dynamic_cast(ctx->stmt(0)); if (block != nullptr) { visitBlockStmt(block); } else { module->enterNewScope(); ctx->stmt(0)->accept(this); module->leaveScope(); } builder.createUncondBrInst(exitBlock); BasicBlock::conectBlocks(builder.getBasicBlock(), exitBlock); labelstring << "if_exit.L" << builder.getLabelIndex(); exitBlock->setName(labelstring.str()); labelstring.str(""); function->addBasicBlock(exitBlock); builder.setPosition(exitBlock, exitBlock->end()); // CSE清除缓存 enterNewBasicBlock(); } return std::any(); } std::any SysYIRGenerator::visitWhileStmt(SysYParser::WhileStmtContext *ctx) { // while structure: // curblock -> headBlock -> bodyBlock -> exitBlock BasicBlock* curBlock = builder.getBasicBlock(); Function* function = builder.getBasicBlock()->getParent(); std::stringstream labelstring; labelstring << "while_head.L" << builder.getLabelIndex(); BasicBlock *headBlock = function->addBasicBlock(labelstring.str()); labelstring.str(""); builder.createUncondBrInst(headBlock); BasicBlock::conectBlocks(curBlock, headBlock); builder.setPosition(headBlock, headBlock->end()); // CSE清除缓存 enterNewBasicBlock(); BasicBlock* bodyBlock = new BasicBlock(function); BasicBlock* exitBlock = new BasicBlock(function); builder.pushTrueBlock(bodyBlock); builder.pushFalseBlock(exitBlock); // 访问条件表达式 visitCond(ctx->cond()); builder.popTrueBlock(); builder.popFalseBlock(); labelstring << "while_body.L" << builder.getLabelIndex(); bodyBlock->setName(labelstring.str()); labelstring.str(""); function->addBasicBlock(bodyBlock); builder.setPosition(bodyBlock, bodyBlock->end()); // CSE清除缓存 enterNewBasicBlock(); builder.pushBreakBlock(exitBlock); builder.pushContinueBlock(headBlock); auto block = dynamic_cast(ctx->stmt()); if( block != nullptr) { visitBlockStmt(block); } else { module->enterNewScope(); ctx->stmt()->accept(this); module->leaveScope(); } builder.createUncondBrInst(headBlock); BasicBlock::conectBlocks(builder.getBasicBlock(), headBlock); builder.popBreakBlock(); builder.popContinueBlock(); labelstring << "while_exit.L" << builder.getLabelIndex(); exitBlock->setName(labelstring.str()); labelstring.str(""); function->addBasicBlock(exitBlock); builder.setPosition(exitBlock, exitBlock->end()); // CSE清除缓存 enterNewBasicBlock(); return std::any(); } std::any SysYIRGenerator::visitBreakStmt(SysYParser::BreakStmtContext *ctx) { BasicBlock* breakBlock = builder.getBreakBlock(); builder.createUncondBrInst(breakBlock); BasicBlock::conectBlocks(builder.getBasicBlock(), breakBlock); return std::any(); } std::any SysYIRGenerator::visitContinueStmt(SysYParser::ContinueStmtContext *ctx) { BasicBlock* continueBlock = builder.getContinueBlock(); builder.createUncondBrInst(continueBlock); BasicBlock::conectBlocks(builder.getBasicBlock(), continueBlock); return std::any(); } std::any SysYIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext *ctx) { Value* returnValue = nullptr; Type* funcType = builder.getBasicBlock()->getParent()->getReturnType(); if (ctx->exp() != nullptr) { returnValue = computeExp(ctx->exp(), funcType); } // TODOL 考虑删除类型转换判断逻辑 if (returnValue != nullptr && funcType!= returnValue->getType()) { ConstantValue * constValue = dynamic_cast(returnValue); if (constValue != nullptr) { if (funcType == Type::getFloatType()) { if(dynamic_cast(constValue)) { // 如果是整型常量,转换为浮点型 returnValue = ConstantFloating::get(static_cast(constValue->getInt())); } else if (dynamic_cast(constValue)) { // 如果是浮点型常量,直接使用 returnValue = ConstantFloating::get(static_cast(constValue->getInt())); } } else { if(dynamic_cast(constValue)) { // 如果是浮点型常量,转换为整型 returnValue = ConstantInteger::get(static_cast(constValue->getFloat())); } else if (dynamic_cast(constValue)) { // 如果是整型常量,直接使用 returnValue = ConstantInteger::get(static_cast(constValue->getFloat())); } } } else { if (funcType == Type::getFloatType()) { returnValue = builder.createItoFInst(returnValue); } else { returnValue = builder.createFtoIInst(returnValue); } } } builder.createReturnInst(returnValue); return std::any(); } // 辅助函数:计算给定类型中嵌套的数组维度数量 // 例如: // - 对于 i32* 类型,它指向 i32,维度为 0。 // - 对于 [10 x i32]* 类型,它指向 [10 x i32],维度为 1。 // - 对于 [20 x [10 x i32]]* 类型,它指向 [20 x [10 x i32]],维度为 2。 unsigned SysYIRGenerator::countArrayDimensions(Type* type) { unsigned dims = 0; Type* currentType = type; // 如果是指针类型,先获取它指向的基础类型 if (currentType->isPointer()) { currentType = currentType->as()->getBaseType(); } // 递归地计算数组的维度层数 while (currentType && currentType->isArray()) { dims++; currentType = currentType->as()->getElementType(); } return dims; } std::any SysYIRGenerator::visitLValue(SysYParser::LValueContext *ctx) { std::string name = ctx->Ident()->getText(); Value* variable = module->getVariable(name); Value* value = nullptr; std::vector dims; for (const auto &exp : ctx->exp()) { Value* expValue = std::any_cast(computeExp(exp)); dims.push_back(expValue); } // 1. 获取变量的声明维度数量 unsigned declaredNumDims = countArrayDimensions(variable->getType()); // 2. 处理常量变量 (ConstantVariable) 且所有索引都是常量的情况 ConstantVariable* constVar = dynamic_cast(variable); if (constVar != nullptr) { bool allIndicesConstant = true; for (const auto &dim : dims) { if (dynamic_cast(dim) == nullptr) { allIndicesConstant = false; break; } } // 如果是常量变量且所有索引都是常量,并且不是数组名单独出现的情况 if (allIndicesConstant && !dims.empty()) { // 如果是常量变量且所有索引都是常量,直接通过 getByIndices 获取编译时值 // 这个方法会根据索引深度返回最终的标量值或指向子数组的指针 (作为 ConstantValue/Variable) return constVar->getByIndices(dims); } // 如果dims为空,检查是否是常量标量 if (dims.empty() && declaredNumDims == 0) { // 常量标量,直接返回其值 // 默认传入空索引列表,表示访问标量本身 return constVar->getByIndices(dims); } // 如果dims为空但不是标量(数组名单独出现),需要走GEP路径来实现数组到指针的退化 } // 3. 处理可变变量 (AllocaInst/GlobalValue) 或带非常量索引的常量变量 // 这里区分标量访问和数组元素/子数组访问 Value *targetAddress = nullptr; // 检查是否是访问标量变量本身(没有索引,且声明维度为0) if (dims.empty() && declaredNumDims == 0) { if (dynamic_cast(variable) || dynamic_cast(variable)) { targetAddress = variable; } else { assert(false && "Unhandled scalar variable type in LValue access."); return static_cast(nullptr); } } else { Value* gepBasePointer = nullptr; std::vector gepIndices; if (AllocaInst *alloc = dynamic_cast(variable)) { Type* allocatedType = alloc->getType()->as()->getBaseType(); if (allocatedType->isPointer()) { gepBasePointer = builder.createLoadInst(alloc); gepIndices = dims; } else { gepBasePointer = alloc; gepIndices.push_back(ConstantInteger::get(0)); if (dims.empty() && declaredNumDims > 0) { // 数组名单独出现(没有索引):在SysY中,多维数组名应该退化为指向第一行的指针 // 对于二维数组 T[M][N],退化为 T(*)[N],需要GEP: getelementptr T[M][N], T[M][N]* ptr, i32 0, i32 0 // 第一个i32 0: 选择数组本身,第二个i32 0: 选择第0行 // 结果类型: T[N]* gepIndices.push_back(ConstantInteger::get(0)); } else { // 正常的数组元素访问 gepIndices.insert(gepIndices.end(), dims.begin(), dims.end()); } } } else if (GlobalValue *glob = dynamic_cast(variable)) { gepBasePointer = glob; gepIndices.push_back(ConstantInteger::get(0)); if (dims.empty() && declaredNumDims > 0) { // 全局数组名单独出现(没有索引):应该退化为指向第一行的指针 // 需要添加一个额外的i32 0索引 gepIndices.push_back(ConstantInteger::get(0)); } else { // 正常的数组元素访问 gepIndices.insert(gepIndices.end(), dims.begin(), dims.end()); } } else if (ConstantVariable *constV = dynamic_cast(variable)) { gepBasePointer = constV; gepIndices.push_back(ConstantInteger::get(0)); if (dims.empty() && declaredNumDims > 0) { // 常量数组名单独出现(没有索引):应该退化为指向第一行的指针 // 需要添加一个额外的i32 0索引 gepIndices.push_back(ConstantInteger::get(0)); } else { // 正常的数组元素访问 gepIndices.insert(gepIndices.end(), dims.begin(), dims.end()); } } else { assert(false && "LValue variable type not supported for GEP base pointer."); return static_cast(nullptr); } targetAddress = getGEPAddressInst(gepBasePointer, gepIndices); } // 如果提供的索引数量少于声明的维度数量,则表示访问的是子数组,返回其地址 (无需加载) if (dims.size() < declaredNumDims) { value = targetAddress; } else { // value = builder.createLoadInst(targetAddress); auto it = availableLoads.find(targetAddress); if (it != availableLoads.end()) { value = it->second; // 缓存命中,重用已有的 LoadInst 结果 } else { // 缓存未命中,创建新的 LoadInst value = builder.createLoadInst(targetAddress); availableLoads[targetAddress] = value; // 将新的 LoadInst 结果加入缓存 } } return value; } std::any SysYIRGenerator::visitPrimaryExp(SysYParser::PrimaryExpContext *ctx) { if (ctx->exp() != nullptr) { BinaryExpStack.push_back(BinaryOp::LPAREN);BinaryExpLenStack.back()++; visitExp(ctx->exp()); BinaryExpStack.push_back(BinaryOp::RPAREN);BinaryExpLenStack.back()++; } if (ctx->lValue() != nullptr) { // 如果是 lValue,将value压入栈中 BinaryExpStack.push_back(std::any_cast(visitLValue(ctx->lValue())));BinaryExpLenStack.back()++; } if (ctx->number() != nullptr) { BinaryExpStack.push_back(std::any_cast(visitNumber(ctx->number())));BinaryExpLenStack.back()++; } if (ctx->string() != nullptr) { cout << "String literal not supported in SysYIRGenerator." << endl; } return std::any(); } std::any SysYIRGenerator::visitNumber(SysYParser::NumberContext *ctx) { if (ctx->ILITERAL() != nullptr) { int value = std::stol(ctx->ILITERAL()->getText(), nullptr, 0); return static_cast(ConstantInteger::get(value)); } else if (ctx->FLITERAL() != nullptr) { float value = std::stof(ctx->FLITERAL()->getText()); return static_cast(ConstantFloating::get(value)); } throw std::runtime_error("Unknown number type."); return std::any(); // 不会到达这里 } std::any SysYIRGenerator::visitCall(SysYParser::CallContext *ctx) { std::string funcName = ctx->Ident()->getText(); Function *function = module->getFunction(funcName); if (function == nullptr) { function = module->getExternalFunction(funcName); if (function == nullptr) { std::cout << "The function " << funcName << " no defined." << std::endl; assert(function); } } std::vector args = {}; if (funcName == "starttime" || funcName == "stoptime") { args.emplace_back( ConstantInteger::get(static_cast(ctx->getStart()->getLine()))); } else { if (ctx->funcRParams() != nullptr) { args = std::any_cast>(visitFuncRParams(ctx->funcRParams())); } // 获取形参列表。`getArguments()` 返回的是 `Argument*` 的集合, // 每个 `Argument` 代表一个函数形参,其 `getType()` 就是指向形参的类型的指针类型。 const auto& formalParams = function->getArguments(); // 检查实参和形参数量是否匹配。 if (args.size() != function->getNumArguments()) { std::cerr << "Error: Function call argument count mismatch for function '" << funcName << "'." << std::endl; assert(false && "Function call argument count mismatch!"); } for (int i = 0; i < args.size(); i++) { // 形参的类型 (e.g., i32, float, i32*, [10 x i32]*) Type* formalParamExpectedValueType = formalParams[i]->getType(); // 实参的实际类型 (e.g., i32, float, i32*, [67 x i32]*) Type* actualArgType = args[i]->getType(); // 如果实参类型与形参类型不匹配,则尝试进行类型转换 if (formalParamExpectedValueType != actualArgType) { ConstantValue *constValue = dynamic_cast(args[i]); if (constValue != nullptr) { if (formalParamExpectedValueType->isInt() && actualArgType->isFloat()) { args[i] = ConstantInteger::get(static_cast(constValue->getFloat())); } else if (formalParamExpectedValueType->isFloat() && actualArgType->isInt()) { args[i] = ConstantFloating::get(static_cast(constValue->getInt())); } else { // 如果是常量但不是简单的 int/float 标量转换, // 或者是指针常量需要 bitcast,则让它进入非常量转换逻辑。 // 例如,一个常量数组的地址,需要 bitcast 成另一种指针类型。 // 目前不知道样例有没有这种情况,所以这里不做处理。 } } else { // 1. 标量值类型转换 (例如:int_reg 到 float_reg,float_reg 到 int_reg) if (formalParamExpectedValueType->isInt() && actualArgType->isFloat()) { args[i] = builder.createFtoIInst(args[i]); } else if (formalParamExpectedValueType->isFloat() && actualArgType->isInt()) { args[i] = builder.createItoFInst(args[i]); } // 2. 指针类型转换 (例如数组退化:`[N x T]*` 到 `T*`,或兼容指针类型之间) // 这种情况常见于数组参数,实参可能是一个更具体的数组指针类型, // 而形参是其退化后的基础指针类型。 else if (formalParamExpectedValueType->isPointer() && actualArgType->isPointer()) { // 检查是否是数组指针到元素指针的decay // 例如:[N x T]* -> T* auto formalPtrType = formalParamExpectedValueType->as(); auto actualPtrType = actualArgType->as(); if (formalPtrType && actualPtrType && actualPtrType->getBaseType()->isArray()) { auto actualArrayType = actualPtrType->getBaseType()->as(); if (actualArrayType && formalPtrType->getBaseType() == actualArrayType->getElementType()) { // 这是数组decay的情况,添加GEP来获取数组的第一个元素 std::vector indices; indices.push_back(ConstantInteger::get(0)); // 第一个索引:解引用指针 indices.push_back(ConstantInteger::get(0)); // 第二个索引:获取数组第一个元素 args[i] = getGEPAddressInst(args[i], indices); } } } // 3. 其他未预期的类型不匹配 // 如果代码执行到这里,说明存在编译器前端未处理的类型不兼容或错误。 else { // assert(false && "Unhandled type mismatch for function call argument."); } } } } } return static_cast(builder.createCallInst(function, args)); } std::any SysYIRGenerator::visitUnaryExp(SysYParser::UnaryExpContext *ctx) { if (ctx->primaryExp() != nullptr) { visitPrimaryExp(ctx->primaryExp()); } else if (ctx->call() != nullptr) { BinaryExpStack.push_back(std::any_cast(visitCall(ctx->call())));BinaryExpLenStack.back()++; invalidateExpressionsOnCall(); } else if (ctx->unaryOp() != nullptr) { // 遇到一元操作符,将其压入 BinaryExpStack auto opNode = dynamic_cast(ctx->unaryOp()->children[0]); int opType = opNode->getSymbol()->getType(); switch(opType) { case SysYParser::ADD: BinaryExpStack.push_back(BinaryOp::PLUS); BinaryExpLenStack.back()++; break; case SysYParser::SUB: BinaryExpStack.push_back(BinaryOp::NEG); BinaryExpLenStack.back()++; break; case SysYParser::NOT: BinaryExpStack.push_back(BinaryOp::NOT); BinaryExpLenStack.back()++; break; default: assert(false && "Unexpected operator in UnaryExp."); } visitUnaryExp(ctx->unaryExp()); } return std::any(); } std::any SysYIRGenerator::visitFuncRParams(SysYParser::FuncRParamsContext *ctx) { std::vector params; for (const auto &exp : ctx->exp()) { auto param = std::any_cast(computeExp(exp)); params.push_back(param); } return params; } std::any SysYIRGenerator::visitMulExp(SysYParser::MulExpContext *ctx) { visitUnaryExp(ctx->unaryExp(0)); for (int i = 1; i < ctx->unaryExp().size(); i++) { auto opNode = dynamic_cast(ctx->children[2*i-1]); int opType = opNode->getSymbol()->getType(); switch(opType) { case SysYParser::MUL: BinaryExpStack.push_back(BinaryOp::MUL); BinaryExpLenStack.back()++; break; case SysYParser::DIV: BinaryExpStack.push_back(BinaryOp::DIV); BinaryExpLenStack.back()++; break; case SysYParser::MOD: BinaryExpStack.push_back(BinaryOp::MOD); BinaryExpLenStack.back()++; break; default: assert(false && "Unexpected operator in MulExp."); } visitUnaryExp(ctx->unaryExp(i)); } return std::any(); } std::any SysYIRGenerator::visitAddExp(SysYParser::AddExpContext *ctx) { visitMulExp(ctx->mulExp(0)); for (int i = 1; i < ctx->mulExp().size(); i++) { auto opNode = dynamic_cast(ctx->children[2*i-1]); int opType = opNode->getSymbol()->getType(); switch(opType) { case SysYParser::ADD: BinaryExpStack.push_back(BinaryOp::ADD); BinaryExpLenStack.back()++; break; case SysYParser::SUB: BinaryExpStack.push_back(BinaryOp::SUB); BinaryExpLenStack.back()++; break; default: assert(false && "Unexpected operator in AddExp."); } visitMulExp(ctx->mulExp(i)); } return std::any(); } std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) { Value* result = computeAddExp(ctx->addExp(0)); for (int i = 1; i < ctx->addExp().size(); i++) { auto opNode = dynamic_cast(ctx->children[2*i-1]); int opType = opNode->getSymbol()->getType(); Value* operand = computeAddExp(ctx->addExp(i)); Type* resultType = result->getType(); Type* operandType = operand->getType(); ConstantValue* constResult = dynamic_cast(result); ConstantValue* constOperand = dynamic_cast(operand); // 常量比较 if ((constResult != nullptr) && (constOperand != nullptr)) { auto operand1 = constResult->isFloat() ? constResult->getFloat() : constResult->getInt(); auto operand2 = constOperand->isFloat() ? constOperand->getFloat() : constOperand->getInt(); if (opType == SysYParser::LT) result = ConstantInteger::get(operand1 < operand2 ? 1 : 0); else if (opType == SysYParser::GT) result = ConstantInteger::get(operand1 > operand2 ? 1 : 0); else if (opType == SysYParser::LE) result = ConstantInteger::get(operand1 <= operand2 ? 1 : 0); else if (opType == SysYParser::GE) result = ConstantInteger::get(operand1 >= operand2 ? 1 : 0); else assert(false); } else { Type* resultType = result->getType(); Type* operandType = operand->getType(); Type* floatType = Type::getFloatType(); // 浮点数处理 if (resultType == floatType || operandType == floatType) { if (resultType != floatType) { if (constResult != nullptr){ if(dynamic_cast(constResult)) { // 如果是整型常量,转换为浮点型 result = ConstantFloating::get(static_cast(constResult->getInt())); } else if (dynamic_cast(constResult)) { // 如果是浮点型常量,直接使用 result = ConstantFloating::get(static_cast(constResult->getFloat())); } } else result = builder.createItoFInst(result); } if (operandType != floatType) { if (constOperand != nullptr) { if(dynamic_cast(constOperand)) { // 如果是整型常量,转换为浮点型 operand = ConstantFloating::get(static_cast(constOperand->getInt())); } else if (dynamic_cast(constOperand)) { // 如果是浮点型常量,直接使用 operand = ConstantFloating::get(static_cast(constOperand->getFloat())); } } else operand = builder.createItoFInst(operand); } if (opType == SysYParser::LT) result = builder.createFCmpLTInst(result, operand); else if (opType == SysYParser::GT) result = builder.createFCmpGTInst(result, operand); else if (opType == SysYParser::LE) result = builder.createFCmpLEInst(result, operand); else if (opType == SysYParser::GE) result = builder.createFCmpGEInst(result, operand); else assert(false); } else { // 整数处理 if (opType == SysYParser::LT) result = builder.createICmpLTInst(result, operand); else if (opType == SysYParser::GT) result = builder.createICmpGTInst(result, operand); else if (opType == SysYParser::LE) result = builder.createICmpLEInst(result, operand); else if (opType == SysYParser::GE) result = builder.createICmpGEInst(result, operand); else assert(false); } } } return result; } std::any SysYIRGenerator::visitEqExp(SysYParser::EqExpContext *ctx) { // TODO:其实已经保证了result是一个int类型的值可以删除冗余判断逻辑 Value * result = std::any_cast(visitRelExp(ctx->relExp(0))); for (int i = 1; i < ctx->relExp().size(); i++) { auto opNode = dynamic_cast(ctx->children[2*i-1]); int opType = opNode->getSymbol()->getType(); Value * operand = std::any_cast(visitRelExp(ctx->relExp(i))); ConstantValue* constResult = dynamic_cast(result); ConstantValue* constOperand = dynamic_cast(operand); if ((constResult != nullptr) && (constOperand != nullptr)) { auto operand1 = constResult->isFloat() ? constResult->getFloat() : constResult->getInt(); auto operand2 = constOperand->isFloat() ? constOperand->getFloat() : constOperand->getInt(); if (opType == SysYParser::EQ) result = ConstantInteger::get(operand1 == operand2 ? 1 : 0); else if (opType == SysYParser::NE) result = ConstantInteger::get(operand1 != operand2 ? 1 : 0); else assert(false); } else { Type* resultType = result->getType(); Type* operandType = operand->getType(); Type* floatType = Type::getFloatType(); if (resultType == floatType || operandType == floatType) { if (resultType != floatType) { if (constResult != nullptr){ if(dynamic_cast(constResult)) { // 如果是整型常量,转换为浮点型 result = ConstantFloating::get(static_cast(constResult->getInt())); } else if (dynamic_cast(constResult)) { // 如果是浮点型常量,直接使用 result = ConstantFloating::get(static_cast(constResult->getFloat())); } } else result = builder.createItoFInst(result); } if (operandType != floatType) { if (constOperand != nullptr) { if(dynamic_cast(constOperand)) { // 如果是整型常量,转换为浮点型 operand = ConstantFloating::get(static_cast(constOperand->getInt())); } else if (dynamic_cast(constOperand)) { // 如果是浮点型常量,直接使用 operand = ConstantFloating::get(static_cast(constOperand->getFloat())); } } else operand = builder.createItoFInst(operand); } if (opType == SysYParser::EQ) result = builder.createFCmpEQInst(result, operand); else if (opType == SysYParser::NE) result = builder.createFCmpNEInst(result, operand); else assert(false); } else { if (opType == SysYParser::EQ) result = builder.createICmpEQInst(result, operand); else if (opType == SysYParser::NE) result = builder.createICmpNEInst(result, operand); else assert(false); } } } if (ctx->relExp().size() == 1) { ConstantValue * constResult = dynamic_cast(result); // 如果只有一个关系表达式,则将结果转换为0或1 if (constResult != nullptr) { if (constResult->isFloat()) result = ConstantInteger::get(constResult->getFloat() != 0.0F ? 1 : 0); else result = ConstantInteger::get(constResult->getInt() != 0 ? 1 : 0); } } return result; } std::any SysYIRGenerator::visitLAndExp(SysYParser::LAndExpContext *ctx){ std::stringstream labelstring; BasicBlock *curBlock = builder.getBasicBlock(); Function *function = builder.getBasicBlock()->getParent(); BasicBlock *trueBlock = builder.getTrueBlock(); BasicBlock *falseBlock = builder.getFalseBlock(); auto conds = ctx->eqExp(); for (int i = 0; i < conds.size() - 1; i++) { labelstring << "AND.L" << builder.getLabelIndex(); BasicBlock *newtrueBlock = function->addBasicBlock(labelstring.str()); labelstring.str(""); auto cond = std::any_cast(visitEqExp(ctx->eqExp(i))); builder.createCondBrInst(cond, newtrueBlock, falseBlock); BasicBlock::conectBlocks(curBlock, newtrueBlock); BasicBlock::conectBlocks(curBlock, falseBlock); curBlock = newtrueBlock; builder.setPosition(curBlock, curBlock->end()); } auto cond = std::any_cast(visitEqExp(conds.back())); builder.createCondBrInst(cond, trueBlock, falseBlock); BasicBlock::conectBlocks(curBlock, trueBlock); BasicBlock::conectBlocks(curBlock, falseBlock); return std::any(); } auto SysYIRGenerator::visitLOrExp(SysYParser::LOrExpContext *ctx) -> std::any { std::stringstream labelstring; BasicBlock *curBlock = builder.getBasicBlock(); Function *function = curBlock->getParent(); auto conds = ctx->lAndExp(); for (int i = 0; i < conds.size() - 1; i++) { labelstring << "OR.L" << builder.getLabelIndex(); BasicBlock *newFalseBlock = function->addBasicBlock(labelstring.str()); labelstring.str(""); builder.pushFalseBlock(newFalseBlock); visitLAndExp(ctx->lAndExp(i)); builder.popFalseBlock(); builder.setPosition(newFalseBlock, newFalseBlock->end()); } visitLAndExp(conds.back()); return std::any(); } // attention : 这里的type是数组元素的type void Utils::tree2Array(Type *type, ArrayValueTree *root, const std::vector &dims, unsigned numDims, ValueCounter &result, IRBuilder *builder) { Value* value = root->getValue(); auto &children = root->getChildren(); // 类型转换 if (value != nullptr) { if (type == value->getType()) { result.push_back(value); } else { if (type == Type::getFloatType()) { ConstantValue* constValue = dynamic_cast(value); if (constValue != nullptr) { if(dynamic_cast(constValue)) result.push_back(ConstantFloating::get(static_cast(constValue->getInt()))); else if (dynamic_cast(constValue)) result.push_back(ConstantFloating::get(static_cast(constValue->getFloat()))); else assert(false && "Unknown constant type for float conversion."); } else result.push_back(builder->createItoFInst(value)); } else { ConstantValue* constValue = dynamic_cast(value); if (constValue != nullptr){ if(dynamic_cast(constValue)) result.push_back(ConstantInteger::get(constValue->getInt())); else if (dynamic_cast(constValue)) result.push_back(ConstantInteger::get(static_cast(constValue->getFloat()))); else assert(false && "Unknown constant type for int conversion."); } else result.push_back(builder->createFtoIInst(value)); } } return; } auto beforeSize = result.size(); for (const auto &child : children) { int begin = result.size(); int newNumDims = 0; for (unsigned i = 0; i < numDims - 1; i++) { auto dim = dynamic_cast(*(dims.rbegin() + i))->getInt(); if (begin % dim == 0) { newNumDims += 1; begin /= dim; } else { break; } } tree2Array(type, child.get(), dims, newNumDims, result, builder); } auto afterSize = result.size(); int blockSize = 1; for (unsigned i = 0; i < numDims; i++) { blockSize *= dynamic_cast(*(dims.rbegin() + i))->getInt(); } int num = blockSize - afterSize + beforeSize; if (num > 0) { if (type == Type::getFloatType()) result.push_back(ConstantFloating::get(0.0F), num); else result.push_back(ConstantInteger::get(0), num); } } void Utils::createExternalFunction( const std::vector ¶mTypes, const std::vector ¶mNames, const std::vector> ¶mDims, Type *returnType, const std::string &funcName, Module *pModule, IRBuilder *pBuilder) { // 根据paramDims调整参数类型,数组参数需要转换为指针类型 std::vector adjustedParamTypes = paramTypes; for (int i = 0; i < paramTypes.size() && i < paramDims.size(); ++i) { if (!paramDims[i].empty()) { // 如果参数有维度信息,说明是数组参数,转换为指针类型 adjustedParamTypes[i] = Type::getPointerType(paramTypes[i]); } } auto funcType = Type::getFunctionType(returnType, adjustedParamTypes); auto function = pModule->createExternalFunction(funcName, funcType); auto entry = function->getEntryBlock(); pBuilder->setPosition(entry, entry->end()); for (int i = 0; i < paramTypes.size(); ++i) { auto arg = new Argument(adjustedParamTypes[i], function, i, paramNames[i]); auto alloca = pBuilder->createAllocaInst( Type::getPointerType(adjustedParamTypes[i]), paramNames[i]); function->insertArgument(arg); auto store = pBuilder->createStoreInst(arg, alloca); pModule->addVariable(paramNames[i], alloca); } } void Utils::initExternalFunction(Module *pModule, IRBuilder *pBuilder) { std::vector paramTypes; std::vector paramNames; std::vector> paramDims; Type *returnType; std::string funcName; returnType = Type::getIntType(); funcName = "getint"; Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, funcName, pModule, pBuilder); funcName = "getch"; Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, funcName, pModule, pBuilder); paramTypes.push_back(Type::getIntType()); paramNames.emplace_back("x"); paramDims.push_back(std::vector{ConstantInteger::get(-1)}); funcName = "getarray"; Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, funcName, pModule, pBuilder); returnType = Type::getFloatType(); paramTypes.clear(); paramNames.clear(); paramDims.clear(); funcName = "getfloat"; Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, funcName, pModule, pBuilder); returnType = Type::getIntType(); paramTypes.push_back(Type::getFloatType()); paramNames.emplace_back("x"); paramDims.push_back(std::vector{ConstantInteger::get(-1)}); funcName = "getfarray"; Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, funcName, pModule, pBuilder); returnType = Type::getVoidType(); paramTypes.clear(); paramTypes.push_back(Type::getIntType()); paramDims.clear(); paramDims.emplace_back(); funcName = "putint"; Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, funcName, pModule, pBuilder); funcName = "putch"; Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, funcName, pModule, pBuilder); paramTypes.clear(); paramTypes.push_back(Type::getIntType()); paramTypes.push_back(Type::getIntType()); paramDims.clear(); paramDims.emplace_back(); paramDims.push_back(std::vector{ConstantInteger::get(-1)}); paramNames.clear(); paramNames.emplace_back("n"); paramNames.emplace_back("a"); funcName = "putarray"; Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, funcName, pModule, pBuilder); paramTypes.clear(); paramTypes.push_back(Type::getFloatType()); paramDims.clear(); paramDims.emplace_back(); paramNames.clear(); paramNames.emplace_back("a"); funcName = "putfloat"; Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, funcName, pModule, pBuilder); paramTypes.clear(); paramTypes.push_back(Type::getIntType()); paramTypes.push_back(Type::getFloatType()); paramDims.clear(); paramDims.emplace_back(); paramDims.push_back(std::vector{ConstantInteger::get(-1)}); paramNames.clear(); paramNames.emplace_back("n"); paramNames.emplace_back("a"); funcName = "putfarray"; Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, funcName, pModule, pBuilder); paramTypes.clear(); paramTypes.push_back(Type::getIntType()); paramDims.clear(); paramDims.emplace_back(); paramNames.clear(); paramNames.emplace_back("__LINE__"); funcName = "starttime"; Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, funcName, pModule, pBuilder); paramTypes.clear(); paramTypes.push_back(Type::getIntType()); paramDims.clear(); paramDims.emplace_back(); paramNames.clear(); paramNames.emplace_back("__LINE__"); funcName = "stoptime"; Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, funcName, pModule, pBuilder); } void Utils::modify_timefuncname(Module *pModule){ auto starttimeFunc = pModule->getExternalFunction("starttime"); auto stoptimeFunc = pModule->getExternalFunction("stoptime"); starttimeFunc->setName("_sysy_starttime"); stoptimeFunc->setName("_sysy_stoptime"); } } // namespace sysy