diff --git a/src/backend/RISCv64/RISCv64Backend.cpp b/src/backend/RISCv64/RISCv64Backend.cpp index 2797eb7..2edb21d 100644 --- a/src/backend/RISCv64/RISCv64Backend.cpp +++ b/src/backend/RISCv64/RISCv64Backend.cpp @@ -12,6 +12,39 @@ std::string RISCv64CodeGen::code_gen() { return module_gen(); } +unsigned RISCv64CodeGen::getTypeSizeInBytes(Type* type) { + if (!type) { + assert(false && "Cannot get size of a null type."); + return 0; + } + + switch (type->getKind()) { + // 对于SysY语言,基本类型int和float都占用4字节 + case Type::kInt: + case Type::kFloat: + return 4; + + // 指针类型在RISC-V 64位架构下占用8字节 + // 虽然SysY没有'int*'语法,但数组变量在IR层面本身就是指针类型 + case Type::kPointer: + return 8; + + // 数组类型的总大小 = 元素数量 * 单个元素的大小 + case Type::kArray: { + auto arrayType = type->as(); + // 递归调用以计算元素大小 + return arrayType->getNumElements() * getTypeSizeInBytes(arrayType->getElementType()); + } + + // 其他类型,如Void, Label等不占用栈空间,或者不应该出现在这里 + default: + // 如果遇到未处理的类型,触发断言,方便调试 + // assert(false && "Unsupported type for size calculation."); + return 0; // 对于像Label或Void这样的类型,返回0是合理的 + } +} + + void printInitializer(std::stringstream& ss, const ValueCounter& init_values) { for (size_t i = 0; i < init_values.getValues().size(); ++i) { auto val = init_values.getValues()[i]; @@ -39,18 +72,36 @@ std::string RISCv64CodeGen::module_gen() { for (const auto& global_ptr : module->getGlobals()) { GlobalValue* global = global_ptr.get(); + + // [核心修改] 使用更健壮的逻辑来判断是否为大型零初始化数组 + bool is_all_zeros = true; const auto& init_values = global->getInitValues(); - // 判断是否为大型零初始化数组,以便放入.bss段 - bool is_large_zero_array = false; - if (init_values.getValues().size() == 1) { - if (auto const_val = dynamic_cast(init_values.getValues()[0])) { - if (const_val->isInt() && const_val->getInt() == 0 && init_values.getNumbers()[0] > 16) { - is_large_zero_array = true; + // 检查初始化值是否全部为0 + if (init_values.getValues().empty()) { + // 如果 ValueCounter 为空,GlobalValue 的构造函数会确保它是零初始化的 + is_all_zeros = true; + } else { + for (auto val : init_values.getValues()) { + if (auto const_val = dynamic_cast(val)) { + if (!const_val->isZero()) { + is_all_zeros = false; + break; + } + } else { + // 如果初始值包含非常量(例如,另一个全局变量的地址),则不认为是纯零初始化 + is_all_zeros = false; + break; } } } + // 使用 getTypeSizeInBytes 检查总大小是否超过阈值 (16个整数 = 64字节) + Type* allocated_type = global->getType()->as()->getBaseType(); + unsigned total_size = getTypeSizeInBytes(allocated_type); + + bool is_large_zero_array = is_all_zeros && (total_size > 64); + if (is_large_zero_array) { bss_globals.push_back(global); } else { @@ -58,12 +109,12 @@ std::string RISCv64CodeGen::module_gen() { } } - // --- 步骤2:生成 .bss 段的代码 (这部分不变) --- + // --- 步骤2:生成 .bss 段的代码 --- if (!bss_globals.empty()) { ss << ".bss\n"; for (GlobalValue* global : bss_globals) { - unsigned count = global->getInitValues().getNumbers()[0]; - unsigned total_size = count * 4; // 假设元素都是4字节 + Type* allocated_type = global->getType()->as()->getBaseType(); + unsigned total_size = getTypeSizeInBytes(allocated_type); ss << " .align 3\n"; ss << ".globl " << global->getName() << "\n"; @@ -74,33 +125,45 @@ std::string RISCv64CodeGen::module_gen() { } } - // --- [修改] 步骤3:生成 .data 段的代码 --- - // 我们需要检查 data_globals 和 常量列表是否都为空 + // --- 步骤3:生成 .data 段的代码 --- if (!data_globals.empty() || !module->getConsts().empty()) { ss << ".data\n"; - // a. 先处理普通的全局变量 (GlobalValue) + // a. 处理普通的全局变量 (GlobalValue) for (GlobalValue* global : data_globals) { + Type* allocated_type = global->getType()->as()->getBaseType(); + unsigned total_size = getTypeSizeInBytes(allocated_type); + + ss << " .align 3\n"; ss << ".globl " << global->getName() << "\n"; + ss << ".type " << global->getName() << ", @object\n"; + ss << ".size " << global->getName() << ", " << total_size << "\n"; ss << global->getName() << ":\n"; printInitializer(ss, global->getInitValues()); } - // b. [新增] 再处理全局常量 (ConstantVariable) + // b. 处理全局常量 (ConstantVariable) for (const auto& const_ptr : module->getConsts()) { ConstantVariable* cnst = const_ptr.get(); + Type* allocated_type = cnst->getType()->as()->getBaseType(); + unsigned total_size = getTypeSizeInBytes(allocated_type); + + ss << " .align 3\n"; ss << ".globl " << cnst->getName() << "\n"; + ss << ".type " << cnst->getName() << ", @object\n"; + ss << ".size " << cnst->getName() << ", " << total_size << "\n"; ss << cnst->getName() << ":\n"; printInitializer(ss, cnst->getInitValues()); } } - // --- 处理函数 (.text段) 的逻辑保持不变 --- + // --- 步骤4:处理函数 (.text段) 的逻辑 --- if (!module->getFunctions().empty()) { ss << ".text\n"; for (const auto& func_pair : module->getFunctions()) { - if (func_pair.second.get()) { + if (func_pair.second.get() && !func_pair.second->getBasicBlocks().empty()) { ss << function_gen(func_pair.second.get()); + if (DEBUG) std::cerr << "Function: " << func_pair.first << " generated.\n"; } } } diff --git a/src/include/backend/RISCv64/RISCv64Backend.h b/src/include/backend/RISCv64/RISCv64Backend.h index 403d586..9e179d9 100644 --- a/src/include/backend/RISCv64/RISCv64Backend.h +++ b/src/include/backend/RISCv64/RISCv64Backend.h @@ -22,6 +22,10 @@ private: // 函数级代码生成 (实现新的流水线) std::string function_gen(Function* func); + + // 私有辅助函数,用于根据类型计算其占用的字节数。 + unsigned getTypeSizeInBytes(Type* type); + Module* module; }; diff --git a/src/midend/SysYIRGenerator.cpp b/src/midend/SysYIRGenerator.cpp index cd0ac3c..97b25b5 100644 --- a/src/midend/SysYIRGenerator.cpp +++ b/src/midend/SysYIRGenerator.cpp @@ -653,7 +653,44 @@ std::any SysYIRGenerator::visitConstDecl(SysYParser::ConstDeclContext *ctx) { Value *currentValue = counterValues[k]; unsigned currentRepeatNum = counterNumbers[k]; + // 检查是否是0,并且重复次数足够大(例如 >16),才用 memset + if (ConstantInteger *constInt = dynamic_cast(currentValue)) { + if (constInt->getInt() == 0 && currentRepeatNum >= 16) { // 阈值可调整(如16、32等) + // 计算 memset 的起始地址(基于当前线性偏移量) + std::vector memsetStartIndices; + int tempLinearIndex = linearIndexOffset; + + // 将线性索引转换为多维索引 + for (int dimIdx = dimSizes.size() - 1; dimIdx >= 0; --dimIdx) { + memsetStartIndices.insert(memsetStartIndices.begin(), + ConstantInteger::get(static_cast(tempLinearIndex % dimSizes[dimIdx]))); + tempLinearIndex /= dimSizes[dimIdx]; + } + + // 构造 GEP 计算 memset 的起始地址 + std::vector gepIndicesForMemset; + gepIndicesForMemset.push_back(ConstantInteger::get(0)); // 跳过 alloca 类型 + gepIndicesForMemset.insert(gepIndicesForMemset.end(), memsetStartIndices.begin(), + memsetStartIndices.end()); + + Value *memsetPtr = builder.createGetElementPtrInst(alloca, gepIndicesForMemset); + + // 计算 memset 的字节数 = 元素个数 × 元素大小 + Type *elementType = type;; + uint64_t elementSize = elementType->getSize(); + Value *size = ConstantInteger::get(currentRepeatNum * elementSize); + + // 生成 memset 指令(假设你的 IRBuilder 有 createMemset 方法) + builder.createMemsetInst(memsetPtr, ConstantInteger::get(0), size, ConstantInteger::get(0)); + + // 跳过这些已处理的0 + linearIndexOffset += currentRepeatNum; + continue; // 直接进入下一次循环 + } + } + for (unsigned i = 0; i < currentRepeatNum; ++i) { + // 对于非零值,生成对应的 store 指令 std::vector currentIndices; int tempLinearIndex = linearIndexOffset + i; // 使用偏移量和当前重复次数内的索引 @@ -761,39 +798,73 @@ std::any SysYIRGenerator::visitVarDecl(SysYParser::VarDeclContext *ctx) { ConstantInteger::get(0)); } else { - + int linearIndexOffset = 0; // 用于追踪当前处理的线性索引的偏移量 for (int k = 0; k < counterValues.size(); ++k) { - // 当前 Value 的值和重复次数 - Value* currentValue = counterValues[k]; - unsigned currentRepeatNum = counterNumbers[k]; + // 当前 Value 的值和重复次数 + Value *currentValue = counterValues[k]; + unsigned currentRepeatNum = counterNumbers[k]; + // 检查是否是0,并且重复次数足够大(例如 >16),才用 memset + if (ConstantInteger *constInt = dynamic_cast(currentValue)) { + if (constInt->getInt() == 0 && currentRepeatNum >= 16) { // 阈值可调整(如16、32等) + // 计算 memset 的起始地址(基于当前线性偏移量) + std::vector memsetStartIndices; + int tempLinearIndex = linearIndexOffset; - for (unsigned i = 0; i < currentRepeatNum; ++i) { - std::vector currentIndices; - int tempLinearIndex = linearIndexOffset + i; // 使用偏移量和当前重复次数内的索引 + // 将线性索引转换为多维索引 + for (int dimIdx = dimSizes.size() - 1; dimIdx >= 0; --dimIdx) { + memsetStartIndices.insert(memsetStartIndices.begin(), + ConstantInteger::get(static_cast(tempLinearIndex % dimSizes[dimIdx]))); + tempLinearIndex /= dimSizes[dimIdx]; + } - // 将线性索引转换为多维索引 - for (int dimIdx = dimSizes.size() - 1; dimIdx >= 0; --dimIdx) { - currentIndices.insert(currentIndices.begin(), - ConstantInteger::get(static_cast(tempLinearIndex % dimSizes[dimIdx]))); - tempLinearIndex /= dimSizes[dimIdx]; - } - - // 对于局部数组,alloca 本身就是 GEP 的基指针。 - // GEP 的第一个索引必须是 0,用于“步过”整个数组。 - std::vector gepIndicesForInit; - gepIndicesForInit.push_back(ConstantInteger::get(0)); - gepIndicesForInit.insert(gepIndicesForInit.end(), currentIndices.begin(), currentIndices.end()); - - // 计算元素的地址 - Value* elementAddress = getGEPAddressInst(alloca, gepIndicesForInit); - // 生成 store 指令 - builder.createStoreInst(currentValue, elementAddress); + // 构造 GEP 计算 memset 的起始地址 + std::vector gepIndicesForMemset; + gepIndicesForMemset.push_back(ConstantInteger::get(0)); // 跳过 alloca 类型 + gepIndicesForMemset.insert(gepIndicesForMemset.end(), memsetStartIndices.begin(), + memsetStartIndices.end()); + + Value *memsetPtr = builder.createGetElementPtrInst(alloca, gepIndicesForMemset); + + // 计算 memset 的字节数 = 元素个数 × 元素大小 + Type *elementType = type; + ; + uint64_t elementSize = elementType->getSize(); + Value *size = ConstantInteger::get(currentRepeatNum * elementSize); + + // 生成 memset 指令(假设你的 IRBuilder 有 createMemset 方法) + builder.createMemsetInst(memsetPtr, ConstantInteger::get(0), size, ConstantInteger::get(0)); + + // 跳过这些已处理的0 + linearIndexOffset += currentRepeatNum; + continue; // 直接进入下一次循环 } - // 更新线性索引偏移量,以便下一次迭代从正确的位置开始 - linearIndexOffset += currentRepeatNum; - } + } + for (unsigned i = 0; i < currentRepeatNum; ++i) { + std::vector currentIndices; + int tempLinearIndex = linearIndexOffset + i; // 使用偏移量和当前重复次数内的索引 + // 将线性索引转换为多维索引 + for (int dimIdx = dimSizes.size() - 1; dimIdx >= 0; --dimIdx) { + currentIndices.insert(currentIndices.begin(), + ConstantInteger::get(static_cast(tempLinearIndex % dimSizes[dimIdx]))); + tempLinearIndex /= dimSizes[dimIdx]; + } + + // 对于局部数组,alloca 本身就是 GEP 的基指针。 + // GEP 的第一个索引必须是 0,用于“步过”整个数组。 + std::vector gepIndicesForInit; + gepIndicesForInit.push_back(ConstantInteger::get(0)); + gepIndicesForInit.insert(gepIndicesForInit.end(), currentIndices.begin(), currentIndices.end()); + + // 计算元素的地址 + Value *elementAddress = getGEPAddressInst(alloca, gepIndicesForInit); + // 生成 store 指令 + builder.createStoreInst(currentValue, elementAddress); + } + // 更新线性索引偏移量,以便下一次迭代从正确的位置开始 + linearIndexOffset += currentRepeatNum; + } } } }