#include "IR.h" #include #include #include #include #include #include #include #include "IRBuilder.h" /** * @file IR.cpp * * @brief 定义IR相关类型与操作的源文件 */ namespace sysy { //===----------------------------------------------------------------------===// // Types //===----------------------------------------------------------------------===// auto Type::getIntType() -> Type * { static Type intType(kInt); return &intType; } auto Type::getFloatType() -> Type * { static Type floatType(kFloat); return &floatType; } auto Type::getVoidType() -> Type * { static Type voidType(kVoid); return &voidType; } auto Type::getLabelType() -> Type * { static Type labelType(kLabel); return &labelType; } auto Type::getPointerType(Type *baseType) -> Type * { // forward to PointerType return PointerType::get(baseType); } auto Type::getFunctionType(Type *returnType, const std::vector ¶mTypes) -> Type * { // forward to FunctionType return FunctionType::get(returnType, paramTypes); } auto Type::getArrayType(Type *elementType, unsigned numElements) -> Type * { // forward to ArrayType return ArrayType::get(elementType, numElements); } auto Type::getSize() const -> unsigned { switch (kind) { case kInt: case kFloat: return 4; case kLabel: case kPointer: case kFunction: return 8; case Kind::kArray: { const ArrayType* arrType = static_cast(this); return arrType->getElementType()->getSize() * arrType->getNumElements(); } case kVoid: return 0; } return 0; } PointerType* PointerType::get(Type *baseType) { static std::map> pointerTypes; auto iter = pointerTypes.find(baseType); if (iter != pointerTypes.end()) { return iter->second.get(); } auto type = new PointerType(baseType); assert(type); auto result = pointerTypes.emplace(baseType, type); return result.first->second.get(); } FunctionType*FunctionType::get(Type *returnType, const std::vector ¶mTypes) { static std::set> functionTypes; auto iter = std::find_if(functionTypes.begin(), functionTypes.end(), [&](const std::unique_ptr &type) -> bool { if (returnType != type->getReturnType() || paramTypes.size() != static_cast(type->getParamTypes().size())) { return false; } return std::equal(paramTypes.begin(), paramTypes.end(), type->getParamTypes().begin()); }); if (iter != functionTypes.end()) { return iter->get(); } auto type = new FunctionType(returnType, paramTypes); assert(type); auto result = functionTypes.emplace(type); return result.first->get(); } ArrayType *ArrayType::get(Type *elementType, unsigned numElements) { static std::set> arrayTypes; auto iter = std::find_if(arrayTypes.begin(), arrayTypes.end(), [&](const std::unique_ptr &type) -> bool { return elementType == type->getElementType() && numElements == type->getNumElements(); }); if (iter != arrayTypes.end()) { return iter->get(); } auto type = new ArrayType(elementType, numElements); assert(type); auto result = arrayTypes.emplace(type); return result.first->get(); } void Value::replaceAllUsesWith(Value *value) { for (auto &use : uses) { use->getUser()->setOperand(use->getIndex(), value); } uses.clear(); } // Implementations for static members std::unordered_map ConstantValue::mConstantPool; std::unordered_map UndefinedValue::UndefValues; ConstantValue* ConstantValue::get(Type* type, ConstantValVariant val) { ConstantValueKey key = {type, val}; auto it = mConstantPool.find(key); if (it != mConstantPool.end()) { return it->second; } ConstantValue* newConstant = nullptr; if (std::holds_alternative(val)) { newConstant = new ConstantInteger(type, std::get(val)); } else if (std::holds_alternative(val)) { newConstant = new ConstantFloating(type, std::get(val)); } else { assert(false && "Unsupported ConstantValVariant type"); } mConstantPool[key] = newConstant; return newConstant; } ConstantInteger* ConstantInteger::get(Type* type, int val) { return dynamic_cast(ConstantValue::get(type, val)); } ConstantFloating* ConstantFloating::get(Type* type, float val) { return dynamic_cast(ConstantValue::get(type, val)); } UndefinedValue* UndefinedValue::get(Type* type) { assert(!type->isVoid() && "Cannot get UndefinedValue of void type!"); auto it = UndefValues.find(type); if (it != UndefValues.end()) { return it->second; } UndefinedValue* newUndef = new UndefinedValue(type); UndefValues[type] = newUndef; return newUndef; } auto Function::getCalleesWithNoExternalAndSelf() -> std::set { std::set result; for (auto callee : callees) { if (parent->getExternalFunctions().count(callee->getName()) == 0U && callee != this) { result.insert(callee); } } return result; } // 函数克隆,后续函数级优化(内联等)需要用到 Function * Function::clone(const std::string &suffix) const { std::stringstream ss; std::map oldNewBlockMap; IRBuilder builder; auto newFunction = new Function(parent, type, name); newFunction->getEntryBlock()->setName(blocks.front()->getName()); oldNewBlockMap.emplace(blocks.front().get(), newFunction->getEntryBlock()); auto oldBlockListIter = std::next(blocks.begin()); while (oldBlockListIter != blocks.end()) { auto newBlock = newFunction->addBasicBlock(oldBlockListIter->get()->getName()); oldNewBlockMap.emplace(oldBlockListIter->get(), newBlock); oldBlockListIter++; } for (const auto &oldNewBlockItem : oldNewBlockMap) { auto oldBlock = oldNewBlockItem.first; auto newBlock = oldNewBlockItem.second; for (const auto &oldPred : oldBlock->getPredecessors()) { newBlock->addPredecessor(oldNewBlockMap.at(oldPred)); } for (const auto &oldSucc : oldBlock->getSuccessors()) { newBlock->addSuccessor(oldNewBlockMap.at(oldSucc)); } } std::map oldNewValueMap; std::map isAddedToCreate; std::map isCreated; std::queue toCreate; for (const auto &oldBlock : blocks) { for (const auto &inst : oldBlock->getInstructions()) { isAddedToCreate.emplace(inst.get(), false); isCreated.emplace(inst.get(), false); } } for (const auto &oldBlock : blocks) { for (const auto &inst : oldBlock->getInstructions()) { for (const auto &valueUse : inst->getOperands()) { auto value = valueUse->getValue(); if (oldNewValueMap.find(value) == oldNewValueMap.end()) { auto oldAllocInst = dynamic_cast(value); if (oldAllocInst != nullptr) { std::vector dims; // TODO: 这里的dims用type推断 // for (const auto &dim : oldAllocInst->getDims()) { // dims.emplace_back(dim->getValue()); // } ss << oldAllocInst->getName() << suffix; auto newAllocInst = new AllocaInst(oldAllocInst->getType(), oldNewBlockMap.at(oldAllocInst->getParent()), ss.str()); ss.str(""); oldNewValueMap.emplace(oldAllocInst, newAllocInst); if (isAddedToCreate.find(oldAllocInst) == isAddedToCreate.end()) { isAddedToCreate.emplace(oldAllocInst, true); } else { isAddedToCreate.at(oldAllocInst) = true; } if (isCreated.find(oldAllocInst) == isCreated.end()) { isCreated.emplace(oldAllocInst, true); } else { isCreated.at(oldAllocInst) = true; } } } } if (inst->getKind() == Instruction::kAlloca) { if (oldNewValueMap.find(inst.get()) == oldNewValueMap.end()) { auto oldAllocInst = dynamic_cast(inst.get()); std::vector dims; // TODO: 这里的dims用type推断 // for (const auto &dim : oldAllocInst->getDims()) { // dims.emplace_back(dim->getValue()); // } ss << oldAllocInst->getName() << suffix; auto newAllocInst = new AllocaInst(oldAllocInst->getType(), oldNewBlockMap.at(oldAllocInst->getParent()), ss.str()); ss.str(""); oldNewValueMap.emplace(oldAllocInst, newAllocInst); if (isAddedToCreate.find(oldAllocInst) == isAddedToCreate.end()) { isAddedToCreate.emplace(oldAllocInst, true); } else { isAddedToCreate.at(oldAllocInst) = true; } if (isCreated.find(oldAllocInst) == isCreated.end()) { isCreated.emplace(oldAllocInst, true); } else { isCreated.at(oldAllocInst) = true; } } } } } for (const auto &oldBlock : blocks) { for (const auto &inst : oldBlock->getInstructions()) { for (const auto &valueUse : inst->getOperands()) { auto value = valueUse->getValue(); if (oldNewValueMap.find(value) == oldNewValueMap.end()) { auto globalValue = dynamic_cast(value); auto constVariable = dynamic_cast(value); auto constantValue = dynamic_cast(value); auto functionValue = dynamic_cast(value); if (globalValue != nullptr || constantValue != nullptr || constVariable != nullptr || functionValue != nullptr) { if (functionValue == this) { oldNewValueMap.emplace(value, newFunction); } else { oldNewValueMap.emplace(value, value); } isCreated.emplace(value, true); isAddedToCreate.emplace(value, true); } } } } } for (const auto &oldBlock : blocks) { for (const auto &inst : oldBlock->getInstructions()) { if (inst->getKind() != Instruction::kAlloca) { bool isReady = true; for (const auto &use : inst->getOperands()) { auto value = use->getValue(); if (dynamic_cast(value) == nullptr && !isCreated.at(value)) { isReady = false; break; } } if (isReady) { toCreate.push(inst.get()); isAddedToCreate.at(inst.get()) = true; } } } } while (!toCreate.empty()) { auto inst = dynamic_cast(toCreate.front()); toCreate.pop(); bool isReady = true; for (const auto &valueUse : inst->getOperands()) { auto value = dynamic_cast(valueUse->getValue()); if (value != nullptr && !isCreated.at(value)) { isReady = false; break; } } if (!isReady) { toCreate.push(inst); continue; } isCreated.at(inst) = true; switch (inst->getKind()) { case Instruction::kAdd: case Instruction::kSub: case Instruction::kMul: case Instruction::kDiv: case Instruction::kRem: case Instruction::kICmpEQ: case Instruction::kICmpNE: case Instruction::kICmpLT: case Instruction::kICmpGT: case Instruction::kICmpLE: case Instruction::kICmpGE: case Instruction::kAnd: case Instruction::kOr: case Instruction::kFAdd: case Instruction::kFSub: case Instruction::kFMul: case Instruction::kFDiv: case Instruction::kFCmpEQ: case Instruction::kFCmpNE: case Instruction::kFCmpLT: case Instruction::kFCmpGT: case Instruction::kFCmpLE: case Instruction::kFCmpGE: { auto oldBinaryInst = dynamic_cast(inst); auto lhs = oldBinaryInst->getLhs(); auto rhs = oldBinaryInst->getRhs(); Value *newLhs; Value *newRhs; newLhs = oldNewValueMap[lhs]; newRhs = oldNewValueMap[rhs]; ss << oldBinaryInst->getName() << suffix; auto newBinaryInst = new BinaryInst(oldBinaryInst->getKind(), oldBinaryInst->getType(), newLhs, newRhs, oldNewBlockMap.at(oldBinaryInst->getParent()), ss.str()); ss.str(""); oldNewValueMap.emplace(oldBinaryInst, newBinaryInst); break; } case Instruction::kNeg: case Instruction::kNot: case Instruction::kFNeg: case Instruction::kFNot: case Instruction::kItoF: case Instruction::kFtoI: { auto oldUnaryInst = dynamic_cast(inst); auto hs = oldUnaryInst->getOperand(); Value *newHs; newHs = oldNewValueMap.at(hs); ss << oldUnaryInst->getName() << suffix; auto newUnaryInst = new UnaryInst(oldUnaryInst->getKind(), oldUnaryInst->getType(), newHs, oldNewBlockMap.at(oldUnaryInst->getParent()), ss.str()); ss.str(""); oldNewValueMap.emplace(oldUnaryInst, newUnaryInst); break; } case Instruction::kCall: { auto oldCallInst = dynamic_cast(inst); std::vector newArgumnts; for (const auto &arg : oldCallInst->getArguments()) { newArgumnts.emplace_back(oldNewValueMap.at(arg->getValue())); } ss << oldCallInst->getName() << suffix; CallInst *newCallInst; newCallInst = new CallInst(oldCallInst->getCallee(), newArgumnts, oldNewBlockMap.at(oldCallInst->getParent()), ss.str()); ss.str(""); // if (oldCallInst->getCallee() != this) { // newCallInst = new CallInst(oldCallInst->getCallee(), newArgumnts, // oldNewBlockMap.at(oldCallInst->getParent()), // oldCallInst->getName()); // } else { // newCallInst = new CallInst(newFunction, newArgumnts, oldNewBlockMap.at(oldCallInst->getParent()), // oldCallInst->getName()); // } oldNewValueMap.emplace(oldCallInst, newCallInst); break; } case Instruction::kCondBr: { auto oldCondBrInst = dynamic_cast(inst); auto oldCond = oldCondBrInst->getCondition(); Value *newCond; newCond = oldNewValueMap.at(oldCond); auto newCondBrInst = new CondBrInst(newCond, oldNewBlockMap.at(oldCondBrInst->getThenBlock()), oldNewBlockMap.at(oldCondBrInst->getElseBlock()), oldNewBlockMap.at(oldCondBrInst->getParent())); oldNewValueMap.emplace(oldCondBrInst, newCondBrInst); break; } case Instruction::kBr: { auto oldBrInst = dynamic_cast(inst); auto newBrInst = new UncondBrInst(oldNewBlockMap.at(oldBrInst->getBlock()), oldNewBlockMap.at(oldBrInst->getParent())); oldNewValueMap.emplace(oldBrInst, newBrInst); break; } case Instruction::kReturn: { auto oldReturnInst = dynamic_cast(inst); auto oldRval = oldReturnInst->getReturnValue(); Value *newRval = nullptr; if (oldRval != nullptr) { newRval = oldNewValueMap.at(oldRval); } auto newReturnInst = new ReturnInst(newRval, oldNewBlockMap.at(oldReturnInst->getParent()), oldReturnInst->getName()); oldNewValueMap.emplace(oldReturnInst, newReturnInst); break; } case Instruction::kAlloca: { assert(false); } case Instruction::kLoad: { auto oldLoadInst = dynamic_cast(inst); auto oldPointer = oldLoadInst->getPointer(); Value *newPointer; newPointer = oldNewValueMap.at(oldPointer); std::vector newIndices; // for (const auto &index : oldLoadInst->getIndices()) { // newIndices.emplace_back(oldNewValueMap.at(index->getValue())); // } ss << oldLoadInst->getName() << suffix; // TODO : 这里的newLoadInst的类型需要根据oldLoadInst的类型来推断 auto newLoadInst = new LoadInst(newPointer, oldNewBlockMap.at(oldLoadInst->getParent()), ss.str()); ss.str(""); oldNewValueMap.emplace(oldLoadInst, newLoadInst); break; } case Instruction::kStore: { auto oldStoreInst = dynamic_cast(inst); auto oldPointer = oldStoreInst->getPointer(); auto oldValue = oldStoreInst->getValue(); Value *newPointer; Value *newValue; std::vector newIndices; newPointer = oldNewValueMap.at(oldPointer); newValue = oldNewValueMap.at(oldValue); // TODO: 这里的newIndices需要根据oldStoreInst的类型来推断 // for (const auto &index : oldStoreInst->getIndices()) { // newIndices.emplace_back(oldNewValueMap.at(index->getValue())); // } auto newStoreInst = new StoreInst(newValue, newPointer, oldNewBlockMap.at(oldStoreInst->getParent()), oldStoreInst->getName()); oldNewValueMap.emplace(oldStoreInst, newStoreInst); break; } // TODO:复制GEP指令 case Instruction::kMemset: { auto oldMemsetInst = dynamic_cast(inst); auto oldPointer = oldMemsetInst->getPointer(); auto oldValue = oldMemsetInst->getValue(); Value *newPointer; Value *newValue; newPointer = oldNewValueMap.at(oldPointer); newValue = oldNewValueMap.at(oldValue); auto newMemsetInst = new MemsetInst(newPointer, oldMemsetInst->getBegin(), oldMemsetInst->getSize(), newValue, oldNewBlockMap.at(oldMemsetInst->getParent()), oldMemsetInst->getName()); oldNewValueMap.emplace(oldMemsetInst, newMemsetInst); break; } case Instruction::kInvalid: case Instruction::kPhi: { break; } default: assert(false); } for (const auto &userUse : inst->getUses()) { auto user = userUse->getUser(); if (!isAddedToCreate.at(user)) { toCreate.push(user); isAddedToCreate.at(user) = true; } } } for (const auto &oldBlock : blocks) { auto newBlock = oldNewBlockMap.at(oldBlock.get()); builder.setPosition(newBlock, newBlock->end()); for (const auto &inst : oldBlock->getInstructions()) { builder.insertInst(dynamic_cast(oldNewValueMap.at(inst.get()))); } } // for (const auto ¶m : blocks.front()->getArguments()) { // newFunction->getEntryBlock()->insertArgument(dynamic_cast(oldNewValueMap.at(param))); // } for (const auto &arg : arguments) { auto newArg = dynamic_cast(oldNewValueMap.at(arg)); if (newArg != nullptr) { newFunction->insertArgument(newArg); } } return newFunction; } /** * 设置操作数 */ void User::setOperand(unsigned index, Value *value) { assert(index < getNumOperands()); operands[index]->setValue(value); value->addUse(operands[index]); } /** * 替换操作数 */ void User::replaceOperand(unsigned index, Value *value) { assert(index < getNumOperands()); auto &use = operands[index]; use->getValue()->removeUse(use); use->setValue(value); value->addUse(use); } /** * phi相关函数 */ Value* PhiInst::getvalfromBlk(BasicBlock* blk){ refreshB2VMap(); if( blk2val.find(blk) != blk2val.end()) { return blk2val.at(blk); } return nullptr; } BasicBlock* PhiInst::getBlkfromVal(Value* val){ // 返回第一个值对应的基本块 for(unsigned i = 0; i < vsize; i++) { if(getValue(i) == val) { return getBlock(i); } } return nullptr; } void PhiInst::delValue(Value* val){ //根据value删除对应的基本块和值 unsigned i = 0; BasicBlock* blk = getBlkfromVal(val); for(i = 0; i < vsize; i++) { if(getValue(i) == val) { break; } } removeOperand(2 * i + 1); // 删除blk removeOperand(2 * i); // 删除val vsize--; blk2val.erase(blk); // 删除blk2val映射 } void PhiInst::delBlk(BasicBlock* blk){ //根据Blk删除对应的基本块和值 unsigned i = 0; Value* val = getvalfromBlk(blk); for(i = 0; i < vsize; i++) { if(getBlock(i) == blk) { break; } } removeOperand(2 * i + 1); // 删除blk removeOperand(2 * i); // 删除val vsize--; blk2val.erase(blk); // 删除blk2val映射 } void PhiInst::replaceBlk(BasicBlock* newBlk, unsigned k){ refreshB2VMap(); BasicBlock* oldBlk = getBlock(k); Value* val = blk2val.at(oldBlk); // Value* val = blk2val.at(getBlock(k)); // 替换基本块 setOperand(2 * k + 1, newBlk); // 替换blk2val映射 blk2val.erase(oldBlk); blk2val.emplace(newBlk, val); } void PhiInst::replaceold2new(BasicBlock* oldBlk, BasicBlock* newBlk){ refreshB2VMap(); Value* val = blk2val.at(oldBlk); // 替换基本块 delBlk(oldBlk); addIncoming(val, newBlk); } void PhiInst::refreshB2VMap(){ blk2val.clear(); for(unsigned i = 0; i < vsize; i++) { blk2val.emplace(getBlock(i), getValue(i)); } } CallInst::CallInst(Function *callee, const std::vector &args, BasicBlock *parent, const std::string &name) : Instruction(kCall, callee->getReturnType(), parent, name) { addOperand(callee); for (auto arg : args) { addOperand(arg); } } /** * 获取被调用函数的指针 */ Function * CallInst::getCallee() const { return dynamic_cast(getOperand(0)); } /** * 获取变量指针 * 如果在当前作用域或父作用域中找到变量,则返回该变量的指针,否则返回nullptr */ auto SymbolTable::getVariable(const std::string &name) const -> Value * { auto node = curNode; while (node != nullptr) { auto iter = node->varList.find(name); if (iter != node->varList.end()) { return iter->second; } node = node->pNode; } return nullptr; } /** * 添加变量到符号表 */ auto SymbolTable::addVariable(const std::string &name, Value *variable) -> Value * { Value *result = nullptr; if (curNode != nullptr) { std::stringstream ss; auto iter = variableIndex.find(name); if (iter != variableIndex.end()) { ss << name << iter->second ; iter->second += 1; } else { variableIndex.emplace(name, 1); ss << name << 0 ; } variable->setName(ss.str()); curNode->varList.emplace(name, variable); auto global = dynamic_cast(variable); auto constvar = dynamic_cast(variable); if (global != nullptr) { globals.emplace_back(global); } else if (constvar != nullptr) { globalconsts.emplace_back(constvar); } result = variable; } return result; } /** * 获取全局变量 */ auto SymbolTable::getGlobals() -> std::vector> & { return globals; } /** * 获取常量 */ auto SymbolTable::getConsts() const -> const std::vector> & { return globalconsts; } /** * 进入新的作用域 */ void SymbolTable::enterNewScope() { auto newNode = new SymbolTableNode; nodeList.emplace_back(newNode); if (curNode != nullptr) { curNode->children.emplace_back(newNode); } newNode->pNode = curNode; curNode = newNode; } /** * 进入全局作用域 */ void SymbolTable::enterGlobalScope() { curNode = nodeList.front().get(); } /** * 离开作用域 */ void SymbolTable::leaveScope() { curNode = curNode->pNode; } /** * 是否位于全局作用域 */ auto SymbolTable::isInGlobalScope() const -> bool { return curNode->pNode == nullptr; } /** *移动指令 */ auto BasicBlock::moveInst(iterator sourcePos, iterator targetPos, BasicBlock *block) -> iterator { auto inst = sourcePos->release(); inst->setParent(block); block->instructions.emplace(targetPos, inst); return instructions.erase(sourcePos); } } // namespace sysy