diff --git a/test/test_operations.py b/test/test_operations.py index 33059f80915..6fefc77165b 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2118,6 +2118,59 @@ def fn(inp, s): self.assertEqual(out, Xout.cpu()) self.assertEqual("f16", torch_xla._XLAC._get_xla_tensor_shape_type(Xout)) + # We skip TPU for 2 reasons: + # 1. upsample_bilinear on f64 tensors doesn't work on TPUs + # 2. This issue only affects non-TPU and non-Neuron devices (i.e. there's + # a short-circuit for both devices that don't go through the bug path) + @skipOnTpu + def test_upsample_bilinear_double(self): + # Originally, the upsample_bilinear implementation (in resize_ops.cpp) + # was copied from TF. The computation was done intentionally on F32 and + # not cast back[1]. However, that didn't reflect in the returned tensor. + # Basically, what would happen is: + # + # 1. A tensor of data-type other than F32 is created: + # > a = torch.rand(..., dtype=torch.double) + # + # 2. Call upsample_bilinear on it + # > r = torch.nn.functional.upsample_bilinear(a, scale_factor=2) + # + # 3. The result's data-type would show as torch.float64, but its inner + # HLO representation would be actually F32. + # + # - It would rarely surface as an error, since we do data-type + # promotion at the HLO level. + # + # - When this result is the argument of a new HLO function, XLA + # would actually expect a F16 tensor, since its torch.Tensor + # data-type "is" torch.float16. However, since the actual HLO + # data-type is F32, XLA raises an error. + # + # See more details at [2]. + # + # [1]: https://github.com/tensorflow/tensorflow/commit/f8b35e00afe09c8606bcb0441a51be8bd38168d2 + # [2]: https://github.com/pytorch/xla/issues/7095 + + def foo(x, is_xla=False): + # Compute upsample_bilinear. + r = torch.nn.functional.upsample_bilinear(x, scale_factor=2) + + if is_xla: + # Mark the end of the HLO graph. + xm.mark_step() + + # Start a new HLO graph using the upsample_bilinear result as + # one of its arguments. + return r + 5 + + inp = torch.rand(1, 3, 10, 10, dtype=torch.double) + Xinp = inp.to(xm.xla_device()) + + out = foo(inp) + Xout = foo(Xinp, is_xla=True) + + self.assertEqual(out, Xout.cpu()) + class MNISTComparator(nn.Module): diff --git a/torch_xla/csrc/resize_ops.cpp b/torch_xla/csrc/resize_ops.cpp index 90f1c6851ea..97fa335d9d6 100644 --- a/torch_xla/csrc/resize_ops.cpp +++ b/torch_xla/csrc/resize_ops.cpp @@ -21,6 +21,10 @@ xla::XlaOp BuildResize(xla::XlaOp input, const xla::Shape& output_shape, bool is_kernel_bilinear) { // Code copied from // https://github.com/tensorflow/tensorflow/blob/e51d6ab5730092775d516b18fa4ee85d49602cd8/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc#L477-L672 + // + // Changes: + // - Remove F32 data-type conversion when is_kernel_bilinear + // See: https://github.com/pytorch/xla/issues/7095 // We implement bilinear interpolation and nearest neighbor with a Gather op. // For each output pixel, we gather the necessary slices of the input. @@ -53,7 +57,7 @@ xla::XlaOp BuildResize(xla::XlaOp input, const xla::Shape& output_shape, << "input and output must have the same element type"; xla::PrimitiveType original_input_type = input_type; - if (is_kernel_bilinear || xla::primitive_util::IsIntegralType(input_type)) { + if (xla::primitive_util::IsIntegralType(input_type)) { input = xla::ConvertElementType(input, xla::F32); input_type = xla::F32; } @@ -210,7 +214,7 @@ xla::XlaOp BuildResize(xla::XlaOp input, const xla::Shape& output_shape, absl::InlinedVector perm = {2, 0, 1, 3}; input = xla::Transpose(input, perm); - if (!is_kernel_bilinear && original_input_type != input_type) { + if (original_input_type != input_type) { input = xla::ConvertElementType(input, original_input_type); } return input;