diff --git a/README.md b/README.md index bcc989727b6..b2a5110b5bb 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ If you're using `DistributedDataParallel`, make the following changes: +import torch_xla.distributed.xla_multiprocessing as xmp +import torch_xla.distributed.xla_backend - def _mp_fn(rank, world_size): + def _mp_fn(rank): ... - os.environ['MASTER_ADDR'] = 'localhost'