From e9ccdc2310a1786b7ee8066c0c2edf534e5efec2 Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Wed, 25 Oct 2023 09:21:17 -0700 Subject: [PATCH] Fix the missing parameter error when running mp_imagenet with torchrun (#5729) * Fix the missing parameter error when running mp_imagenet with torchrun * made it local rank --- test/test_train_mp_imagenet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_train_mp_imagenet.py b/test/test_train_mp_imagenet.py index 13af80a2108..122e8513dbf 100644 --- a/test/test_train_mp_imagenet.py +++ b/test/test_train_mp_imagenet.py @@ -85,6 +85,7 @@ import torch_xla.distributed.parallel_loader as pl import torch_xla.debug.profiler as xp import torch_xla.utils.utils as xu +import torch_xla.core.xla_env_vars as xenv import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp import torch_xla.test.test_utils as test_utils @@ -375,6 +376,6 @@ def _mp_fn(index, flags): if __name__ == '__main__': if dist.is_torchelastic_launched(): - _mp_fn(FLAGS) + _mp_fn(xu.getenv_as(xenv.LOCAL_RANK, int), FLAGS) else: xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores)