Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-land: upsample_bilinear: fix output data-type. #7168

Merged
merged 2 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2118,6 +2118,59 @@ def fn(inp, s):
self.assertEqual(out, Xout.cpu())
self.assertEqual("f16", torch_xla._XLAC._get_xla_tensor_shape_type(Xout))

# We skip TPU for 2 reasons:
# 1. upsample_bilinear on f64 tensors doesn't work on TPUs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will f64 works prior to your change on TPU? I don't want this to become a regression.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My change doesn't affect TPU. The lowering for upsample_bilinear is different for CUDA and TPU. This change only affects non-TPU and non-Neuron devices.

# 2. This issue only affects non-TPU and non-Neuron devices (i.e. there's
# a short-circuit for both devices that don't go through the bug path)
@skipOnTpu
def test_upsample_bilinear_double(self):
# Originally, the upsample_bilinear implementation (in resize_ops.cpp)
# was copied from TF. The computation was done intentionally on F32 and
# not cast back[1]. However, that didn't reflect in the returned tensor.
# Basically, what would happen is:
#
# 1. A tensor of data-type other than F32 is created:
# > a = torch.rand(..., dtype=torch.double)
#
# 2. Call upsample_bilinear on it
# > r = torch.nn.functional.upsample_bilinear(a, scale_factor=2)
#
# 3. The result's data-type would show as torch.float64, but its inner
# HLO representation would be actually F32.
#
# - It would rarely surface as an error, since we do data-type
# promotion at the HLO level.
#
# - When this result is the argument of a new HLO function, XLA
# would actually expect a F16 tensor, since its torch.Tensor
# data-type "is" torch.float16. However, since the actual HLO
# data-type is F32, XLA raises an error.
#
# See more details at [2].
#
# [1]: https://github.com/tensorflow/tensorflow/commit/f8b35e00afe09c8606bcb0441a51be8bd38168d2
# [2]: https://github.com/pytorch/xla/issues/7095

def foo(x, is_xla=False):
# Compute upsample_bilinear.
r = torch.nn.functional.upsample_bilinear(x, scale_factor=2)

if is_xla:
# Mark the end of the HLO graph.
xm.mark_step()

# Start a new HLO graph using the upsample_bilinear result as
# one of its arguments.
return r + 5

inp = torch.rand(1, 3, 10, 10, dtype=torch.double)
Xinp = inp.to(xm.xla_device())

out = foo(inp)
Xout = foo(Xinp, is_xla=True)

self.assertEqual(out, Xout.cpu())


class MNISTComparator(nn.Module):

Expand Down
8 changes: 6 additions & 2 deletions torch_xla/csrc/resize_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ xla::XlaOp BuildResize(xla::XlaOp input, const xla::Shape& output_shape,
bool is_kernel_bilinear) {
// Code copied from
// https://github.com/tensorflow/tensorflow/blob/e51d6ab5730092775d516b18fa4ee85d49602cd8/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc#L477-L672
//
// Changes:
// - Remove F32 data-type conversion when is_kernel_bilinear
// See: https://github.com/pytorch/xla/issues/7095

// We implement bilinear interpolation and nearest neighbor with a Gather op.
// For each output pixel, we gather the necessary slices of the input.
Expand Down Expand Up @@ -53,7 +57,7 @@ xla::XlaOp BuildResize(xla::XlaOp input, const xla::Shape& output_shape,
<< "input and output must have the same element type";

xla::PrimitiveType original_input_type = input_type;
if (is_kernel_bilinear || xla::primitive_util::IsIntegralType(input_type)) {
if (xla::primitive_util::IsIntegralType(input_type)) {
input = xla::ConvertElementType(input, xla::F32);
input_type = xla::F32;
}
Expand Down Expand Up @@ -210,7 +214,7 @@ xla::XlaOp BuildResize(xla::XlaOp input, const xla::Shape& output_shape,
absl::InlinedVector<int64_t, 4> perm = {2, 0, 1, 3};
input = xla::Transpose(input, perm);

if (!is_kernel_bilinear && original_input_type != input_type) {
if (original_input_type != input_type) {
input = xla::ConvertElementType(input, original_input_type);
}
return input;
Expand Down
Loading