Skip to content

Commit

Permalink
Re-land: upsample_bilinear: fix output data-type. (#7168)
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi authored Jun 4, 2024
1 parent 376d645 commit e563cfe
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 2 deletions.
53 changes: 53 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2118,6 +2118,59 @@ def fn(inp, s):
self.assertEqual(out, Xout.cpu())
self.assertEqual("f16", torch_xla._XLAC._get_xla_tensor_shape_type(Xout))

# We skip TPU for 2 reasons:
# 1. upsample_bilinear on f64 tensors doesn't work on TPUs
# 2. This issue only affects non-TPU and non-Neuron devices (i.e. there's
# a short-circuit for both devices that don't go through the bug path)
@skipOnTpu
def test_upsample_bilinear_double(self):
# Originally, the upsample_bilinear implementation (in resize_ops.cpp)
# was copied from TF. The computation was done intentionally on F32 and
# not cast back[1]. However, that didn't reflect in the returned tensor.
# Basically, what would happen is:
#
# 1. A tensor of data-type other than F32 is created:
# > a = torch.rand(..., dtype=torch.double)
#
# 2. Call upsample_bilinear on it
# > r = torch.nn.functional.upsample_bilinear(a, scale_factor=2)
#
# 3. The result's data-type would show as torch.float64, but its inner
# HLO representation would be actually F32.
#
# - It would rarely surface as an error, since we do data-type
# promotion at the HLO level.
#
# - When this result is the argument of a new HLO function, XLA
# would actually expect a F16 tensor, since its torch.Tensor
# data-type "is" torch.float16. However, since the actual HLO
# data-type is F32, XLA raises an error.
#
# See more details at [2].
#
# [1]: https://github.com/tensorflow/tensorflow/commit/f8b35e00afe09c8606bcb0441a51be8bd38168d2
# [2]: https://github.com/pytorch/xla/issues/7095

def foo(x, is_xla=False):
# Compute upsample_bilinear.
r = torch.nn.functional.upsample_bilinear(x, scale_factor=2)

if is_xla:
# Mark the end of the HLO graph.
xm.mark_step()

# Start a new HLO graph using the upsample_bilinear result as
# one of its arguments.
return r + 5

inp = torch.rand(1, 3, 10, 10, dtype=torch.double)
Xinp = inp.to(xm.xla_device())

out = foo(inp)
Xout = foo(Xinp, is_xla=True)

self.assertEqual(out, Xout.cpu())


class MNISTComparator(nn.Module):

Expand Down
8 changes: 6 additions & 2 deletions torch_xla/csrc/resize_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ xla::XlaOp BuildResize(xla::XlaOp input, const xla::Shape& output_shape,
bool is_kernel_bilinear) {
// Code copied from
// https://github.com/tensorflow/tensorflow/blob/e51d6ab5730092775d516b18fa4ee85d49602cd8/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc#L477-L672
//
// Changes:
// - Remove F32 data-type conversion when is_kernel_bilinear
// See: https://github.com/pytorch/xla/issues/7095

// We implement bilinear interpolation and nearest neighbor with a Gather op.
// For each output pixel, we gather the necessary slices of the input.
Expand Down Expand Up @@ -53,7 +57,7 @@ xla::XlaOp BuildResize(xla::XlaOp input, const xla::Shape& output_shape,
<< "input and output must have the same element type";

xla::PrimitiveType original_input_type = input_type;
if (is_kernel_bilinear || xla::primitive_util::IsIntegralType(input_type)) {
if (xla::primitive_util::IsIntegralType(input_type)) {
input = xla::ConvertElementType(input, xla::F32);
input_type = xla::F32;
}
Expand Down Expand Up @@ -210,7 +214,7 @@ xla::XlaOp BuildResize(xla::XlaOp input, const xla::Shape& output_shape,
absl::InlinedVector<int64_t, 4> perm = {2, 0, 1, 3};
input = xla::Transpose(input, perm);

if (!is_kernel_bilinear && original_input_type != input_type) {
if (original_input_type != input_type) {
input = xla::ConvertElementType(input, original_input_type);
}
return input;
Expand Down

0 comments on commit e563cfe

Please sign in to comment.