Skip to content

Commit

Permalink
Add float16, bfloat16 and float128 IR types
Browse files Browse the repository at this point in the history
  • Loading branch information
arshajii committed Nov 11, 2023
1 parent f255698 commit a739bef
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 0 deletions.
28 changes: 28 additions & 0 deletions codon/cir/llvm/llvisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2086,6 +2086,18 @@ llvm::Type *LLVMVisitor::getLLVMType(types::Type *t) {
return B->getFloatTy();
}

if (auto *x = cast<types::Float16Type>(t)) {
return B->getHalfTy();
}

if (auto *x = cast<types::BFloat16Type>(t)) {
return B->getBFloatTy();
}

if (auto *x = cast<types::Float128Type>(t)) {
return llvm::Type::getFP128Ty(*context);
}

if (auto *x = cast<types::BoolType>(t)) {
return B->getInt8Ty();
}
Expand Down Expand Up @@ -2203,6 +2215,22 @@ llvm::DIType *LLVMVisitor::getDITypeHelper(
x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_float);
}

if (auto *x = cast<types::Float16Type>(t)) {
return db.builder->createBasicType(
x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_float);
}

if (auto *x = cast<types::BFloat16Type>(t)) {
return db.builder->createBasicType(
x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_float);
}

if (auto *x = cast<types::Float128Type>(t)) {
return db.builder->createBasicType(x->getName(),
layout.getTypeAllocSizeInBits(type),
llvm::dwarf::DW_ATE_HP_float128);
}

