diff --git a/docs/frontend/tensorflow.rst b/docs/frontend/tensorflow.rst index 45db9e43b922..a158db9c5589 100644 --- a/docs/frontend/tensorflow.rst +++ b/docs/frontend/tensorflow.rst @@ -162,6 +162,7 @@ Supported Ops - Identity - IsFinite - IsInf +- IsNan - LeakyRelu - LeftShift - Less diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 18868cf8491c..38a811d1d558 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1118,12 +1118,45 @@ def _impl(inputs, input_types): return _op.tensor.sqrt(data) return _impl + +def _rsqrt(): + def _impl(inputs, input_types): + data = inputs[0] + return _op.tensor.rsqrt(data) + return _impl + + +def _ceil(): + def _impl(inputs, input_types): + data = inputs[0] + return _op.ceil(data) + return _impl + + +def _clamp(): + def _impl(inputs, input_types): + print(inputs, input_types) + data = inputs[0] + amin = inputs[1] if inputs[1] else np.finfo(np.float32).min + amax = inputs[2] if inputs[2] else np.finfo(np.float32).max + return _op.clip(data, amin, amax) + return _impl + + def _floor(): def _impl(inputs, input_types): data = inputs[0] return _op.floor(data) return _impl + +def _round(): + def _impl(inputs, input_types): + data = inputs[0] + return _op.round(data) + return _impl + + def _to(): def _impl(inputs, input_types): data = inputs[0] @@ -1232,6 +1265,18 @@ def _impl(inputs, input_types): return _impl +def _isfinite(): + def _impl(inputs, input_types): + return _op.isfinite(inputs[0]) + return _impl + + +def _isnan(): + def _impl(inputs, input_types): + return _op.isnan(inputs[0]) + return _impl + + def _list_getitem(prelude): def _impl(inputs, input_types): return prelude.nth(inputs[0], _wrap_const(inputs[1])) @@ -1429,7 +1474,11 @@ def _get_convert_map(prelude): "aten::std" : _std(), "aten::var" : _variance(), "aten::sqrt" : _sqrt(), - 'aten::floor' : _floor(), + "aten::rsqrt" : _rsqrt(), + "aten::ceil" : _ceil(), + "aten::clamp" : _clamp(), + "aten::floor" : _floor(), + "aten::round" : _round(), "aten::detach" : _identity(), "aten::upsample_bilinear2d" : _upsample("bilinear"), "aten::upsample_nearest2d" : _upsample("nearest_neighbor"), @@ -1439,6 +1488,9 @@ def _get_convert_map(prelude): "aten::le" : _elemwise("less_equal"), "aten::ge" : _elemwise("greater_equal"), "aten::ne" : _elemwise("not_equal"), + "aten::eq" : _elemwise("equal"), + "aten::isfinite" : _isfinite(), + "aten::isnan" : _isnan(), "aten::Bool" : _Bool(), "aten::Float" : _Float(), "aten::neg" : _neg(), diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index a607a4724584..79a623d34c4a 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -66,6 +66,7 @@ register_broadcast_schedule("less_equal") register_broadcast_schedule("greater") register_broadcast_schedule("greater_equal") +register_broadcast_schedule("isnan") register_broadcast_schedule("isfinite") register_broadcast_schedule("isinf") register_injective_schedule("maximum") diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index 1f481eefd475..f6024075d925 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -1010,6 +1010,22 @@ def ndarray_size(data, dtype="int32"): return _make.ndarray_size(data, dtype) +def isnan(data): + """Check nan in input data element-wise. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.isnan(data) + + def isfinite(data): """Compute element-wise finiteness of data. diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index 4cca8b0b07cd..10da11d8c7ac 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -426,6 +426,15 @@ ElemwiseArbitraryLayout) .set_support_level(10) .set_attr("FTVMCompute", NdarraySizeCompute); +RELAY_REGISTER_UNARY_OP("isnan") +.describe(R"code(Returns whether the input contains any NaN, computed element-wise. +.. math:: + isnan(x) +)code" TVM_ADD_FILELINE) +.set_support_level(3) +.add_type_rel("IdentityCompRel", IdentityCompRel) +.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isnan)); + RELAY_REGISTER_UNARY_OP("isfinite") .describe(R"code(Returns the finiteness of input, computed element-wise. .. math:: @@ -438,7 +447,7 @@ RELAY_REGISTER_UNARY_OP("isfinite") RELAY_REGISTER_UNARY_OP("isinf") .describe(R"code(Returns the infiniteness of input, computed element-wise. .. math:: - isfinite(x) + isinf(x) )code" TVM_ADD_FILELINE) .set_support_level(3) .add_type_rel("IdentityCompRel", IdentityCompRel) diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index 5d393ab8ebb2..51f65b416c7a 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -96,6 +96,14 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sigmoid") *rv = one / (one + exp(-call->args[0])); }); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.nan") +.set_body([](const TVMArgs& args, TVMRetValue* rv){ + PrimExpr e = args[0]; + const CallNode* call = e.as(); + CHECK(call != nullptr); + *rv = isnan(call->args[0]); + }); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.isfinite") .set_body([](const TVMArgs& args, TVMRetValue* rv){ PrimExpr e = args[0]; diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 91e14c697f35..d9d280f25a70 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1441,6 +1441,110 @@ def forward(self, *args): verify_model(Variance5().float().eval(), input_data=input_data) + +def test_forward_isfinite(): + torch.set_grad_enabled(False) + + class IsFinite1(Module): + def forward(self, *args): + return torch.isfinite(args[0]) + + input_data = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).float() + verify_model(IsFinite1().float().eval(), input_data=input_data) + + +def test_forward_isnan(): + torch.set_grad_enabled(False) + + class IsNan1(Module): + def forward(self, *args): + return torch.isnan(args[0]) + + input_data = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).float() + verify_model(IsNan1().float().eval(), input_data=input_data) + + +def test_forward_isinf(): + torch.set_grad_enabled(False) + + class IsInf1(Module): + def forward(self, *args): + return torch.isinf(args[0]) + + input_data = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).float() + verify_model(IsInf1().float().eval(), input_data=input_data) + + +def test_forward_rsqrt(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10] + + class Rsqrt1(Module): + def forward(self, *args): + return torch.rsqrt(args[0]) + + input_data = torch.rand(input_shape).float() + verify_model(Rsqrt1().float().eval(), input_data=input_data) + + +def test_forward_ceil(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10] + + class Ceil1(Module): + def forward(self, *args): + return torch.ceil(args[0]) + + input_data = torch.rand(input_shape).float() + verify_model(Ceil1().float().eval(), input_data=input_data) + + +def test_forward_clamp(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10] + + class Clamp1(Module): + def forward(self, *args): + return torch.clamp(args[0], min=-0.5, max=0.5) + + class Clamp2(Module): + def forward(self, *args): + return torch.clamp(args[0], min=-0.3) + + class Clamp3(Module): + def forward(self, *args): + return torch.clamp(args[0], max=1.0) + + input_data = torch.rand(input_shape).float() + verify_model(Clamp1().float().eval(), input_data=input_data) + verify_model(Clamp2().float().eval(), input_data=input_data) + verify_model(Clamp3().float().eval(), input_data=input_data) + + +def test_forward_floor(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10] + + class Floor1(Module): + def forward(self, *args): + return torch.floor(args[0]) + + input_data = torch.rand(input_shape).float() + verify_model(Floor1().float().eval(), input_data=input_data) + + +def test_forward_round(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10] + + class Round1(Module): + def forward(self, *args): + return torch.round(args[0]) + + input_data = torch.rand(input_shape).float() + verify_model(Round1().float().eval(), input_data=input_data) + + if __name__ == "__main__": # Single operator tests test_forward_add() @@ -1497,6 +1601,14 @@ def forward(self, *args): test_forward_expand() test_forward_pow() test_forward_abs() + test_forward_rsqrt() + test_forward_ceil() + test_forward_clamp() + test_forward_floor() + test_forward_round() + test_forward_isfinite() + test_forward_isnan() + test_forward_isinf() test_forward_arange() test_forward_chunk() test_forward_split()