-
Notifications
You must be signed in to change notification settings - Fork 467
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
upsample_bilinear2d
HLO returns unexpected data-type.
#7095
Comments
This comment was marked as outdated.
This comment was marked as outdated.
upsample_bilinear2d
returns unexpected data-type.upsample_bilinear2d
HLO returns unexpected data-type.
Note: the offending code was extracted from elsewhere: xla/torch_xla/csrc/resize_ops.cpp Lines 22 to 23 in f336317
|
@JackCaoG I would imagine, not only we should return a tensor casted to the input data-type, but we also should do the computation in |
I think this op is copied from tf, and from the comment in tensorflow/tensorflow@f8b35e0 TF output shape is always f32.. I guess we don't need to follow that rule. |
I'm thinking that the way to go here is to compute and return using the input data-type, and not |
sgtm! |
🐛 Bug
At first, it seems to work, i.e. the returned data-type is
torch.float16
, as expected. However, when using it with anothertorch.float16
tensor, it breaks unexpectedly.In the example below,
foo
stacks the result of anupsample_bilinear
with anothertorch.float16
tensor. The function fails when using PyTorch/XLA becausestack
(lowered toconcatenate
) expects all inputs to be of the same data-type (note that this behavior is being fixed in #7091). However, as we can see from the error message, we are trying to callconcatenate(f32[...], f16[...])
. Meaning that the result ofupsample_bilinear
wasn't reallyf16
.In summary:
upsample_bilinear2d
returns atorch.float16
tensor, even though its HLO representation isf32
. The expected data-type isf16
Environment
Additional context
This seems to happen due to the fact that we are computing on
F32
regardless of what the original input data type is.xla/torch_xla/csrc/resize_ops.cpp
Lines 56 to 59 in f336317
cc @miladm @JackCaoG
The text was updated successfully, but these errors were encountered: