Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchbench] hf_Reformer fails to run on dynamo+openxla training. #6009

Closed
ysiraichi opened this issue Dec 4, 2023 · 1 comment
Closed
Labels

Comments

@ysiraichi
Copy link
Collaborator

🐛 Bug

Running the upstreamed benchmarking scripts with the following command results in an unexpected error.

python xla/benchmarks/experiment_runner.py \
       --suite-name torchbench \
       --accelerator cuda \
       --xla PJRT \
       --dynamo openxla \
       --test train \
       --repeat 30 --iterations-per-run 5 \
       --print-subprocess \
       --no-resume -k hf_Reformer
Traceback (most recent call last):
  File "xla/benchmarks/experiment_runner.py", line 601, in <module>
    main()
  File "xla/benchmarks/experiment_runner.py", line 597, in main
    runner.run()
  File "xla/benchmarks/experiment_runner.py", line 65, in run
    self.run_single_experiment(experiment_config, model_config)
  File "xla/benchmarks/experiment_runner.py", line 161, in run_single_experiment
    run_metrics, output = self.timed_run(benchmark_experiment,
  File "xla/benchmarks/experiment_runner.py", line 328, in timed_run
    output = loop()
  File "xla/benchmarks/experiment_runner.py", line 310, in loop
    output = benchmark_model.model_iter_fn(
  File "torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "xla/benchmarks/torchbench_model.py", line 247, in train
    super().train(inputs, collect_full_output=collect_full_output)
  File "xla/benchmarks/benchmark_model.py", line 142, in train
    self._optimizer_zero_grad()
  File "xla/benchmarks/benchmark_model.py", line 143, in resume_in_train
    pred = self.module(*inputs)
  File "torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/lib/python3.8/site-packages/transformers/models/reformer/modeling_reformer.py", line 2404, in forward
    reformer_outputs = self.reformer(
  File "torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/lib/python3.8/site-packages/transformers/models/reformer/modeling_reformer.py", line 2004, in forward
    @add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
  File "/lib/python3.8/site-packages/transformers/models/reformer/modeling_reformer.py", line 2061, in resume_in_forward
    input_shape[-1] % least_common_mult_chunk_length != 0
  File "/lib/python3.8/site-packages/transformers/models/reformer/modeling_reformer.py", line 2101, in resume_in_forward
    encoder_outputs = self.encoder(
  File "torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/lib/python3.8/site-packages/transformers/models/reformer/modeling_reformer.py", line 1727, in forward
    hidden_states = _ReversibleFunction.apply(
  File "torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/lib/python3.8/site-packages/transformers/models/reformer/modeling_reformer.py", line 1615, in forward
    layer_outputs = layer(
  File "torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/lib/python3.8/site-packages/transformers/models/reformer/modeling_reformer.py", line 1478, in forward
    self._init_attention_seed()
  File "torch/_dynamo/eval_frame.py", line 655, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "torch/_dynamo/convert_frame.py", line 721, in _convert_frame
    result = inner_convert(frame, cache_entry, hooks, frame_state)
  File "torch/_dynamo/convert_frame.py", line 383, in _convert_frame_assert
    compiled_product = _compile(
  File "torch/_dynamo/convert_frame.py", line 645, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
  File "torch/_dynamo/convert_frame.py", line 562, in compile_inner
    out_code = transform_code_object(code, transform)
  File "torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
    transformations(instructions, code_options)
  File "torch/_dynamo/convert_frame.py", line 151, in _fn
    return fn(*args, **kwargs)
  File "torch/_dynamo/convert_frame.py", line 527, in transform
    tracer.run()
  File "torch/_dynamo/symbolic_convert.py", line 2123, in run
    super().run()
  File "torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "torch/_dynamo/symbolic_convert.py", line 470, in wrapper
    return inner_fn(self, inst)
  File "torch/_dynamo/symbolic_convert.py", line 1264, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "torch/_dynamo/symbolic_convert.py", line 652, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "torch/_dynamo/variables/nn_module.py", line 328, in call_function
    return tx.inline_user_function_return(
  File "torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "torch/_dynamo/symbolic_convert.py", line 2256, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "torch/_dynamo/symbolic_convert.py", line 2371, in inline_call_
    tracer.run()
  File "torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "torch/_dynamo/symbolic_convert.py", line 470, in wrapper
    return inner_fn(self, inst)
  File "torch/_dynamo/symbolic_convert.py", line 1252, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars.items)
  File "torch/_dynamo/symbolic_convert.py", line 652, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "torch/_dynamo/variables/functions.py", line 294, in call_function
    return super().call_function(tx, args, kwargs)
  File "torch/_dynamo/variables/functions.py", line 248, in call_function
    return super().call_function(tx, args, kwargs)
  File "torch/_dynamo/variables/functions.py", line 81, in call_function
    return tx.inline_user_function_return(
  File "torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "torch/_dynamo/symbolic_convert.py", line 2256, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "torch/_dynamo/symbolic_convert.py", line 2371, in inline_call_
    tracer.run()
  File "torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "torch/_dynamo/symbolic_convert.py", line 470, in wrapper
    return inner_fn(self, inst)
  File "torch/_dynamo/symbolic_convert.py", line 1264, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "torch/_dynamo/symbolic_convert.py", line 652, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "torch/_dynamo/variables/nn_module.py", line 328, in call_function
    return tx.inline_user_function_return(
  File "torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "torch/_dynamo/symbolic_convert.py", line 2256, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "torch/_dynamo/symbolic_convert.py", line 2371, in inline_call_
    tracer.run()
  File "torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "torch/_dynamo/symbolic_convert.py", line 470, in wrapper
    return inner_fn(self, inst)
  File "torch/_dynamo/symbolic_convert.py", line 1252, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars.items)
  File "torch/_dynamo/symbolic_convert.py", line 652, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "torch/_dynamo/variables/functions.py", line 294, in call_function
    return super().call_function(tx, args, kwargs)
  File "torch/_dynamo/variables/functions.py", line 248, in call_function
    return super().call_function(tx, args, kwargs)
  File "torch/_dynamo/variables/functions.py", line 81, in call_function
    return tx.inline_user_function_return(
  File "torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "torch/_dynamo/symbolic_convert.py", line 2256, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "torch/_dynamo/symbolic_convert.py", line 2371, in inline_call_
    tracer.run()
  File "torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "torch/_dynamo/symbolic_convert.py", line 470, in wrapper
    return inner_fn(self, inst)
  File "torch/_dynamo/symbolic_convert.py", line 1213, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "torch/_dynamo/symbolic_convert.py", line 652, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "torch/_dynamo/variables/functions.py", line 294, in call_function
    return super().call_function(tx, args, kwargs)
  File "torch/_dynamo/variables/functions.py", line 248, in call_function
    return super().call_function(tx, args, kwargs)
  File "torch/_dynamo/variables/functions.py", line 81, in call_function
    return tx.inline_user_function_return(
  File "torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "torch/_dynamo/symbolic_convert.py", line 2256, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "torch/_dynamo/symbolic_convert.py", line 2371, in inline_call_
    tracer.run()
  File "torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "torch/_dynamo/symbolic_convert.py", line 470, in wrapper
    return inner_fn(self, inst)
  File "torch/_dynamo/symbolic_convert.py", line 1213, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "torch/_dynamo/symbolic_convert.py", line 652, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "torch/_dynamo/variables/misc.py", line 643, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
  File "torch/_dynamo/variables/tensor.py", line 748, in call_method
    return wrap_fx_proxy(
  File "torch/_dynamo/variables/builder.py", line 1283, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
  File "torch/_dynamo/variables/builder.py", line 1368, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
  File "torch/_dynamo/utils.py", line 1524, in get_fake_value
    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
  File "torch/_dynamo/utils.py", line 1485, in get_fake_value
    ret_val = wrap_fake_exception(
  File "torch/_dynamo/utils.py", line 1026, in wrap_fake_exception
    return fn()
  File "torch/_dynamo/utils.py", line 1486, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "torch/_dynamo/utils.py", line 1591, in run_node
    raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e
  File "torch/_dynamo/utils.py", line 1572, in run_node
    return getattr(args[0], node.target)(*args[1:], **kwargs)
  File "torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "torch/_subclasses/fake_tensor.py", line 1392, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "torch/_subclasses/fake_tensor.py", line 1707, in dispatch
    return self.wrap_meta_outputs_with_default_device_logic(
  File "torch/_subclasses/fake_tensor.py", line 1810, in wrap_meta_outputs_with_default_device_logic
    return tree_map(wrap, r)
  File "torch/utils/_pytree.py", line 531, in tree_map
    return tree_unflatten([func(i) for i in flat_args], spec)
  File "torch/utils/_pytree.py", line 531, in <listcomp>
    return tree_unflatten([func(i) for i in flat_args], spec)
  File "torch/_subclasses/fake_tensor.py", line 1785, in wrap
    ) = FakeTensor._find_common_device(func, flat_args)
  File "torch/_subclasses/fake_tensor.py", line 1271, in _find_common_device
    merge_devices(arg)
  File "torch/_subclasses/fake_tensor.py", line 1266, in merge_devices
    raise RuntimeError(
torch._dynamo.exc.TorchRuntimeError: Failed running call_method scatter_(*(FakeTensor(..., device='xla', size=(8, 12, 4096), dtype=torch.int64), -1, FakeTensor(..., device='xla:0', size=(8, 12, 4096), dtype=torch.int64), FakeTensor(..., device='xla:0', size=(8, 12, 4096), dtype=torch.int64)), **{}):
Unhandled FakeTensor Device Propagation for aten.scatter_.src, found two different devices xla, xla:0

from user code:
   File "/lib/python3.8/site-packages/transformers/models/reformer/modeling_reformer.py", line 1480, in resume_in_forward
    attn_outputs = self.attention(
  File "torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/lib/python3.8/site-packages/transformers/models/reformer/modeling_reformer.py", line 1313, in forward
    self_attention_outputs = self.self_attention(
  File "torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/lib/python3.8/site-packages/transformers/models/reformer/modeling_reformer.py", line 482, in forward
    sorted_bucket_idx, undo_sorted_bucket_idx = self._get_sorted_bucket_idx_and_undo_sorted_bucket_idx(
  File "/lib/python3.8/site-packages/transformers/models/reformer/modeling_reformer.py", line 702, in _get_sorted_bucket_idx_and_undo_sorted_bucket_idx
    undo_sorted_bucket_idx.scatter_(-1, sorted_bucket_idx, indices)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Environment

@JackCaoG
Copy link
Collaborator

JackCaoG commented Dec 4, 2023

hmm interesting found two different devices xla, xla:0. device should never just be xla, there must be a bug somewhere.

ysiraichi added a commit to pytorch/pytorch that referenced this issue Mar 4, 2024
….new_empty` method."


Fix: pytorch/xla#6009

This PR adds another case to `TensorVariable.method_new` special case, where it
re-dispatches `new` into `new_empty`.

Since we are using fake tensors, the `new` call doesn't actually gets to the corresponding
backend (e.g. XLA). So, things like the following might happen:

```python
torch.compile(backend="openxla")
def foo(x):
    new_x = x.new(*x.size())

    # new_x.device() == "xla"
    # x.device() == "xla:0"

    return new_x + x

a = torch.arange(10)
foo(a.to(xm.xla_device()))
```

Resulting in the following error:

```python
Traceback (most recent call last):
  ...
  File "torch/_dynamo/utils.py", line 1654, in get_fake_value
    ret_val = wrap_fake_exception(
  File "torch/_dynamo/utils.py", line 1190, in wrap_fake_exception
    return fn()
  File "torch/_dynamo/utils.py", line 1655, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "torch/_dynamo/utils.py", line 1776, in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
  File "torch/_dynamo/utils.py", line 1758, in run_node
    return node.target(*args, **kwargs)
  File "torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "torch/_subclasses/fake_tensor.py", line 885, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "torch/_subclasses/fake_tensor.py", line 1224, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
  File "torch/_subclasses/fake_tensor.py", line 955, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "torch/_subclasses/fake_tensor.py", line 1445, in _dispatch_impl
    return self.wrap_meta_outputs_with_default_device_logic(
  File "torch/_subclasses/fake_tensor.py", line 1575, in wrap_meta_outputs_with_default_device_logic
    return tree_map(wrap, r)
  File "torch/utils/_pytree.py", line 900, in tree_map
    return treespec.unflatten(map(func, *flat_args))
  File "torch/utils/_pytree.py", line 736, in unflatten
    leaves = list(leaves)
  File "torch/_subclasses/fake_tensor.py", line 1550, in wrap
    ) = FakeTensor._find_common_device(func, flat_args)
  File "torch/_subclasses/fake_tensor.py", line 625, in _find_common_device
    merge_devices(arg)
  File "torch/_subclasses/fake_tensor.py", line 620, in merge_devices
    raise RuntimeError(
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function add>(*(FakeTensor(..., device='xla', size=(10,), dtype=torch.int64), FakeTensor(..., device='xla:0', size=(10,), dtype=torch.int64)), **{}):
Unhandled FakeTensor Device Propagation for aten.add.Tensor, found two different devices xla, xla:0
```

Using `new_empty`, instead, fixes this error because it uses the device from the source
tensor, instead of inferring from the current dispatch key set.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng miladm JackCaoG alanwaketan lezcano 

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants