diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 56df39fdaa30..a0a837f92df9 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -20,6 +20,7 @@ # pylint: disable=missing-function-docstring """PT: PyTorch frontend.""" import itertools +import functools import logging import math import sys @@ -2763,6 +2764,16 @@ def lstm(self, inputs, input_types): return (output, _op.stack(hy, 0), _op.stack(cy, 0)) + def all_any_common(self, op, inputs, input_types): + dim = inputs[1] + keepdim = inputs[2] + if self.infer_type(inputs[0]).dtype != "bool": + # The input dtype can be uint8. + inp = _op.cast(inputs[0], "bool") + else: + inp = inputs[0] + return op(inp, axis=dim, keepdims=keepdim) + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -2986,6 +2997,8 @@ def create_convert_map(self): "aten::flip": self.flip, "aten::gru": self.gru, "aten::lstm": self.lstm, + "aten::all": functools.partial(self.all_any_common, _op.all), + "aten::any": functools.partial(self.all_any_common, _op.any), } def update_convert_map(self, custom_map): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 9238acd5f049..6f5eb1825dfc 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3948,5 +3948,17 @@ def test_annotate_span(): relay.transform.AnnotateSpans()(mod) +@tvm.testing.uses_gpu +def test_all_any(): + def test_fn(f, dim=None, keepdim=False): + return lambda x: f(x, dim=dim, keepdim=keepdim) + + for f in [torch.all, torch.any]: + verify_model(test_fn(f, 0), [torch.rand(1, 2).bool()]) + verify_model(test_fn(f, 0), [torch.arange(0, 3).to(torch.uint8)]) + verify_model(test_fn(f, 1), [torch.rand(4, 2).bool()]) + verify_model(test_fn(f, 0, keepdim=True), [torch.rand(4, 2).bool()]) + + if __name__ == "__main__": pytest.main([__file__])