Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU compilation fixes #496

Merged
merged 15 commits into from
Nov 18, 2023
Merged
4 changes: 2 additions & 2 deletions codon/cir/analyze/dataflow/reaching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ struct BitSet {
return res;
}

void set(unsigned bit) { words.data()[bit / B] |= (1 << (bit % B)); }
void set(unsigned bit) { words.data()[bit / B] |= (1UL << (bit % B)); }

bool get(unsigned bit) const {
return (words.data()[bit / B] & (1 << (bit % B))) != 0;
return (words.data()[bit / B] & (1UL << (bit % B))) != 0;
}

bool equals(const BitSet &other, unsigned size) {
Expand Down
8 changes: 8 additions & 0 deletions codon/cir/llvm/gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,14 @@ void moduleToPTX(llvm::Module *M, const std::string &filename,
linkLibdevice(M, libdevice);
remapFunctions(M);

// Strip debug info and remove noinline from functions (added in debug mode).
// Also, tell LLVM that all functions will return.
for (auto &F : *M) {
F.removeFnAttr(llvm::Attribute::AttrKind::NoInline);
F.setWillReturn();
}
llvm::StripDebugInfo(*M);

// Run NVPTX passes and general opt pipeline.
{
llvm::LoopAnalysisManager lam;
Expand Down
28 changes: 28 additions & 0 deletions codon/cir/llvm/llvisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2086,6 +2086,18 @@ llvm::Type *LLVMVisitor::getLLVMType(types::Type *t) {
return B->getFloatTy();
}

if (auto *x = cast<types::Float16Type>(t)) {
return B->getHalfTy();
}

if (auto *x = cast<types::BFloat16Type>(t)) {
return B->getBFloatTy();
}

if (auto *x = cast<types::Float128Type>(t)) {
return llvm::Type::getFP128Ty(*context);
}

if (auto *x = cast<types::BoolType>(t)) {
return B->getInt8Ty();
}
Expand Down Expand Up @@ -2203,6 +2215,22 @@ llvm::DIType *LLVMVisitor::getDITypeHelper(
x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_float);
}

if (auto *x = cast<types::Float16Type>(t)) {
return db.builder->createBasicType(
x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_float);
}

if (auto *x = cast<types::BFloat16Type>(t)) {
return db.builder->createBasicType(
x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_float);
}

if (auto *x = cast<types::Float128Type>(t)) {
return db.builder->createBasicType(x->getName(),
layout.getTypeAllocSizeInBits(type),
llvm::dwarf::DW_ATE_HP_float128);
}

if (auto *x = cast<types::BoolType>(t)) {
return db.builder->createBasicType(
x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_boolean);
Expand Down
21 changes: 21 additions & 0 deletions codon/cir/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ const std::string Module::BYTE_NAME = "byte";
const std::string Module::INT_NAME = "int";
const std::string Module::FLOAT_NAME = "float";
const std::string Module::FLOAT32_NAME = "float32";
const std::string Module::FLOAT16_NAME = "float16";
const std::string Module::BFLOAT16_NAME = "bfloat16";
const std::string Module::FLOAT128_NAME = "float128";
const std::string Module::STRING_NAME = "str";

const std::string Module::EQ_MAGIC_NAME = "__eq__";
Expand Down Expand Up @@ -239,6 +242,24 @@ types::Type *Module::getFloat32Type() {
return Nr<types::Float32Type>();
}

types::Type *Module::getFloat16Type() {
if (auto *rVal = getType(FLOAT16_NAME))
return rVal;
return Nr<types::Float16Type>();
}

types::Type *Module::getBFloat16Type() {
if (auto *rVal = getType(BFLOAT16_NAME))
return rVal;
return Nr<types::BFloat16Type>();
}

types::Type *Module::getFloat128Type() {
if (auto *rVal = getType(FLOAT128_NAME))
return rVal;
return Nr<types::Float128Type>();
}

types::Type *Module::getStringType() {
if (auto *rVal = getType(STRING_NAME))
return rVal;
Expand Down
9 changes: 9 additions & 0 deletions codon/cir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class Module : public AcceptorExtend<Module, Node> {
static const std::string INT_NAME;
static const std::string FLOAT_NAME;
static const std::string FLOAT32_NAME;
static const std::string FLOAT16_NAME;
static const std::string BFLOAT16_NAME;
static const std::string FLOAT128_NAME;
static const std::string STRING_NAME;

static const std::string EQ_MAGIC_NAME;
Expand Down Expand Up @@ -338,6 +341,12 @@ class Module : public AcceptorExtend<Module, Node> {
types::Type *getFloatType();
/// @return the float32 type
types::Type *getFloat32Type();
/// @return the float16 type
types::Type *getFloat16Type();
/// @return the bfloat16 type
types::Type *getBFloat16Type();
/// @return the float128 type
types::Type *getFloat128Type();
/// @return the string type
types::Type *getStringType();
/// Gets a pointer type.
Expand Down
6 changes: 6 additions & 0 deletions codon/cir/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ const char FloatType::NodeId = 0;

const char Float32Type::NodeId = 0;

const char Float16Type::NodeId = 0;

const char BFloat16Type::NodeId = 0;

const char Float128Type::NodeId = 0;

const char BoolType::NodeId = 0;

const char ByteType::NodeId = 0;
Expand Down
27 changes: 27 additions & 0 deletions codon/cir/types/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,33 @@ class Float32Type : public AcceptorExtend<Float32Type, PrimitiveType> {
Float32Type() : AcceptorExtend("float32") {}
};

/// Float16 type (16-bit float)
class Float16Type : public AcceptorExtend<Float16Type, PrimitiveType> {
public:
static const char NodeId;

/// Constructs a float16 type.
Float16Type() : AcceptorExtend("float16") {}
};

/// BFloat16 type (16-bit brain float)
class BFloat16Type : public AcceptorExtend<BFloat16Type, PrimitiveType> {
public:
static const char NodeId;

/// Constructs a bfloat16 type.
BFloat16Type() : AcceptorExtend("bfloat16") {}
};

/// Float128 type (128-bit float)
class Float128Type : public AcceptorExtend<Float128Type, PrimitiveType> {
public:
static const char NodeId;

/// Constructs a float128 type.
Float128Type() : AcceptorExtend("float128") {}
};

/// Bool type (8-bit unsigned integer; either 0 or 1)
class BoolType : public AcceptorExtend<BoolType, PrimitiveType> {
public:
Expand Down
9 changes: 9 additions & 0 deletions codon/cir/util/format.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,15 @@ class FormatVisitor : util::ConstVisitor {
void visit(const types::Float32Type *v) override {
fmt::print(os, FMT_STRING("(float32 '\"{}\")"), v->referenceString());
}
void visit(const types::Float16Type *v) override {
fmt::print(os, FMT_STRING("(float16 '\"{}\")"), v->referenceString());
}
void visit(const types::BFloat16Type *v) override {
fmt::print(os, FMT_STRING("(bfloat16 '\"{}\")"), v->referenceString());
}
void visit(const types::Float128Type *v) override {
fmt::print(os, FMT_STRING("(float128 '\"{}\")"), v->referenceString());
}
void visit(const types::BoolType *v) override {
fmt::print(os, FMT_STRING("(bool '\"{}\")"), v->referenceString());
}
Expand Down
6 changes: 6 additions & 0 deletions codon/cir/util/visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ void Visitor::visit(types::PrimitiveType *x) { defaultVisit(x); }
void Visitor::visit(types::IntType *x) { defaultVisit(x); }
void Visitor::visit(types::FloatType *x) { defaultVisit(x); }
void Visitor::visit(types::Float32Type *x) { defaultVisit(x); }
void Visitor::visit(types::Float16Type *x) { defaultVisit(x); }
void Visitor::visit(types::BFloat16Type *x) { defaultVisit(x); }
void Visitor::visit(types::Float128Type *x) { defaultVisit(x); }
void Visitor::visit(types::BoolType *x) { defaultVisit(x); }
void Visitor::visit(types::ByteType *x) { defaultVisit(x); }
void Visitor::visit(types::VoidType *x) { defaultVisit(x); }
Expand Down Expand Up @@ -114,6 +117,9 @@ void ConstVisitor::visit(const types::PrimitiveType *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::IntType *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::FloatType *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::Float32Type *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::Float16Type *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::BFloat16Type *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::Float128Type *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::BoolType *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::ByteType *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::VoidType *x) { defaultVisit(x); }
Expand Down
9 changes: 9 additions & 0 deletions codon/cir/util/visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ class PrimitiveType;
class IntType;
class FloatType;
class Float32Type;
class Float16Type;
class BFloat16Type;
class Float128Type;
class BoolType;
class ByteType;
class VoidType;
Expand Down Expand Up @@ -152,6 +155,9 @@ class Visitor {
VISIT(types::IntType);
VISIT(types::FloatType);
VISIT(types::Float32Type);
VISIT(types::Float16Type);
VISIT(types::BFloat16Type);
VISIT(types::Float128Type);
VISIT(types::BoolType);
VISIT(types::ByteType);
VISIT(types::VoidType);
Expand Down Expand Up @@ -229,6 +235,9 @@ class ConstVisitor {
CONST_VISIT(types::IntType);
CONST_VISIT(types::FloatType);
CONST_VISIT(types::Float32Type);
CONST_VISIT(types::Float16Type);
CONST_VISIT(types::BFloat16Type);
CONST_VISIT(types::Float128Type);
CONST_VISIT(types::BoolType);
CONST_VISIT(types::ByteType);
CONST_VISIT(types::VoidType);
Expand Down
6 changes: 4 additions & 2 deletions codon/parser/visitors/typecheck/call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,7 @@ ExprPtr TypecheckVisitor::transformArray(CallExpr *expr) {
/// `isinstance(obj, ByRef)` is True if `type(obj)` is a reference type
ExprPtr TypecheckVisitor::transformIsInstance(CallExpr *expr) {
expr->setType(unify(expr->type, ctx->getType("bool")));
expr->staticValue.type = StaticValue::INT; // prevent branching until this is resolved
transform(expr->args[0].value);
auto typ = expr->args[0].value->type->getClass();
if (!typ || !typ->canRealize())
Expand Down Expand Up @@ -947,10 +948,11 @@ ExprPtr TypecheckVisitor::transformStaticPrintFn(CallExpr *expr) {
auto &args = expr->args[0].value->getCall()->args;
for (size_t i = 0; i < args.size(); i++) {
realize(args[i].value->type);
fmt::print(stderr, "[static_print] {}: {} := {}{}\n", getSrcInfo(),
fmt::print(stderr, "[static_print] {}: {} := {}{} (iter: {})\n", getSrcInfo(),
FormatVisitor::apply(args[i].value),
args[i].value->type ? args[i].value->type->debugString(1) : "-",
args[i].value->isStatic() ? " [static]" : "");
args[i].value->isStatic() ? " [static]" : "",
ctx->getRealizationBase()->iteration);
}
return nullptr;
}
Expand Down
4 changes: 2 additions & 2 deletions codon/parser/visitors/typecheck/ctx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,13 @@ types::TypePtr TypeContext::instantiate(const SrcInfo &srcInfo,
if (auto l = i.second->getLink()) {
i.second->setSrcInfo(srcInfo);
if (l->defaultType) {
pendingDefaults.insert(i.second);
getRealizationBase()->pendingDefaults.insert(i.second);
}
}
}
if (t->getUnion() && !t->getUnion()->isSealed()) {
t->setSrcInfo(srcInfo);
pendingDefaults.insert(t);
getRealizationBase()->pendingDefaults.insert(t);
}
if (auto r = t->getRecord())
if (r->repeats && r->repeats->canRealize())
Expand Down
2 changes: 1 addition & 1 deletion codon/parser/visitors/typecheck/ctx.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ struct TypeContext : public Context<TypecheckItem> {
types::TypePtr returnType = nullptr;
/// Typechecking iteration
int iteration = 0;
std::set<types::TypePtr> pendingDefaults;
};
std::vector<RealizationBase> realizationBases;

/// The current type-checking level (for type instantiation and generalization).
int typecheckLevel;
std::set<types::TypePtr> pendingDefaults;
int changedNodes;

/// The age of the currently parsed statement.
Expand Down
17 changes: 13 additions & 4 deletions codon/parser/visitors/typecheck/infer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,9 @@ StmtPtr TypecheckVisitor::inferTypes(StmtPtr result, bool isToplevel) {
bool anotherRound = false;
// Special case: return type might have default as well (e.g., Union)
if (ctx->getRealizationBase()->returnType)
ctx->pendingDefaults.insert(ctx->getRealizationBase()->returnType);
for (auto &unbound : ctx->pendingDefaults) {
ctx->getRealizationBase()->pendingDefaults.insert(
ctx->getRealizationBase()->returnType);
for (auto &unbound : ctx->getRealizationBase()->pendingDefaults) {
if (auto tu = unbound->getUnion()) {
// Seal all dynamic unions after the iteration is over
if (!tu->isSealed()) {
Expand All @@ -113,7 +114,7 @@ StmtPtr TypecheckVisitor::inferTypes(StmtPtr result, bool isToplevel) {
anotherRound = true;
}
}
ctx->pendingDefaults.clear();
ctx->getRealizationBase()->pendingDefaults.clear();
if (anotherRound)
continue;

Expand Down Expand Up @@ -653,6 +654,12 @@ ir::types::Type *TypecheckVisitor::makeIRType(types::ClassType *t) {
handle = module->getFloatType();
} else if (t->name == "float32") {
handle = module->getFloat32Type();
} else if (t->name == "float16") {
handle = module->getFloat16Type();
} else if (t->name == "bfloat16") {
handle = module->getBFloat16Type();
} else if (t->name == "float128") {
handle = module->getFloat128Type();
} else if (t->name == "str") {
handle = module->getStringType();
} else if (t->name == "Int" || t->name == "UInt") {
Expand Down Expand Up @@ -936,7 +943,9 @@ TypecheckVisitor::generateSpecialAst(types::FuncType *type) {
N<ThrowStmt>(N<CallExpr>(N<IdExpr>("std.internal.types.error.TypeError"),
N<StringExpr>("invalid union call"))));
// suite->stmts.push_back(N<ReturnStmt>(N<NoneExpr>()));
unify(type->getRetType(), ctx->instantiate(ctx->getType("Union")));

auto ret = ctx->instantiate(ctx->getType("Union"));
unify(type->getRetType(), ret);
ast->suite = suite;
} else if (startswith(ast->name, "__internal__.get_union_first:0")) {
// def __internal__.get_union_first(union: Union[T0]):
Expand Down
4 changes: 3 additions & 1 deletion codon/parser/visitors/typecheck/typecheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ ExprPtr TypecheckVisitor::transform(ExprPtr &expr) {

auto typ = expr->type;
if (!expr->done) {
bool isIntStatic = expr->staticValue.type == StaticValue::INT;
TypecheckVisitor v(ctx, prependStmts);
v.setSrcInfo(expr->getSrcInfo());
ctx->pushSrcInfo(expr->getSrcInfo());
Expand All @@ -60,7 +61,8 @@ ExprPtr TypecheckVisitor::transform(ExprPtr &expr) {
expr = v.resultExpr;
}
seqassert(expr->type, "type not set for {}", expr);
unify(typ, expr->type);
if (!(isIntStatic && expr->type->is("bool")))
unify(typ, expr->type);
if (expr->done) {
ctx->changedNodes++;
}
Expand Down
21 changes: 21 additions & 0 deletions stdlib/internal/core.codon
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,27 @@ class float32:
MIN_10_EXP = -37
pass

@tuple
@__internal__
@__notuple__
class float16:
MIN_10_EXP = -4
pass

@tuple
@__internal__
@__notuple__
class bfloat16:
MIN_10_EXP = -37
pass

@tuple
@__internal__
@__notuple__
class float128:
MIN_10_EXP = -4931
pass

@tuple
@__internal__
class type:
Expand Down
2 changes: 1 addition & 1 deletion stdlib/internal/internal.codon
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ class __magic__:

# @dataclass parameter: gpu=True
def from_gpu_new(other: T, T: type) -> T:
__internal__.class_from_gpu_new(other)
return __internal__.class_from_gpu_new(other)

# @dataclass parameter: repr=True
def repr(slf) -> str:
Expand Down
Loading