diff --git a/src/SysYIRGenerator.cpp b/src/SysYIRGenerator.cpp index 3eccc14..94e7640 100644 --- a/src/SysYIRGenerator.cpp +++ b/src/SysYIRGenerator.cpp @@ -6,18 +6,346 @@ using namespace std; namespace sysy { -any SysYIRGenerator::visitCompUnit(SysYParser::CompUnitContext *ctx) { +/* + * @brief: visit compUnit + * @details: + * compUnit: (decl | funcDef)+; + */ +std::any SysYIRGenerator::visitCompUnit(SysYParser::CompUnitContext *ctx) { // create the IR module auto pModule = new Module(); assert(pModule); module.reset(pModule); + + SymbolTable::ModuleScope scope(symbols_table); + + // 待添加运行时库函数getint等 // generates globals and functions + visitChildren(ctx); // return the IR module return pModule; } -std::any -SysYIRGenerator::visitFuncRParams(SysYParser::FuncRParamsContext *ctx) { + +/* + * @brief: visit decl + * @details: + * decl: constDecl | varDecl; + * constDecl: CONST bType constDef (COMMA constDef)* SEMI; + * varDecl: bType varDef (COMMA varDef)* SEMI; + * constDecl and varDecl shares similar syntax structure + * we consider them together? not sure + */ +std::any SysYIRGenerator::visitDecl(SysYParser::DeclContext *ctx) { + if(ctx->constDecl()){ + return visitConstDecl(ctx->constDecl()); + }else if(ctx->varDecl()){ + return visitVarDecl(ctx->varDecl()); + } + return nullptr; +} + +/* + * @brief: visit constdecl + * @details: + * constDecl: CONST bType constDef (COMMA constDef)* SEMI; + */ +std::any SysYIRGenerator::visitConstDecl(SysYParser::ConstDeclContext *ctx) { + auto type = Type::getPointerType(any_cast(ctx->bType()->accept(this))); + if(symbols_table.isModuleScope()) + visitConstGlobalDecl(ctx, type); + else + visitConstLocalDecl(ctx, type); + return std::any(); +} + +/* + * @brief: visit btype + * @details: + * bType: INT | FLOAT; + */ +std::any SysYIRGenerator::visitBType(SysYParser::BTypeContext *ctx) { + return ctx->INT() ? Type::getIntType() : Type::getFloatType(); +} + +// std::any visitConstDef(SysYParser::ConstDefContext *ctx); + +std::any SysYIRGenerator::visitConstGlobalDecl(SysYParser::ConstDeclContext *ctx, Type* type) { + std::vector values; + for (auto constDef : ctx->constDef()) { + + auto name = constDef->Ident()->getText(); + // get its dimensions + vector dims; + for (auto dim : constDef->constExp()) + dims.push_back(any_cast(dim->accept(this))); + + if (dims.size() == 0) { + auto init = constDef->ASSIGN() ? any_cast((constDef->constInitVal()->constExp()->accept(this))) + : nullptr; + if (init && isa(init)){ + Type *btype = type->as()->getBaseType(); + if (btype->isInt() && init->getType()->isFloat()) + init = ConstantValue::get((int)dynamic_cast(init)->getFloat()); + else if (btype->isFloat() && init->getType()->isInt()) + init = ConstantValue::get((float)dynamic_cast(init)->getInt()); + } + + auto global_value = module->createGlobalValue(name, type, dims, init); + + symbols_table.insert(name, global_value); + values.push_back(global_value); + } + else{ + auto init = constDef->ASSIGN() ? any_cast(dims[0]) + : nullptr; + auto global_value = module->createGlobalValue(name, type, dims, init); + if (constDef->ASSIGN()) { + d = 0; + n = 0; + path.clear(); + path = vector(dims.size(), 0); + isalloca = false; + current_type = global_value->getType()->as()->getBaseType(); + current_global = global_value; + numdims = global_value->getNumDims(); + for (auto init : constDef->constInitVal()->constInitVal()) + init->accept(this); + // visitConstInitValue(init); + } + symbols_table.insert(name, global_value); + values.push_back(global_value); + } + } + return values; +} + +std::any SysYIRGenerator::visitVarGlobalDecl(SysYParser::VarDeclContext *ctx, Type* type){ + std::vector values; + for (auto varDef : ctx->varDef()) { + + auto name = varDef->Ident()->getText(); + // get its dimensions + vector dims; + for (auto dim : varDef->constExp()) + dims.push_back(any_cast(dim->accept(this))); + + if (dims.size() == 0) { + auto init = varDef->ASSIGN() ? any_cast((varDef->initVal()->exp()->accept(this))) + : nullptr; + if (init && isa(init)){ + Type *btype = type->as()->getBaseType(); + if (btype->isInt() && init->getType()->isFloat()) + init = ConstantValue::get((int)dynamic_cast(init)->getFloat()); + else if (btype->isFloat() && init->getType()->isInt()) + init = ConstantValue::get((float)dynamic_cast(init)->getInt()); + } + + auto global_value = module->createGlobalValue(name, type, dims, init); + + symbols_table.insert(name, global_value); + values.push_back(global_value); + } + else{ + auto init = varDef->ASSIGN() ? any_cast(dims[0]) + : nullptr; + auto global_value = module->createGlobalValue(name, type, dims, init); + if (varDef->ASSIGN()) { + d = 0; + n = 0; + path.clear(); + path = vector(dims.size(), 0); + isalloca = false; + current_type = global_value->getType()->as()->getBaseType(); + current_global = global_value; + numdims = global_value->getNumDims(); + for (auto init : varDef->initVal()->initVal()) + init->accept(this); + // visitInitValue(init); + } + symbols_table.insert(name, global_value); + values.push_back(global_value); + } + } + return values; +} + +std::any SysYIRGenerator::visitConstLocalDecl(SysYParser::ConstDeclContext *ctx, Type* type){ + std::vector values; + // handle variables + for (auto constDef : ctx->constDef()) { + + auto name = constDef->Ident()->getText(); + vector dims; + for (auto dim : constDef->constExp()) + dims.push_back(any_cast(dim->accept(this))); + auto alloca = builder.createAllocaInst(type, dims, name); + symbols_table.insert(name, alloca); + + if (constDef->ASSIGN()) { + if (alloca->getNumDims() == 0) { + + auto value = any_cast(constDef->constInitVal()->constExp()->accept(this)); + + if (isa(value)) { + if (ctx->bType()->INT() && dynamic_cast(value)->isFloat()) + value = ConstantValue::get((int)dynamic_cast(value)->getFloat()); + else if (ctx->bType()->FLOAT() && dynamic_cast(value)->isInt()) + value = ConstantValue::get((float)dynamic_cast(value)->getInt()); + } + else if (alloca->getType()->as()->getBaseType()->isInt() && value->getType()->isFloat()) + value = builder.createFtoIInst(value); + else if (alloca->getType()->as()->getBaseType()->isFloat() && value->getType()->isInt()) + value = builder.createIToFInst(value); + + auto store = builder.createStoreInst(value, alloca); + } + else{ + d = 0; + n = 0; + path.clear(); + path = vector(alloca->getNumDims(), 0); + isalloca = true; + current_alloca = alloca; + current_type = alloca->getType()->as()->getBaseType(); + numdims = alloca->getNumDims(); + for (auto init : constDef->constInitVal()->constInitVal()) + init->accept(this); + } + } + + values.push_back(alloca); + } + return values; +} + +std::any SysYIRGenerator::visitVarLocalDecl(SysYParser::VarDeclContext *ctx, Type* type){ + std::vector values; + for (auto varDef : ctx->varDef()) { + + auto name = varDef->Ident()->getText(); + vector dims; + for (auto dim : varDef->constExp()) + dims.push_back(any_cast(dim->accept(this))); + auto alloca = builder.createAllocaInst(type, dims, name); + symbols_table.insert(name, alloca); + + if (varDef->ASSIGN()) { + if (alloca->getNumDims() == 0) { + + auto value = any_cast(varDef->initVal()->exp()->accept(this)); + + if (isa(value)) { + if (ctx->bType()->INT() && dynamic_cast(value)->isFloat()) + value = ConstantValue::get((int)dynamic_cast(value)->getFloat()); + else if (ctx->bType()->FLOAT() && dynamic_cast(value)->isInt()) + value = ConstantValue::get((float)dynamic_cast(value)->getInt()); + } + else if (alloca->getType()->as()->getBaseType()->isInt() && value->getType()->isFloat()) + value = builder.createFtoIInst(value); + else if (alloca->getType()->as()->getBaseType()->isFloat() && value->getType()->isInt()) + value = builder.createIToFInst(value); + + auto store = builder.createStoreInst(value, alloca); + } + else{ + d = 0; + n = 0; + path.clear(); + path = vector(alloca->getNumDims(), 0); + isalloca = true; + current_alloca = alloca; + current_type = alloca->getType()->as()->getBaseType(); + numdims = alloca->getNumDims(); + for (auto init : varDef->initVal()->initVal()) + init->accept(this); + } + } + + values.push_back(alloca); + } + return values; +} + + + +/* + * @brief: visit constInitVal + * @details: + * constInitVal: constExp + * | LBRACE (constInitVal (COMMA constInitVal)*)? RBRACE; + */ +std::any visitConstInitVal(SysYParser::ConstInitValContext *ctx){ + +} + + +/* + * @brief: visit function type + * @details: + * funcType: VOID | INT | FLOAT; + */ +std::any SysYIRGenerator::visitFuncType(SysYParser::FuncTypeContext* ctx){ + return ctx->INT() ? Type::getIntType() : (ctx->FLOAT() ? Type::getFloatType() : Type::getVoidType()); +} + +/* + * @brief: visit function define + * @details: + * funcDef: funcType Ident LPAREN funcFParams? RPAREN blockStmt; + * funcFParams: funcFParam (COMMA funcFParam)*; + * funcFParam: bType Ident (LBRACK RBRACK (LBRACK exp RBRACK)*)?; + * entry -> next -> others -> exit + * entry: allocas, br + * next: retval, params, br + * other: blockStmt init block + * exit: load retval, ret + */ +std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext* ctx){ + auto funcName = ctx->Ident()->getText(); + auto funcParams = ctx->funcFParams()->funcFParam(); + Type* returnType = any_cast(ctx->funcType()->accept(this)); + + vector paramTypes; + vector paramNames; + + for(auto funcParam:funcParams){ + Type* paramType = any_cast(funcParam->bType()->accept(this)); + paramTypes.push_back(paramType); + paramNames.push_back(funcParam->Ident()->getText()); + } + + auto funcType = FunctionType::get(returnType, paramTypes); + auto function = module->createFunction(funcName, funcType); + + auto entry = function->getEntryBlock(); + for(size_t i = 0; i < paramTypes.size(); i++) + entry->createArgument(paramTypes[i], paramNames[i]); + builder.setPosition(entry, entry->end()); + ctx->blockStmt()->accept(this); + + return function; +} + +/* + * @brief: visit blockStmt + * @details: + * blockStmt: LBRACE blockItem* RBRACE; + * blockItem: decl | stmt; + */ +std::any SysYIRGenerator::visitBlockStmt(SysYParser::BlockStmtContext* ctx){ + + SymbolTable::BlockScope scope(symbols_table); + + for (auto item : ctx->blockItem()) + item->accept(this); + + builder.getBasicBlock(); + return std::any(); +} + + +std::any SysYIRGenerator::visitFuncRParams(SysYParser::FuncRParamsContext *ctx) { return visitChildren(ctx); } std::any SysYIRGenerator::visitNumber(SysYParser::NumberContext *ctx) { diff --git a/src/SysYIRGenerator.h b/src/SysYIRGenerator.h index bfc1865..b0419a6 100644 --- a/src/SysYIRGenerator.h +++ b/src/SysYIRGenerator.h @@ -5,13 +5,86 @@ #include "SysYBaseVisitor.h" #include "SysYParser.h" #include +#include +#include namespace sysy { +class SymbolTable{ +private: + enum Kind + { + kModule, + kFunction, + kBlock, + }; + + std::forward_list>> Scopes; + +public: + struct ModuleScope { + SymbolTable& tables_ref; + ModuleScope(SymbolTable& tables) : tables_ref(tables) { + tables.enter(kModule); + } + ~ModuleScope() { tables_ref.exit(); } + }; + struct FunctionScope { + SymbolTable& tables_ref; + FunctionScope(SymbolTable& tables) : tables_ref(tables) { + tables.enter(kFunction); + } + ~FunctionScope() { tables_ref.exit(); } + }; + struct BlockScope { + SymbolTable& tables_ref; + BlockScope(SymbolTable& tables) : tables_ref(tables) { + tables.enter(kBlock); + } + ~BlockScope() { tables_ref.exit(); } + }; + + SymbolTable() = default; + + bool isModuleScope() const { return Scopes.front().first == kModule; } + bool isFunctionScope() const { return Scopes.front().first == kFunction; } + bool isBlockScope() const { return Scopes.front().first == kBlock; } + Value *lookup(const std::string &name) const { + for (auto &scope : Scopes) { + auto iter = scope.second.find(name); + if (iter != scope.second.end()) + return iter->second; + } + return nullptr; + } + auto insert(const std::string &name, Value *value) { + assert(not Scopes.empty()); + return Scopes.front().second.emplace(name, value); + } +private: + void enter(Kind kind) { + Scopes.emplace_front(); + Scopes.front().first = kind; + } + void exit() { + Scopes.pop_front(); + } + +}; + class SysYIRGenerator : public SysYBaseVisitor { private: std::unique_ptr module; IRBuilder builder; + SymbolTable symbols_table; + //array init use variables + int d = 0, n = 0; + vector path; + bool isalloca; + AllocaInst *current_alloca; + GlobalValue *current_global; + Type *current_type; + int numdims = 0; public: SysYIRGenerator() = default; @@ -21,9 +94,50 @@ public: public: std::any visitCompUnit(SysYParser::CompUnitContext *ctx) override; + + std::any visitDecl(SysYParser::DeclContext *ctx) override; + + std::any visitConstDecl(SysYParser::ConstDeclContext *ctx) override; + + std::any visitBType(SysYParser::BTypeContext *ctx) override; + + std::any visitConstDef(SysYParser::ConstDefContext *ctx) override; + + std::any visitConstInitVal(SysYParser::ConstInitValContext *ctx) override; + + std::any visitFuncType(SysYParser::FuncTypeContext* ctx) override; + + std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override; + + + std::any visitVarDecl(SysYParser::VarDeclContext *ctx) override; + + std::any visitVarDef(SysYParser::VarDefContext *ctx, Type* btype); + + std::any visitFuncFParam(SysYParser::FuncFParamContext *ctx) override; + + std::any visitFuncFParams(SysYParser::FuncFParamsContext *ctx) override; + + std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override; + + std::any visitStmt(SysYParser::StmtContext *ctx) override; + std::any visitFuncRParams(SysYParser::FuncRParamsContext *ctx) override; + std::any visitNumber(SysYParser::NumberContext *ctx) override; + std::any visitString(SysYParser::StringContext *ctx) override; + +private: + std::any visitConstGlobalDecl(SysYParser::ConstDeclContext *ctx, Type* type); + std::any visitVarGlobalDecl(SysYParser::VarDeclContext *ctx, Type* type); + std::any visitConstLocalDecl(SysYParser::ConstDeclContext *ctx, Type* type); + std::any visitVarLocalDecl(SysYParser::VarDeclContext *ctx, Type* type); + Type *getArithmeticResultType(Type *lhs, Type *rhs) { + assert(lhs->isIntOrFloat() and rhs->isIntOrFloat()); + return lhs == rhs ? lhs : Type::getFloatType(); + } + }; // class SysYIRGenerator } // namespace sysy \ No newline at end of file