引入了常量池优化,修改constvalue类并对IR生成修复,能够编译通过

This commit is contained in:
rain2133
2025-06-19 00:18:58 +08:00
parent 1aa785efc3
commit 1de8c0e7d7
4 changed files with 194 additions and 64 deletions

View File

@ -13,6 +13,8 @@
#include <string>
#include <utility>
#include <vector>
#include <variant>
#include <iomanip>
using namespace std;
namespace sysy {
@ -80,6 +82,15 @@ Type *Type::getFunctionType(Type *returnType,
return FunctionType::get(returnType, paramTypes);
}
Type *Type::getArrayType(Type *elementType, const vector<int> &dims) {
// forward to ArrayType
return ArrayType::get(elementType, dims);
}
ArrayType* Type::asArrayType() const {
return isArray() ? dynamic_cast<ArrayType*>(const_cast<Type*>(this)) : nullptr;
}
int Type::getSize() const {
switch (kind) {
case kInt:
@ -177,33 +188,75 @@ bool Value::isConstant() const {
return false;
}
ConstantValue *ConstantValue::get(int value) {
static std::map<int, std::unique_ptr<ConstantValue>> intConstants;
auto iter = intConstants.find(value);
if (iter != intConstants.end())
return iter->second.get();
auto constant = new ConstantValue(value);
assert(constant);
auto result = intConstants.emplace(value, constant);
return result.first->second.get();
// 定义静态常量池
std::unordered_map<ConstantValueKey, ConstantValue*, ConstantValue::ConstantValueHash> ConstantValue::constantPool;
// 常量池实现
ConstantValue* ConstantValue::get(Type* type, int32_t value) {
ConstantValueKey key = {type, ConstantValVariant(value)};
if (auto it = constantPool.find(key); it != constantPool.end()) {
return it->second;
}
ConstantValue* constant = new ConstantInt(type, value);
constantPool[key] = constant;
return constant;
}
ConstantValue *ConstantValue::get(float value) {
static std::map<float, std::unique_ptr<ConstantValue>> floatConstants;
auto iter = floatConstants.find(value);
if (iter != floatConstants.end())
return iter->second.get();
auto constant = new ConstantValue(value);
assert(constant);
auto result = floatConstants.emplace(value, constant);
return result.first->second.get();
ConstantValue* ConstantValue::get(Type* type, float value) {
ConstantValueKey key = {type, ConstantValVariant(value)};
if (auto it = constantPool.find(key); it != constantPool.end()) {
return it->second;
}
ConstantValue* constant = new ConstantFloat(type, value);
constantPool[key] = constant;
return constant;
}
void ConstantValue::print(ostream &os) const {
if (isInt())
os << getInt();
else
os << getFloat();
ConstantValue* ConstantValue::getInt32(int32_t value) {
return get(Type::getIntType(), value);
}
ConstantValue* ConstantValue::getFloat32(float value) {
return get(Type::getFloatType(), value);
}
ConstantValue* ConstantValue::getTrue() {
return get(Type::getIntType(), 1);
}
ConstantValue* ConstantValue::getFalse() {
return get(Type::getIntType(), 0);
}
void ConstantValue::print(std::ostream &os) const {
// 根据类型调用相应的打印实现
if (auto intConst = dynamic_cast<const ConstantInt*>(this)) {
intConst->print(os);
}
else if (auto floatConst = dynamic_cast<const ConstantFloat*>(this)) {
floatConst->print(os);
}
else {
os << "???"; // 未知常量类型
}
}
void ConstantInt::print(std::ostream &os) const {
os << value;
}
void ConstantFloat::print(std::ostream &os) const {
if (value == static_cast<int>(value)) {
os << value << ".0"; // 确保输出带小数点
} else {
os << std::fixed << std::setprecision(6) << value;
}
}
Argument::Argument(Type *type, BasicBlock *block, int index,

145
src/IR.h
View File

@ -11,6 +11,9 @@
#include <string>
#include <type_traits>
#include <vector>
#include <variant>
#include <unordered_map>
#include <cmath>
namespace sysy {
@ -33,6 +36,9 @@ namespace sysy {
* include `int`, `float`, `void`, and the label type representing branch
* targets
*/
class ArrayType;
class Type {
public:
enum Kind {
@ -58,9 +64,7 @@ public:
static Type *getPointerType(Type *baseType);
static Type *getFunctionType(Type *returnType,
const std::vector<Type *> &paramTypes = {});
static Type *getArrayType(Type *elementType, const std::vector<int> &dims = {}) {
return ArrayType::get(elementType, dims);
}
static Type *getArrayType(Type *elementType, const std::vector<int> &dims = {});
public:
Kind getKind() const { return kind; }
@ -73,9 +77,9 @@ public:
bool isArray() const { return kind == kArray; }
bool isIntOrFloat() const { return kind == kInt or kind == kFloat; }
int getSize() const;
ArrayType *asArrayType() const {
return isArray() ? static_cast<ArrayType*>(const_cast<Type*>(this)) : nullptr;
}
ArrayType* asArrayType() const;
template <typename T>
std::enable_if_t<std::is_base_of_v<Type, T>, T *> as() const {
return dynamic_cast<T *>(const_cast<Type *>(this));
@ -335,41 +339,114 @@ public:
* `ConstantValue`s are not defined by instructions, and do not use any other
* `Value`s. It's type is either `int` or `float`.
*/
class ConstantInt;
class ConstantFloat;
//常量池优化
using ConstantValVariant = std::variant<int32_t, float>;
using ConstantValueKey = std::pair<Type*, ConstantValVariant>;
class ConstantValue : public Value {
protected:
union {
int iScalar;
float fScalar;
};
protected:
ConstantValue(int value)
: Value(kConstant, Type::getIntType(), ""), iScalar(value) {}
ConstantValue(float value)
: Value(kConstant, Type::getFloatType(), ""), fScalar(value) {}
ConstantValue(Type* type)
: Value(kConstant, type, "") {}
public:
static ConstantValue *get(int value);
static ConstantValue *get(float value);
public:
static bool classof(const Value *value) {
struct ConstantValueHash;
struct ConstantValueEqual;
static std::unordered_map<ConstantValueKey, ConstantValue*, ConstantValueHash> constantPool;
virtual ~ConstantValue() = default;
static ConstantValue* get(Type* type, int32_t value);
static ConstantValue* get(Type* type, float value);
static bool classof(const Value* value) {
return value->getKind() == kConstant;
}
virtual int32_t getInt() const = 0;
virtual float getFloat() const = 0;
virtual bool isZero() const = 0;
virtual bool isOne() const = 0;
static ConstantValue* getInt32(int32_t value);
static ConstantValue* getFloat32(float value);
static ConstantValue* getTrue() ;
static ConstantValue* getFalse();
public:
int getInt() const {
assert(isInt());
return iScalar;
}
float getFloat() const {
assert(isFloat());
return fScalar;
}
public:
void print(std::ostream &os) const override;
}; // class ConstantValue
};
struct ConstantValue::ConstantValueHash {
std::size_t operator()(const ConstantValueKey& key) const {
std::size_t typeHash = std::hash<Type*>{}(key.first);
std::size_t valHash = 0;
if (key.first->isInt()) {
valHash = std::hash<int32_t>{}(std::get<int32_t>(key.second));
} else if (key.first->isFloat()) {
// 修复5: 确保float哈希正确
valHash = std::hash<float>{}(std::get<float>(key.second));
}
return typeHash ^ (valHash << 1);
}
};
struct ConstantValue::ConstantValueEqual {
bool operator()(const ConstantValueKey& lhs, const ConstantValueKey& rhs) const {
if (lhs.first != rhs.first) return false;
if (lhs.first->isInt()) {
return std::get<int32_t>(lhs.second) == std::get<int32_t>(rhs.second);
} else if (lhs.first->isFloat()) {
// 修复6: 使用浮点比较容差
const float eps = 1e-6;
return fabs(std::get<float>(lhs.second) - std::get<float>(rhs.second)) < eps;
}
return false;
}
};
class ConstantInt : public ConstantValue {
int32_t value;
friend class ConstantValue;
protected:
ConstantInt(Type* type, int32_t value)
: ConstantValue(type), value(value) {
assert(type->isInt() && "Invalid type for ConstantInt");
}
public:
static ConstantInt* get(Type* type, int32_t value);
int32_t getInt() const override { return value; }
float getFloat() const override { return static_cast<float>(value); }
bool isZero() const override { return value == 0; }
bool isOne() const override { return value == 1; }
void print(std::ostream& os) const override ;
};
class ConstantFloat : public ConstantValue {
float value;
friend class ConstantValue;
protected:
ConstantFloat(Type* type, float value)
: ConstantValue(type), value(value) {
assert(type->isFloat() && "Invalid type for ConstantFloat");
}
public:
static ConstantFloat* get(Type* type, float value);
int32_t getInt() const override { return static_cast<int32_t>(value); }
float getFloat() const override { return value; }
bool isZero() const override { return value == 0.0f; }
bool isOne() const override { return value == 1.0f; }
void print(std::ostream& os) const override;
};
class BasicBlock;
/*!

View File

@ -91,7 +91,7 @@ std::any LLVMIRGenerator::visitVarDecl(SysYParser::VarDeclContext* ctx) {
if (varDef->ASSIGN()) {
value = std::any_cast<std::string>(varDef->initVal()->accept(this));
if (irTmpTable.find(value) != irTmpTable.end() && isa<sysy::ConstantValue>(irTmpTable[value])) {
if (irTmpTable.find(value) != irTmpTable.end() && sysy::isa<sysy::ConstantValue>(irTmpTable[value])) {
initValue = irTmpTable[value];
}
}
@ -134,7 +134,7 @@ std::any LLVMIRGenerator::visitConstDecl(SysYParser::ConstDeclContext* ctx) {
try {
value = std::any_cast<std::string>(constDef->constInitVal()->accept(this));
if (isa<sysy::ConstantValue>(irTmpTable[value])) {
if (sysy::isa<sysy::ConstantValue>(irTmpTable[value])) {
initValue = irTmpTable[value];
}
} catch (...) {
@ -310,7 +310,7 @@ std::any LLVMIRGenerator::visitFuncDef(SysYParser::FuncDefContext* ctx) {
} else {
irStream << " ret " << currentReturnType << " 0\n";
sysy::IRBuilder builder(currentIRBlock);
builder.createReturnInst(sysy::ConstantValue::get(0));
builder.createReturnInst(sysy::ConstantValue::get(getIRType("int"),0));
}
}
irStream << "}\n";
@ -524,10 +524,10 @@ std::any LLVMIRGenerator::visitNumber(SysYParser::NumberContext* ctx) {
sysy::Value* irValue = nullptr;
if (ctx->ILITERAL()) {
value = ctx->ILITERAL()->getText();
irValue = sysy::ConstantValue::get(std::stoi(value));
irValue = sysy::ConstantValue::get(getIRType("int"), std::stoi(value));
} else if (ctx->FLITERAL()) {
value = ctx->FLITERAL()->getText();
irValue = sysy::ConstantValue::get(std::stof(value));
irValue = sysy::ConstantValue::get(getIRType("float"), std::stof(value));
} else {
value = "";
}

View File

@ -552,10 +552,10 @@ std::any SysYIRGenerator::visitNumber(SysYParser::NumberContext *ctx) {
} else if (text.find("0") == 0) {
base = 8;
}
res = ConstantValue::get((int)std::stol(text, 0, base));
res = ConstantValue::get(Type::getIntType() ,(int)std::stol(text, 0, base));
} else if (auto fLiteral = ctx->FLITERAL()) {
const auto text = fLiteral->getText();
res = ConstantValue::get((float)std::stof(text));
res = ConstantValue::get(Type::getFloatType(), (float)std::stof(text));
}
cout << "number: ";
res->print(cout);