diff --git a/test/test_train_mp_imagenet.py b/test/test_train_mp_imagenet.py index 13af80a2108..122e8513dbf 100644 --- a/test/test_train_mp_imagenet.py +++ b/test/test_train_mp_imagenet.py @@ -85,6 +85,7 @@ import torch_xla.distributed.parallel_loader as pl import torch_xla.debug.profiler as xp import torch_xla.utils.utils as xu +import torch_xla.core.xla_env_vars as xenv import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp import torch_xla.test.test_utils as test_utils @@ -375,6 +376,6 @@ def _mp_fn(index, flags): if __name__ == '__main__': if dist.is_torchelastic_launched(): - _mp_fn(FLAGS) + _mp_fn(xu.getenv_as(xenv.LOCAL_RANK, int), FLAGS) else: xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores)