diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 9cf7d0a3cd1f..60ad55102029 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -835,10 +835,18 @@ TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s Span span = Span()); // Intrinsic operators -#define TVM_DECLARE_INTRIN_UNARY(OpName) \ - inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \ - static const Op& op = Op::Get("tir." #OpName); \ - return tir::Call(x.dtype(), op, {x}, span); \ +#define TVM_DECLARE_INTRIN_UNARY(OpName) \ + inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \ + static const Op& op = Op::Get("tir." #OpName); \ + if (x.dtype().is_bfloat16()) { \ + DataType srcType = x.dtype(); \ + DataType dstType(kDLFloat, 32, srcType.lanes()); \ + PrimExpr castX = tir::Cast(dstType, {x}, span); \ + PrimExpr result = tir::Call(dstType, op, {castX}, span); \ + return tir::Cast(srcType, {result}, span); \ + } else { \ + return tir::Call(x.dtype(), op, {x}, span); \ + } \ } TVM_DECLARE_INTRIN_UNARY(exp); diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 4a99e10211b7..55f0cf5f3929 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -461,7 +461,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { // x / 2.0 = x * 0.5 if (const FloatImmNode* ptr = op->b.as()) { - ICHECK(op->dtype.is_float() || + ICHECK(op->dtype.is_float() || op->dtype.is_bfloat16() || datatype::Registry::Global()->GetTypeRegistered(op->dtype.code())); return op->a * make_const(op->b.dtype(), 1.0 / ptr->value); } diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index aaf7d48b10c5..5809888543c6 100755 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -246,14 +246,14 @@ int64_t GetLoopExtent(const ForNode* node) { // Count math ops in an expr class MathOpCounter : public StmtExprVisitor { public: -#define VisitBinary(Type, float_ct, int_ct) \ - void VisitExpr_(const Type* op) final { \ - if (op->a.dtype().is_float()) { \ - float_ct++; \ - } else { \ - int_ct++; \ - } \ - StmtExprVisitor::VisitExpr_(op); \ +#define VisitBinary(Type, float_ct, int_ct) \ + void VisitExpr_(const Type* op) final { \ + if (op->a.dtype().is_float() || op->a.dtype().is_bfloat16()) { \ + float_ct++; \ + } else { \ + int_ct++; \ + } \ + StmtExprVisitor::VisitExpr_(op); \ } VisitBinary(AddNode, float_addsub, int_addsub); @@ -299,13 +299,13 @@ class MathOpCounter : public StmtExprVisitor { effect_kind == CallEffectKind::kPure || effect_kind == CallEffectKind::kExprAnnotation; if (is_pure) { - if (op->dtype.is_float()) { + if (op->dtype.is_float() || op->dtype.is_bfloat16()) { float_math_func++; } else { int_math_func++; } } else { - if (op->dtype.is_float()) { + if (op->dtype.is_float() || op->dtype.is_bfloat16()) { float_other_func++; } else { int_other_func++; diff --git a/src/autotvm/touch_extractor.h b/src/autotvm/touch_extractor.h index 313e4d78d6e1..83260e1e0633 100644 --- a/src/autotvm/touch_extractor.h +++ b/src/autotvm/touch_extractor.h @@ -87,27 +87,37 @@ class TouchExtractor : public FeatureVisitor { // arithmetic stats void VisitExpr_(const AddNode* op) final { - if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].add_ct++; + if (op->dtype.is_float() || op->dtype.is_bfloat16()) { + itervar_map[itervar_stack_.back()].add_ct++; + } FeatureVisitor::VisitExpr_(op); } void VisitExpr_(const SubNode* op) final { - if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].add_ct++; + if (op->dtype.is_float() || op->dtype.is_bfloat16()) { + itervar_map[itervar_stack_.back()].add_ct++; + } FeatureVisitor::VisitExpr_(op); } void VisitExpr_(const MulNode* op) final { - if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].mul_ct++; + if (op->dtype.is_float() || op->dtype.is_bfloat16()) { + itervar_map[itervar_stack_.back()].mul_ct++; + } FeatureVisitor::VisitExpr_(op); } void VisitExpr_(const DivNode* op) final { - if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].div_ct++; + if (op->dtype.is_float() || op->dtype.is_bfloat16()) { + itervar_map[itervar_stack_.back()].div_ct++; + } FeatureVisitor::VisitExpr_(op); } void VisitExpr_(const ModNode* op) final { - if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].div_ct++; + if (op->dtype.is_float() || op->dtype.is_bfloat16()) { + itervar_map[itervar_stack_.back()].div_ct++; + } FeatureVisitor::VisitExpr_(op); } diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index 54edbaee35cd..5872a49968cb 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -69,6 +69,9 @@ void CodeGenHybrid::PrintType(DataType t, std::ostream& os) { } else if (t.is_int()) { os << "int"; ICHECK(t.bits() == 8 || t.bits() == 16 || t.bits() == 32 || t.bits() == 64); + } else if (t.is_bfloat16()) { + os << "bfloat"; + ICHECK(t.bits() == 16); } else { ICHECK(t.is_uint()) << "Unsupported type " << t; os << "uint"; diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index 3c6f810534ea..49a5bca068d1 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -363,6 +363,8 @@ class CodegenCBase { dtype = "float"; } else if (runtime::TypeMatch(ttype->dtype, kDLFloat, 16)) { dtype = "half"; + } else if (runtime::TypeMatch(ttype->dtype, kDLBfloat, 16)) { + dtype = "bfloat"; } else if (runtime::TypeMatch(ttype->dtype, kDLInt, 32)) { dtype = "int"; } else if (runtime::TypeMatch(ttype->dtype, kDLInt, 64)) { diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 658283b5dc36..02075616c6c8 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -302,6 +302,8 @@ inline std::string DType2String(const tvm::DataType dtype) { os << "int"; } else if (dtype.is_uint()) { os << "uint"; + } else if (dtype.is_bfloat16()) { + os << "bfloat"; } else if ((*GetPackedFunc("runtime._datatype_get_type_registered"))(dtype.code())) { os << "custom[" << (*GetPackedFunc("runtime._datatype_get_type_name"))(dtype.code()).operator std::string() diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 89ef2708ff27..a959dd7e9915 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -1177,7 +1177,8 @@ bool NLLLossRel(const Array& types, int num_inputs, const Attrs& attrs, << ", weights shape = " << weights->shape); return false; } - if (!(predictions->dtype == weights->dtype && predictions->dtype.is_float())) { + if (!(predictions->dtype == weights->dtype && + (predictions->dtype.is_float() || predictions->dtype.is_bfloat16()))) { reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) << "NLLLossRel: predictions and weights should" << " be of the same floating type."); diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index 69ad20a7ceaf..4084553419df 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -63,6 +63,9 @@ namespace relay { } else if (type == DataType::Float(16)) { \ typedef uint16_t DType; \ { __VA_ARGS__ } \ + } else if (type == DataType::BFloat(16)) { \ + typedef uint16_t DType; \ + { __VA_ARGS__ } \ } else if (type == DataType::Int(64)) { \ typedef int64_t DType; \ { __VA_ARGS__ } \ @@ -259,6 +262,11 @@ inline Constant MakeConstantScalar(DataType dtype, T value) { // storage is uint16_t *static_cast(arr->data) = __truncXfYf2__(static_cast(value)); + } else if (dtype == DataType::BFloat(16)) { + // convert to bfloat16 + // storage is uint16_t + *static_cast(arr->data) = + __truncXfYf2__(static_cast(value)); } else { *static_cast(arr->data) = value; } @@ -286,6 +294,12 @@ static inline Constant MakeConstantTensor(DataType dtype, std::vector s *(static_cast(arr->data) + i) = __truncXfYf2__( static_cast(value[i])); + } else if (dtype == DataType::BFloat(16)) { + // convert to bfloat16 + // storage is uint16_t + *(static_cast(arr->data) + i) = + __truncXfYf2__( + static_cast(value[i])); } else { *(static_cast(arr->data) + i) = value[i]; } @@ -314,6 +328,12 @@ static inline Constant MakeConstantTensor(DataType dtype, std::vector s *(static_cast(arr->data) + i) = __truncXfYf2__( static_cast(value[i])); + } else if (dtype == DataType::BFloat(16)) { + // convert to bfloat16 + // storage is uint16_t + *(static_cast(arr->data) + i) = + __truncXfYf2__( + static_cast(value[i])); } else { *(static_cast(arr->data) + i) = value[i]; } @@ -417,6 +437,12 @@ static inline dmlc::optional TryToScalar(const runtime::NDArray& ar } else if (array->dtype.bits == 64) { return dmlc::optional(reinterpret_cast(array->data)[i]); } + } else if (array->dtype.code == kDLBfloat) { + if (array->dtype.bits == 16) { + return dmlc::optional( + __extendXfYf2__( + reinterpret_cast(array->data)[i])); + } } return dmlc::optional(); } diff --git a/src/runtime/crt/common/packed_func.c b/src/runtime/crt/common/packed_func.c index e946cda9d9ae..645b22f3b255 100644 --- a/src/runtime/crt/common/packed_func.c +++ b/src/runtime/crt/common/packed_func.c @@ -49,6 +49,9 @@ DLDataType String2DLDataType(const char* s) { } else if (!strncmp(s, "float", 5)) { t.code = kDLFloat; scan = s + 5; + } else if (!strncmp(s, "bfloat", 6)) { + t.code = kDLBfloat; + scan = s + 6; } else if (!strncmp(s, "handle", 6)) { t.code = kTVMOpaqueHandle; t.bits = 64; // handle uses 64 bit by default. diff --git a/src/runtime/vm/bytecode.cc b/src/runtime/vm/bytecode.cc index f83e27d2c11d..a2fa478ac6c8 100644 --- a/src/runtime/vm/bytecode.cc +++ b/src/runtime/vm/bytecode.cc @@ -497,6 +497,9 @@ void DLDatatypePrint(std::ostream& os, const DLDataType& dtype) { case kDLFloat: os << "float"; break; + case kDLBfloat: + os << "bfloat"; + break; } os << int(dtype.bits); diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index d08bef2ab91a..1a9a73e9dc94 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -128,6 +128,16 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) !rtype.is_float()) { // Cast int->float when the other operand is a float rhs = cast(ltype, rhs); + } else if (!ltype.is_bfloat16() && + (rtype.is_bfloat16() || + datatype::Registry::Global()->GetTypeRegistered(rtype.code()))) { + // Cast int->bfloat16 when the other operand is a bfloat16 + lhs = cast(rtype, lhs); + } else if ((ltype.is_bfloat16() || + datatype::Registry::Global()->GetTypeRegistered(ltype.code())) && + !rtype.is_bfloat16()) { + // Cast int->bfloat16 when the other operand is a bfloat16 + rhs = cast(ltype, rhs); } else if ((ltype.is_int() && rtype.is_int()) || (ltype.is_uint() && rtype.is_uint())) { // Promote int to higher bits e.g. int8 + int16 --> int16 + int16 if (ltype.bits() < rtype.bits()) { @@ -186,6 +196,8 @@ PrimExpr max_value(const DataType& dtype, Span span) { } else if (dtype.bits() == 16) { return FloatImm(dtype, 65504.0, span); } + } else if (dtype.is_bfloat16()) { + return FloatImm(dtype, std::numeric_limits::max(), span); } LOG(FATAL) << "Cannot decide max_value for type" << dtype; return PrimExpr(); @@ -219,6 +231,8 @@ PrimExpr min_value(const DataType& dtype, Span span) { } else if (dtype.bits() == 16) { return FloatImm(dtype, -65504.0, span); } + } else if (dtype.is_bfloat16()) { + return FloatImm(dtype, std::numeric_limits::lowest(), span); } LOG(FATAL) << "Cannot decide min_value for type" << dtype; return PrimExpr(); diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc index d3ab32cbd7f9..1e566a980463 100644 --- a/src/tir/transforms/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -169,7 +169,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, IntImm(DataType::UInt(8), dtype.bits()) && TVMArrayGet(DataType::UInt(16), handle, builtin::kArrTypeLanes) == IntImm(DataType::UInt(16), dtype.lanes())); - if (!(dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1))) { + if (!(dtype == DataType::Int(1) || dtype == DataType::Int(4) || dtype == DataType::UInt(4) || + dtype == DataType::UInt(16))) { auto type_msg = tvm::tir::StringImm(type_err_msg.str()); asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop)); asserts_.emplace_back(AssertStmt(cond, type_msg, nop)); diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 76845cbebd2a..79c406818185 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -45,26 +45,6 @@ class BF16PromoteRewriter : public StmtExprMutator { Stmt operator()(Stmt s) { return VisitStmt(s); } - std::tuple DoCast(PrimExpr orig_a, PrimExpr orig_b, bool* is_bfloat16) { - auto a = this->VisitExpr(orig_a); - auto b = this->VisitExpr(orig_b); - *is_bfloat16 = false; - if (a->dtype.is_bfloat16()) { - ICHECK(b->dtype.is_bfloat16()); - *is_bfloat16 = true; - } else if (b->dtype.is_bfloat16()) { - ICHECK(a->dtype.is_bfloat16()); - *is_bfloat16 = true; - } - - if (*is_bfloat16) { - DataType fp32ty(kDLFloat, 32, 1); - a = Cast(fp32ty, a); - b = Cast(fp32ty, b); - } - return std::make_tuple(a, b); - } - PrimExpr VisitExpr_(const AddNode* op) final; PrimExpr VisitExpr_(const SubNode* op) final; PrimExpr VisitExpr_(const MulNode* op) final; @@ -77,45 +57,36 @@ class BF16PromoteRewriter : public StmtExprMutator { PrimExpr VisitExpr_(const GENode* op) final; }; -#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ - PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) { \ - PrimExpr a, b; \ - bool is_bfloat16; \ - std::tie(a, b) = DoCast(op->a, op->b, &is_bfloat16); \ - if (a.same_as(op->a) && b.same_as(op->b)) { \ - return GetRef(op); \ - } else { \ - auto ret = FUNC(a, b); \ - if (!is_bfloat16) \ - return ret; \ - else \ - return Cast(DataType(kDLBfloat, 16, 1), ret); \ - } \ - } - -#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(OP, FUNC) \ - PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) { \ - PrimExpr a, b; \ - bool is_bfloat16; \ - std::tie(a, b) = DoCast(op->a, op->b, &is_bfloat16); \ - if (a.same_as(op->a) && b.same_as(op->b)) { \ - return GetRef(op); \ - } else { \ - auto ret = FUNC(a, b); \ - return ret; \ - } \ +#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC, NEEDCAST) \ + PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) { \ + PrimExpr origin_a = this->VisitExpr(op->a); \ + PrimExpr origin_b = this->VisitExpr(op->b); \ + bool a_is_bfloat16 = origin_a->dtype.is_bfloat16(); \ + bool b_is_bfloat16 = origin_b->dtype.is_bfloat16(); \ + bool both_bfloat16 = a_is_bfloat16 && b_is_bfloat16; \ + bool none_bfloat16 = !(a_is_bfloat16 || b_is_bfloat16); \ + if (none_bfloat16) { \ + return GetRef(op); \ + } \ + DataType float32_dtype(kDLFloat, 32, 1); \ + PrimExpr float32_a = a_is_bfloat16 ? Cast(float32_dtype, origin_a) : origin_a; \ + PrimExpr float32_b = b_is_bfloat16 ? Cast(float32_dtype, origin_b) : origin_b; \ + PrimExpr result = FUNC(float32_a, float32_b); \ + DataType bfloat16_dtype(kDLBfloat, 16, 1); \ + bool do_cast = both_bfloat16 && NEEDCAST; \ + return do_cast ? Cast(bfloat16_dtype, result) : result; \ } -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(LTNode, operator<) // NOLINT(*) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(LENode, operator<=) // NOLINT(*) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(GTNode, operator>) // NOLINT(*) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(GENode, operator>=) // NOLINT(*) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+, true) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-, true) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*, true) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div, true) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min, true) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max, true) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator<, false) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=, false) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>, false) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=, false) /* * Eliminate verbose casting between fp32 and bf16 @@ -179,25 +150,23 @@ class BF16LowerRewriter : public StmtExprMutator { using StmtExprMutator::operator(); PrimExpr VisitExpr_(const CastNode* op) final { - auto op_val = StmtExprMutator::VisitExpr(op->value); - if (op->value->dtype.is_bfloat16()) { - // if is cast_from_bf16, check if is to fp32 - ICHECK(op->dtype.is_float() && op->dtype.bits() == 32); - auto uint32_dtype = DataType(kDLUInt, 32, op_val->dtype.lanes()); - auto uint32_v = Cast(uint32_dtype, op_val); - // to be endian invariant. - return Call(op->dtype, builtin::reinterpret(), {uint32_v << 16}); - } else if (op->dtype.is_bfloat16()) { - // if is cast_to_bf16, check if op->value is fp32 - ICHECK(op->value->dtype.is_float() && op->value->dtype.bits() == 32); - auto uint32_dtype = DataType(kDLUInt, 32, op_val->dtype.lanes()); - auto uint32_v = Call(uint32_dtype, builtin::reinterpret(), {op_val}); - auto uint16_dtype = DataType(kDLUInt, 16, op_val->dtype.lanes()); + PrimExpr op_val = StmtExprMutator::VisitExpr(op->value); + DataType uint32_dtype(kDLUInt, 32, op_val->dtype.lanes()); + DataType float32_dtype(kDLFloat, 32, op_val->dtype.lanes()); + if (op->value->dtype.is_bfloat16()) { // cast from bf16 + PrimExpr uint32_v = Cast(uint32_dtype, op_val); + PrimExpr float32_v = Call(float32_dtype, builtin::reinterpret(), {uint32_v << 16}); + bool is_to_float32 = op->dtype.is_float() && op->dtype.bits() == 32; + return is_to_float32 ? float32_v : Cast(op->dtype, float32_v); + } else if (op->dtype.is_bfloat16()) { // cast to bf16 + bool is_from_float32 = op->value->dtype.is_float() && op->value->dtype.bits() == 32; + PrimExpr float32_v = is_from_float32 ? op_val : Cast(float32_dtype, op_val); + PrimExpr uint32_v = Call(uint32_dtype, builtin::reinterpret(), {float32_v}); + DataType uint16_dtype(kDLUInt, 16, op_val->dtype.lanes()); /* the following TIR is equivalent to the C++ code below: uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); return static_cast((U32 + rounding_bias) >> 16);*/ - auto rounding_bias = ((uint32_v >> 16) & 1) + make_const(uint16_dtype, 0x7FFF); - // to be endian invariant. + PrimExpr rounding_bias = ((uint32_v >> 16) & 1) + make_const(uint16_dtype, 0x7FFF); return Cast(uint16_dtype, {(uint32_v + rounding_bias) >> 16}); } if (op->value.same_as(op_val)) return GetRef(op); diff --git a/tests/python/relay/test_cpp_build_module.py b/tests/python/relay/test_cpp_build_module.py index 23bc7ca95a34..ccf961fbe4db 100644 --- a/tests/python/relay/test_cpp_build_module.py +++ b/tests/python/relay/test_cpp_build_module.py @@ -93,6 +93,45 @@ def test_fp16_build(): np.testing.assert_allclose(out.numpy(), X.numpy() + Y.numpy(), atol=1e-5, rtol=1e-5) +@tvm.testing.requires_llvm +def test_bf16_build(): + data = relay.var("data", shape=(1, 3, 224, 224), dtype="float32") + weight = relay.var("weight", shape=(64, 3, 7, 7), dtype="float32") + bn_gamma = relay.var("gamma", shape=(64,), dtype="float32") + bn_beta = relay.var("beta", shape=(64,), dtype="float32") + bn_mean = relay.var("mean", shape=(64,), dtype="float32") + bn_var = relay.var("var", shape=(64,), dtype="float32") + params = { + "weight": np.random.uniform(-1, 1, size=(64, 3, 7, 7)).astype("float32"), + "gamma": np.random.uniform(-1, 1, size=(64,)).astype("float32"), + "beta": np.random.uniform(-1, 1, size=(64,)).astype("float32"), + "mean": np.random.uniform(-1, 1, size=(64,)).astype("float32"), + "var": np.random.uniform(-1, 1, size=(64,)).astype("float32"), + } + conv_bf16 = relay.nn.conv2d( + relay.cast(data, "bfloat16"), + relay.cast(weight, "bfloat16"), + strides=(2, 2), + padding=(3, 3, 3, 3), + channels=64, + kernel_size=(7, 7), + out_dtype="bfloat16", + ) + bn_bf16 = relay.nn.batch_norm( + conv_bf16, + relay.cast(bn_gamma, "bfloat16"), + relay.cast(bn_beta, "bfloat16"), + relay.cast(bn_mean, "bfloat16"), + relay.cast(bn_var, "bfloat16"), + ) + relu_bf16 = relay.nn.relu(bn_bf16[0]) + maxpool_bf16 = relay.nn.max_pool2d(relu_bf16, pool_size=(2, 2), strides=(2, 2)) + avgpool_bf16 = relay.nn.avg_pool2d(maxpool_bf16, pool_size=(2, 2), strides=(2, 2)) + mod_bf16 = tvm.IRModule.from_expr(avgpool_bf16) + with tvm.transform.PassContext(opt_level=3): + relay.build(mod_bf16, target="llvm", params=params) + + @tvm.testing.parametrize_targets("llvm", "cuda") def test_fp16_conversion(target, dev): if target == "cuda" and not have_fp16(dev.compute_version): @@ -126,3 +165,4 @@ def test_fp16_conversion(target, dev): test_basic_build() test_fp16_build() test_fp16_conversion() + test_bf16_build()