[midend-mem2reg]修复assignstmt对lvalue的错误解析(lvaue会被exp解释为值,而被assign解释为地址)

This commit is contained in:
rain2133
2025-07-25 20:00:41 +08:00
parent e2c97fd171
commit 04c5c6b44d
3 changed files with 69 additions and 33 deletions

View File

@ -647,7 +647,7 @@ Function * CallInst::getCallee() const { return dynamic_cast<Function *>(getOper
/**
* 获取变量指针
*/
auto SymbolTable::getVariable(const std::string &name) const -> User * {
auto SymbolTable::getVariable(const std::string &name) const -> Value * {
auto node = curNode;
while (node != nullptr) {
auto iter = node->varList.find(name);
@ -662,8 +662,8 @@ auto SymbolTable::getVariable(const std::string &name) const -> User * {
/**
* 添加变量到符号表
*/
auto SymbolTable::addVariable(const std::string &name, User *variable) -> User * {
User *result = nullptr;
auto SymbolTable::addVariable(const std::string &name, Value *variable) -> Value * {
Value *result = nullptr;
if (curNode != nullptr) {
std::stringstream ss;
auto iter = variableIndex.find(name);

View File

@ -450,44 +450,79 @@ std::any SysYIRGenerator::visitBlockStmt(SysYParser::BlockStmtContext *ctx) {
std::any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) {
auto lVal = ctx->lValue();
std::string name = lVal->Ident()->getText();
std::vector<Value *> dims;
for (const auto &exp : lVal->exp()) {
dims.push_back(std::any_cast<Value *>(visitExp(exp)));
Value* LValue = nullptr;
Value* variable = module->getVariable(name); // 左值
vector<Value *> indices;
if (lVal->exp().size() > 0) {
// 如果有下标,访问表达式获取下标值
for (const auto &exp : lVal->exp()) {
Value* indexValue = std::any_cast<Value *>(visitExp(exp));
indices.push_back(indexValue);
}
}
if (indices.empty()) {
// variable 本身就是指向标量的指针 (e.g., int* %a)
if (dynamic_cast<AllocaInst*>(variable) || dynamic_cast<GlobalValue*>(variable)) {
LValue = variable;
}
}
else {
// 对于数组或多维数组的左值处理
// 需要获取 GEP 地址
Value* gepBasePointer = nullptr;
std::vector<Value*> gepIndices;
if (AllocaInst *alloc = dynamic_cast<AllocaInst *>(variable)) {
Type* allocatedType = alloc->getType()->as<PointerType>()->getBaseType();
if (allocatedType->isPointer()) {
gepBasePointer = builder.createLoadInst(alloc);
gepIndices = indices;
} else {
gepBasePointer = alloc;
gepIndices.push_back(ConstantInteger::get(0));
gepIndices.insert(gepIndices.end(), indices.begin(), indices.end());
}
} else if (GlobalValue *glob = dynamic_cast<GlobalValue *>(variable)) {
// 情况 B: 全局变量 (GlobalValue)
gepBasePointer = glob;
gepIndices.push_back(ConstantInteger::get(0));
gepIndices.insert(gepIndices.end(), indices.begin(), indices.end());
} else if (ConstantVariable *constV = dynamic_cast<ConstantVariable *>(variable)) {
gepBasePointer = constV;
gepIndices.push_back(ConstantInteger::get(0));
gepIndices.insert(gepIndices.end(), indices.begin(), indices.end());
}
// 左值为地址
LValue = getGEPAddressInst(gepBasePointer, gepIndices);
}
auto variable = module->getVariable(name); // 获取 AllocaInst 或 GlobalValue
Value* value = std::any_cast<Value *>(visitExp(ctx->exp())); // 右值
Type* targetElementType = variable->getType(); // 从基指针指向的类型开始
Value* RValue = std::any_cast<Value *>(visitExp(ctx->exp())); // 右值
//根据 dims 确定最终元素的类型
targetElementType = builder.getIndexedType(targetElementType, dims);
// 先推断 LValue 的类型
// 如果 LValue 是指向数组的指针,则需要根据 indices 获取正确的类型
// 如果 LValue 是标量,则直接使用其类型
// 注意LValue 的类型可能是指向数组的指针 (e.g., int(*)[3]) 或者指向标量的指针 (e.g., int*) 也能推断
Type* LType = builder.getIndexedType(variable->getType(), indices);
Type* RType = RValue->getType();
// 左值右值类型不同处理:根据最终元素类型进行转换
if (targetElementType != value->getType()) {
ConstantValue * constValue = dynamic_cast<ConstantValue *>(value);
if (LType != RType) {
ConstantValue * constValue = dynamic_cast<ConstantValue *>(RValue);
if (constValue != nullptr) {
if (targetElementType == Type::getFloatType()) {
value = ConstantFloating::get(static_cast<float>(constValue->getFloat()));
if (LType == Type::getFloatType()) {
RValue = ConstantFloating::get(static_cast<float>(constValue->getFloat()));
} else { // 假设如果不是浮点型,就是整型
value = ConstantInteger::get(static_cast<int>(constValue->getInt()));
RValue = ConstantInteger::get(static_cast<int>(constValue->getInt()));
}
} else {
if (targetElementType == Type::getFloatType()) {
value = builder.createIToFInst(value);
if (LType == Type::getFloatType()) {
RValue = builder.createIToFInst(RValue);
} else { // 假设如果不是浮点型,就是整型
value = builder.createFtoIInst(value);
RValue = builder.createFtoIInst(RValue);
}
}
}
// 计算目标地址:如果 dims 为空,就是变量本身地址;否则通过 GEP 计算
Value* targetAddress = variable;
if (!dims.empty()) {
targetAddress = getGEPAddressInst(variable, dims);
}
builder.createStoreInst(value, targetAddress);
builder.createStoreInst(RValue, LValue);
return std::any();
}
@ -711,7 +746,7 @@ unsigned SysYIRGenerator::countArrayDimensions(Type* type) {
std::any SysYIRGenerator::visitLValue(SysYParser::LValueContext *ctx) {
std::string name = ctx->Ident()->getText();
User* variable = module->getVariable(name);
Value* variable = module->getVariable(name);
Value* value = nullptr;

View File

@ -521,6 +521,7 @@ public:
Function* getParent() const { return parent; }
void setParent(Function *func) { parent = func; }
inst_list& getInstructions() { return instructions; }
auto getInstructions_Range() const { return make_range(instructions); }
arg_list& getArguments() { return arguments; }
block_list& getPredecessors() { return predecessors; }
void clearPredecessors() { predecessors.clear(); }
@ -1404,7 +1405,7 @@ class ConstantVariable : public User {
using SymbolTableNode = struct SymbolTableNode {
SymbolTableNode *pNode; ///< 父节点
std::vector<SymbolTableNode *> children; ///< 子节点列表
std::map<std::string, User *> varList; ///< 变量列表
std::map<std::string, Value *> varList; ///< 变量列表
};
@ -1419,8 +1420,8 @@ class SymbolTable {
public:
SymbolTable() = default;
User* getVariable(const std::string &name) const; ///< 根据名字name以及当前作用域获取变量
User* addVariable(const std::string &name, User *variable); ///< 添加变量
Value* getVariable(const std::string &name) const; ///< 根据名字name以及当前作用域获取变量
Value* addVariable(const std::string &name, Value *variable); ///< 添加变量
std::vector<std::unique_ptr<GlobalValue>>& getGlobals(); ///< 获取全局变量列表
const std::vector<std::unique_ptr<ConstantVariable>>& getConsts() const; ///< 获取常量列表
void enterNewScope(); ///< 进入新的作用域
@ -1482,7 +1483,7 @@ class Module {
void addVariable(const std::string &name, AllocaInst *variable) {
variableTable.addVariable(name, variable);
} ///< 添加变量
User* getVariable(const std::string &name) {
Value* getVariable(const std::string &name) {
return variableTable.getVariable(name);
} ///< 根据名字name和当前作用域获取变量
Function* getFunction(const std::string &name) const {