diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index ebc0132435ba..c75bd2dd3c09 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2067,7 +2067,7 @@ def is_floating_point(self, inputs, input_types): else: input_type = input_types[0] - is_float = input_type in ["float32", "float64", "float16"] + is_float = input_type in ["float32", "float64", "float16", "bfloat16"] return _expr.const(is_float) # Operator mappings