引入了常量池优化,修改constvalue类并对IR生成修复,能够编译通过
This commit is contained in:
99
src/IR.cpp
99
src/IR.cpp
@ -13,6 +13,8 @@
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <variant>
|
||||
#include <iomanip>
|
||||
using namespace std;
|
||||
|
||||
namespace sysy {
|
||||
@ -80,6 +82,15 @@ Type *Type::getFunctionType(Type *returnType,
|
||||
return FunctionType::get(returnType, paramTypes);
|
||||
}
|
||||
|
||||
Type *Type::getArrayType(Type *elementType, const vector<int> &dims) {
|
||||
// forward to ArrayType
|
||||
return ArrayType::get(elementType, dims);
|
||||
}
|
||||
|
||||
ArrayType* Type::asArrayType() const {
|
||||
return isArray() ? dynamic_cast<ArrayType*>(const_cast<Type*>(this)) : nullptr;
|
||||
}
|
||||
|
||||
int Type::getSize() const {
|
||||
switch (kind) {
|
||||
case kInt:
|
||||
@ -177,33 +188,75 @@ bool Value::isConstant() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
ConstantValue *ConstantValue::get(int value) {
|
||||
static std::map<int, std::unique_ptr<ConstantValue>> intConstants;
|
||||
auto iter = intConstants.find(value);
|
||||
if (iter != intConstants.end())
|
||||
return iter->second.get();
|
||||
auto constant = new ConstantValue(value);
|
||||
assert(constant);
|
||||
auto result = intConstants.emplace(value, constant);
|
||||
return result.first->second.get();
|
||||
|
||||
// 定义静态常量池
|
||||
std::unordered_map<ConstantValueKey, ConstantValue*, ConstantValue::ConstantValueHash> ConstantValue::constantPool;
|
||||
|
||||
// 常量池实现
|
||||
ConstantValue* ConstantValue::get(Type* type, int32_t value) {
|
||||
ConstantValueKey key = {type, ConstantValVariant(value)};
|
||||
|
||||
if (auto it = constantPool.find(key); it != constantPool.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
ConstantValue* constant = new ConstantInt(type, value);
|
||||
constantPool[key] = constant;
|
||||
return constant;
|
||||
}
|
||||
|
||||
ConstantValue *ConstantValue::get(float value) {
|
||||
static std::map<float, std::unique_ptr<ConstantValue>> floatConstants;
|
||||
auto iter = floatConstants.find(value);
|
||||
if (iter != floatConstants.end())
|
||||
return iter->second.get();
|
||||
auto constant = new ConstantValue(value);
|
||||
assert(constant);
|
||||
auto result = floatConstants.emplace(value, constant);
|
||||
return result.first->second.get();
|
||||
ConstantValue* ConstantValue::get(Type* type, float value) {
|
||||
ConstantValueKey key = {type, ConstantValVariant(value)};
|
||||
|
||||
if (auto it = constantPool.find(key); it != constantPool.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
ConstantValue* constant = new ConstantFloat(type, value);
|
||||
constantPool[key] = constant;
|
||||
return constant;
|
||||
}
|
||||
|
||||
void ConstantValue::print(ostream &os) const {
|
||||
if (isInt())
|
||||
os << getInt();
|
||||
else
|
||||
os << getFloat();
|
||||
ConstantValue* ConstantValue::getInt32(int32_t value) {
|
||||
return get(Type::getIntType(), value);
|
||||
}
|
||||
|
||||
ConstantValue* ConstantValue::getFloat32(float value) {
|
||||
return get(Type::getFloatType(), value);
|
||||
}
|
||||
|
||||
ConstantValue* ConstantValue::getTrue() {
|
||||
return get(Type::getIntType(), 1);
|
||||
}
|
||||
|
||||
ConstantValue* ConstantValue::getFalse() {
|
||||
return get(Type::getIntType(), 0);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void ConstantValue::print(std::ostream &os) const {
|
||||
// 根据类型调用相应的打印实现
|
||||
if (auto intConst = dynamic_cast<const ConstantInt*>(this)) {
|
||||
intConst->print(os);
|
||||
}
|
||||
else if (auto floatConst = dynamic_cast<const ConstantFloat*>(this)) {
|
||||
floatConst->print(os);
|
||||
}
|
||||
else {
|
||||
os << "???"; // 未知常量类型
|
||||
}
|
||||
}
|
||||
|
||||
void ConstantInt::print(std::ostream &os) const {
|
||||
os << value;
|
||||
}
|
||||
void ConstantFloat::print(std::ostream &os) const {
|
||||
if (value == static_cast<int>(value)) {
|
||||
os << value << ".0"; // 确保输出带小数点
|
||||
} else {
|
||||
os << std::fixed << std::setprecision(6) << value;
|
||||
}
|
||||
}
|
||||
|
||||
Argument::Argument(Type *type, BasicBlock *block, int index,
|
||||
|
||||
145
src/IR.h
145
src/IR.h
@ -11,6 +11,9 @@
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
#include <variant>
|
||||
#include <unordered_map>
|
||||
#include <cmath>
|
||||
|
||||
namespace sysy {
|
||||
|
||||
@ -33,6 +36,9 @@ namespace sysy {
|
||||
* include `int`, `float`, `void`, and the label type representing branch
|
||||
* targets
|
||||
*/
|
||||
|
||||
class ArrayType;
|
||||
|
||||
class Type {
|
||||
public:
|
||||
enum Kind {
|
||||
@ -58,9 +64,7 @@ public:
|
||||
static Type *getPointerType(Type *baseType);
|
||||
static Type *getFunctionType(Type *returnType,
|
||||
const std::vector<Type *> ¶mTypes = {});
|
||||
static Type *getArrayType(Type *elementType, const std::vector<int> &dims = {}) {
|
||||
return ArrayType::get(elementType, dims);
|
||||
}
|
||||
static Type *getArrayType(Type *elementType, const std::vector<int> &dims = {});
|
||||
|
||||
public:
|
||||
Kind getKind() const { return kind; }
|
||||
@ -73,9 +77,9 @@ public:
|
||||
bool isArray() const { return kind == kArray; }
|
||||
bool isIntOrFloat() const { return kind == kInt or kind == kFloat; }
|
||||
int getSize() const;
|
||||
ArrayType *asArrayType() const {
|
||||
return isArray() ? static_cast<ArrayType*>(const_cast<Type*>(this)) : nullptr;
|
||||
}
|
||||
|
||||
ArrayType* asArrayType() const;
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_base_of_v<Type, T>, T *> as() const {
|
||||
return dynamic_cast<T *>(const_cast<Type *>(this));
|
||||
@ -335,41 +339,114 @@ public:
|
||||
* `ConstantValue`s are not defined by instructions, and do not use any other
|
||||
* `Value`s. It's type is either `int` or `float`.
|
||||
*/
|
||||
|
||||
class ConstantInt;
|
||||
class ConstantFloat;
|
||||
//常量池优化
|
||||
|
||||
using ConstantValVariant = std::variant<int32_t, float>;
|
||||
using ConstantValueKey = std::pair<Type*, ConstantValVariant>;
|
||||
|
||||
class ConstantValue : public Value {
|
||||
protected:
|
||||
union {
|
||||
int iScalar;
|
||||
float fScalar;
|
||||
};
|
||||
|
||||
protected:
|
||||
ConstantValue(int value)
|
||||
: Value(kConstant, Type::getIntType(), ""), iScalar(value) {}
|
||||
ConstantValue(float value)
|
||||
: Value(kConstant, Type::getFloatType(), ""), fScalar(value) {}
|
||||
|
||||
ConstantValue(Type* type)
|
||||
: Value(kConstant, type, "") {}
|
||||
public:
|
||||
static ConstantValue *get(int value);
|
||||
static ConstantValue *get(float value);
|
||||
|
||||
public:
|
||||
static bool classof(const Value *value) {
|
||||
struct ConstantValueHash;
|
||||
struct ConstantValueEqual;
|
||||
|
||||
static std::unordered_map<ConstantValueKey, ConstantValue*, ConstantValueHash> constantPool;
|
||||
|
||||
virtual ~ConstantValue() = default;
|
||||
|
||||
static ConstantValue* get(Type* type, int32_t value);
|
||||
static ConstantValue* get(Type* type, float value);
|
||||
|
||||
static bool classof(const Value* value) {
|
||||
return value->getKind() == kConstant;
|
||||
}
|
||||
|
||||
virtual int32_t getInt() const = 0;
|
||||
virtual float getFloat() const = 0;
|
||||
virtual bool isZero() const = 0;
|
||||
virtual bool isOne() const = 0;
|
||||
|
||||
|
||||
static ConstantValue* getInt32(int32_t value);
|
||||
static ConstantValue* getFloat32(float value);
|
||||
static ConstantValue* getTrue() ;
|
||||
static ConstantValue* getFalse();
|
||||
|
||||
public:
|
||||
int getInt() const {
|
||||
assert(isInt());
|
||||
return iScalar;
|
||||
}
|
||||
float getFloat() const {
|
||||
assert(isFloat());
|
||||
return fScalar;
|
||||
}
|
||||
|
||||
public:
|
||||
void print(std::ostream &os) const override;
|
||||
}; // class ConstantValue
|
||||
};
|
||||
|
||||
struct ConstantValue::ConstantValueHash {
|
||||
std::size_t operator()(const ConstantValueKey& key) const {
|
||||
std::size_t typeHash = std::hash<Type*>{}(key.first);
|
||||
std::size_t valHash = 0;
|
||||
if (key.first->isInt()) {
|
||||
valHash = std::hash<int32_t>{}(std::get<int32_t>(key.second));
|
||||
} else if (key.first->isFloat()) {
|
||||
// 修复5: 确保float哈希正确
|
||||
valHash = std::hash<float>{}(std::get<float>(key.second));
|
||||
}
|
||||
return typeHash ^ (valHash << 1);
|
||||
}
|
||||
};
|
||||
|
||||
struct ConstantValue::ConstantValueEqual {
|
||||
bool operator()(const ConstantValueKey& lhs, const ConstantValueKey& rhs) const {
|
||||
if (lhs.first != rhs.first) return false;
|
||||
if (lhs.first->isInt()) {
|
||||
return std::get<int32_t>(lhs.second) == std::get<int32_t>(rhs.second);
|
||||
} else if (lhs.first->isFloat()) {
|
||||
// 修复6: 使用浮点比较容差
|
||||
const float eps = 1e-6;
|
||||
return fabs(std::get<float>(lhs.second) - std::get<float>(rhs.second)) < eps;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
class ConstantInt : public ConstantValue {
|
||||
int32_t value;
|
||||
friend class ConstantValue;
|
||||
|
||||
protected:
|
||||
ConstantInt(Type* type, int32_t value)
|
||||
: ConstantValue(type), value(value) {
|
||||
assert(type->isInt() && "Invalid type for ConstantInt");
|
||||
}
|
||||
public:
|
||||
static ConstantInt* get(Type* type, int32_t value);
|
||||
|
||||
int32_t getInt() const override { return value; }
|
||||
float getFloat() const override { return static_cast<float>(value); }
|
||||
bool isZero() const override { return value == 0; }
|
||||
bool isOne() const override { return value == 1; }
|
||||
|
||||
void print(std::ostream& os) const override ;
|
||||
};
|
||||
|
||||
class ConstantFloat : public ConstantValue {
|
||||
float value;
|
||||
friend class ConstantValue;
|
||||
|
||||
protected:
|
||||
ConstantFloat(Type* type, float value)
|
||||
: ConstantValue(type), value(value) {
|
||||
assert(type->isFloat() && "Invalid type for ConstantFloat");
|
||||
}
|
||||
public:
|
||||
static ConstantFloat* get(Type* type, float value);
|
||||
|
||||
int32_t getInt() const override { return static_cast<int32_t>(value); }
|
||||
float getFloat() const override { return value; }
|
||||
bool isZero() const override { return value == 0.0f; }
|
||||
bool isOne() const override { return value == 1.0f; }
|
||||
|
||||
void print(std::ostream& os) const override;
|
||||
};
|
||||
|
||||
class BasicBlock;
|
||||
/*!
|
||||
|
||||
@ -91,7 +91,7 @@ std::any LLVMIRGenerator::visitVarDecl(SysYParser::VarDeclContext* ctx) {
|
||||
|
||||
if (varDef->ASSIGN()) {
|
||||
value = std::any_cast<std::string>(varDef->initVal()->accept(this));
|
||||
if (irTmpTable.find(value) != irTmpTable.end() && isa<sysy::ConstantValue>(irTmpTable[value])) {
|
||||
if (irTmpTable.find(value) != irTmpTable.end() && sysy::isa<sysy::ConstantValue>(irTmpTable[value])) {
|
||||
initValue = irTmpTable[value];
|
||||
}
|
||||
}
|
||||
@ -134,7 +134,7 @@ std::any LLVMIRGenerator::visitConstDecl(SysYParser::ConstDeclContext* ctx) {
|
||||
|
||||
try {
|
||||
value = std::any_cast<std::string>(constDef->constInitVal()->accept(this));
|
||||
if (isa<sysy::ConstantValue>(irTmpTable[value])) {
|
||||
if (sysy::isa<sysy::ConstantValue>(irTmpTable[value])) {
|
||||
initValue = irTmpTable[value];
|
||||
}
|
||||
} catch (...) {
|
||||
@ -310,7 +310,7 @@ std::any LLVMIRGenerator::visitFuncDef(SysYParser::FuncDefContext* ctx) {
|
||||
} else {
|
||||
irStream << " ret " << currentReturnType << " 0\n";
|
||||
sysy::IRBuilder builder(currentIRBlock);
|
||||
builder.createReturnInst(sysy::ConstantValue::get(0));
|
||||
builder.createReturnInst(sysy::ConstantValue::get(getIRType("int"),0));
|
||||
}
|
||||
}
|
||||
irStream << "}\n";
|
||||
@ -524,10 +524,10 @@ std::any LLVMIRGenerator::visitNumber(SysYParser::NumberContext* ctx) {
|
||||
sysy::Value* irValue = nullptr;
|
||||
if (ctx->ILITERAL()) {
|
||||
value = ctx->ILITERAL()->getText();
|
||||
irValue = sysy::ConstantValue::get(std::stoi(value));
|
||||
irValue = sysy::ConstantValue::get(getIRType("int"), std::stoi(value));
|
||||
} else if (ctx->FLITERAL()) {
|
||||
value = ctx->FLITERAL()->getText();
|
||||
irValue = sysy::ConstantValue::get(std::stof(value));
|
||||
irValue = sysy::ConstantValue::get(getIRType("float"), std::stof(value));
|
||||
} else {
|
||||
value = "";
|
||||
}
|
||||
|
||||
@ -552,10 +552,10 @@ std::any SysYIRGenerator::visitNumber(SysYParser::NumberContext *ctx) {
|
||||
} else if (text.find("0") == 0) {
|
||||
base = 8;
|
||||
}
|
||||
res = ConstantValue::get((int)std::stol(text, 0, base));
|
||||
res = ConstantValue::get(Type::getIntType() ,(int)std::stol(text, 0, base));
|
||||
} else if (auto fLiteral = ctx->FLITERAL()) {
|
||||
const auto text = fLiteral->getText();
|
||||
res = ConstantValue::get((float)std::stof(text));
|
||||
res = ConstantValue::get(Type::getFloatType(), (float)std::stof(text));
|
||||
}
|
||||
cout << "number: ";
|
||||
res->print(cout);
|
||||
|
||||
Reference in New Issue
Block a user