From 74b401d538e33e931a25b4b1f6b6e98677514db5 Mon Sep 17 00:00:00 2001 From: masahi Date: Thu, 28 Jan 2021 04:30:08 +0900 Subject: [PATCH] [Torch] Various updates for PyTorch frontend (#7348) * add conversion for detr * remove explicit broadcast_to before batched matmul * use take with wrap mode * add test for transformer and negative indices * add sort and argsort * add logical_and * support masked_select * add gpu targets to masked_select test * improve sort conversion --- python/tvm/relay/frontend/pytorch.py | 63 ++++++++--- tests/python/frontend/pytorch/test_forward.py | 101 +++++++++++++++++- 2 files changed, 150 insertions(+), 14 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 991e3a8a0032..68e68fdbeed2 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -399,10 +399,7 @@ def slice(self, inputs, input_types): begin = [0] * ndim dim = int(inputs[1]) stride = int(inputs[4]) - if isinstance(inputs[2], _expr.Call): - begin[dim], _ = try_infer_value(inputs[2], lambda ret: np.asscalar(ret.astype(np.int))) - else: - begin[dim] = int(inputs[2]) + begin[dim], _ = try_infer_value(inputs[2], lambda ret: np.asscalar(ret.astype(np.int))) # Process begin if not isinstance(begin[dim], int): @@ -518,13 +515,13 @@ def select(self, inputs, input_types): data = inputs[0] dim = int(inputs[1]) index = _wrap_const(inputs[2]) - return _op.transform.take(data, index, axis=dim) + return _op.transform.take(data, index, axis=dim, mode="wrap") def take(self, inputs, input_types): data = inputs[0] indices = _op.cast(inputs[1], "int32") - return _op.transform.take(data, indices=indices) + return _op.transform.take(data, indices=indices, mode="wrap") def topk(self, inputs, input_types): data = inputs[0] @@ -551,7 +548,13 @@ def reciprocal(self, inputs, input_types): def repeat(self, inputs, input_types): data = inputs[0] - reps = inputs[1] + reps = [] + for r in inputs[1]: + if isinstance(r, int): + reps.append(r) + else: + reps.append(int(_infer_value(r, {}).asnumpy())) + return _op.transform.tile(data, reps=reps) def repeat_interleave(self, inputs, input_types): @@ -1520,12 +1523,6 @@ def matmul(self, inputs, input_types): # Convert a and b into 3 dimensional tensors. a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]]) b = _op.reshape(inputs_1, [-1, b_shape[-2], b_shape[-1]]) - # Broadcast b to match batch size of a - new_b_shape = list(self.infer_shape_with_prelude(b)) - new_a_shape = self.infer_shape_with_prelude(a) - if new_a_shape[0] > new_b_shape[0]: - new_b_shape[0] = new_a_shape[0] - b = _op.broadcast_to(b, new_b_shape) # Transpose matrix dimensions of b. b = _op.transpose(b, [0, 2, 1]) # Perform a batch matmul. @@ -2070,6 +2067,40 @@ def scatter_add(self, inputs, input_types): src = inputs[3] return _op.scatter_add(data, index, src, axis=axis) + def cumsum(self, inputs, input_types): + data = inputs[0] + dim = inputs[1] + dtype = inputs[2] + + if inputs[2] is not None: + dtype = _convert_dtype_value(inputs[2]) + + return _op.cumsum(data, axis=dim, dtype=dtype) + + def masked_fill(self, inputs, input_types): + mask = inputs[1] + value = _op.cast(_wrap_const(inputs[2]), input_types[0]) + return _op.where(mask, value, inputs[0]) + + def masked_select(self, inputs, input_types): + mask = inputs[1] + indices = self.nonzero([mask], input_types, is_numpy_style=True) + return _op.adv_index([inputs[0]] + [indices[i] for i in range(indices.size)]) + + def sort(self, inputs, input_types): + data = inputs[0] + dim = inputs[1] + is_descending = inputs[2] + # pytorch sort returns both sorted indices and values + indices = _op.argsort(data, dim, not is_descending) + return _op.gather(data, dim, indices), indices + + def argsort(self, inputs, input_types): + data = inputs[0] + dim = inputs[1] + is_descending = inputs[2] + return _op.argsort(data, dim, not is_descending) + def is_floating_point(self, inputs, input_types): assert len(inputs) == 1 @@ -2263,6 +2294,7 @@ def create_convert_map(self): "torchvision::roi_align": self.roi_align, "aten::unbind": self.unbind, "aten::__and__": self.logical_and, + "aten::logical_and": self.logical_and, "aten::_shape_as_tensor": self.shape_as_tensor, "aten::nonzero": self.nonzero, "aten::nonzero_numpy": self.nonzero_numpy, @@ -2278,6 +2310,11 @@ def create_convert_map(self): "aten::__not__": self.logical_not, "aten::hardswish_": self.hard_swish, "aten::hardswish": self.hard_swish, + "aten::cumsum": self.cumsum, + "aten::masked_fill": self.masked_fill, + "aten::masked_select": self.masked_select, + "aten::argsort": self.argsort, + "aten::sort": self.sort, } def update_convert_map(self, custom_map): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 7cdd450448ca..6d9b559c6ba1 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1147,7 +1147,7 @@ def forward(self, *args): @tvm.testing.uses_gpu def test_forward_select(): torch.set_grad_enabled(False) - input_shape = [1, 3, 10, 10] + input_shape = [5, 3, 10, 10] class Select1(Module): def forward(self, *args): @@ -1167,6 +1167,9 @@ def forward(self, index): input_data = torch.rand(input_shape).float() verify_model(Select1().float().eval(), input_data=input_data) + # test negative indexing + verify_model(lambda x: x[-1], input_data=input_data) + x = torch.randn(3, 4) indices = torch.tensor([0, 2]) verify_model(IndexedSelect(x, 0).eval(), input_data=indices) @@ -2653,6 +2656,8 @@ def forward(self, *args): verify_model(Take1().float().eval(), input_data=input_data) indices = torch.tensor([[0, 0], [1, 0]]) verify_model(Take2().float().eval(), input_data=[input_data, indices]) + indices = torch.tensor([0, -1]) + verify_model(Take2().float().eval(), input_data=[input_data, indices]) @tvm.testing.uses_gpu @@ -3452,6 +3457,93 @@ def test_hard_swish(): verify_model(torch.nn.Hardswish(inplace=True).eval(), input_data=input) +def test_cumsum(): + def test_fn(dim, dtype=None): + return lambda x: torch.cumsum(x, dim=dim, dtype=dtype) + + inp = torch.randint(0, 100, (10000,), dtype=torch.int32) + verify_model(test_fn(0), [inp]) + verify_model(test_fn(0), [inp.to(torch.int64)]) + verify_model(test_fn(0, dtype=torch.int64), [inp.to(torch.int64)]) + + inp = torch.randn((100, 100), dtype=torch.float32) + verify_model(test_fn(dim=0, dtype=torch.float64), [inp]) + verify_model(test_fn(dim=1), [inp]) + + inp = torch.randn((100, 100), dtype=torch.float32) > 0.5 + verify_model(test_fn(dim=0, dtype=torch.int32), [inp]) + + +def test_masked_fill(): + def test_fn(x, mask): + return torch.masked_fill(x, mask, 0.0) + + inp = torch.randn(100, 100) + verify_model(test_fn, [inp, inp > 0.5]) + verify_model(test_fn, [inp.to(torch.float64), inp > 0.5]) + + +def test_transformer(): + model = torch.nn.Transformer(d_model=256, nhead=8, num_encoder_layers=6, num_decoder_layers=6) + model = model.eval() + src = torch.rand((10, 32, 256)) + tgt = torch.rand((20, 32, 256)) + verify_model(model.eval(), input_data=[src, tgt]) + + +def test_argsort(): + def test_fn(dim, descending): + return lambda x: torch.argsort(x, dim=dim, descending=descending) + + inp = torch.randn(100) + verify_model(test_fn(0, True), [inp]) + verify_model(test_fn(0, False), [inp]) + + inp = torch.randn(100, 100) + verify_model(test_fn(0, True), [inp]) + verify_model(test_fn(0, False), [inp]) + verify_model(test_fn(1, True), [inp]) + verify_model(test_fn(1, False), [inp]) + + +def test_sort(): + def test_fn(dim, descending): + return lambda x: torch.sort(x, dim=dim, descending=descending) + + inp = torch.randn(100) + verify_model(test_fn(0, True), [inp]) + verify_model(test_fn(0, False), [inp]) + + inp = torch.randn(100, 100) + verify_model(test_fn(0, True), [inp]) + verify_model(test_fn(0, False), [inp]) + verify_model(test_fn(1, True), [inp]) + verify_model(test_fn(1, False), [inp]) + + +def test_logical_and(): + def test_fn(x, y): + return torch.logical_and(x, y) + + a = torch.tensor([0, 1, 10, 0], dtype=torch.int8) + b = torch.tensor([4, 0, 1, 0], dtype=torch.int8) + verify_model(test_fn, [a, b]) + + a = torch.tensor([True, False, True]) + b = torch.tensor([True, False, False]) + verify_model(test_fn, [a, b]) + + +def test_masked_select(): + def test_fn(x, mask): + return torch.masked_select(x, mask) + + for shape in [(10,), (3, 4), (16, 32, 64)]: + x = torch.randn(*shape) + mask = x.ge(0.5) + verify_trace_model(test_fn, [x, mask], ["llvm", "cuda", "nvptx"]) + + if __name__ == "__main__": # some structural tests test_forward_traced_function() @@ -3580,6 +3672,13 @@ def test_hard_swish(): test_forward_scatter() test_numel() test_bincount() + test_cumsum() + test_masked_fill() + test_transformer() + test_sort() + test_argsort() + test_logical_and() + test_masked_select() # Model tests test_resnet18()