if (auto *x = cast<types::BoolType>(t)) {
return db.builder->createBasicType(
x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_boolean);
Expand Down
21 changes: 21 additions & 0 deletions codon/cir/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ const std::string Module::BYTE_NAME = "byte";
const std::string Module::INT_NAME = "int";
const std::string Module::FLOAT_NAME = "float";
const std::string Module::FLOAT32_NAME = "float32";
const std::string Module::FLOAT16_NAME = "float16";
const std::string Module::BFLOAT16_NAME = "bfloat16";
const std::string Module::FLOAT128_NAME = "float128";
const std::string Module::STRING_NAME = "str";

const std::string Module::EQ_MAGIC_NAME = "__eq__";
Expand Down Expand Up @@ -239,6 +242,24 @@ types::Type *Module::getFloat32Type() {
return Nr<types::Float32Type>();
}

types::Type *Module::getFloat16Type() {
if (auto *rVal = getType(FLOAT16_NAME))
return rVal;
return Nr<types::Float16Type>();
}

types::Type *Module::getBFloat16Type() {
if (auto *rVal = getType(BFLOAT16_NAME))
return rVal;
return Nr<types::BFloat16Type>();
}

types::Type *Module::getFloat128Type() {
if (auto *rVal = getType(FLOAT128_NAME))
return rVal;
return Nr<types::Float128Type>();
}

types::Type *Module::getStringType() {
if (auto *rVal = getType(STRING_NAME))
return rVal;
Expand Down
9 changes: 9 additions & 0 deletions codon/cir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class Module : public AcceptorExtend<Module, Node> {
static const std::string INT_NAME;
static const std::string FLOAT_NAME;
static const std::string FLOAT32_NAME;
static const std::string FLOAT16_NAME;
static const std::string BFLOAT16_NAME;
static const std::string FLOAT128_NAME;
static const std::string STRING_NAME;

static const std::string EQ_MAGIC_NAME;
Expand Down Expand Up @@ -338,6 +341,12 @@ class Module : public AcceptorExtend<Module, Node> {
types::Type *getFloatType();
/// @return the float32 type
types::Type *getFloat32Type();
/// @return the float16 type
types::Type *getFloat16Type();
/// @return the bfloat16 type
types::Type *getBFloat16Type();
/// @return the float128 type
types::Type *getFloat128Type();
/// @return the string type
types::Type *getStringType();
/// Gets a pointer type.
Expand Down
6 changes: 6 additions & 0 deletions codon/cir/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ const char FloatType::NodeId = 0;

const char Float32Type::NodeId = 0;

const char Float16Type::NodeId = 0;

const char BFloat16Type::NodeId = 0;

const char Float128Type::NodeId = 0;

const char BoolType::NodeId = 0;

const char ByteType::NodeId = 0;
Expand Down
27 changes: 27 additions & 0 deletions codon/cir/types/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,33 @@ class Float32Type : public AcceptorExtend<Float32Type, PrimitiveType> {
Float32Type() : AcceptorExtend("float32") {}
};

/// Float16 type (16-bit float)
class Float16Type : public AcceptorExtend<Float16Type, PrimitiveType> {
public:
static const char NodeId;

/// Constructs a float16 type.
Float16Type() : AcceptorExtend("float16") {}
};

/// BFloat16 type (16-bit brain float)
class BFloat16Type : public AcceptorExtend<BFloat16Type, PrimitiveType> {
public:
static const char NodeId;

/// Constructs a bfloat16 type.
BFloat16Type() : AcceptorExtend("bfloat16") {}
};

/// Float128 type (128-bit float)
class Float128Type : public AcceptorExtend<Float128Type, PrimitiveType> {
public:
static const char NodeId;

/// Constructs a float128 type.
Float128Type() : AcceptorExtend("float128") {}
};

/// Bool type (8-bit unsigned integer; either 0 or 1)
class BoolType : public AcceptorExtend<BoolType, PrimitiveType> {
public:
Expand Down
9 changes: 9 additions & 0 deletions codon/cir/util/format.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,15 @@ class FormatVisitor : util::ConstVisitor {
void visit(const types::Float32Type *v) override {
fmt::print(os, FMT_STRING("(float32 '\"{}\")"), v->referenceString());
}
void visit(const types::Float16Type *v) override {
fmt::print(os, FMT_STRING("(float16 '\"{}\")"), v->referenceString());
}
void visit(const types::BFloat16Type *v) override {
fmt::print(os, FMT_STRING("(bfloat16 '\"{}\")"), v->referenceString());
}
void visit(const types::Float128Type *v) override {
fmt::print(os, FMT_STRING("(float128 '\"{}\")"), v->referenceString());
}
void visit(const types::BoolType *v) override {
fmt::print(os, FMT_STRING("(bool '\"{}\")"), v->referenceString());
}
Expand Down
6 changes: 6 additions & 0 deletions codon/cir/util/visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ void Visitor::visit(types::PrimitiveType *x) { defaultVisit(x); }
void Visitor::visit(types::IntType *x) { defaultVisit(x); }
void Visitor::visit(types::FloatType *x) { defaultVisit(x); }
void Visitor::visit(types::Float32Type *x) { defaultVisit(x); }
void Visitor::visit(types::Float16Type *x) { defaultVisit(x); }
void Visitor::visit(types::BFloat16Type *x) { defaultVisit(x); }
void Visitor::visit(types::Float128Type *x) { defaultVisit(x); }
void Visitor::visit(types::BoolType *x) { defaultVisit(x); }
void Visitor::visit(types::ByteType *x) { defaultVisit(x); }
void Visitor::visit(types::VoidType *x) { defaultVisit(x); }
Expand Down Expand Up @@ -114,6 +117,9 @@ void ConstVisitor::visit(const types::PrimitiveType *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::IntType *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::FloatType *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::Float32Type *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::Float16Type *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::BFloat16Type *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::Float128Type *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::BoolType *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::ByteType *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::VoidType *x) { defaultVisit(x); }
Expand Down
9 changes: 9 additions & 0 deletions codon/cir/util/visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ class PrimitiveType;
class IntType;
class FloatType;
class Float32Type;
class Float16Type;
class BFloat16Type;
class Float128Type;
class BoolType;
class ByteType;
class VoidType;
Expand Down Expand Up @@ -152,6 +155,9 @@ class Visitor {
VISIT(types::IntType);
VISIT(types::FloatType);
VISIT(types::Float32Type);
VISIT(types::Float16Type);
VISIT(types::BFloat16Type);
VISIT(types::Float128Type);
VISIT(types::BoolType);
VISIT(types::ByteType);
VISIT(types::VoidType);
Expand Down Expand Up @@ -229,6 +235,9 @@ class ConstVisitor {
CONST_VISIT(types::IntType);
CONST_VISIT(types::FloatType);
CONST_VISIT(types::Float32Type);
CONST_VISIT(types::Float16Type);
CONST_VISIT(types::BFloat16Type);
CONST_VISIT(types::Float128Type);
CONST_VISIT(types::BoolType);
CONST_VISIT(types::ByteType);
CONST_VISIT(types::VoidType);
Expand Down

0 comments on commit a739bef

Please sign in to comment.