Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: clean trainer device & distrib setters #5297

Merged
merged 25 commits into from
Jan 4, 2021
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed distributed setting and `ddp_cpu` only with `num_processes>1` ([#5297](https://github.com/PyTorchLightning/pytorch-lightning/pull/5297))


## [1.1.0] - 2020-12-09

Expand Down
54 changes: 18 additions & 36 deletions benchmarks/test_sharded_parity.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import platform
import time
from typing import Union
from typing import Type, Union

import pytest
import torch
Expand All @@ -14,64 +14,48 @@
from tests.base.boring_model import BoringModel, RandomDataset


@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_one_device():
Borda marked this conversation as resolved.
Show resolved Hide resolved
plugin_parity_test(
accelerator='ddp_cpu',
max_percent_speed_diff=0.15, # slower speed due to one CPU doing additional sequential memory saving calls
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_one_gpu():
plugin_parity_test(
gpus=1,
accelerator='ddp_spawn',
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel
model_cls=SeedTrainLoaderModel,
)


@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_amp_one_gpu():
plugin_parity_test(
gpus=1,
precision=16,
accelerator='ddp_spawn',
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel
model_cls=SeedTrainLoaderModel,
)


@pytest.mark.skip(reason="Not a critical test, skip till drone CI performance improves.")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_multi_gpu():
plugin_parity_test(
gpus=2,
accelerator='ddp_spawn',
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel,
max_percent_speed_diff=0.25
max_percent_speed_diff=0.25,
)


@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_amp_multi_gpu():
Expand All @@ -81,13 +65,12 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu():
accelerator='ddp_spawn',
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel,
max_percent_speed_diff=0.25
max_percent_speed_diff=0.25,
)


@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu():
Expand All @@ -97,7 +80,7 @@ def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu():
accelerator='ddp_spawn',
plugin='ddp_sharded',
model_cls=SeedTrainLoaderModel,
max_percent_speed_diff=0.25
max_percent_speed_diff=0.25,
)


Expand Down Expand Up @@ -133,8 +116,7 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None):

@pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim():
"""
Expand All @@ -145,14 +127,13 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim():
gpus=2,
accelerator='ddp_spawn',
model_cls=SeedTrainLoaderMultipleOptimizersModel,
max_percent_speed_diff=0.25 # Increase speed diff since only 2 GPUs sharding 2 optimizers
max_percent_speed_diff=0.25, # Increase speed diff since only 2 GPUs sharding 2 optimizers
)


@pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir):
"""
Expand All @@ -163,7 +144,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir):
gpus=2,
accelerator='ddp_spawn',
model_cls=SeedTrainLoaderManualModel,
max_percent_speed_diff=0.25 # Increase speed diff since only 2 GPUs sharding 2 optimizers
max_percent_speed_diff=0.25, # Increase speed diff since only 2 GPUs sharding 2 optimizers
)


Expand Down Expand Up @@ -259,13 +240,14 @@ def record_ddp_fit_model_stats(trainer, model, use_cuda):


