-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[torchbench] hf_Reformer
fails to run on dynamo+openxla
training.
#6009
Labels
Comments
hmm interesting |
This was referenced Mar 2, 2024
Closed
ysiraichi
added a commit
to pytorch/pytorch
that referenced
this issue
Mar 4, 2024
….new_empty` method." Fix: pytorch/xla#6009 This PR adds another case to `TensorVariable.method_new` special case, where it re-dispatches `new` into `new_empty`. Since we are using fake tensors, the `new` call doesn't actually gets to the corresponding backend (e.g. XLA). So, things like the following might happen: ```python torch.compile(backend="openxla") def foo(x): new_x = x.new(*x.size()) # new_x.device() == "xla" # x.device() == "xla:0" return new_x + x a = torch.arange(10) foo(a.to(xm.xla_device())) ``` Resulting in the following error: ```python Traceback (most recent call last): ... File "torch/_dynamo/utils.py", line 1654, in get_fake_value ret_val = wrap_fake_exception( File "torch/_dynamo/utils.py", line 1190, in wrap_fake_exception return fn() File "torch/_dynamo/utils.py", line 1655, in <lambda> lambda: run_node(tx.output, node, args, kwargs, nnmodule) File "torch/_dynamo/utils.py", line 1776, in run_node raise RuntimeError(make_error_message(e)).with_traceback( File "torch/_dynamo/utils.py", line 1758, in run_node return node.target(*args, **kwargs) File "torch/utils/_stats.py", line 20, in wrapper return fn(*args, **kwargs) File "torch/_subclasses/fake_tensor.py", line 885, in __torch_dispatch__ return self.dispatch(func, types, args, kwargs) File "torch/_subclasses/fake_tensor.py", line 1224, in dispatch return self._cached_dispatch_impl(func, types, args, kwargs) File "torch/_subclasses/fake_tensor.py", line 955, in _cached_dispatch_impl output = self._dispatch_impl(func, types, args, kwargs) File "torch/_subclasses/fake_tensor.py", line 1445, in _dispatch_impl return self.wrap_meta_outputs_with_default_device_logic( File "torch/_subclasses/fake_tensor.py", line 1575, in wrap_meta_outputs_with_default_device_logic return tree_map(wrap, r) File "torch/utils/_pytree.py", line 900, in tree_map return treespec.unflatten(map(func, *flat_args)) File "torch/utils/_pytree.py", line 736, in unflatten leaves = list(leaves) File "torch/_subclasses/fake_tensor.py", line 1550, in wrap ) = FakeTensor._find_common_device(func, flat_args) File "torch/_subclasses/fake_tensor.py", line 625, in _find_common_device merge_devices(arg) File "torch/_subclasses/fake_tensor.py", line 620, in merge_devices raise RuntimeError( torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function add>(*(FakeTensor(..., device='xla', size=(10,), dtype=torch.int64), FakeTensor(..., device='xla:0', size=(10,), dtype=torch.int64)), **{}): Unhandled FakeTensor Device Propagation for aten.add.Tensor, found two different devices xla, xla:0 ``` Using `new_empty`, instead, fixes this error because it uses the device from the source tensor, instead of inferring from the current dispatch key set. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng miladm JackCaoG alanwaketan lezcano [ghstack-poisoned]
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
🐛 Bug
Running the upstreamed benchmarking scripts with the following command results in an unexpected error.
Environment
The text was updated successfully, but these errors were encountered: