Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable prepare_data from correct processes - clarify local vs global rank #2166

Merged
merged 23 commits into from
Jun 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def format_checkpoint_name(self, epoch, metrics, ver=None):
@rank_zero_only
def on_validation_end(self, trainer, pl_module):
# only run on main process
if trainer.proc_rank != 0:
if trainer.global_rank != 0:
return

metrics = trainer.callback_metrics
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def forward(self, x):
self.print(x, 'in forward')

"""
if self.trainer.proc_rank == 0:
if self.trainer.is_global_zero:
print(*args, **kwargs)

@abstractmethod
Expand Down Expand Up @@ -922,7 +922,7 @@ def _init_slurm_connection(self) -> None:

def init_ddp_connection(
self,
proc_rank: int,
global_rank: int,
world_size: int,
is_slurm_managing_tasks: bool = True
) -> None:
Expand All @@ -933,7 +933,7 @@ def init_ddp_connection(
for SLURM managed cluster.

Args:
proc_rank: The current process rank within the node.
global_rank: The global process idx.
world_size: Number of GPUs being use across all nodes. (num_nodes * num_gpus).
is_slurm_managing_tasks: is cluster managed by SLURM.

Expand All @@ -956,8 +956,8 @@ def init_ddp_connection(
f"is not equal to the computed world size ({world_size}). Ignored.")

torch_backend = "nccl" if self.trainer.on_gpu else "gloo"
log.info(f"initializing ddp: LOCAL_RANK: {proc_rank}/{world_size - 1} WORLD_SIZE:{world_size}")
torch_distrib.init_process_group(torch_backend, rank=proc_rank, world_size=world_size)
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank+1}/{world_size}")
torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size)

def configure_apex(
self,
Expand Down
13 changes: 13 additions & 0 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,19 @@ def on_train_end(self, trainer, pl_module):
--env=XLA_USE_BF16=1
-- python your_trainer_file.py
prepare_data_per_node
^^^^^^^^^^^^^^^^^^^^^
If True will call `prepare_data()` on LOCAL_RANK=0 for every node.
If False will only call from NODE_RANK=0, LOCAL_RANK=0
Example::
# default
Trainer(prepare_data_per_node=True)
# use only NODE_RANK=0, LOCAL_RANK=0
Trainer(prepare_data_per_node=False)
tpu_cores
^^^^^^^^^
- How many TPU cores to train on (1 or 8).
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class TrainerDataLoadingMixin(ABC):

# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
proc_rank: int
global_rank: int
use_ddp: bool
use_ddp2: bool
use_horovod: bool
Expand Down Expand Up @@ -147,7 +147,7 @@ def _get_distributed_sampler(self, dataloader):
'ddp_cpu': self.num_processes * self.num_nodes
}
assert self.distributed_backend is not None
kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.proc_rank)
kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.global_rank)
sampler = DistributedSampler(dataloader.dataset, **kwargs)
return sampler

Expand Down
35 changes: 27 additions & 8 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,10 @@ class TrainerDDPMixin(ABC):
num_nodes: int
node_rank: int

@property
def is_global_zero(self) -> int:
williamFalcon marked this conversation as resolved.
Show resolved Hide resolved
"""Warning: this is just empty shell for code implemented in other class."""

@property
@abstractmethod
def num_gpus(self) -> int:
Expand Down Expand Up @@ -300,6 +304,13 @@ def configure_slurm_ddp(self, num_gpu_nodes):
if self.is_slurm_managing_tasks:
log.info('Multi-processing is handled by Slurm.')

def determine_local_rank(self):
if self.is_slurm_managing_tasks:
return int(os.environ['SLURM_LOCALID'])

else:
return int(os.environ.get('LOCAL_RANK', 0))

def determine_ddp_node_rank(self):
if self.is_slurm_managing_tasks:
return int(os.environ['SLURM_NODEID'])
Expand Down Expand Up @@ -423,21 +434,30 @@ def ddp_train(self, process_idx, model, is_master=False, proc_offset=0):

# determine which process we are and world size
if self.use_ddp:
self.proc_rank = self.node_rank * self.num_processes + process_idx
self.local_rank = process_idx
self.global_rank = self.node_rank * self.num_processes + process_idx
self.world_size = self.num_nodes * self.num_processes

elif self.use_ddp2:
self.proc_rank = self.node_rank
self.local_rank = self.node_rank
self.global_rank = self.node_rank
self.world_size = self.num_nodes

# set warning rank
rank_zero_only.rank = self.proc_rank
rank_zero_only.rank = self.global_rank

# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
model.trainer = self
model.init_ddp_connection(self.proc_rank, self.world_size, self.is_slurm_managing_tasks)
model.init_ddp_connection(self.global_rank, self.world_size, self.is_slurm_managing_tasks)

# on world_size=0 let everyone know training is starting
if self.is_global_zero:
log.info('-' * 100)
log.info(f'distributed_backend={self.distributed_backend}')
log.info(f'All DDP processes registered. Starting ddp with {self.world_size} processes')
log.info('-' * 100)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
Expand All @@ -450,8 +470,7 @@ def ddp_train(self, process_idx, model, is_master=False, proc_offset=0):
if is_master:
# source of truth is cuda for gpu idx
gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
local_rank = int(os.environ['LOCAL_RANK'])
gpu_idx = int(gpus[local_rank])
gpu_idx = int(gpus[self.local_rank])

self.root_gpu = gpu_idx
torch.cuda.set_device(self.root_gpu)
Expand Down Expand Up @@ -488,7 +507,7 @@ def save_spawn_weights(self, model):
:param model:
:return:
"""
if self.proc_rank == 0:
if self.is_global_zero:
path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt')
self.save_checkpoint(path)

