diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 40b95d1d1f71..3cf0e080da67 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1808,7 +1808,7 @@ def get_upsample_out_size(self, inputs, method): else: out_size.append(size) else: - scale_index = 3 if method == "linear" else 2 + scale_index = 3 if method != "nearest_neighbor" else 2 scales = inputs[scale_index] assert scales is not None, "neither out size nor scale provided" assert isinstance(scales, list) @@ -1823,7 +1823,7 @@ def upsample(inputs, input_types): data = inputs[0] out_size = self.get_upsample_out_size(inputs, method) - if len(inputs) > 2 and method == "linear": + if len(inputs) > 2 and method != "nearest_neighbor": align_corners = inputs[2] else: align_corners = False @@ -1836,7 +1836,9 @@ def upsample(inputs, input_types): coord_trans = "half_pixel" def func(x): - return _op.image.resize2d(x, out_size, "NCHW", method, coord_trans) + return _op.image.resize2d( + x, out_size, "NCHW", method, coord_trans, cubic_alpha=-0.75 + ) if self.is_quantized_tensor(data): # input qparams are manually appended by us @@ -2212,7 +2214,7 @@ def interpolate(self, inputs, input_types): else: coord_trans = "half_pixel" - return _op.image.resize2d(data, out_size, "NCHW", method, coord_trans) + return _op.image.resize2d(data, out_size, "NCHW", method, coord_trans, cubic_alpha=-0.75) def numel(self, inputs, input_types): return _op.ndarray_size(inputs[0]) @@ -2780,6 +2782,7 @@ def create_convert_map(self): "aten::clamp_": self.clamp, "aten::detach": self.identity, "aten::upsample_bilinear2d": self.make_upsample("linear"), + "aten::upsample_bicubic2d": self.make_upsample("cubic"), "aten::upsample_nearest2d": self.make_upsample("nearest_neighbor"), "aten::upsample_trilinear3d": self.make_upsample3d("linear"), "aten::upsample_nearest3d": self.make_upsample3d("nearest_neighbor"), diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index e58575266414..2e6828f693b6 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1761,6 +1761,9 @@ def forward(self, x): verify_model(Upsample(size=(64, 64), mode="bilinear", align_corners=True), inp) verify_model(Upsample(scale=2, mode="bilinear", align_corners=True), inp) verify_model(Upsample(size=(50, 50), mode="bilinear", align_corners=True), inp) + verify_model(Upsample(size=(64, 64), mode="bicubic", align_corners=True), inp) + verify_model(Upsample(scale=2, mode="bicubic", align_corners=True), inp) + verify_model(Upsample(size=(50, 50), mode="bicubic", align_corners=True), inp) @tvm.testing.uses_gpu