Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mul: convert inputs to result type. #7130

Merged
merged 3 commits into from
Jun 3, 2024

Conversation

ysiraichi
Copy link
Collaborator

Fix: #7084

This PR fixes a data-type related problem for mul operation. It does so by creating a structure OpConfig that behaves similarly to DoBinaryOp. The difference is that it takes care of pre/post-processing of inputs and outputs, casting them to the correct data-type.

Problem

t = torch.rand(10, dtype=torch.half).to(xm.xla_device())
s = torch.tensor(10, dtype=torch.double).to(xm.xla_device())
out = torch.mul_(t, s)
  • Tensor.mul_ is dispatched to its CompositeExplicitAutograd kernel
    • It wraps the scalar into a tensor, and calls torch.mul (functional version)
  • DoBinaryOp is called
    • Computes at::result_type (let's call it common_dtype) and passes it on to bin_op
    • Note that UnwrapNumber does nothing, since s is a tensor with is_wrapped_number_ unset
  • The computed common_dtype is passed on to tensor_methods::mul
    • Creates an IR node with data-type common_dtype
    • Does nothing with its inputs
  • Later, when BuildMul is called, we have 2 XlaOp with different data-types: f16 and f64
    • BuildMul promotes f16 to f64
  • The output is common_dtype (torch.float16), but the actal XlaOp is f64

Solution

Following PyTorch behavior [1, 2, 3], I created OpConfig: a structure that let us specify common pre/post-processing on inputs and outputs.

Affected Models

  • timm_nfnet (training+nondynamo)

cc @miladm @JackCaoG @lezcano

@JackCaoG
Copy link
Collaborator

hmm I am surprise that bf16 and f64's at::result_type is f64..

@ysiraichi
Copy link
Collaborator Author

hmm I am surprise that bf16 and f64's at::result_type is f64.

This is a bit confusing, so let me try and clarify the cases, given op(bf16, f64):

bf16 f64 at::result_type PromoteType
tensor scalar bf16 f64
scalar tensor f64 f64
tensor tensor f64 f64

@ysiraichi ysiraichi force-pushed the ysiraichi/fix-mul-dtype-promotion branch from 7221e69 to 7e28074 Compare May 29, 2024 15:39
@lezcano
Copy link
Collaborator

lezcano commented May 29, 2024

The tl;dr is: Scalars don't promote unless they are of a different kind.
Here are the exact rules: https://pytorch.org/docs/stable/tensor_attributes.html

@ysiraichi ysiraichi force-pushed the ysiraichi/fix-mul-dtype-promotion branch from 7e28074 to a1b67f2 Compare May 29, 2024 20:36
@ysiraichi ysiraichi merged commit 7938bb5 into master Jun 3, 2024
19 checks passed
@bhavya01
Copy link
Collaborator

@ysiraichi This PR is failing the mul operation on TPUs.

It fails this check https://github.com/pytorch/xla/blob/master/torch_xla/csrc/aten_xla_type.cpp#L161

>>> import torch
>>> import torch_xla
>>> x = torch.tensor([1,2,3]).to('xla')
>>> y = torch.tensor([2,4,5]).to('xla')
>>> x*y
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: torch_xla/csrc/aten_xla_type.cpp:161 : Check failed: it != inputs_.end() 
*** Begin stack trace ***
	tsl::CurrentStackTrace[abi:cxx11]()
	torch_xla::XLANativeFunctions::mul(at::Tensor const&, at::Tensor const&)
	
	at::_ops::mul_Tensor::redispatch(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&)
	
	
	at::_ops::mul_Tensor::call(at::Tensor const&, at::Tensor const&)
	
	PyNumber_Multiply
	_PyEval_EvalFrameDefault
	
	PyEval_EvalCode
	
	_PyRun_InteractiveLoopObject
	
	PyRun_AnyFileExFlags
	
	Py_BytesMain
	
	__libc_start_main
	
*** End stack trace ***

@ysiraichi
Copy link
Collaborator Author

Apparently, you are not the only one: #7266

@vanbasten23
Copy link
Collaborator

This PR is also impacting DDP:
image

@JackCaoG
Copy link
Collaborator

Let me take a look this afternoon

@JackCaoG
Copy link
Collaborator

I can't repo this issue but I do see that half of our internal TPU test crashed because of this. Let me revert this pr for now while figuring out what happened.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[torchbench] timm_nfnet training failing on non-dynamo.
5 participants