Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dist] log node's hostname plus rank in the exception message #63174

Closed
stas00 opened this issue Aug 12, 2021 · 7 comments
Closed

[dist] log node's hostname plus rank in the exception message #63174

stas00 opened this issue Aug 12, 2021 · 7 comments
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: c10d Issues/PRs related to collective communications and process groups oncall: distributed Add this issue/PR to distributed oncall triage queue

Comments

@stas00
Copy link
Contributor

stas00 commented Aug 12, 2021

🚀 Feature

Would it be possible to log (1) the hostname and (2) the rank with the exceptions?

Motivation

Currently it's very difficult to diagnose which node is faulty and remove it from the slurm pool.

For example a multi-node training has crashed with:

    torch.autograd.backward(tensors=(outputs, ), grad_tensors=(grad_tensors, ))
  File "/gpfswork/rech/six/commun/conda/tr1-13B/lib/python3.8/site-packages/torch/autograd/__init__.py", line 145, in backward
    Variable._execution_engine.run_backward(
RuntimeError: transform: failed to synchronize: cudaErrorECCUncorrectable: uncorrectable ECC error encountered
terminate called after throwing an instance of 'c10::Error'
  what():  CUDA error: uncorrectable ECC error encountered
Exception raised from create_event_internal at /opt/conda/conda-bld/pytorch_1616554793803/work/c10/cuda/CUDACachingAllocator.cpp:733 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x1500fb4d42f2 in /gpfswork/rech/six/commun/conda/tr1-13B/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x5b (0x1500fb4d167b in /gpfswork/rech/six/commun/conda/tr1-13B/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #2: c10::cuda::CUDACachingAllocator::raw_delete(void*) + 0x809 (0x1500fb72d219 in /gpfswork/rech/six/commun/conda/tr1-13B/lib/python3.8/site-packages/torch/lib/libc10_cuda.so)
frame #3: c10::TensorImpl::release_resources() + 0x54 (0x1500fb4bc3a4 in /gpfswork/rech/six/commun/conda/tr1-13B/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #4: <unknown function> + 0x6e0e5a (0x150152432e5a in /gpfswork/rech/six/commun/conda/tr1-13B/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #5: <unknown function> + 0x6e0ef1 (0x150152432ef1 in /gpfswork/rech/six/commun/conda/tr1-13B/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #6: <unknown function> + 0x1a6b5a (0x56434fce9b5a in /gpfswork/rech/six/commun/conda/tr1-13B/bin/python)
frame #7: <unknown function> + 0x110b7c (0x56434fc53b7c in /gpfswork/rech/six/commun/conda/tr1-13B/bin/python)
frame #8: <unknown function> + 0x1105b9 (0x56434fc535b9 in /gpfswork/rech/six/commun/conda/tr1-13B/bin/python)
frame #9: <unknown function> + 0x1105a3 (0x56434fc535a3 in /gpfswork/rech/six/commun/conda/tr1-13B/bin/python)
frame #10: <unknown function> + 0x1105a3 (0x56434fc535a3 in /gpfswork/rech/six/commun/conda/tr1-13B/bin/python)
frame #11: <unknown function> + 0x177917 (0x56434fcba917 in /gpfswork/rech/six/commun/conda/tr1-13B/bin/python)
frame #12: PyDict_SetItemString + 0x4c (0x56434fcbd86c in /gpfswork/rech/six/commun/conda/tr1-13B/bin/python)
frame #13: PyImport_Cleanup + 0xac (0x56434fd2f0ec in /gpfswork/rech/six/commun/conda/tr1-13B/bin/python)
frame #14: Py_FinalizeEx + 0x79 (0x56434fd95589 in /gpfswork/rech/six/commun/conda/tr1-13B/bin/python)
frame #15: Py_RunMain + 0x1bc (0x56434fd988fc in /gpfswork/rech/six/commun/conda/tr1-13B/bin/python)
frame #16: Py_BytesMain + 0x39 (0x56434fd98ce9 in /gpfswork/rech/six/commun/conda/tr1-13B/bin/python)
frame #17: __libc_start_main + 0xf3 (0x150183467873 in /lib64/libc.so.6)
frame #18: <unknown function> + 0x1f7847 (0x56434fd3a847 in /gpfswork/rech/six/commun/conda/tr1-13B/bin/python)

After restarting the training it failed again in a different code path, this time with:

    torch.distributed.barrier()
  File "/gpfswork/rech/six/commun/conda/tr1-13B/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 2420, in barrier
    work = default_pg.barrier(opts=opts)
RuntimeError: CUDA error: out of memory

I suspect that that one gpu on one of the nodes went bunkers on the hardware level, which crashed the training. And of course since the node hasn't been rebooted, it was still unusable. So the next training most likely hit the same node (this is slurm env) and of course it was still broken, hence the breakage again, just happened at a different code path.

In this circumstance it'd have been good to know if the exception happened on the same hostname + rank, as it'd then help us to exclude that node from future trainings and not hit it again., or request its reboot

Otherwise, we are very likely to hit that node again and again.

Thank you!

p.s. I wonder if some of this info I'm asking for would have showed up in the 1.9's elastic version of the launcher.

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @agolynski @SciPioneer @H-Huang @mrzzd @cbalioglu @gcramer23

@bdhirsh bdhirsh added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Aug 12, 2021
@mrshenli mrshenli added module: c10d Issues/PRs related to collective communications and process groups enhancement Not as big of a feature, but technically not a bug. Should be easy to fix labels Aug 15, 2021
@cbalioglu
Copy link
Contributor

Hi @stas00, thanks for the suggestion. As I mentioned a while ago, I am planning to write an RFC soon (hopefully this week) aiming to improve our current multiprocessing logic and failure handling. Just curious regarding this: do you use Slurm's I/O redirection flags to write down stdout/stderr of your workers to a specific location? How do you retrieve this stack trace information from your training job right now?

@stas00
Copy link
Contributor Author

stas00 commented Aug 16, 2021

We are currently logging all nodes to the same log file:

#SBATCH --output=%x-%j.out

I know I can add %N to create a separate log file per node, but it'd make monitoring much more difficult as we have 64-to-256 nodes. As currently I just run tail -f on that single log file to watch the whole "battlefield".

If you have a better strategy I'm all ears.

Thank you, @cbalioglu.

@stas00
Copy link
Contributor Author

stas00 commented Sep 14, 2021

pinging @cbalioglu as a month has passed. Thank you!

@stas00
Copy link
Contributor Author

stas00 commented Sep 30, 2021

What are the chances this can be added to 1.10 as it's really close now. We badly need this feature in the huge multi-node training on HPC.

It's really about adding a short hostname + rank to the elastic logging when it dumps the error, any obstacles to having this functionality? I think rank is already there - just needing the short hostname.

thank you!

@cbalioglu
Copy link
Contributor

cc @kiukchung @aivanou

@kiukchung
Copy link
Collaborator

@stas00 here's the PR (#66182) to add hostnames to the error summary. Note that you need to add the @torch.distributed.elastic.multiprocessing.error.record annotation to your main function otherwise the traceback will not be available (this is because there is no IPC exception handling built into python).

Here's the sample output

============================================================
run_script_path FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2021-10-05_17:37:22
  host      : devvm4955.prn0.facebook.com
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 3296201)
  error_file: /home/kiuk/tmp/elastic/none_3_lsytqe/attempt_0/0/error.json
  traceback :
  Traceback (most recent call last):
    File "/tmp/jetter.xr3_x6qq/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 372, in wrapper
      return f(*args, **kwargs)
    File "main.py", line 28, in main
      raise RuntimeError(args.throws)
  RuntimeError: foobar

============================================================

Based on your use-case (tailing all the logs into one console), I've also taken the liberty to move the "root cause" section to the bottom so that the root cause errors appear at the end of the logs.

@kiukchung
Copy link
Collaborator

Closing since we have a PR out there. Feel free to comment on the PR - this will make it to torch-1.10 if I merge the PR by the end of the week.

kiukchung pushed a commit to kiukchung/pytorch that referenced this issue Oct 7, 2021
Summary:
Pull Request resolved: pytorch#66182

closes pytorch#63174

Does a few things:

1. adds hostname to the error report
2. moves the "root cause" section to the end (presumably since the logs are being "tailed" we want the root cause to appear at the end)
3. moves redundant error info logging to debug
4. makes the border max 60 char in length and justifies left for the header

NOTE: YOU HAVE TO annotate your main function with torch.distributed.elastic.multiprocessing.errors.record, otherwise no traceback is printed (this is because python exception propagation does NOT work out of the both for IPC - hence the extra record annotation).

Test Plan:
Sample

```
============================================================
run_script_path FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2021-10-05_17:37:22
  host      : devvm4955.prn0.facebook.com
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 3296201)
  error_file: /home/kiuk/tmp/elastic/none_3_lsytqe/attempt_0/0/error.json
  traceback :
  Traceback (most recent call last):
    File "/tmp/jetter.xr3_x6qq/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 372, in wrapper
      return f(*args, **kwargs)
    File "main.py", line 28, in main
      raise RuntimeError(args.throws)
  RuntimeError: foobar

============================================================
```

Reviewed By: cbalioglu, aivanou

Differential Revision: D31416492

fbshipit-source-id: 7490e1a90b8083fd38329f321cc09ab8b8713b26
facebook-github-bot pushed a commit that referenced this issue Oct 7, 2021
Summary:
Pull Request resolved: #66182

closes #63174

Does a few things:

1. adds hostname to the error report
2. moves the "root cause" section to the end (presumably since the logs are being "tailed" we want the root cause to appear at the end)
3. moves redundant error info logging to debug
4. makes the border max 60 char in length and justifies left for the header

NOTE: YOU HAVE TO annotate your main function with torch.distributed.elastic.multiprocessing.errors.record, otherwise no traceback is printed (this is because python exception propagation does NOT work out of the both for IPC - hence the extra record annotation).

Test Plan:
Sample

```
============================================================
run_script_path FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2021-10-05_17:37:22
  host      : devvm4955.prn0.facebook.com
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 3296201)
  error_file: /home/kiuk/tmp/elastic/none_3_lsytqe/attempt_0/0/error.json
  traceback :
  Traceback (most recent call last):
    File "/tmp/jetter.xr3_x6qq/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 372, in wrapper
      return f(*args, **kwargs)
    File "main.py", line 28, in main
      raise RuntimeError(args.throws)
  RuntimeError: foobar

============================================================
```

Reviewed By: cbalioglu, aivanou

Differential Revision: D31416492

fbshipit-source-id: 0aeaf6e634e23ce0ea7f6a03b12c8a9ac57246e9
kiukchung pushed a commit that referenced this issue Oct 14, 2021
Summary:
Pull Request resolved: #66182

closes #63174

Does a few things:

1. adds hostname to the error report
2. moves the "root cause" section to the end (presumably since the logs are being "tailed" we want the root cause to appear at the end)
3. moves redundant error info logging to debug
4. makes the border max 60 char in length and justifies left for the header

NOTE: YOU HAVE TO annotate your main function with torch.distributed.elastic.multiprocessing.errors.record, otherwise no traceback is printed (this is because python exception propagation does NOT work out of the both for IPC - hence the extra record annotation).

Test Plan:
Sample

```
============================================================
run_script_path FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2021-10-05_17:37:22
  host      : devvm4955.prn0.facebook.com
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 3296201)
  error_file: /home/kiuk/tmp/elastic/none_3_lsytqe/attempt_0/0/error.json
  traceback :
  Traceback (most recent call last):
    File "/tmp/jetter.xr3_x6qq/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 372, in wrapper
      return f(*args, **kwargs)
    File "main.py", line 28, in main
      raise RuntimeError(args.throws)
  RuntimeError: foobar

============================================================
```

Reviewed By: cbalioglu, aivanou

Differential Revision: D31416492

fbshipit-source-id: 0aeaf6e634e23ce0ea7f6a03b12c8a9ac57246e9
malfet pushed a commit that referenced this issue Oct 15, 2021
Summary:
Pull Request resolved: #66182

closes #63174

Does a few things:

1. adds hostname to the error report
2. moves the "root cause" section to the end (presumably since the logs are being "tailed" we want the root cause to appear at the end)
3. moves redundant error info logging to debug
4. makes the border max 60 char in length and justifies left for the header

NOTE: YOU HAVE TO annotate your main function with torch.distributed.elastic.multiprocessing.errors.record, otherwise no traceback is printed (this is because python exception propagation does NOT work out of the both for IPC - hence the extra record annotation).

Test Plan:
Sample

```
============================================================
run_script_path FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2021-10-05_17:37:22
  host      : devvm4955.prn0.facebook.com
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 3296201)
  error_file: /home/kiuk/tmp/elastic/none_3_lsytqe/attempt_0/0/error.json
  traceback :
  Traceback (most recent call last):
    File "/tmp/jetter.xr3_x6qq/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 372, in wrapper
      return f(*args, **kwargs)
    File "main.py", line 28, in main
      raise RuntimeError(args.throws)
  RuntimeError: foobar

============================================================
```

Reviewed By: cbalioglu, aivanou

Differential Revision: D31416492

fbshipit-source-id: 0aeaf6e634e23ce0ea7f6a03b12c8a9ac57246e9
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: c10d Issues/PRs related to collective communications and process groups oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
5 participants