Skip to content

Commit

Permalink
support Torch all and any op (#9185)
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi authored Oct 4, 2021
1 parent b9f2284 commit 2f02b1e
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
13 changes: 13 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# pylint: disable=missing-function-docstring
"""PT: PyTorch frontend."""
import itertools
import functools
import logging
import math
import sys
Expand Down Expand Up @@ -2763,6 +2764,16 @@ def lstm(self, inputs, input_types):

return (output, _op.stack(hy, 0), _op.stack(cy, 0))

def all_any_common(self, op, inputs, input_types):
dim = inputs[1]
keepdim = inputs[2]
if self.infer_type(inputs[0]).dtype != "bool":
# The input dtype can be uint8.
inp = _op.cast(inputs[0], "bool")
else:
inp = inputs[0]
return op(inp, axis=dim, keepdims=keepdim)

# Operator mappings
def create_convert_map(self):
self.convert_map = {
Expand Down Expand Up @@ -2986,6 +2997,8 @@ def create_convert_map(self):
"aten::flip": self.flip,
"aten::gru": self.gru,
"aten::lstm": self.lstm,
"aten::all": functools.partial(self.all_any_common, _op.all),
"aten::any": functools.partial(self.all_any_common, _op.any),
}

def update_convert_map(self, custom_map):
Expand Down
12 changes: 12 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3948,5 +3948,17 @@ def test_annotate_span():
relay.transform.AnnotateSpans()(mod)


@tvm.testing.uses_gpu
def test_all_any():
def test_fn(f, dim=None, keepdim=False):
return lambda x: f(x, dim=dim, keepdim=keepdim)

for f in [torch.all, torch.any]:
verify_model(test_fn(f, 0), [torch.rand(1, 2).bool()])
verify_model(test_fn(f, 0), [torch.arange(0, 3).to(torch.uint8)])
verify_model(test_fn(f, 1), [torch.rand(4, 2).bool()])
verify_model(test_fn(f, 0, keepdim=True), [torch.rand(4, 2).bool()])


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 2f02b1e

Please sign in to comment.