From 5fd01b0e68a6087908ac0bcefd4edaeddfb0e248 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 13 Jun 2020 12:00:14 -0400 Subject: [PATCH] Finish Ananthsub patch 1 (enable prepare_data from correct processes). clarify local vs global rank (#2166) * [trainer] Call prepare_data once per node in DDP/DDP2 training * refactored DDP routes * renamed proc_rank to local_rank * renamed proc_rank to local_rank * renamed proc_rank to local_rank * renamed proc_rank to local_rank * renamed proc_rank to local_rank * renamed proc_rank to local_rank * renamed proc_rank to local_rank * renamed proc_rank to local_rank * renamed proc_rank to local_rank * renamed proc_rank to local_rank * renamed proc_rank to local_rank * renamed proc_rank to local_rank * spawn message * spawn message * spawn message * fixes * fixes * fixes * fixes * fixes * Update trainer.py Co-authored-by: ananthsub --- .../callbacks/model_checkpoint.py | 2 +- pytorch_lightning/core/lightning.py | 10 +++--- pytorch_lightning/trainer/__init__.py | 13 +++++++ pytorch_lightning/trainer/data_loading.py | 4 +-- .../trainer/distrib_data_parallel.py | 35 ++++++++++++++----- pytorch_lightning/trainer/distrib_parts.py | 10 +++--- pytorch_lightning/trainer/evaluation_loop.py | 10 ++---- pytorch_lightning/trainer/logging.py | 8 ++--- pytorch_lightning/trainer/trainer.py | 30 ++++++++++++---- pytorch_lightning/trainer/training_io.py | 6 ++-- pytorch_lightning/trainer/training_loop.py | 9 +++-- 11 files changed, 91 insertions(+), 46 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 03080fbb2b7b6..672670da72b55 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -229,7 +229,7 @@ def format_checkpoint_name(self, epoch, metrics, ver=None): @rank_zero_only def on_validation_end(self, trainer, pl_module): # only run on main process - if trainer.proc_rank != 0: + if trainer.global_rank != 0: return metrics = trainer.callback_metrics diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index ae94b080f4863..af7527f550d0e 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -99,7 +99,7 @@ def forward(self, x): self.print(x, 'in forward') """ - if self.trainer.proc_rank == 0: + if self.trainer.is_global_zero: print(*args, **kwargs) @abstractmethod @@ -922,7 +922,7 @@ def _init_slurm_connection(self) -> None: def init_ddp_connection( self, - proc_rank: int, + global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True ) -> None: @@ -933,7 +933,7 @@ def init_ddp_connection( for SLURM managed cluster. Args: - proc_rank: The current process rank within the node. + global_rank: The global process idx. world_size: Number of GPUs being use across all nodes. (num_nodes * num_gpus). is_slurm_managing_tasks: is cluster managed by SLURM. @@ -956,8 +956,8 @@ def init_ddp_connection( f"is not equal to the computed world size ({world_size}). Ignored.") torch_backend = "nccl" if self.trainer.on_gpu else "gloo" - log.info(f"initializing ddp: LOCAL_RANK: {proc_rank}/{world_size - 1} WORLD_SIZE:{world_size}") - torch_distrib.init_process_group(torch_backend, rank=proc_rank, world_size=world_size) + log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank+1}/{world_size}") + torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size) def configure_apex( self, diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index b9737b0557b75..24787c61061d7 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -585,6 +585,19 @@ def on_train_end(self, trainer, pl_module): --env=XLA_USE_BF16=1 -- python your_trainer_file.py +prepare_data_per_node +^^^^^^^^^^^^^^^^^^^^^ +If True will call `prepare_data()` on LOCAL_RANK=0 for every node. +If False will only call from NODE_RANK=0, LOCAL_RANK=0 + +Example:: + + # default + Trainer(prepare_data_per_node=True) + + # use only NODE_RANK=0, LOCAL_RANK=0 + Trainer(prepare_data_per_node=False) + tpu_cores ^^^^^^^^^ - How many TPU cores to train on (1 or 8). diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 71983addb9e4f..380e8257c1b43 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -57,7 +57,7 @@ class TrainerDataLoadingMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class - proc_rank: int + global_rank: int use_ddp: bool use_ddp2: bool use_horovod: bool @@ -160,7 +160,7 @@ def _get_distributed_sampler(self, dataloader): 'ddp_cpu': self.num_processes * self.num_nodes } assert self.distributed_backend is not None - kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.proc_rank) + kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.global_rank) sampler = DistributedSampler(dataloader.dataset, **kwargs) return sampler diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index 316c16fa531a0..d2b190ca06279 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -165,6 +165,10 @@ class TrainerDDPMixin(ABC): num_nodes: int node_rank: int + @property + def is_global_zero(self) -> int: + """Warning: this is just empty shell for code implemented in other class.""" + @property @abstractmethod def num_gpus(self) -> int: @@ -300,6 +304,13 @@ def configure_slurm_ddp(self, num_gpu_nodes): if self.is_slurm_managing_tasks: rank_zero_info('Multi-processing is handled by Slurm.') + def determine_local_rank(self): + if self.is_slurm_managing_tasks: + return int(os.environ['SLURM_LOCALID']) + + else: + return int(os.environ.get('LOCAL_RANK', 0)) + def determine_ddp_node_rank(self): if self.is_slurm_managing_tasks: return int(os.environ['SLURM_NODEID']) @@ -423,21 +434,30 @@ def ddp_train(self, process_idx, model, is_master=False, proc_offset=0): # determine which process we are and world size if self.use_ddp: - self.proc_rank = self.node_rank * self.num_processes + process_idx + self.local_rank = process_idx + self.global_rank = self.node_rank * self.num_processes + process_idx self.world_size = self.num_nodes * self.num_processes elif self.use_ddp2: - self.proc_rank = self.node_rank + self.local_rank = self.node_rank + self.global_rank = self.node_rank self.world_size = self.num_nodes # set warning rank - rank_zero_only.rank = self.proc_rank + rank_zero_only.rank = self.global_rank # set up server using proc 0's ip address # try to init for 20 times at max in case ports are taken # where to store ip_table model.trainer = self - model.init_ddp_connection(self.proc_rank, self.world_size, self.is_slurm_managing_tasks) + model.init_ddp_connection(self.global_rank, self.world_size, self.is_slurm_managing_tasks) + + # on world_size=0 let everyone know training is starting + if self.is_global_zero: + log.info('-' * 100) + log.info(f'distributed_backend={self.distributed_backend}') + log.info(f'All DDP processes registered. Starting ddp with {self.world_size} processes') + log.info('-' * 100) # CHOOSE OPTIMIZER # allow for lr schedulers as well @@ -450,8 +470,7 @@ def ddp_train(self, process_idx, model, is_master=False, proc_offset=0): if is_master: # source of truth is cuda for gpu idx gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',') - local_rank = int(os.environ['LOCAL_RANK']) - gpu_idx = int(gpus[local_rank]) + gpu_idx = int(gpus[self.local_rank]) self.root_gpu = gpu_idx torch.cuda.set_device(self.root_gpu) @@ -488,7 +507,7 @@ def save_spawn_weights(self, model): :param model: :return: """ - if self.proc_rank == 0: + if self.is_global_zero: path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt') self.save_checkpoint(path) @@ -502,7 +521,7 @@ def load_spawn_weights(self, original_model): loaded_model = original_model - if self.proc_rank == 0: + if self.is_global_zero: # load weights saved in ddp path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt') loaded_model = original_model.__class__.load_from_checkpoint(path) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 3e09db101f047..8807fd1cfc879 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -57,7 +57,7 @@ class TrainerDPMixin(ABC): root_gpu: ... amp_level: str precision: ... - proc_rank: int + global_rank: int tpu_local_core_rank: int tpu_global_core_rank: int use_tpu: bool @@ -183,8 +183,8 @@ def tpu_train(self, tpu_core_idx, model): if self.tpu_global_core_rank != 0 and self.progress_bar_callback is not None: self.progress_bar_callback.disable() - self.proc_rank = self.tpu_local_core_rank - rank_zero_only.rank = self.proc_rank + self.global_rank = self.tpu_local_core_rank + rank_zero_only.rank = self.global_rank # CHOOSE OPTIMIZER # allow for lr schedulers as well @@ -289,8 +289,8 @@ def filter_named_parameters(model, optimizer): # Update logger rank info from Horovod to avoid race conditions from different ranks # creating directories / writing files in the same locations. - self.proc_rank = hvd.rank() - rank_zero_only.rank = self.proc_rank + self.global_rank = hvd.rank() + rank_zero_only.rank = self.global_rank with ExitStack() as stack: for optimizer in self.optimizers: diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index c15ddec68128f..b6c490d6eb4ec 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -130,7 +130,6 @@ from torch.utils.data import DataLoader from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.profiler.profilers import BaseProfiler from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities import rank_zero_warn @@ -160,8 +159,6 @@ class TrainerEvaluationLoopMixin(ABC): use_dp: bool use_ddp2: bool use_horovod: bool - use_amp: bool - use_native_amp: bool single_gpu: bool data_parallel_device_ids: ... model: LightningModule @@ -170,15 +167,14 @@ class TrainerEvaluationLoopMixin(ABC): fast_dev_run: ... process_output: ... progress_bar_dict: ... - proc_rank: int + global_rank: int current_epoch: int callback_metrics: ... test_dataloaders: DataLoader val_dataloaders: DataLoader use_tpu: bool reload_dataloaders_every_epoch: ... - tpu_id: Optional[int] - profiler: BaseProfiler + tpu_id: int # Callback system on_validation_batch_start: Callable @@ -379,7 +375,7 @@ def run_evaluation(self, test_mode: bool = False): self.add_progress_bar_metrics(prog_bar_metrics) # log results of test - if test_mode and self.proc_rank == 0: + if test_mode and self.is_global_zero: print('-' * 80) print('TEST RESULTS') pprint(callback_metrics) diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index f1f93853cb223..5349849e09b89 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Iterable, Optional +from typing import Union, Iterable import torch @@ -15,10 +15,10 @@ class TrainerLoggingMixin(ABC): current_epoch: int on_gpu: bool log_gpu_memory: ... - logger: Optional[LightningLoggerBase] + logger: Union[LightningLoggerBase, bool] progress_bar_metrics: ... global_step: int - proc_rank: int + global_rank: int use_dp: bool use_ddp2: bool default_root_dir: str @@ -69,7 +69,7 @@ def log_metrics(self, metrics, grad_norm_dic, step=None): scalar_metrics['epoch'] = self.current_epoch step = step if step is not None else self.global_step # log actual metrics - if self.proc_rank == 0 and self.logger is not None: + if self.is_global_zero and self.logger is not None: self.logger.agg_and_log_metrics(scalar_metrics, step=step) self.logger.save() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 9aa97b102eb9d..b6cdbb0cf130d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -123,6 +123,7 @@ def __init__( replace_sampler_ddp: bool = True, terminate_on_nan: bool = False, auto_scale_batch_size: Union[str, bool] = False, + prepare_data_per_node: bool = True, amp_level: str = 'O1', # backward compatible, todo: remove in v1.0.0 num_tpu_cores: Optional[int] = None, # backward compatible, todo: remove in v0.9.0 use_amp=None, # backward compatible, todo: remove in v0.9.0 @@ -282,6 +283,9 @@ def __init__( The result will be stored in self.batch_size in the LightningModule. Additionally, can be set to either `power` that estimates the batch size through a power search or `binsearch` that estimates the batch size through a binary search. + + prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data. + Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data """ super().__init__() @@ -293,6 +297,7 @@ def __init__( os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0) # Init callbacks + self.prepare_data_per_node = prepare_data_per_node self.callbacks = callbacks or [] self.on_init_start() @@ -439,11 +444,12 @@ def __init__( self.init_tpu() # init flags for SLURM+ddp to work - self.proc_rank = 0 self.world_size = 1 self.interactive_ddp_procs = [] self.configure_slurm_ddp(self.num_nodes) self.node_rank = self.determine_ddp_node_rank() + self.local_rank = self.determine_local_rank() + self.global_rank = 0 # nvidia setup self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids) @@ -481,6 +487,10 @@ def __init__( # Callback system self.on_init_end() + @property + def is_global_zero(self): + return self.global_rank == 0 + @property def slurm_job_id(self) -> Optional[int]: try: @@ -532,6 +542,7 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: ('max_epochs', (,), 1000), ... ('precision', (,), 32), + ('prepare_data_per_node', (,), True), ('print_nan_grads', (,), False), ('process_position', (,), 0), ('profiler', @@ -773,10 +784,9 @@ def fit( # check that model is configured correctly self.check_model_configuration(model) - # download the data and do whatever transforms we need - # do before any spawn calls so that the model can assign properties - # only on proc 0 because no spawn has happened yet - if not self._is_data_prepared: + # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 + # or in the case where each node needs to do its own manipulation in which case just local_rank=0 + if self.can_prepare_data(): model.prepare_data() self._is_data_prepared = True @@ -801,6 +811,7 @@ def fit( # torchelastic or general non_slurm ddp2 elif 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ): task = int(os.environ['LOCAL_RANK']) + self.ddp_train(task, model) elif self.use_ddp: if self.is_slurm_managing_tasks: @@ -872,6 +883,13 @@ def fit( # used for testing or when we need to know that training succeeded return 1 + def can_prepare_data(self): + if self.prepare_data_per_node: + return self.local_rank == 0 + + else: + return self.node_rank == 0 and self.local_rank == 0 + def __attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=None, test_dataloaders=None): # when dataloader is passed via fit, patch the train_dataloader # functions to overwrite with these implementations @@ -928,7 +946,7 @@ def run_pretrain_routine(self, model: LightningModule): # print model summary # TODO: remove self.testing condition because model.summarize() is wiping out the weights - if self.proc_rank == 0 and self.weights_summary is not None and not self.testing: + if self.is_global_zero and self.weights_summary is not None and not self.testing: if self.weights_summary in ['full', 'top']: ref_model.summarize(mode=self.weights_summary) else: diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index d19dd8d4019e3..6f4e85d5b28e4 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -132,7 +132,7 @@ class TrainerIOMixin(ABC): use_ddp2: bool use_horovod: bool checkpoint_callback: ... - proc_rank: int + global_rank: int weights_save_path: str logger: LightningLoggerBase early_stop_callback: ... @@ -213,7 +213,7 @@ def register_slurm_signal_handlers(self): signal.signal(signal.SIGTERM, self.term_handler) def sig_handler(self, signum, frame): # pragma: no-cover - if self.proc_rank == 0: + if self.is_global_zero: # save weights log.info('handling SIGUSR1') self.hpc_save(self.weights_save_path, self.logger) @@ -262,7 +262,7 @@ def _atomic_save(self, checkpoint, filepath: str): def save_checkpoint(self, filepath, weights_only: bool = False): checkpoint = self.dump_checkpoint(weights_only) - if self.proc_rank == 0: + if self.is_global_zero: # do the actual save try: self._atomic_save(checkpoint, filepath) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index c77bd5bafd86d..dac73c6747bac 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -145,7 +145,7 @@ def training_step(self, batch, batch_idx): import signal from abc import ABC, abstractmethod from typing import Callable -from typing import Union, List, Optional +from typing import Union, List import numpy as np import torch @@ -214,7 +214,7 @@ class TrainerTrainLoopMixin(ABC): global_step: int testing: bool log_save_interval: float - proc_rank: int + global_rank: int row_log_interval: float truncated_bptt_steps: ... optimizers: ... @@ -236,8 +236,7 @@ class TrainerTrainLoopMixin(ABC): total_batch_idx: int checkpoint_callback: ... terminate_on_nan: bool - tpu_id: Optional[int] - interactive_ddp_procs: List + tpu_id: int # Callback system callbacks: List[Callback] @@ -481,7 +480,7 @@ def run_training_epoch(self): # when logs should be saved should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch if should_save_log or self.fast_dev_run: - if self.proc_rank == 0 and self.logger is not None: + if self.is_global_zero and self.logger is not None: self.logger.save() # when metrics should be logged