From 149f420b0f53ee8828b10bf4345d035cc21d88c8 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 29 May 2024 17:41:51 -0300 Subject: [PATCH 1/2] Re-land: `upsample_bilinear`: fix output data-type. --- test/test_operations.py | 48 +++++++++++++++++++++++++++++++++++ torch_xla/csrc/resize_ops.cpp | 8 ++++-- 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index 33059f80915..6b9e457f6be 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2118,6 +2118,54 @@ def fn(inp, s): self.assertEqual(out, Xout.cpu()) self.assertEqual("f16", torch_xla._XLAC._get_xla_tensor_shape_type(Xout)) + def test_upsample_bilinear_double(self): + # Originally, the upsample_bilinear implementation (in resize_ops.cpp) + # was copied from TF. The computation was done intentionally on F32 and + # not cast back[1]. However, that didn't reflect in the returned tensor. + # Basically, what would happen is: + # + # 1. A tensor of data-type other than F32 is created: + # > a = torch.rand(..., dtype=torch.double) + # + # 2. Call upsample_bilinear on it + # > r = torch.nn.functional.upsample_bilinear(a, scale_factor=2) + # + # 3. The result's data-type would show as torch.float64, but its inner + # HLO representation would be actually F32. + # + # - It would rarely surface as an error, since we do data-type + # promotion at the HLO level. + # + # - When this result is the argument of a new HLO function, XLA + # would actually expect a F16 tensor, since its torch.Tensor + # data-type "is" torch.float16. However, since the actual HLO + # data-type is F32, XLA raises an error. + # + # See more details at [2]. + # + # [1]: https://github.com/tensorflow/tensorflow/commit/f8b35e00afe09c8606bcb0441a51be8bd38168d2 + # [2]: https://github.com/pytorch/xla/issues/7095 + + def foo(x, is_xla=False): + # Compute upsample_bilinear. + r = torch.nn.functional.upsample_bilinear(x, scale_factor=2) + + if is_xla: + # Mark the end of the HLO graph. + xm.mark_step() + + # Start a new HLO graph using the upsample_bilinear result as + # one of its arguments. + return r + 5 + + inp = torch.rand(1, 3, 10, 10, dtype=torch.double) + Xinp = inp.to(xm.xla_device()) + + out = foo(inp) + Xout = foo(Xinp, is_xla=True) + + self.assertEqual(out, Xout.cpu()) + class MNISTComparator(nn.Module): diff --git a/torch_xla/csrc/resize_ops.cpp b/torch_xla/csrc/resize_ops.cpp index 90f1c6851ea..97fa335d9d6 100644 --- a/torch_xla/csrc/resize_ops.cpp +++ b/torch_xla/csrc/resize_ops.cpp @@ -21,6 +21,10 @@ xla::XlaOp BuildResize(xla::XlaOp input, const xla::Shape& output_shape, bool is_kernel_bilinear) { // Code copied from // https://github.com/tensorflow/tensorflow/blob/e51d6ab5730092775d516b18fa4ee85d49602cd8/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc#L477-L672 + // + // Changes: + // - Remove F32 data-type conversion when is_kernel_bilinear + // See: https://github.com/pytorch/xla/issues/7095 // We implement bilinear interpolation and nearest neighbor with a Gather op. // For each output pixel, we gather the necessary slices of the input. @@ -53,7 +57,7 @@ xla::XlaOp BuildResize(xla::XlaOp input, const xla::Shape& output_shape, << "input and output must have the same element type"; xla::PrimitiveType original_input_type = input_type; - if (is_kernel_bilinear || xla::primitive_util::IsIntegralType(input_type)) { + if (xla::primitive_util::IsIntegralType(input_type)) { input = xla::ConvertElementType(input, xla::F32); input_type = xla::F32; } @@ -210,7 +214,7 @@ xla::XlaOp BuildResize(xla::XlaOp input, const xla::Shape& output_shape, absl::InlinedVector perm = {2, 0, 1, 3}; input = xla::Transpose(input, perm); - if (!is_kernel_bilinear && original_input_type != input_type) { + if (original_input_type != input_type) { input = xla::ConvertElementType(input, original_input_type); } return input; From 215c753bd6143d652fe94d6429afacd026017b87 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Mon, 3 Jun 2024 10:52:41 -0300 Subject: [PATCH 2/2] Skip test on TPU. --- test/test_operations.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/test_operations.py b/test/test_operations.py index 6b9e457f6be..6fefc77165b 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2118,6 +2118,11 @@ def fn(inp, s): self.assertEqual(out, Xout.cpu()) self.assertEqual("f16", torch_xla._XLAC._get_xla_tensor_shape_type(Xout)) + # We skip TPU for 2 reasons: + # 1. upsample_bilinear on f64 tensors doesn't work on TPUs + # 2. This issue only affects non-TPU and non-Neuron devices (i.e. there's + # a short-circuit for both devices that don't go through the bug path) + @skipOnTpu def test_upsample_bilinear_double(self): # Originally, the upsample_bilinear implementation (in resize_ops.cpp) # was copied from TF. The computation was done intentionally on F32 and