145 lines
4.8 KiB
C++
145 lines
4.8 KiB
C++
#include "../../include/midend/Pass/Optimize/LargeArrayToGlobal.h"
|
|
#include "../../IR.h"
|
|
#include <unordered_map>
|
|
#include <sstream>
|
|
#include <string>
|
|
|
|
namespace sysy {
|
|
|
|
// Helper function to convert type to string
|
|
static std::string typeToString(Type *type) {
|
|
if (!type) return "null";
|
|
|
|
switch (type->getKind()) {
|
|
case Type::kInt:
|
|
return "int";
|
|
case Type::kFloat:
|
|
return "float";
|
|
case Type::kPointer:
|
|
return "ptr";
|
|
case Type::kArray: {
|
|
auto *arrayType = type->as<ArrayType>();
|
|
return "[" + std::to_string(arrayType->getNumElements()) + " x " +
|
|
typeToString(arrayType->getElementType()) + "]";
|
|
}
|
|
default:
|
|
return "unknown";
|
|
}
|
|
}
|
|
|
|
void *LargeArrayToGlobalPass::ID = &LargeArrayToGlobalPass::ID;
|
|
|
|
bool LargeArrayToGlobalPass::runOnModule(Module *M, AnalysisManager &AM) {
|
|
bool changed = false;
|
|
|
|
if (!M) {
|
|
return false;
|
|
}
|
|
|
|
// Collect all alloca instructions from all functions
|
|
std::vector<std::pair<AllocaInst*, Function*>> allocasToConvert;
|
|
|
|
for (auto &funcPair : M->getFunctions()) {
|
|
Function *F = funcPair.second.get();
|
|
if (!F || F->getBasicBlocks().begin() == F->getBasicBlocks().end()) {
|
|
continue;
|
|
}
|
|
|
|
for (auto &BB : F->getBasicBlocks()) {
|
|
for (auto &inst : BB->getInstructions()) {
|
|
if (auto *alloca = dynamic_cast<AllocaInst*>(inst.get())) {
|
|
Type *allocatedType = alloca->getAllocatedType();
|
|
|
|
// Calculate the size of the allocated type
|
|
unsigned size = calculateTypeSize(allocatedType);
|
|
if(DEBUG){
|
|
// Debug: print size information
|
|
std::cout << "LargeArrayToGlobalPass: Found alloca with size " << size
|
|
<< " for type " << typeToString(allocatedType) << std::endl;
|
|
}
|
|
|
|
// Convert arrays of 1KB (1024 bytes) or larger to global variables
|
|
if (size >= 1024) {
|
|
if(DEBUG)
|
|
std::cout << "LargeArrayToGlobalPass: Converting array of size " << size << " to global" << std::endl;
|
|
allocasToConvert.emplace_back(alloca, F);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Convert the collected alloca instructions to global variables
|
|
for (auto [alloca, F] : allocasToConvert) {
|
|
convertAllocaToGlobal(alloca, F, M);
|
|
changed = true;
|
|
}
|
|
|
|
return changed;
|
|
}
|
|
|
|
unsigned LargeArrayToGlobalPass::calculateTypeSize(Type *type) {
|
|
if (!type) return 0;
|
|
|
|
switch (type->getKind()) {
|
|
case Type::kInt:
|
|
case Type::kFloat:
|
|
return 4;
|
|
case Type::kPointer:
|
|
return 8;
|
|
case Type::kArray: {
|
|
auto *arrayType = type->as<ArrayType>();
|
|
return arrayType->getNumElements() * calculateTypeSize(arrayType->getElementType());
|
|
}
|
|
default:
|
|
return 0;
|
|
}
|
|
}
|
|
|
|
void LargeArrayToGlobalPass::convertAllocaToGlobal(AllocaInst *alloca, Function *F, Module *M) {
|
|
Type *allocatedType = alloca->getAllocatedType();
|
|
|
|
// Create a unique name for the global variable
|
|
std::string globalName = generateUniqueGlobalName(alloca, F);
|
|
|
|
// Create the global variable - GlobalValue expects pointer type
|
|
Type *pointerType = Type::getPointerType(allocatedType);
|
|
GlobalValue *globalVar = M->createGlobalValue(globalName, pointerType);
|
|
|
|
if (!globalVar) {
|
|
return;
|
|
}
|
|
|
|
// Replace all uses of the alloca with the global variable
|
|
alloca->replaceAllUsesWith(globalVar);
|
|
|
|
// Remove the alloca instruction from its basic block
|
|
for (auto &BB : F->getBasicBlocks()) {
|
|
auto &instructions = BB->getInstructions();
|
|
for (auto it = instructions.begin(); it != instructions.end(); ++it) {
|
|
if (it->get() == alloca) {
|
|
instructions.erase(it);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
std::string LargeArrayToGlobalPass::generateUniqueGlobalName(AllocaInst *alloca, Function *F) {
|
|
std::string baseName = alloca->getName();
|
|
if (baseName.empty()) {
|
|
baseName = "array";
|
|
}
|
|
|
|
// Ensure uniqueness by appending function name and counter
|
|
static std::unordered_map<std::string, int> nameCounter;
|
|
std::string key = F->getName() + "." + baseName;
|
|
|
|
int counter = nameCounter[key]++;
|
|
std::ostringstream oss;
|
|
oss << key << "." << counter;
|
|
|
|
return oss.str();
|
|
}
|
|
|
|
} // namespace sysy
|