From ac2b0f0f066967b5896a73591ce7aa25dfb75306 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 1 Oct 2020 09:25:33 -0400 Subject: [PATCH] ref: continue #3733 (#3767) * ref: #3733 part 2 * ref: #3733 part 2 --- pytorch_lightning/trainer/data_loading.py | 19 ++-------- pytorch_lightning/trainer/trainer.py | 5 +-- pytorch_lightning/trainer/training_loop.py | 9 +---- tests/backends/__init__.py | 0 tests/backends/ddp_model.py | 43 ++++++++++++++++++++++ 5 files changed, 51 insertions(+), 25 deletions(-) create mode 100644 tests/backends/__init__.py create mode 100644 tests/backends/ddp_model.py diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 35c1a12d4d9d5..6411c2b61fda7 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -17,10 +17,10 @@ from abc import ABC, abstractmethod from typing import Union, List, Tuple, Callable, Optional -import torch.distributed as torch_distrib from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler +from pytorch_lightning.accelerators.base_backend import Accelerator from pytorch_lightning.accelerators.base_backend import BackendType from pytorch_lightning.core import LightningModule from pytorch_lightning.utilities import rank_zero_warn @@ -75,6 +75,7 @@ class TrainerDataLoadingMixin(ABC): limit_val_batches: Union[int, float] limit_test_batches: Union[int, float] replace_sampler_ddp: bool + accelerator_backend: Accelerator num_nodes: int num_processes: int distributed_backend: Optional[str] @@ -337,18 +338,6 @@ def request_dataloader(self, dataloader_fx: Callable) -> DataLoader: """ dataloader = dataloader_fx() - # get the function we'll use to get data - if self.use_ddp or self.use_ddp2: - # all processes wait until data download has happened - torch_distrib.barrier() - - # data download/load on TPU - elif self.use_tpu and XLA_AVAILABLE: - # all processes wait until data download has happened - torch_xla.core.xla_model.rendezvous('pl.TrainerDataLoadingMixin.get_dataloaders') - - elif self.use_horovod: - # all processes wait until data download has happened - hvd.join() - + if self.accelerator_backend is not None: + self.accelerator_backend.barrier('get_dataloaders') return dataloader diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 138bdf92c3eab..446591a436c39 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -696,9 +696,6 @@ def test( # -------------------- self.verbose_test = verbose - if self.global_rank != 0: - return - # If you supply a datamodule you can't supply train_dataloader or val_dataloaders if test_dataloaders and datamodule: raise MisconfigurationException( @@ -738,6 +735,8 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): f'specify a path for a checkpoint .test(ckpt_path=PATH)' ) return {} + if self.accelerator_backend is not None: + self.accelerator_backend.barrier() ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt['state_dict']) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index e670c01f04156..99318b9f34324 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -181,13 +181,8 @@ def on_train_end(self): if self.trainer.global_rank == 0: self.trainer.profiler.describe() - if self.trainer.global_rank == 0: - for proc in self.trainer.interactive_ddp_procs: - subprocess.Popen.kill(proc) - - # clean up dist group - if self.trainer.use_ddp or self.trainer.use_ddp2: - torch_distrib.destroy_process_group() + # give accelerators a chance to finish + self.trainer.accelerator_backend.on_train_end() # clear mem if self.trainer.on_gpu: diff --git a/tests/backends/__init__.py b/tests/backends/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/backends/ddp_model.py b/tests/backends/ddp_model.py new file mode 100644 index 0000000000000..9f75415fe686d --- /dev/null +++ b/tests/backends/ddp_model.py @@ -0,0 +1,43 @@ +""" +Runs either `.fit()` or `.test()` on a single node across multiple gpus. +""" +from argparse import ArgumentParser + +from pytorch_lightning import Trainer, seed_everything +from tests.base import EvalModelTemplate +import os +import torch + + +def main(): + seed_everything(1234) + parser = ArgumentParser(add_help=False) + parser = Trainer.add_argparse_args(parser) + parser.add_argument('--trainer_method', default='fit') + parser.add_argument('--tmpdir') + parser.set_defaults(gpus=2) + parser.set_defaults(distributed_backend="ddp") + args = parser.parse_args() + + model = EvalModelTemplate() + trainer = Trainer.from_argparse_args(args) + + result = {} + if args.trainer_method == 'fit': + trainer.fit(model) + result = {'status': 'complete', 'method': args.trainer_method, 'result': None} + if args.trainer_method == 'test': + result = trainer.test(model) + result = {'status': 'complete', 'method': args.trainer_method, 'result': result} + if args.trainer_method == 'fit_test': + trainer.fit(model) + result = trainer.test(model) + result = {'status': 'complete', 'method': args.trainer_method, 'result': result} + + if len(result) > 0: + file_path = os.path.join(args.tmpdir, 'ddp.result') + torch.save(result, file_path) + + +if __name__ == '__main__': + main()