Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

upsample_bilinear2d HLO returns unexpected data-type. #7095

Closed
ysiraichi opened this issue May 22, 2024 · 6 comments · Fixed by #7111
Closed

upsample_bilinear2d HLO returns unexpected data-type. #7095

ysiraichi opened this issue May 22, 2024 · 6 comments · Fixed by #7111
Labels

Comments

@ysiraichi
Copy link
Collaborator

ysiraichi commented May 22, 2024

🐛 Bug

At first, it seems to work, i.e. the returned data-type is torch.float16, as expected. However, when using it with another torch.float16 tensor, it breaks unexpectedly.

In the example below, foo stacks the result of an upsample_bilinear with another torch.float16 tensor. The function fails when using PyTorch/XLA because stack (lowered to concatenate) expects all inputs to be of the same data-type (note that this behavior is being fixed in #7091). However, as we can see from the error message, we are trying to call concatenate(f32[...], f16[...]). Meaning that the result of upsample_bilinear wasn't really f16.

In summary: upsample_bilinear2d returns a torch.float16 tensor, even though its HLO representation is f32. The expected data-type is f16

def foo(x, y):
    return torch.stack([torch.nn.functional.upsample_bilinear(x, scale_factor=2), y])

a = torch.rand(1, 3, 10, 10, dtype=torch.half)
b = torch.rand(1, 3, 20, 20, dtype=torch.half)

Xa = a.to(xm.xla_device())
Xb = b.to(xm.xla_device())

out = foo(a, b)
print(out.dtype)  # torch.float16

Xout = foo(Xa, Xb)
print(Xout.dtype)  # torch.float16

# Fails with the error below.
Xout.cpu()
Non-OK-status: status.status() status: INTERNAL: during context [Unknown]: Seen floating point types of different precisions in %concatenate.82 = f32[2,1,3,20,20]{4,3,2,1,0} concatenate(f32[1,1,3,20,20]{4,3,2,1,0} %reshape.80, f16[1,1,3,20,20]{4,3,2,1,0} %reshape.81), dimensions={0}, but mixed precision is disallowed.
*** Begin stack trace ***
        tsl::CurrentStackTrace[abi:cxx11]()
        std::unique_ptr<xla::PjRtLoadedExecutable, std::default_delete<xla::PjRtLoadedExecutable> > ConsumeValue<std::unique_ptr<xla::PjRtLoadedExecutable, std::default_delete<xla::PjRtLoadedExecutable> > >(absl::lts_20230802::StatusOr<std::unique_ptr<xla::PjRtLoadedExecutable, std::default_delete<xla::PjRtLoadedExecutable> > >&&)
        torch_xla::runtime::PjRtComputationClient::Compile(std::vector<torch_xla::runtime::ComputationClient::CompileInstance, std::allocator<torch_xla::runtime::ComputationClient::CompileInstance> >)
        torch_xla::XLAGraphExecutor::Compile(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >&, absl::lts_20230802::Span<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const>, torch::lazy::LazyGraphExecutor::SyncTensorCollection const&, torch::lazy::LazyGraphExecutor::PostOrderData*, std::vector<torch::lazy::Value, std::allocator<torch::lazy::Value> > const&)
        torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230802::Span<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const>, torch::lazy::LazyGraphExecutor::SyncTensorsConfig const&, bool)
        torch_xla::XLAGraphExecutor::SyncTensorsGraph(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230802::Span<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const>, bool, bool, bool)
        torch_xla::XLATensor::ApplyPendingGraph()
        torch_xla::XLATensor::GetXlaData()
        torch_xla::XLATensor::ToTensor(bool)
        torch_xla::XLANativeFunctions::_to_copy(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>)




        at::_ops::_to_copy::redispatch(c10::DispatchKeySet, at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>)



        at::_ops::_to_copy::call(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>)





        at::_ops::_to_copy::redispatch(c10::DispatchKeySet, at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>)





        at::_ops::_to_copy::call(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>)


        at::native::to(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, bool, std::optional<c10::MemoryFormat>)



        at::_ops::to_dtype_layout::call(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, bool, std::optional<c10::MemoryFormat>)
        at::Tensor::to(c10::TensorOptions, bool, bool, std::optional<c10::MemoryFormat>) const



        _PyEval_EvalFrameDefault

        PyEval_EvalCode



        _PyRun_SimpleFileObject
        _PyRun_AnyFileObject
        Py_RunMain
        Py_BytesMain
        __libc_start_main
        _start
*** End stack trace ***

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: CUDA
  • torch_xla version: 8d35eb0

Additional context

This seems to happen due to the fact that we are computing on F32 regardless of what the original input data type is.

if (is_kernel_bilinear || xla::primitive_util::IsIntegralType(input_type)) {
input = xla::ConvertElementType(input, xla::F32);
input_type = xla::F32;
}

cc @miladm @JackCaoG

@ysiraichi

This comment was marked as outdated.

@ysiraichi ysiraichi changed the title upsample_bilinear2d returns unexpected data-type. upsample_bilinear2d HLO returns unexpected data-type. May 22, 2024
@ysiraichi
Copy link
Collaborator Author

Note: the offending code was extracted from elsewhere:

// Code copied from
// https://github.com/tensorflow/tensorflow/blob/e51d6ab5730092775d516b18fa4ee85d49602cd8/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc#L477-L672

@ysiraichi
Copy link
Collaborator Author

@JackCaoG I would imagine, not only we should return a tensor casted to the input data-type, but we also should do the computation in f16, since we are not using AMP. Let me know what you think.

@JackCaoG
Copy link
Collaborator

I think this op is copied from tf, and from the comment in tensorflow/tensorflow@f8b35e0 TF output shape is always f32.. I guess we don't need to follow that rule.

@ysiraichi
Copy link
Collaborator Author

I'm thinking that the way to go here is to compute and return using the input data-type, and not f32.
@JackCaoG how does that sound?

@JackCaoG
Copy link
Collaborator

sgtm!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants