更新IR,.g4修改

This commit is contained in:
rain2133
2025-06-21 18:06:29 +08:00
parent 3ed1c7fecd
commit 0a04c816cf
4 changed files with 446 additions and 21 deletions

View File

@ -153,10 +153,10 @@ cond: lOrExp;
lValue: Ident (LBRACK exp RBRACK)*;
// 为了方便测试 primaryExp 可以是一个string
primaryExp: LPAREN exp RPAREN #parenExp
| lValue #lVal
| number #num
| string #str;
primaryExp: LPAREN exp RPAREN
| lValue
| number
| string;
number: ILITERAL | FLITERAL;
unaryExp: primaryExp #primExp

View File

@ -382,7 +382,8 @@ std::any SysYIRGenerator::visitIfStmt(SysYParser::IfStmtContext *ctx) {
}
std::any SysYIRGenerator::visitWhileStmt(SysYParser::WhileStmtContext *ctx) {
// while structure:
// curblock -> headBlock -> bodyBlock -> exitBlock
BasicBlock* curBlock = builder.getBasicBlock();
Function* function = builder.getBasicBlock()->getParent();
@ -390,18 +391,16 @@ std::any SysYIRGenerator::visitWhileStmt(SysYParser::WhileStmtContext *ctx) {
labelstring << "head.L" << builder.getLabelIndex();
BasicBlock *headBlock = function->addBasicBlock(labelstring.str());
labelstring.str("");
BasicBlock::conectBlocks(curBlock, headBlock);
builder.setPosition(headBlock, headBlock->end())
builder.setPosition(headBlock, headBlock->end());
function->addBasicBlock(condBlock);
builder.setPosition(condBlock, condBlock->end());
BasicBlock* bodyBlock = new BasicBlock(function);
BasicBlock* exitBlock = new BasicBlock(function);
builder.pushTrueBlock(bodyBlock);
builder.pushFalseBlock(exitBlock);
// 访问条件表达式
visitCond(ctx->cond());
builder.popTrueBlock();
builder.popFalseBlock();
@ -411,27 +410,451 @@ std::any SysYIRGenerator::visitWhileStmt(SysYParser::WhileStmtContext *ctx) {
function->addBasicBlock(bodyBlock);
builder.setPosition(bodyBlock, bodyBlock->end());
module->enterNewScope();
for (auto item : ctx->blockStmt()->blockItem()) {
visitBlockItem(item);
}
module->leaveScope();
builder.createUncondBrInst(condBlock, {});
builder.pushBreakBlock(exitBlock);
builder.pushContinueBlock(headBlock);
BasicBlock::conectBlocks(builder.getBasicBlock(), condBlock);
auto block = dynamic_cast<SysYParser::BlockStmtContext *>(ctx->stmt());
if( block != nullptr) {
visitBlockStmt(block);
} else {
module->enterNewScope();
ctx->stmt()->accept(this);
module->leaveScope();
}
builder.createUncondBrInst(headBlock, {});
BasicBlock::conectBlocks(builder.getBasicBlock(), exitBlock);
builder.popBreakBlock();
builder.popContinueBlock();
labelstring << "exit.L" << builder.getLabelIndex();
exitBlock->setName(labelstring.str());
labelstring.str("");
function->addBasicBlock(exitBlock);
builder.setPosition(exitBlock, exitBlock->end());
return std::any();
}
std::any SysYIRGenerator::visitBreakStmt(SysYParser::BreakStmtContext *ctx) {
BasicBlock* breakBlock = builder.getBreakBlock();
builder.pushBreakBlock(breakBlock);
BasicBlock::conectBlocks(builder.getBasicBlock(), breakBlock);
return std::any();
}
std::any SysYIRGenerator::visitContinueStmt(SysYParser::ContinueStmtContext *ctx) {
BasicBlock* continueBlock = builder.getContinueBlock();
builder.createUncondBrInst(continueBlock, {});
return std::any();
}
std::any SysYIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext *ctx) {
Value* returnValue = nullptr;
if (ctx->exp() != nullptr) {
returnValue = std::any_cast<Value *>(visitExp(ctx->exp()));
}
Type* funcType = builder.getBasicBlock()->getParent()->getType();
if (funcType!= returnValue->getType() && returnValue != nullptr) {
ConstantValue * constValue = dynamic_cast<ConstantValue *>(returnValue);
if (constValue != nullptr) {
if (funcType == Type::getFloatType()) {
returnValue = ConstantValue::get(static_cast<float>(constValue->getInt()));
} else {
returnValue = ConstantValue::get(static_cast<int>(constValue->getFloat()));
}
} else {
if (funcType == Type::getFloatType()) {
returnValue = builder.createIToFInst(returnValue);
} else {
returnValue = builder.createFtoIInst(returnValue);
}
}
}
builder.createRetInst(returnValue);
return std::any();
}
std::any SysYIRGenerator::visitLVal(SysYParser::LValContext *ctx) {
std::string name = ctx->Ident()->getText();
User* variable = module->getVariable(name);
Value* value = nullptr;
std::vector<Value *> dims;
for (const auto &exp : ctx->exp()) {
dims.push_back(std::any_cast<Value *>(visitExp(exp)));
}
if (variable == nullptr) {
throw std::runtime_error("Variable " + name + " not found.");
}
bool indicesConstant = true;
for (const auto &index : indices) {
if (dynamic_cast<ConstantValue *>(index) == nullptr) {
indicesConstant = false;
break;
}
}
ConstantVariable* constVar = dynamic_cast<ConstantVariable *>(variable);
GlobalValue* globalVar = dynamic_cast<GlobalValue *>(variable);
AllocaInst* localVar = dynamic_cast<AllocaInst *>(variable);
if (constVar != nullptr && indicesConstant) {
// 如果是常量变量,且索引是常量,则直接获取子数组
value = constVar->getByIndices(indices);
} else if (module->isInGlobalArea() && (globalVar != nullptr)) {
assert(indicesConstant);
value = globalVar->getByIndices(indices);
} else {
if ((globalVar != nullptr && globalVar->getNumDims() > indices.size()) ||
(localVar != nullptr && localVar->getNumDims() > indices.size()) ||
(constVar != nullptr && constVar->getNumDims() > indices.size())) {
// value = builder.createLaInst(variable, indices);
// 如果变量是全局变量或局部变量且索引数量小于维度数量则创建createGetSubArray获取子数组
auto getArrayInst =
builder.createGetSubArray(dynamic_cast<LVal *>(variable), indices);
value = getArrayInst->getChildArray();
} else {
value = builder.createLoadInst(variable, indices);
}
}
return value;
}
std::any SysYIRGenerator::visitPrimaryExp(SysYParser::PrimaryExpContext *ctx) {
if (ctx->exp() != nullptr)
return visitExp(ctx->exp());
if (ctx->lVal() != nullptr)
return visitLVal(ctx->lVal());
if (ctx->number() != nullptr)
return visitNumber(ctx->number());
// if (ctx->string() != nullptr) {
// std::string str = ctx->string()->getText();
// str = str.substr(1, str.size() - 2); // 去掉双引号
// return ConstantValue::get(str);
// }
return visitNumber(ctx->number());
}
std::any SysYIRGenerator::visitNumber(SysYParser::NumberContext *ctx) {
if (ctx->ILITERAL() != nullptr) {
int value = std::stol(ctx->ILITERAL()->getText(), nullptr, 0);
return static_cast<Value *>(ConstantValue::get(Type::getIntType(), value));
} else if (ctx->FLITERAL() != nullptr) {
float value = std::stof(ctx->FLITERAL()->getText());
return static_cast<Value *>(ConstantValue::get(Type::getFloatType(), value));
}
throw std::runtime_error("Unknown number type.");
return std::any(); // 不会到达这里
}
std::any SysYIRGenerator::visitCall(SysYParser::CallContext *ctx) {
std::string funcName = ctx->Ident()->getText();
Function *function = module->getFunction(funcName);
if (function == nullptr) {
function = module->getExternalFunction(name);
if (function == nullptr) {
std::cout << "The function " << name << " no defined." << std::endl;
assert(function);
}
}
std::vector<Value *> args = {};
if (name == "starttime" || name == "stoptime") {
// 如果是starttime或stoptime函数
// TODO: 这里需要处理starttime和stoptime函数的参数
// args.emplace_back()
} else {
if (ctx->funcRParams() != nullptr) {
args = std::any_cast<std::vector<Value *>>(visitFuncRParams(ctx->funcRParams()));
}
auto params = function->getEntryBlock()->getArguments();
for (size_t i = 0; i < args.size(); i++) {
// 参数类型转换
if (params[i]->getType() != args[i]->getType() &&
(params[i]->getNumDims() != 0 ||
params[i]->getType()->as<PointerType>()->getBaseType() != args[i]->getType())) {
ConstantValue * constValue = dynamic_cast<ConstantValue *>(args[i]);
if (constValue != nullptr) {
if (params[i]->getType() == Type::getPointerType(Type::getFloatType())) {
args[i] = ConstantValue::get(static_cast<float>(constValue->getInt()));
} else {
args[i] = ConstantValue::get(static_cast<int>(constValue->getFloat()));
}
} else {
if (params[i]->getType() == Type::getPointerType(Type::getFloatType())) {
args[i] = builder.createIToFInst(args[i]);
} else {
args[i] = builder.createFtoIInst(args[i]);
}
}
}
}
}
return static_cast<Value *>(builder.createCallInst(function, args));
}
std::any SysYIRGenerator::visitUnExp(SysYParser::UnExpContext *ctx) {
Value* value = std::any_cast<Value *>(visitUnaryExp(ctx->unaryExp()));
Value* result = value;
if (ctx->unaryOp()->SUB() != nullptr) {
ConstantValue * constValue = dynamic_cast<ConstantValue *>(value);
if (constValue != nullptr) {
if (constValue->isFloat()) {
result = ConstantValue::get(-constValue->getFloat());
} else {
result = ConstantValue::get(-constValue->getInt());
}
} else if (value != nullptr) {
if (value->getType() == Type::getIntType()) {
result = builder.createNegInst(value);
} else {
result = builder.createFNegInst(value);
}
} else {
std::cout << "UnExp: value is nullptr." << std::endl;
assert(false);
}
} else if (ctx->unaryOp()->NOT() != nullptr) {
auto constValue = dynamic_cast<ConstantValue *>(value);
if (constValue != nullptr) {
if (constValue->isFloat()) {
result =
ConstantValue::get(1 - (constValue->getFloat() != 0.0F ? 1 : 0));
} else {
result = ConstantValue::get(1 - (constValue->getInt() != 0 ? 1 : 0));
}
} else if (value != nullptr) {
if (value->getType() == Type::getIntType()) {
result = builder.createNotInst(value);
} else {
result = builder.createFNotInst(value);
}
} else {
std::cout << "UnExp: value is nullptr." << std::endl;
assert(false);
}
}
return result;
}
std::any SysYIRGenerator::visitFuncRParams(SysYParser::FuncRParamsContext *ctx) {
std::vector<Value *> params;
for (const auto &exp : ctx->exp())
params.push_back(std::any_cast<Value *>(visitExp(exp)));
return params;
}
std::any SysYIRGenerator::visitMulExp(SysYParser::MulExpContext *ctx) {
auto result = std::any_cast<Value *>(visitUnaryExp(ctx->unaryExp(0)));
for (size_t i = 1; i < ctx->unaryExp().size(); i++) {
auto op = ctx->mulOp(i - 1);
Value* operand = std::any_cast<Value *>(visitUnaryExp(ctx->unaryExp(i)));
Type* resultType = result->getType();
Type* operandType = operand->getType();
if (resultType == Type::getFloatType() || operandType == Type::getFloatType()) {
// 如果有一个操作数是浮点数,则将两个操作数都转换为浮点数
if (operandType != Type::getFloatType()) {
ConstantValue * constValue = dynamic_cast<ConstantValue *>(operand);
if (constValue != nullptr)
operand = ConstantValue::get(static_cast<float>(constValue->getInt()));
else
operand = builder.createIToFInst(operand);
} else if (resultType != Type::getFloatType()) {
ConstantValue* constResult = dynamic_cast<ConstantValue *>(result);
if (constResult != nullptr)
result = ConstantValue::get(static_cast<float>(constResult->getInt()));
else
result = builder.createIToFInst(result);
}
ConstantValue* constResult = dynamic_cast<ConstantValue *>(result);
ConstantValue* constOperand = dynamic_cast<ConstantValue *>(operand);
if (op->MUL() != nullptr) {
if ((constOperand != nullptr) && (constResult != nullptr)) {
result = ConstantValue::get(constResult->getFloat() *
constOperand->getFloat());
} else {
result = builder.createFMulInst(result, operand);
}
} else if (op->DIV() != nullptr) {
if ((constOperand != nullptr) && (constResult != nullptr)) {
result = ConstantValue::get(constResult->getFloat() /
constOperand->getFloat());
} else {
result = builder.createFDivInst(result, operand);
}
} else {
// float类型的取模操作不允许
std::cout << "MulExp: float type mod operation is not allowed." << std::endl;
assert(false);
}
} else {
ConstantValue * constResult = dynamic_cast<ConstantValue *>(result);
ConstantValue * constOperand = dynamic_cast<ConstantValue *>(operand);
if (op->MUL() != nullptr) {
if ((constOperand != nullptr) && (constResult != nullptr))
result = ConstantValue::get(constResult->getInt() * constOperand->getInt());
else
result = builder.createMulInst(result, operand);
} else if (op->DIV() != nullptr) {
if ((constOperand != nullptr) && (constResult != nullptr))
result = ConstantValue::get(constResult->getInt() / constOperand->getInt());
else
result = builder.createDivInst(result, operand);
} else {
if ((constOperand != nullptr) && (constResult != nullptr))
result = ConstantValue::get(constResult->getInt() % constOperand->getInt());
else
result = builder.createRemInst(result, operand);
}
}
}
return result;
}
std::any SysYIRGenerator::visitAddExp(SysYParser::AddExpContext *ctx) {
Value* result = std::any_cast<Value *>(visitMulExp(ctx->mulExp(0)));
for (size_t i = 1; i < ctx->mulExp().size(); i++) {
auto op = ctx->addOp(i - 1);
Value* operand = std::any_cast<Value *>(visitMulExp(ctx->mulExp(i)));
Type* resultType = result->getType();
Type* operandType = operand->getType();
if (resultType == Type::getFloatType() || operandType == Type::getFloatType()) {
// 类型转换
if (operandType != Type::getFloatType()) {
Value* constOperand = dynamic_cast<ConstantValue *>(operand);
if (constOperand != nullptr)
operand = ConstantValue::get(static_cast<float>(constOperand->getInt()));
else
operand = builder.createIToFInst(operand);
} else if (resultType != Type::getFloatType()) {
Value* constResult = dynamic_cast<ConstantValue *>(result);
if (constResult != nullptr)
result = ConstantValue::get(static_cast<float>(constResult->getInt()));
else
result = builder.createIToFInst(result);
}
Value* constResult = dynamic_cast<ConstantValue *>(result);
Value* constOperand = dynamic_cast<ConstantValue *>(operand);
if (op->ADD() != nullptr) {
if ((constResult != nullptr) && (constOperand != nullptr))
result = ConstantValue::get(constResult->getFloat() + constOperand->getFloat());
else
result = builder.createFAddInst(result, operand);
} else {
if ((constResult != nullptr) && (constOperand != nullptr))
result = ConstantValue::get(constResult->getFloat() - constOperand->getFloat());
else
result = builder.createFSubInst(result, operand);
}
} else {
Value* constResult = dynamic_cast<ConstantValue *>(result);
Value* constOperand = dynamic_cast<ConstantValue *>(operand);
if (op->ADD() != nullptr) {
if ((constResult != nullptr) && (constOperand != nullptr))
result = ConstantValue::get(constResult->getInt() + constOperand->getInt());
else
result = builder.createAddInst(result, operand);
} else {
if ((constResult != nullptr) && (constOperand != nullptr))
result = ConstantValue::get(constResult->getInt() - constOperand->getInt());
else
result = builder.createSubInst(result, operand);
}
}
}
return result;
}
std:any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) {
Value* result = std::any_cast<Value *>(visitAddExp(ctx->addExp(0)));
for (size_t i = 1; i < ctx->addExp().size(); i++) {
auto op = ctx->relOp(i - 1);
Value* operand = std::any_cast<Value *>(visitAddExp(ctx->addExp(i)));
Type* resultType = result->getType();
Type* operandType = operand->getType();
ConstantValue* constResult = dynamic_cast<ConstantValue *>(result);
ConstantValue* constOperand = dynamic_cast<ConstantValue *>(operand);
// 常量比较
if ((constResult != nullptr) && (constOperand != nullptr)) {
auto operand1 = constResult->isFloat() ? constResult->getFloat()
: constResult->getInt();
auto operand2 = constOperand->isFloat() ? constOperand->getFloat()
: constOperand->getInt();
if (op->LT() != nullptr) result = ConstantValue::get(operand1 < operand2 ? 1 : 0);
else if (op->GT() != nullptr) result = ConstantValue::get(operand1 > operand2 ? 1 : 0);
else if (op->LE() != nullptr) result = ConstantValue::get(operand1 <= operand2 ? 1 : 0);
else if (op->GE() != nullptr) result = ConstantValue::get(operand1 >= operand2 ? 1 : 0);
else assert(false);
} else {
Type* resultType = result->getType();
Type* operandType = operand->getType();
Type* floatType = Type::getFloatType();
// 浮点数处理
if (resultType == floatType || operandType == floatType) {
if (resultType != floatType) {
if (constResult != nullptr)
result = ConstantValue::get(static_cast<float>(constResult->getInt()));
else
result = builder.createIToFInst(result);
}
if (operandType != floatType) {
if (constOperand != nullptr)
operand = ConstantValue::get(static_cast<float>(constOperand->getInt()));
else
operand = builder.createIToFInst(operand);
}
if (op->LT() != nullptr) result = builder.createFCmpLTInst(result, operand);
else if (op->GT() != nullptr) result = builder.createFCmpGTInst(result, operand);
else if (op->LE() != nullptr) result = builder.createFCmpLEInst(result, operand);
else if (op->GE() != nullptr) result = builder.createFCmpGEInst(result, operand);
else assert(false);
} else {
// 整数处理
if (op->LT() != nullptr) result = builder.createICmpLTInst(result, operand);
else if (op->GT() != nullptr) result = builder.createICmpGTInst(result, operand);
else if (op->LE() != nullptr) result = builder.createICmpLEInst(result, operand);
else if (op->GE() != nullptr) result = builder.createICmpGEInst(result, operand);
else assert(false);
}
}
}
return result;
}
void Utils::tree2Array(Type *type, ArrayValueTree *root,
const std::vector<Value *> &dims, unsigned numDims,
ValueCounter &result, IRBuilder *builder) {

View File

@ -1575,6 +1575,7 @@ class ConstantVariable : public User, public LVal {
Value* getByIndex(unsigned index) const { return initValues.getValue(index); } ///< 通过一维位置index获取值
Value* getByIndices(const std::vector<Value *> &indices) const {
int index = 0;
// 计算偏移量
for (size_t i = 0; i < indices.size(); i++) {
index = dynamic_cast<ConstantValue *>(getDim(i))->getInt() * index +
dynamic_cast<ConstantValue *>(indices[i])->getInt();

View File

@ -120,6 +120,7 @@ public:
// std::any visitUnaryOp(SysYParser::UnaryOpContext *ctx) override;
std::any visitUnExp(SysYParser::UnExpContext *ctx) override;
std::any visitFuncRParams(SysYParser::FuncRParamsContext *ctx) override;
std::any visitMulExp(SysYParser::MulExpContext *ctx) override;
std::any visitAddExp(SysYParser::AddExpContext *ctx) override;