Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchbench] moco fails to run with CUDA OpenXLA fallback. #7647

Open
ysiraichi opened this issue Jul 9, 2024 · 8 comments
Open

[torchbench] moco fails to run with CUDA OpenXLA fallback. #7647

ysiraichi opened this issue Jul 9, 2024 · 8 comments
Labels

Comments

@ysiraichi
Copy link
Collaborator

🐛 Bug

Running the upstreamed benchmarking scripts with the following command results in an unexpected error. It does work when using CPU OpenXLA fallback, though.

python xla/benchmarks/experiment_runner.py \
       --suite-name torchbench \
       --accelerator cuda \
       --xla PJRT \
       --dynamo None \
       --test eval \
       --repeat 30 --iterations-per-run 5 \
       --print-subprocess \
       --no-resume --filter moco
[rank0]: Traceback (most recent call last):
[rank0]:   File "xla/benchmarks/experiment_runner.py", line 1011, in <module>
[rank0]:     main()
[rank0]:   File "xla/benchmarks/experiment_runner.py", line 1007, in main
[rank0]:     runner.run()
[rank0]:   File "xla/benchmarks/experiment_runner.py", line 73, in run
[rank0]:     self.run_single_config()
[rank0]:   File "xla/benchmarks/experiment_runner.py", line 278, in run_single_config
[rank0]:     metrics, last_output = self.run_once_and_gather_metrics(
[rank0]:   File "xla/benchmarks/experiment_runner.py", line 374, in run_once_and_gather_metrics
[rank0]:     output, _ = loop(iter_fn=self._default_iter_fn)
[rank0]:   File "xla/benchmarks/experiment_runner.py", line 331, in loop
[rank0]:     output, timing, trace = iter_fn(benchmark_experiment, benchmark_model,
[rank0]:   File "xla/benchmarks/experiment_runner.py", line 244, in _default_iter_fn
[rank0]:     self._mark_step(benchmark_experiment, output)
[rank0]:   File "xla/benchmarks/experiment_runner.py", line 456, in _mark_step
[rank0]:     xm.mark_step()
[rank0]:   File "xla/torch_xla/core/xla_model.py", line 1056, in mark_step
[rank0]:     torch_xla._XLAC._xla_step_marker(
[rank0]: RuntimeError: ./torch_xla/csrc/runtime/pjrt_computation_client.h:192 : Check failed: HasValue()
[rank0]: *** Begin stack trace ***
[rank0]:        tsl::CurrentStackTrace[abi:cxx11]()
[rank0]:        torch_xla::runtime::PjRtComputationClient::PjRtData::GetHandle()
[rank0]:        torch::lazy::LazyGraphExecutor::RunPostOrder(std::vector<torch::lazy::Value, std::allocator<torch::lazy::Value> > const&, torch::lazy::LazyGraphExecutor::SyncTensorCollection*)
[rank0]:        torch_xla::XLAGraphExecutor::RunPostOrder(std::vector<torch::lazy::Value, std::allocator<torch::lazy::Value> > const&, torch::lazy::LazyGraphExecutor::SyncTensorCollection*)
[rank0]:        torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230802::Span<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const>, torch::lazy::LazyGraphExecutor::SyncTensorsConfig const&, bool)
[rank0]:        torch_xla::XLAGraphExecutor::SyncTensorsGraph(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230802::Span<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const>, bool, bool, bool)
[rank0]:        torch_xla::XLAGraphExecutor::SyncLiveTensorsGraph(torch::lazy::BackendDevice const*, c10::ArrayRef<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, bool)
[rank0]:
[rank0]:
[rank0]:
[rank0]:
[rank0]:
[rank0]:
[rank0]:
[rank0]:
[rank0]:        _PyObject_MakeTpCall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:
[rank0]:
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:
[rank0]:        PyEval_EvalCode
[rank0]:
[rank0]:
[rank0]:
[rank0]:        _PyRun_SimpleFileObject
[rank0]:        _PyRun_AnyFileObject
[rank0]:        Py_RunMain
[rank0]:        Py_BytesMain
[rank0]:        __libc_start_main
[rank0]:        _start
[rank0]: *** End stack trace ***
[rank0]: buffer with shape s64[1] on device CUDA:0 is deleted

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: CUDA
  • torch_xla version: c782e0d

cc @miladm @JackCaoG

@JackCaoG
Copy link
Collaborator

JackCaoG commented Jul 9, 2024

seems like during mark_step we found a XLATensor with empty data handle

@ysiraichi
Copy link
Collaborator Author

As far as I have investigated, the only fallback we are running differently is aten::_local_scalar_dense.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Jul 9, 2024

_local_scalar_dense should be run on CPU I guess? This op usually happens when we move the tensor to CPU for print.

@ysiraichi
Copy link
Collaborator Author

Right. But I wonder whether this issue sheds light into a CUDA OpenXLA fallback implementation issue. In the sense that, even if we run that on CUDA, it should still work.

@ysiraichi
Copy link
Collaborator Author

This is odd. I tried replacing the DLPack conversion with tensor.to("cpu").to("cuda") and tensor.to("cpu").to("xla"), and still got the same error.

@ysiraichi
Copy link
Collaborator Author

Forcing CPU fallback on _local_scalar_dense did work, though.

@ysiraichi
Copy link
Collaborator Author

@JackCaoG
I have been debugging this for a while, now. And here's what I found out:

  • The PjRtData that was deleted is not the same as the fallback input holds. It was created in a later mark_step() call
  • PjRtData instantiation: it is first instantiated by a CreateDataPlaceholder call, inside ExtractIRAndPrepareXlaData_ function, (as far as I understand) when mark_step() is called.
  • PjRtStreamExecutorBuffer deletion: Delete calls Release after RunPostOrder finishes. That said, I believe that, at that point, the buffer is already deleted (i.e. PjRtStreamExecutorBuffer::IsDeleted() == true). The reason being that PjRtStreamExecutorBuffer::ConfirmDonation is called before.

Basically, this is the timeline I am seeing:

...
CreateDataPlaceholder(tensor: 0x55a254171e70)
XLAData (ptr: 0x55a254142e60):
  Data Device: CUDA:0
  Data Shape: s64[1]
  Data Handle: None
...
PjRtData::Assign: Handle changes from None to 0x7fecfc0710a0
  >> Old: XLAData (0x55a254142e60):
  Data Device: CUDA:0
  Data Shape: s64[1]
  Data Handle: None

  >> New: XLAData (0x7fecfc677340):
  Data Device: CUDA:0
  Data Shape: s64[1]
  Data Handle: 0x7fecfc0710a0
...
PjRtStreamExecutorBuffer::GetBufferWithHold(Usage): 0x7fecfc0710a0
...
PjRtStreamExecutorBuffer::GetBufferWithHold(Donation): 0x7fecfc0710a0
...
PjRtStreamExecutorBuffer::ConfirmDonation: 0x7fecfc0710a0
  >> Resets the buffer, i.e. deletes it!
...
Could NOT get handle (0x55a254142e60): XLAData:
  Data Device: CUDA:0
  Data Shape: s64[1]
  Data Handle: Deleted

PjRtStreamExecutorBuffer::Delete: 0x7fecfc0710a0
  >> Delete is called, but buffer is already deleted, i.e. `PjRtStreamExecutorBuffer::device_buffer_ == nullptr`
...

Do you see anything strange? Any ideas of where to look at?

@ysiraichi
Copy link
Collaborator Author

In an external discussion, we decided to work around this issue for now by forcing aten::_local_scalar_dense to be run on CPU. Since this isn't exactly fixed (i.e. it may actually be the symptom of a more complex hidden error), I won't close this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants