Skip to content

Commit

Permalink
Fix the missing parameter error when running mp_imagenet with torchrun (
Browse files Browse the repository at this point in the history
pytorch#5729)

* Fix the missing parameter error when running mp_imagenet with torchrun

* made it local rank
  • Loading branch information
vanbasten23 authored and mbzomowski committed Nov 16, 2023
1 parent 9c8108c commit e9ccdc2
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion test/test_train_mp_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
import torch_xla.distributed.parallel_loader as pl
import torch_xla.debug.profiler as xp
import torch_xla.utils.utils as xu
import torch_xla.core.xla_env_vars as xenv
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils
Expand Down Expand Up @@ -375,6 +376,6 @@ def _mp_fn(index, flags):

if __name__ == '__main__':
if dist.is_torchelastic_launched():
_mp_fn(FLAGS)
_mp_fn(xu.getenv_as(xenv.LOCAL_RANK, int), FLAGS)
else:
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores)

0 comments on commit e9ccdc2

Please sign in to comment.