-
Notifications
You must be signed in to change notification settings - Fork 467
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
mul
: convert inputs to result type.
#7130
Conversation
hmm I am surprise that bf16 and f64's |
This is a bit confusing, so let me try and clarify the cases, given
|
7221e69
to
7e28074
Compare
The tl;dr is: Scalars don't promote unless they are of a different kind. |
7e28074
to
a1b67f2
Compare
@ysiraichi This PR is failing the mul operation on TPUs. It fails this check https://github.com/pytorch/xla/blob/master/torch_xla/csrc/aten_xla_type.cpp#L161
|
Apparently, you are not the only one: #7266 |
Let me take a look this afternoon |
I can't repo this issue but I do see that half of our internal TPU test crashed because of this. Let me revert this pr for now while figuring out what happened. |
Fix: #7084
This PR fixes a data-type related problem for
mul
operation. It does so by creating a structureOpConfig
that behaves similarly toDoBinaryOp
. The difference is that it takes care of pre/post-processing of inputs and outputs, casting them to the correct data-type.Problem
Tensor.mul_
is dispatched to itsCompositeExplicitAutograd
kerneltorch.mul
(functional version)DoBinaryOp
is calledat::result_type
(let's call itcommon_dtype
) and passes it on tobin_op
UnwrapNumber
does nothing, sinces
is a tensor withis_wrapped_number_
unsetcommon_dtype
is passed on totensor_methods::mul
common_dtype
BuildMul
is called, we have 2XlaOp
with different data-types:f16
andf64
BuildMul
promotesf16
tof64
common_dtype
(torch.float16
), but the actalXlaOp
isf64
Solution
Following PyTorch behavior [1, 2, 3], I created
OpConfig
: a structure that let us specify common pre/post-processing on inputs and outputs.Affected Models
timm_nfnet
(training+nondynamo)cc @miladm @JackCaoG @lezcano