def plugin_parity_test(
model_cls: SeedTrainLoaderModel,
model_cls: Type[SeedTrainLoaderModel],
plugin: Union[str, DDPPlugin],
seed: int = 42,
accelerator: str = 'ddp_spawn',
gpus: int = 0,
precision: int = 32,
max_percent_speed_diff: float = 0.1):
max_percent_speed_diff: float = 0.1,
):
"""
Ensures that the trained model is identical to the standard DDP implementation.
Also checks for speed/memory regressions, we should expect always less memory but performance to fluctuate.
Expand Down
141 changes: 74 additions & 67 deletions pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import torch

from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, DeviceType, DistributedType
from pytorch_lightning import _logger as log
from pytorch_lightning import accelerators
from pytorch_lightning.accelerators.accelerator import Accelerator
Expand Down Expand Up @@ -81,10 +81,7 @@ def on_trainer_init(
# sync-bn backend
self.trainer.sync_batchnorm = sync_batchnorm

self.trainer.tpu_cores = device_parser.parse_tpu_cores(tpu_cores)
self.trainer.on_tpu = self.trainer.tpu_cores is not None

self.trainer.tpu_id = self.trainer.tpu_cores[0] if isinstance(self.trainer.tpu_cores, list) else None
self._parse_tpu_device_details(tpu_cores)

if num_processes != 1 and distributed_backend != "ddp_cpu":
rank_zero_warn("num_processes is only used for `accelerator='ddp_cpu'`. Ignoring it.")
Expand All @@ -100,23 +97,10 @@ def on_trainer_init(

self.trainer.data_parallel_device_ids = device_parser.parse_gpu_ids(self.trainer.gpus)
self.trainer.root_gpu = device_parser.determine_root_gpu_device(self.trainer.data_parallel_device_ids)
self.trainer.root_device = torch.device("cpu")

self.trainer.on_gpu = True if (self.trainer.data_parallel_device_ids and torch.cuda.is_available()) else False

# tpu state flags
self.trainer.use_tpu = False
self.trainer.tpu_local_core_rank = None
self.trainer.tpu_global_core_rank = None

# distributed backend choice
self.set_distributed_mode()

# override dist backend when using tpus
if self.trainer.on_tpu:
self.trainer.distributed_backend = "tpu"
self.trainer.use_tpu = True

Borda marked this conversation as resolved.
Show resolved Hide resolved
# init flags for SLURM+DDP to work
self.trainer.world_size = 1
self.trainer.interactive_ddp_procs = []
Expand All @@ -135,10 +119,29 @@ def on_trainer_init(

self.trainer.replace_sampler_ddp = replace_sampler_ddp

def _parse_tpu_device_details(self, tpu_cores):
self.trainer.tpu_cores = device_parser.parse_tpu_cores(tpu_cores)
if self.trainer.tpu_cores is not None:
if _TPU_AVAILABLE:
self.trainer._device_type = DeviceType.TPU
self.trainer.distributed_backend = "tpu"
else:
raise MisconfigurationException(
f"You have requested {self.trainer.tpu_cores} TPU cores but none is available."
)

self.trainer.tpu_id = self.trainer.tpu_cores[0] if isinstance(self.trainer.tpu_cores, list) else None

# tpu state flags
self.trainer.tpu_local_core_rank = None
self.trainer.tpu_global_core_rank = None

def _map_deprecated_dist_backend(self, accelerator, distributed_backend):
if distributed_backend is not None:
rank_zero_warn(DeprecationWarning('distributed_backend has been renamed to accelerator. '
'Deprecated in 1.0.0, will be removed in 1.2.0'))
rank_zero_warn(
'`distributed_backend` has been renamed to accelerator. Deprecated in 1.0.0, will be removed in 1.2.0',
DeprecationWarning
)

# temporary mapping until we remove all the distributed_backend references
if accelerator is not None:
Expand Down Expand Up @@ -276,71 +279,75 @@ def select_accelerator(self):
accelerator_backend = accelerators.CPUAccelerator(self.trainer, cluster_env)
else:
raise MisconfigurationException(
f'Trainer(accelerator={self.trainer.distributed_backend} is not a supported backend'
f'`Trainer(accelerator={self.trainer.distributed_backend}, num_nodes={self.trainer.num_nodes},'
f' num_processes={self.trainer.num_processes}, ...)` is not a supported backend for'
f' num_gpus={self.trainer.num_gpus}'
)

return accelerator_backend

def set_distributed_mode(self):
self.trainer.use_dp = False
self.trainer.use_ddp = False
self.trainer.use_ddp2 = False
self.trainer.use_horovod = False
self.trainer.use_single_gpu = False

if self.trainer.distributed_backend is None:
if self.has_horovodrun():
self._set_horovod_backend()
elif self.trainer.num_gpus == 0:
if self.trainer.num_nodes > 1 or self.trainer.num_processes > 1:
self.trainer.use_ddp = True # ddp_cpu
elif self.trainer.num_gpus == 1:
self.trainer.use_single_gpu = True
elif self.trainer.num_gpus == 0 and (self.trainer.num_nodes > 1 or self.trainer.num_processes > 1):
self.trainer._distrib_type = DistributedType.DDP
elif self.trainer.num_gpus > 1:
rank_zero_warn(
'You requested multiple GPUs but did not specify a backend, e.g.'
' `Trainer(accelerator="dp"|"ddp"|"ddp2")`.'
' Setting `accelerator="ddp_spawn"` for you.'
' `Trainer(accelerator="dp"|"ddp"|"ddp2")`. Setting `accelerator="ddp_spawn"` for you.'
)
self.trainer.distributed_backend = "ddp_spawn"

if self.trainer.distributed_backend == "dp":
# do nothing if num_gpus == 0
if self.trainer.num_gpus == 1:
self.trainer.use_single_gpu = True
self.trainer.use_dp = True
elif self.trainer.num_gpus > 1:
self.trainer.use_dp = True

elif self.trainer.distributed_backend in ("ddp", "ddp_spawn"):
if self.trainer.num_gpus == 0:
if self.trainer.num_nodes > 1 or self.trainer.num_processes > 1:
self.trainer.use_ddp = True # ddp_cpu
elif self.trainer.num_gpus == 1:
self.trainer.use_single_gpu = True
self.trainer.use_ddp = True
elif self.trainer.num_gpus > 1:
self.trainer.use_ddp = True
self.trainer.num_processes = self.trainer.num_gpus

elif self.trainer.distributed_backend == "ddp2":
# do nothing if num_gpus == 0
if self.trainer.num_gpus >= 1:
self.trainer.use_ddp2 = True
elif self.trainer.distributed_backend == "ddp_cpu":
# special case with DDP on CPUs
if self.trainer.distributed_backend == "ddp_cpu":
self.trainer._distrib_type = DistributedType.DDP
self.trainer.data_parallel_device_ids = None
if self.trainer.num_gpus > 0:
rank_zero_warn(
'You requested one or more GPUs, but set the backend to `ddp_cpu`. Training will not use GPUs.'
)
self.trainer.use_ddp = True
self.trainer.data_parallel_device_ids = None
self.trainer.on_gpu = False
self.trainer.on_cpu = True
elif self.trainer.distributed_backend == "horovod":
if self.trainer.num_processes is None:
# define the max CPU available
self.trainer.num_processes = os.cpu_count()
# special case with TPUs
elif self.trainer.distributed_backend == 'tpu':
self.trainer._device_type = DeviceType.TPU
# set all other requested distrib. types adn if it was not set in the
elif self.trainer.distributed_backend and self.trainer._distrib_type is None:
self.trainer._distrib_type = DistributedType(self.trainer.distributed_backend)

# unless you request explicitly for CPU and some GPU are available use them
_on_cpu = self.trainer.distributed_backend and 'cpu' in self.trainer.distributed_backend
if (self.trainer.num_gpus > 0 and not _on_cpu):
self.trainer._device_type = DeviceType.GPU

_distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2)
# DP and DDP2 cannot run without GPU
if (self.trainer.num_gpus == 0 and self.trainer._distrib_type in _distrib_types):
rank_zero_warn(
'You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`.'
)
# todo: in some cases it yield in comarison None and int
if ((self.trainer.num_nodes and self.trainer.num_nodes > 1)
or (self.trainer.num_processes and self.trainer.num_processes > 1)):
self.trainer._distrib_type = DistributedType.DDP
else:
rank_zero_warn('You are running on single node with no parallelization, so distributed has no effect.')
self.trainer._distrib_type = None

# for DDP overwrite nb processes by requested GPUs
if (self.trainer._device_type == DeviceType.GPU
and self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN)):
self.trainer.num_processes = self.trainer.num_gpus

# Horovod si an extra case...
if self.trainer.distributed_backend == "horovod":
self._set_horovod_backend()

# throw error to force user ddp or ddp2 choice
if self.trainer.num_nodes > 1 and not (self.trainer.use_ddp2 or self.trainer.use_ddp):
if self.trainer.num_nodes > 1 and self.trainer._distrib_type not in (DistributedType.DDP2, DistributedType.DDP):
raise MisconfigurationException(
'DataParallel does not support num_nodes > 1. Switching to DistributedDataParallel for you. '
'To silence this warning set `accelerator="ddp"` or `accelerator="ddp2"`'
Expand All @@ -350,20 +357,20 @@ def set_distributed_mode(self):
num_cores = self.trainer.tpu_cores if self.trainer.tpu_cores is not None else 0
rank_zero_info(f'TPU available: {_TPU_AVAILABLE}, using: {num_cores} TPU cores')

if torch.cuda.is_available() and not self.trainer.on_gpu:
if torch.cuda.is_available() and self.trainer._device_type != DeviceType.GPU:
rank_zero_warn('GPU available but not used. Set the --gpus flag when calling the script.')

def _set_horovod_backend(self):
self.check_horovod()
self.trainer.use_horovod = True
self._check_horovod()
self.trainer._distrib_type = DistributedType.HOROVOD

# Initialize Horovod to get rank / size info
hvd.init()
if self.trainer.on_gpu:
# Horovod assigns one local GPU per process
self.trainer.root_gpu = hvd.local_rank()

def check_horovod(self):
def _check_horovod(self):
"""Raises a `MisconfigurationException` if the Trainer is not configured correctly for Horovod."""
if not _HOROVOD_AVAILABLE:
raise MisconfigurationException(
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/plugin_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def __init__(self, trainer):
self.plugins = []
self.ddp_plugin = DDPPlugin()
self.cloud_environment = None
self.amp_plugin = NativeAMPPlugin(trainer)
self.apex_plugin = ApexPlugin(trainer)
# self.amp_plugin = NativeAMPPlugin(trainer)
Borda marked this conversation as resolved.
Show resolved Hide resolved
# self.apex_plugin = ApexPlugin(trainer)
Borda marked this conversation as resolved.
Show resolved Hide resolved

def on_trainer_init(self, plugins: Optional[Union[str, list]]):
self.plugins = plugins
Expand Down
Loading