From f228a658055b733c16ec0f81c89cb2aa68d6d79e Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 26 Oct 2023 12:19:59 -0700 Subject: [PATCH] Merge `--pjrt_distributed` flag with `--ddp` flag. (#5732) --- test/test_train_mp_imagenet.py | 10 +--------- test/test_train_mp_mnist.py | 8 +------- 2 files changed, 2 insertions(+), 16 deletions(-) diff --git a/test/test_train_mp_imagenet.py b/test/test_train_mp_imagenet.py index 122e8513dbf..65a359ca9be 100644 --- a/test/test_train_mp_imagenet.py +++ b/test/test_train_mp_imagenet.py @@ -31,11 +31,6 @@ '--ddp': { 'action': 'store_true', }, - # Use xla:// init_method instead of env:// for `torch.distributed`. - # Required for DDP on TPU v2/v3 when using PJRT. - '--pjrt_distributed': { - 'action': 'store_true', - }, '--profile': { 'action': 'store_true', }, @@ -180,11 +175,8 @@ def _train_update(device, step, loss, tracker, epoch, writer): def train_imagenet(): - if FLAGS.pjrt_distributed: + if FLAGS.ddp: dist.init_process_group('xla', init_method='xla://') - elif FLAGS.ddp: - dist.init_process_group( - 'xla', world_size=xm.xrt_world_size(), rank=xm.get_ordinal()) print('==> Preparing data..') img_dim = get_model_property('img_dim') diff --git a/test/test_train_mp_mnist.py b/test/test_train_mp_mnist.py index 22253fbea73..8ae434acfcb 100644 --- a/test/test_train_mp_mnist.py +++ b/test/test_train_mp_mnist.py @@ -5,9 +5,6 @@ '--ddp': { 'action': 'store_true', }, - '--pjrt_distributed': { - 'action': 'store_true', - }, } FLAGS = args_parse.parse_common_options( @@ -76,11 +73,8 @@ def _train_update(device, step, loss, tracker, epoch, writer): def train_mnist(flags, **kwargs): - if flags.pjrt_distributed: + if flags.ddp: dist.init_process_group('xla', init_method='xla://') - elif flags.ddp: - dist.init_process_group( - 'xla', world_size=xm.xrt_world_size(), rank=xm.get_ordinal()) torch.manual_seed(1)