Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ddp fix for trainer.test() + add basic ddp tests #2997

Merged
merged 33 commits into from
Aug 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed adding val step argument to metrics ([#2986](https://github.com/PyTorchLightning/pytorch-lightning/pull/2986))

- Fixed an issue that caused `Trainer.test()` to stall in ddp mode ([#2997](https://github.com/PyTorchLightning/pytorch-lightning/pull/2997))

## [0.8.5] - 2020-07-09

### Added
Expand Down
12 changes: 7 additions & 5 deletions docs/source/multi_gpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -286,17 +286,19 @@ variables:
MASTER_ADDR=localhost MASTER_PORT=random() WORLD_SIZE=3 NODE_RANK=1 LOCAL_RANK=0 python my_file.py --gpus 3 --etc
MASTER_ADDR=localhost MASTER_PORT=random() WORLD_SIZE=3 NODE_RANK=2 LOCAL_RANK=0 python my_file.py --gpus 3 --etc

If your code does not support this (ie: jupyter notebook, colab, or a nested script without a root package),
use `dp` or `ddp_spawn`.
We use DDP this way because `ddp_spawn` has a few limitations (due to Python and PyTorch):

1. Since `.spawn()` trains the model in subprocesses, the model on the main process does not get updated.

2. Dataloader(num_workers=N), where N is large, bottlenecks training with DDP... ie: it will be VERY slow or won't work at all. This is a PyTorch limitation.

3. Forces everything to be picklable.

However, if you don't mind these limitations, you can use `ddp_spawn`.
There are cases in which it is not possible to use DDP. Examples are:

- Jupyter Notebook, Google COLAB, Kaggle, etc.
- You have a nested script without a root package
- Your script needs to invoke `.fit` or `.test` multiple times

In these situations you should use `dp` or `ddp_spawn` instead.

Distributed Data Parallel 2
^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
26 changes: 16 additions & 10 deletions pytorch_lightning/accelerators/ddp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.distributed import rank_zero_only, find_free_network_port

try:
from hydra.utils import to_absolute_path, get_original_cwd
Expand All @@ -45,6 +45,7 @@ class DDPBackend(object):
def __init__(self, trainer):
self.trainer = trainer
self.task_idx = None
self._has_spawned_children = False

def slurm_setup(self):
self.task_idx = int(os.environ['SLURM_LOCALID'])
Expand All @@ -56,19 +57,17 @@ def train(self, model):
self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model)

def spawn_ddp_children(self, model):
port = os.environ['MASTER_PORT']
assert self.trainer.global_rank == 0
self._check_can_spawn_children()
self._has_spawned_children = True

master_address = '127.0.0.1' if 'MASTER_ADDR' not in os.environ else os.environ['MASTER_ADDR']
os.environ['MASTER_PORT'] = f'{port}'
os.environ['MASTER_ADDR'] = f'{master_address}'
os.environ['MASTER_ADDR'] = os.environ.get('MASTER_ADDR', '127.0.0.1')
os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', str(find_free_network_port()))

# allow the user to pass the node rank
node_rank = '0'
if 'NODE_RANK' in os.environ:
node_rank = os.environ['NODE_RANK']
if 'GROUP_RANK' in os.environ:
node_rank = os.environ['GROUP_RANK']

node_rank = os.environ.get('NODE_RANK', node_rank)
node_rank = os.environ.get('GROUP_RANK', node_rank)
os.environ['NODE_RANK'] = node_rank
os.environ['LOCAL_RANK'] = '0'

Expand Down Expand Up @@ -235,3 +234,10 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0

if self.trainer.global_rank == 0 and self.trainer.distributed_backend not in ['ddp_spawn', 'ddp_cpu']:
return results

def _check_can_spawn_children(self):
if self._has_spawned_children:
raise RuntimeError(
"You tried to run `.fit` or `.test` multiple times in the same script."
" This is not supported in DDP mode, switch to `distributed_backend='ddp_spawn'` instead."
)
5 changes: 3 additions & 2 deletions pytorch_lightning/accelerators/ddp_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import os

import torch
import torch.multiprocessing as mp

from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.distributed import rank_zero_only, find_free_network_port

try:
from apex import amp
Expand All @@ -32,7 +33,7 @@ def __init__(self, trainer):
self.mp_queue = None

def setup(self):
self.trainer.set_random_port()
os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', str(find_free_network_port()))

# pass in a state q
smp = mp.get_context('spawn')
Expand Down
4 changes: 0 additions & 4 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,10 +893,6 @@ def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managi
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank+1}/{world_size}")
torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size)

"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did we remove this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like it does not belong here. My question would be why is it here in the first place?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

configure_sync_batchnorm
^^^^^^^^^^^^^^^^^^^^^^^^
"""
def configure_sync_batchnorm(self, model: 'LightningModule') -> 'LightningModule':
"""
Add global batchnorm for a model spread across multiple GPUs and nodes.
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
with open(config_yaml, 'w', newline='') as fp:
yaml.dump(hparams, fp)


def convert(val: str) -> Union[int, float, bool, str]:
try:
return ast.literal_eval(val)
Expand Down
21 changes: 0 additions & 21 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ def train_fx(trial_hparams, cluster_manager, _):
from abc import ABC, abstractmethod
from typing import Union, List, Optional, Tuple

import numpy as np
import torch

from pytorch_lightning import _logger as log
Expand Down Expand Up @@ -163,10 +162,6 @@ def train_fx(trial_hparams, cluster_manager, _):
else:
XLA_AVAILABLE = True

PID = os.getpid()
RNG1 = np.random.RandomState(PID)
RANDOM_PORTS = RNG1.randint(10000, 19999, 1000)


class TrainerDDPMixin(ABC):

Expand Down Expand Up @@ -389,22 +384,6 @@ def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids):
# don't make this debug... this is good UX
rank_zero_info(f'CUDA_VISIBLE_DEVICES: [{os.environ["CUDA_VISIBLE_DEVICES"]}]')