Expand All @@ -502,7 +521,7 @@ def load_spawn_weights(self, original_model):

loaded_model = original_model

if self.proc_rank == 0:
if self.is_global_zero:
# load weights saved in ddp
path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt')
loaded_model = original_model.__class__.load_from_checkpoint(path)
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class TrainerDPMixin(ABC):
root_gpu: ...
amp_level: str
precision: ...
proc_rank: int
global_rank: int
tpu_local_core_rank: int
tpu_global_core_rank: int
use_tpu: bool
Expand Down Expand Up @@ -183,8 +183,8 @@ def tpu_train(self, tpu_core_idx, model):
if self.tpu_global_core_rank != 0 and self.progress_bar_callback is not None:
self.progress_bar_callback.disable()

self.proc_rank = self.tpu_local_core_rank
rank_zero_only.rank = self.proc_rank
self.global_rank = self.tpu_local_core_rank
rank_zero_only.rank = self.global_rank

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
Expand Down Expand Up @@ -289,8 +289,8 @@ def filter_named_parameters(model, optimizer):

# Update logger rank info from Horovod to avoid race conditions from different ranks
# creating directories / writing files in the same locations.
self.proc_rank = hvd.rank()
rank_zero_only.rank = self.proc_rank
self.global_rank = hvd.rank()
rank_zero_only.rank = self.global_rank

with ExitStack() as stack:
for optimizer in self.optimizers:
Expand Down
10 changes: 3 additions & 7 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@
from torch.utils.data import DataLoader

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.profiler.profilers import BaseProfiler
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import rank_zero_warn
Expand Down Expand Up @@ -160,8 +159,6 @@ class TrainerEvaluationLoopMixin(ABC):
use_dp: bool
use_ddp2: bool
use_horovod: bool
use_amp: bool
use_native_amp: bool
single_gpu: bool
data_parallel_device_ids: ...
model: LightningModule
Expand All @@ -170,15 +167,14 @@ class TrainerEvaluationLoopMixin(ABC):
fast_dev_run: ...
process_output: ...
progress_bar_dict: ...
proc_rank: int
global_rank: int
current_epoch: int
callback_metrics: ...
test_dataloaders: DataLoader
val_dataloaders: DataLoader
use_tpu: bool
reload_dataloaders_every_epoch: ...
tpu_id: Optional[int]
profiler: BaseProfiler
tpu_id: int

# Callback system
on_validation_batch_start: Callable
Expand Down Expand Up @@ -379,7 +375,7 @@ def run_evaluation(self, test_mode: bool = False):
self.add_progress_bar_metrics(prog_bar_metrics)

# log results of test
if test_mode and self.proc_rank == 0:
if test_mode and self.is_global_zero:
print('-' * 80)
print('TEST RESULTS')
pprint(callback_metrics)
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC
from typing import Iterable, Optional
from typing import Union, Iterable

import torch

Expand All @@ -15,10 +15,10 @@ class TrainerLoggingMixin(ABC):
current_epoch: int
on_gpu: bool
log_gpu_memory: ...
logger: Optional[LightningLoggerBase]
logger: Union[LightningLoggerBase, bool]
progress_bar_metrics: ...
global_step: int
proc_rank: int
global_rank: int
use_dp: bool
use_ddp2: bool
default_root_dir: str
Expand Down Expand Up @@ -69,7 +69,7 @@ def log_metrics(self, metrics, grad_norm_dic, step=None):
scalar_metrics['epoch'] = self.current_epoch
step = step if step is not None else self.global_step
# log actual metrics
if self.proc_rank == 0 and self.logger is not None:
if self.is_global_zero and self.logger is not None:
self.logger.agg_and_log_metrics(scalar_metrics, step=step)
self.logger.save()

