Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR, Relay] improve bfloat16 support #10112

Merged
merged 16 commits into from
Feb 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -835,10 +835,18 @@ TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s
Span span = Span());

// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
static const Op& op = Op::Get("tir." #OpName); \
return tir::Call(x.dtype(), op, {x}, span); \
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
static const Op& op = Op::Get("tir." #OpName); \
if (x.dtype().is_bfloat16()) { \
DataType srcType = x.dtype(); \
DataType dstType(kDLFloat, 32, srcType.lanes()); \
PrimExpr castX = tir::Cast(dstType, {x}, span); \
PrimExpr result = tir::Call(dstType, op, {castX}, span); \
return tir::Cast(srcType, {result}, span); \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just do tir::Cast("bfloat16", {result}, span). We use camel_case.

Can be fixed in a follow up.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, Thanks. I will refine the code style in next PR.

} else { \
return tir::Call(x.dtype(), op, {x}, span); \
} \
}

TVM_DECLARE_INTRIN_UNARY(exp);
Expand Down
2 changes: 1 addition & 1 deletion src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) {

// x / 2.0 = x * 0.5
if (const FloatImmNode* ptr = op->b.as<FloatImmNode>()) {
ICHECK(op->dtype.is_float() ||
ICHECK(op->dtype.is_float() || op->dtype.is_bfloat16() ||
datatype::Registry::Global()->GetTypeRegistered(op->dtype.code()));
return op->a * make_const(op->b.dtype(), 1.0 / ptr->value);
}
Expand Down
20 changes: 10 additions & 10 deletions src/auto_scheduler/feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,14 +246,14 @@ int64_t GetLoopExtent(const ForNode* node) {
// Count math ops in an expr
class MathOpCounter : public StmtExprVisitor {
public:
#define VisitBinary(Type, float_ct, int_ct) \
void VisitExpr_(const Type* op) final { \
if (op->a.dtype().is_float()) { \
float_ct++; \
} else { \
int_ct++; \
} \
StmtExprVisitor::VisitExpr_(op); \
#define VisitBinary(Type, float_ct, int_ct) \
void VisitExpr_(const Type* op) final { \
if (op->a.dtype().is_float() || op->a.dtype().is_bfloat16()) { \
float_ct++; \
} else { \
int_ct++; \
} \
StmtExprVisitor::VisitExpr_(op); \
}

VisitBinary(AddNode, float_addsub, int_addsub);
Expand Down Expand Up @@ -299,13 +299,13 @@ class MathOpCounter : public StmtExprVisitor {
effect_kind == CallEffectKind::kPure || effect_kind == CallEffectKind::kExprAnnotation;

if (is_pure) {
if (op->dtype.is_float()) {
if (op->dtype.is_float() || op->dtype.is_bfloat16()) {
float_math_func++;
} else {
int_math_func++;
}
} else {
if (op->dtype.is_float()) {
if (op->dtype.is_float() || op->dtype.is_bfloat16()) {
float_other_func++;
} else {
int_other_func++;
Expand Down
20 changes: 15 additions & 5 deletions src/autotvm/touch_extractor.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,27 +87,37 @@ class TouchExtractor : public FeatureVisitor {

// arithmetic stats
void VisitExpr_(const AddNode* op) final {
if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].add_ct++;
if (op->dtype.is_float() || op->dtype.is_bfloat16()) {
itervar_map[itervar_stack_.back()].add_ct++;
}
FeatureVisitor::VisitExpr_(op);
}

void VisitExpr_(const SubNode* op) final {
if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].add_ct++;
if (op->dtype.is_float() || op->dtype.is_bfloat16()) {
itervar_map[itervar_stack_.back()].add_ct++;
}
FeatureVisitor::VisitExpr_(op);
}

void VisitExpr_(const MulNode* op) final {
if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].mul_ct++;
if (op->dtype.is_float() || op->dtype.is_bfloat16()) {
itervar_map[itervar_stack_.back()].mul_ct++;
}
FeatureVisitor::VisitExpr_(op);
}

void VisitExpr_(const DivNode* op) final {
if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].div_ct++;
if (op->dtype.is_float() || op->dtype.is_bfloat16()) {
itervar_map[itervar_stack_.back()].div_ct++;
}
FeatureVisitor::VisitExpr_(op);
}

void VisitExpr_(const ModNode* op) final {
if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].div_ct++;
if (op->dtype.is_float() || op->dtype.is_bfloat16()) {
itervar_map[itervar_stack_.back()].div_ct++;
}
FeatureVisitor::VisitExpr_(op);
}

Expand Down
3 changes: 3 additions & 0 deletions src/contrib/hybrid/codegen_hybrid.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ void CodeGenHybrid::PrintType(DataType t, std::ostream& os) {
} else if (t.is_int()) {
os << "int";
ICHECK(t.bits() == 8 || t.bits() == 16 || t.bits() == 32 || t.bits() == 64);
} else if (t.is_bfloat16()) {
os << "bfloat";
ICHECK(t.bits() == 16);
} else {
ICHECK(t.is_uint()) << "Unsupported type " << t;
os << "uint";
Expand Down
2 changes: 2 additions & 0 deletions src/relay/backend/contrib/codegen_c/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,8 @@ class CodegenCBase {
dtype = "float";
} else if (runtime::TypeMatch(ttype->dtype, kDLFloat, 16)) {
dtype = "half";
} else if (runtime::TypeMatch(ttype->dtype, kDLBfloat, 16)) {
dtype = "bfloat";
} else if (runtime::TypeMatch(ttype->dtype, kDLInt, 32)) {
dtype = "int";
} else if (runtime::TypeMatch(ttype->dtype, kDLInt, 64)) {
Expand Down
2 changes: 2 additions & 0 deletions src/relay/backend/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ inline std::string DType2String(const tvm::DataType dtype) {
os << "int";
} else if (dtype.is_uint()) {
os << "uint";
} else if (dtype.is_bfloat16()) {
os << "bfloat";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
os << "bfloat";
os << "bfloat16";

} else if ((*GetPackedFunc("runtime._datatype_get_type_registered"))(dtype.code())) {
os << "custom["
<< (*GetPackedFunc("runtime._datatype_get_type_name"))(dtype.code()).operator std::string()
Expand Down
3 changes: 2 additions & 1 deletion src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1177,7 +1177,8 @@ bool NLLLossRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
<< ", weights shape = " << weights->shape);
return false;
}
if (!(predictions->dtype == weights->dtype && predictions->dtype.is_float())) {
if (!(predictions->dtype == weights->dtype &&
(predictions->dtype.is_float() || predictions->dtype.is_bfloat16()))) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we let is_float() be true for bfloat16 exprs?

Copy link
Contributor Author

@yangulei yangulei Feb 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer this way too, since they are all floating-point datatypes.
While there are some practical inconsistences so far, for example, if we let is_float() == true for bfloat16, then we cannot distinguish bfloat16 and float16 anymore as they both satisfy the condition is_float() == true && bits() == 16.

reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "NLLLossRel: predictions and weights should"
<< " be of the same floating type.");
Expand Down
26 changes: 26 additions & 0 deletions src/relay/transforms/pattern_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ namespace relay {
} else if (type == DataType::Float(16)) { \
typedef uint16_t DType; \
{ __VA_ARGS__ } \
} else if (type == DataType::BFloat(16)) { \
typedef uint16_t DType; \
{ __VA_ARGS__ } \
} else if (type == DataType::Int(64)) { \
typedef int64_t DType; \
{ __VA_ARGS__ } \
Expand Down Expand Up @@ -259,6 +262,11 @@ inline Constant MakeConstantScalar(DataType dtype, T value) {
// storage is uint16_t
*static_cast<DType*>(arr->data) =
__truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(static_cast<float>(value));
} else if (dtype == DataType::BFloat(16)) {
// convert to bfloat16
// storage is uint16_t
*static_cast<DType*>(arr->data) =
__truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 7>(static_cast<float>(value));
} else {
*static_cast<DType*>(arr->data) = value;
}
Expand Down Expand Up @@ -286,6 +294,12 @@ static inline Constant MakeConstantTensor(DataType dtype, std::vector<int64_t> s
*(static_cast<DType*>(arr->data) + i) =
__truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(
static_cast<float>(value[i]));
} else if (dtype == DataType::BFloat(16)) {
// convert to bfloat16
// storage is uint16_t
*(static_cast<DType*>(arr->data) + i) =
__truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 7>(
static_cast<float>(value[i]));
} else {
*(static_cast<DType*>(arr->data) + i) = value[i];
}
Expand Down Expand Up @@ -314,6 +328,12 @@ static inline Constant MakeConstantTensor(DataType dtype, std::vector<int64_t> s
*(static_cast<DType*>(arr->data) + i) =
__truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(
static_cast<float>(value[i]));
} else if (dtype == DataType::BFloat(16)) {
// convert to bfloat16
// storage is uint16_t
*(static_cast<DType*>(arr->data) + i) =
__truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 7>(
static_cast<float>(value[i]));
} else {
*(static_cast<DType*>(arr->data) + i) = value[i];
}
Expand Down Expand Up @@ -417,6 +437,12 @@ static inline dmlc::optional<long double> TryToScalar(const runtime::NDArray& ar
} else if (array->dtype.bits == 64) {
return dmlc::optional<long double>(reinterpret_cast<double*>(array->data)[i]);
}
} else if (array->dtype.code == kDLBfloat) {
if (array->dtype.bits == 16) {
return dmlc::optional<long double>(
__extendXfYf2__<uint16_t, uint16_t, 7, float, uint32_t, 23>(
reinterpret_cast<uint16_t*>(array->data)[i]));
}
}
return dmlc::optional<long double>();
}
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/crt/common/packed_func.c
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ DLDataType String2DLDataType(const char* s) {
} else if (!strncmp(s, "float", 5)) {
t.code = kDLFloat;
scan = s + 5;
} else if (!strncmp(s, "bfloat", 6)) {
t.code = kDLBfloat;
scan = s + 6;
} else if (!strncmp(s, "handle", 6)) {
t.code = kTVMOpaqueHandle;
t.bits = 64; // handle uses 64 bit by default.
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/vm/bytecode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,9 @@ void DLDatatypePrint(std::ostream& os, const DLDataType& dtype) {
case kDLFloat:
os << "float";
break;
case kDLBfloat:
os << "bfloat";
break;
}

os << int(dtype.bits);
Expand Down
14 changes: 14 additions & 0 deletions src/tir/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,16 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*)
!rtype.is_float()) {
// Cast int->float when the other operand is a float
rhs = cast(ltype, rhs);
} else if (!ltype.is_bfloat16() &&
(rtype.is_bfloat16() ||
datatype::Registry::Global()->GetTypeRegistered(rtype.code()))) {
// Cast int->bfloat16 when the other operand is a bfloat16
lhs = cast(rtype, lhs);
} else if ((ltype.is_bfloat16() ||
datatype::Registry::Global()->GetTypeRegistered(ltype.code())) &&
!rtype.is_bfloat16()) {
// Cast int->bfloat16 when the other operand is a bfloat16
rhs = cast(ltype, rhs);
} else if ((ltype.is_int() && rtype.is_int()) || (ltype.is_uint() && rtype.is_uint())) {
// Promote int to higher bits e.g. int8 + int16 --> int16 + int16
if (ltype.bits() < rtype.bits()) {
Expand Down Expand Up @@ -186,6 +196,8 @@ PrimExpr max_value(const DataType& dtype, Span span) {
} else if (dtype.bits() == 16) {
return FloatImm(dtype, 65504.0, span);
}
} else if (dtype.is_bfloat16()) {
return FloatImm(dtype, std::numeric_limits<float>::max(), span);
}
LOG(FATAL) << "Cannot decide max_value for type" << dtype;
return PrimExpr();
Expand Down Expand Up @@ -219,6 +231,8 @@ PrimExpr min_value(const DataType& dtype, Span span) {
} else if (dtype.bits() == 16) {
return FloatImm(dtype, -65504.0, span);
}
} else if (dtype.is_bfloat16()) {
return FloatImm(dtype, std::numeric_limits<float>::lowest(), span);
}
LOG(FATAL) << "Cannot decide min_value for type" << dtype;
return PrimExpr();
Expand Down
3 changes: 2 additions & 1 deletion src/tir/transforms/arg_binder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
IntImm(DataType::UInt(8), dtype.bits()) &&
TVMArrayGet(DataType::UInt(16), handle, builtin::kArrTypeLanes) ==
IntImm(DataType::UInt(16), dtype.lanes()));
if (!(dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1))) {
if (!(dtype == DataType::Int(1) || dtype == DataType::Int(4) || dtype == DataType::UInt(4) ||
dtype == DataType::UInt(16))) {
auto type_msg = tvm::tir::StringImm(type_err_msg.str());
asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop));
asserts_.emplace_back(AssertStmt(cond, type_msg, nop));
Expand Down
Loading