def set_random_port(self, force=False):
"""
When running DDP NOT managed by SLURM, the ports might collide
"""
# pick a random port first
assert self.num_nodes == 1, 'random port can only be called from single node training'
global RANDOM_PORTS
default_port = RANDOM_PORTS[-1]
RANDOM_PORTS = RANDOM_PORTS[:-1]

# when not forced, use the user port
if not force:
default_port = os.environ.get('MASTER_PORT', default_port)

os.environ['MASTER_PORT'] = str(default_port)

def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
if self.distributed_backend.lower() not in ['ddp_spawn', 'ddp_cpu', 'tpu']:
return
Expand Down
3 changes: 0 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,7 +1039,6 @@ def fit(

# ddp
elif self.distributed_backend == 'ddp':
self.set_random_port()
self.accelerator_backend = DDPBackend(self)
results = self.accelerator_backend.spawn_ddp_children(model)

Expand Down Expand Up @@ -1377,7 +1376,6 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders):

# run tests
self.tested_ckpt_path = ckpt_path
self.set_random_port(force=True)
self.testing = True
os.environ['PL_TESTING_MODE'] = '1'
self.model = model
Expand All @@ -1400,7 +1398,6 @@ def __test_given_model(self, model, test_dataloaders):

# run test
# sets up testing so we short circuit to eval
self.set_random_port(force=True)
self.testing = True
self.model = model
results = self.fit(model)
Expand Down
15 changes: 15 additions & 0 deletions pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,18 @@ def _debug(*args, **kwargs):
rank_zero_debug = rank_zero_only(_debug)
rank_zero_info = rank_zero_only(_info)
rank_zero_warn = rank_zero_only(_warn)


def find_free_network_port() -> int:
"""
Finds a free port on localhost.
It is useful in single-node training when we don't want to connect to a real master node but
have to set the `MASTER_PORT` environment variable.
"""
import socket
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
s.listen(1)
port = s.getsockname()[1]
s.close()
return port
1 change: 0 additions & 1 deletion tests/base/model_valid_epoch_ends.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def _mean(res, key):
# recursive mean for multilevel dicts
return torch.stack([x[key] if isinstance(x, dict) else _mean(x, key) for x in res]).mean()

print('in validation epoch end')
val_loss_mean = _mean(outputs, 'val_loss')
val_acc_mean = _mean(outputs, 'val_acc')

Expand Down
44 changes: 44 additions & 0 deletions tests/models/data/ddp/train_test_variations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
Runs either `.fit()` or `.test()` on a single node across multiple gpus.
"""
from argparse import ArgumentParser

from pytorch_lightning import Trainer, seed_everything
from tests.base import EvalModelTemplate


def variation_fit(trainer, model):
trainer.fit(model)


def variation_test(trainer, model):
trainer.test(model)


def get_variations():
variations = [
"variation_fit",
"variation_test",
]
return variations


def main():
seed_everything(1234)
parser = ArgumentParser(add_help=False)
parser = Trainer.add_argparse_args(parser)
parser.add_argument('--variation', default=variation_fit.__name__)
parser.set_defaults(gpus=2)
parser.set_defaults(distributed_backend="ddp")
args = parser.parse_args()

model = EvalModelTemplate()
trainer = Trainer.from_argparse_args(args)

# run the chosen variation
run_variation = globals()[args.variation]
run_variation(trainer, model)


if __name__ == '__main__':
main()
34 changes: 34 additions & 0 deletions tests/models/test_gpu.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
import os
import subprocess
import sys
from collections import namedtuple
from pathlib import Path
from unittest.mock import patch

import pytest
import torch
from torchtext.data import Batch, Dataset, Example, Field, LabelField

import pytorch_lightning
import tests.base.develop_pipelines as tpipes
import tests.base.develop_utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.core import memory
from pytorch_lightning.trainer.distrib_parts import _parse_gpu_ids, determine_root_gpu_device
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.models.data.ddp import train_test_variations

PRETEND_N_OF_GPUS = 16

Expand Down Expand Up @@ -94,6 +100,34 @@ def test_multi_gpu_model_dp(tmpdir):
memory.get_memory_profile('min_max')


@pytest.mark.parametrize('cli_args', [
pytest.param('--max_epochs 1 --gpus 2 --distributed_backend ddp'),
])
@pytest.mark.parametrize('variation', train_test_variations.get_variations())
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_multi_gpu_model_ddp(tmpdir, cli_args, variation):
""" Runs a basic training and test run with distributed_backend=ddp. """
file = Path(train_test_variations.__file__).absolute()
cli_args = cli_args.split(' ') if cli_args else []
cli_args += ['--default_root_dir', str(tmpdir)]
cli_args += ['--variation', variation]
command = [sys.executable, str(file)] + cli_args

# need to set the PYTHONPATH in case pytorch_lightning was not installed into the environment
env = os.environ.copy()
env['PYTHONPATH'] = f'{pytorch_lightning.__file__}:' + env.get('PYTHONPATH', '')

# for running in ddp mode, we need to lauch it's own process or pytest will get stuck
p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env)

std, err = p.communicate(timeout=60)
std = std.decode('utf-8').strip()
err = err.decode('utf-8').strip()
assert std, f"{variation} produced no output"
if p.returncode > 0:
pytest.fail(err)


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_multi_gpu_model_ddp_spawn(tmpdir):
tutils.set_random_master_port()
Expand Down