Expand Down
30 changes: 24 additions & 6 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __init__(
replace_sampler_ddp: bool = True,
terminate_on_nan: bool = False,
auto_scale_batch_size: Union[str, bool] = False,
prepare_data_per_node: bool = True,
amp_level: str = 'O1', # backward compatible, todo: remove in v1.0.0
num_tpu_cores: Optional[int] = None, # backward compatible, todo: remove in v0.9.0
use_amp=None, # backward compatible, todo: remove in v0.9.0
Expand Down Expand Up @@ -282,6 +283,9 @@ def __init__(
The result will be stored in self.batch_size in the LightningModule.
Additionally, can be set to either `power` that estimates the batch size through
a power search or `binsearch` that estimates the batch size through a binary search.

prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data.
Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data
"""
super().__init__()

Expand All @@ -293,6 +297,7 @@ def __init__(
os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0)

# Init callbacks
self.prepare_data_per_node = prepare_data_per_node
self.callbacks = callbacks or []
self.on_init_start()

Expand Down Expand Up @@ -439,11 +444,12 @@ def __init__(
self.init_tpu()

# init flags for SLURM+ddp to work
self.proc_rank = 0
self.world_size = 1
self.interactive_ddp_procs = []
self.configure_slurm_ddp(self.num_nodes)
self.node_rank = self.determine_ddp_node_rank()
self.local_rank = self.determine_local_rank()
self.global_rank = 0

# nvidia setup
self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids)
Expand Down Expand Up @@ -481,6 +487,10 @@ def __init__(
# Callback system
self.on_init_end()

@property
def is_global_zero(self):
return self.global_rank == 0

@property
def slurm_job_id(self) -> Optional[int]:
try:
Expand Down Expand Up @@ -532,6 +542,7 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
('max_epochs', (<class 'int'>,), 1000),
...
('precision', (<class 'int'>,), 32),
('prepare_data_per_node', (<class 'bool'>,), True),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is in the wrong position. it should go after auto_scale_batch, which is not listed here. so the solution is to remove it.
also you removed
doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
which I am fairly confident is needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is it here?
if i remove it, it still fails

Copy link
Contributor

@awaelchli awaelchli Jun 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well maybe I was wrong and it is just the missing +ELLIPSIS +NORMALIZE_WHITESPACE
cloned and doctest runs locally.
so i think its ok now

('print_nan_grads', (<class 'bool'>,), False),
('process_position', (<class 'int'>,), 0),
('profiler',
Expand Down Expand Up @@ -773,10 +784,9 @@ def fit(
# check that model is configured correctly
self.check_model_configuration(model)

# download the data and do whatever transforms we need
# do before any spawn calls so that the model can assign properties
# only on proc 0 because no spawn has happened yet
if not self._is_data_prepared:
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
# or in the case where each node needs to do its own manipulation in which case just local_rank=0
if self.can_prepare_data():
model.prepare_data()
self._is_data_prepared = True

Expand All @@ -801,6 +811,7 @@ def fit(
# torchelastic or general non_slurm ddp2
elif 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ):
task = int(os.environ['LOCAL_RANK'])

self.ddp_train(task, model)
elif self.use_ddp:
if self.is_slurm_managing_tasks:
Expand Down Expand Up @@ -872,6 +883,13 @@ def fit(
# used for testing or when we need to know that training succeeded
return 1

def can_prepare_data(self):
if self.prepare_data_per_node:
return self.local_rank == 0

else:
return self.node_rank == 0 and self.local_rank == 0

def __attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=None, test_dataloaders=None):
# when dataloader is passed via fit, patch the train_dataloader
# functions to overwrite with these implementations
Expand Down Expand Up @@ -928,7 +946,7 @@ def run_pretrain_routine(self, model: LightningModule):

# print model summary
# TODO: remove self.testing condition because model.summarize() is wiping out the weights
if self.proc_rank == 0 and self.weights_summary is not None and not self.testing:
if self.is_global_zero and self.weights_summary is not None and not self.testing:
if self.weights_summary in ['full', 'top']:
ref_model.summarize(mode=self.weights_summary)
else:
Expand Down
Loading