Skip to content

Commit

Permalink
Merge --pjrt_distributed flag with --ddp flag. (#5732)
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar authored and bhavya01 committed Apr 22, 2024
1 parent 49db89d commit f228a65
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 16 deletions.
10 changes: 1 addition & 9 deletions test/test_train_mp_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,6 @@
'--ddp': {
'action': 'store_true',
},
# Use xla:// init_method instead of env:// for `torch.distributed`.
# Required for DDP on TPU v2/v3 when using PJRT.
'--pjrt_distributed': {
'action': 'store_true',
},
'--profile': {
'action': 'store_true',
},
Expand Down Expand Up @@ -180,11 +175,8 @@ def _train_update(device, step, loss, tracker, epoch, writer):


def train_imagenet():
if FLAGS.pjrt_distributed:
if FLAGS.ddp:
dist.init_process_group('xla', init_method='xla://')
elif FLAGS.ddp:
dist.init_process_group(
'xla', world_size=xm.xrt_world_size(), rank=xm.get_ordinal())

print('==> Preparing data..')
img_dim = get_model_property('img_dim')
Expand Down
8 changes: 1 addition & 7 deletions test/test_train_mp_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
'--ddp': {
'action': 'store_true',
},
'--pjrt_distributed': {
'action': 'store_true',
},
}

FLAGS = args_parse.parse_common_options(
Expand Down Expand Up @@ -76,11 +73,8 @@ def _train_update(device, step, loss, tracker, epoch, writer):


def train_mnist(flags, **kwargs):
if flags.pjrt_distributed:
if flags.ddp:
dist.init_process_group('xla', init_method='xla://')
elif flags.ddp:
dist.init_process_group(
'xla', world_size=xm.xrt_world_size(), rank=xm.get_ordinal())

torch.manual_seed(1)

Expand Down

0 comments on commit f228a65

Please sign in to comment.