diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 54004c379d52..0213dcc488fd 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2328,6 +2328,21 @@ def broadcast_tensors(self, inputs, input_types): res_shape = list(torch.broadcast_tensors(*map(torch.empty, infer_shape_value))[0].shape) return [_op.broadcast_to(tensor, res_shape) for tensor in tensor_list] + def broadcast_to(self, inputs, input_types): + tensor = inputs[0] + new_shape = inputs[1] + import torch + + if not isinstance(new_shape, (list, tuple, torch.Size)): + msg = f"Data type {type(new_shape)} could not be parsed in broadcast_to op" + raise AssertionError(msg) + + for i, dim in enumerate(new_shape): + if not isinstance(dim, int): + new_shape[i] = int(_infer_value(dim, {}).numpy()) + + return _op.broadcast_to(tensor, new_shape) + def Bool(self, inputs, input_types): assert len(inputs) == 1 return inputs[0] @@ -4190,6 +4205,7 @@ def create_convert_map(self): "aten::upsample_nearest3d": self.make_upsample3d("nearest_neighbor"), "aten::expand_as": self.expand_as, "aten::broadcast_tensors": self.broadcast_tensors, + "aten::broadcast_to": self.broadcast_to, "aten::lt": self.make_elemwise("less"), "aten::gt": self.make_elemwise("greater"), "aten::le": self.make_elemwise("less_equal"), diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 56afe72ecd3e..6178a58b6d13 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -2162,6 +2162,31 @@ def forward(self, x, y, z): verify_model(BroadCastTensors2().float().eval(), input_data=[x, y, z]) +@tvm.testing.uses_gpu +def test_forward_broadcast_to(): + """test_forward_broadcast_to""" + torch.set_grad_enabled(False) + + class BroadCastTo1(Module): + def forward(self, x): + return torch.broadcast_to(x, (3, 3)) + + x = torch.tensor([1, 2, 3]) + verify_model(BroadCastTo1().float().eval(), input_data=[x]) + + class BroadCastTo2(Module): + def __init__(self): + super().__init__() + self.y = torch.tensor(1) + self.z = torch.tensor(2) + + def forward(self, x): + return torch.broadcast_to(x, (self.y + self.z, 3)) + + x = torch.tensor([1, 2, 3]) + verify_model(BroadCastTo2().float().eval(), input_data=[x]) + + @tvm.testing.uses_gpu def test_forward_pow(): """test_forward_pow"""