Skip to content

Commit

Permalink
sync accelerator connector changes from dev1.2
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Feb 1, 2021
1 parent df0900c commit 3019414
Showing 1 changed file with 93 additions and 81 deletions.
174 changes: 93 additions & 81 deletions pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,16 @@
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment
from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus
from pytorch_lightning.utilities import _APEX_AVAILABLE, _NATIVE_AMP_AVAILABLE, AMPType, device_parser, rank_zero_only
from pytorch_lightning.utilities import (
_APEX_AVAILABLE,
_NATIVE_AMP_AVAILABLE,
_TPU_AVAILABLE,
AMPType,
device_parser,
DeviceType,
DistributedType,
rank_zero_only,
)
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down Expand Up @@ -77,13 +86,9 @@ def __init__(
amp_level,
cluster_environment,
):

# initialization
self.use_dp = False
self.use_ddp = False
self.use_ddp2 = False
self.use_horovod = False
self.use_single_gpu = False
self._device_type = DeviceType.CPU
self._distrib_type = None

self.num_processes = num_processes
self.tpu_cores = device_parser.parse_tpu_cores(tpu_cores)
Expand Down Expand Up @@ -149,6 +154,10 @@ def __init__(

self.replace_sampler_ddp = replace_sampler_ddp

@property
def on_cpu(self):
return self._device_type == DeviceType.CPU

@property
def on_tpu(self):
return self.tpu_cores is not None
Expand All @@ -165,6 +174,22 @@ def on_gpu(self):
gpus = self.parallel_device_ids
return gpus is not None and len(gpus) > 0 and torch.cuda.is_available()

@property
def use_dp(self):
return self._distrib_type == DistributedType.DP

@property
def use_ddp(self):
return self._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN)

@property
def use_ddp2(self):
return self._distrib_type == DistributedType.DDP2

@property
def use_horovod(self):
return self._distrib_type == DistributedType.HOROVOD

@property
def num_gpus(self) -> int:
gpus = self.parallel_device_ids
Expand Down Expand Up @@ -236,8 +261,8 @@ def select_training_type_plugin(self):
elif self.use_ddp:
use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks
use_torchelastic_ddp = self.use_ddp and self.is_using_torchelastic
use_ddp_spawn = self.use_ddp and self.distributed_backend == "ddp_spawn"
use_ddp_cpu_spawn = self.use_ddp and self.distributed_backend == "ddp_cpu"
use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN
use_ddp_cpu_spawn = self.use_ddp and self.on_cpu
use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self.is_using_torchelastic
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks
# use_ddp_sharded = self.distributed_backend == "ddp_sharded"
Expand Down Expand Up @@ -273,11 +298,10 @@ def select_training_type_plugin(self):
plugin = DataParallelPlugin(parallel_devices=self.parallel_devices)
elif self.use_horovod:
plugin = HorovodPlugin(parallel_devices=self.parallel_devices)
elif self.on_tpu:
plugin = SingleTPUPlugin(self.tpu_id)
else:
if self.on_tpu:
plugin = SingleTPUPlugin(self.tpu_id)
else:
plugin = SingleDevicePlugin(device=torch.device(f"cuda:{self.root_gpu}" if self.on_gpu else "cpu"))
plugin = SingleDevicePlugin(device=torch.device(f"cuda:{self.root_gpu}" if self.on_gpu else "cpu"))
return plugin

def select_accelerator(self):
Expand All @@ -287,7 +311,7 @@ def select_accelerator(self):

if self.on_gpu:
acc_cls = GPUAccelerator
elif self.on_gpu:
elif self.on_tpu:
acc_cls = TPUAccelerator
else:
acc_cls = CPUAccelerator
Expand All @@ -313,96 +337,84 @@ def select_cluster_environment(self):
return env

def set_distributed_mode(self):
# No distributed backend

if self.distributed_backend is None:
# horovod multi GPU
if self.has_horovodrun():
self._set_horovod_backend()

# DDP CPU
elif self.num_gpus == 0:
if self.num_nodes > 1 or self.num_processes > 1:
self.use_ddp = True

# Single GPU
elif self.num_gpus == 1:
self.use_single_gpu = True

# Default: DDP-Spawn
elif self.num_gpus == 0 and (self.num_nodes > 1 or self.num_processes > 1):
self._distrib_type = DistributedType.DDP
elif self.num_gpus > 1:
rank_zero_warn(
"You requested multiple GPUs but did not specify a backend, e.g."
' (distributed_backend="dp"|"ddp"|"ddp2").'
' Setting distributed_backend="ddp_spawn" for you.'
'You requested multiple GPUs but did not specify a backend, e.g.'
' `Trainer(accelerator="dp"|"ddp"|"ddp2")`. Setting `accelerator="ddp_spawn"` for you.'
)
self.distributed_backend = "ddp_spawn"

# DP
if self.distributed_backend == "dp":
# do nothing if num_gpus == 0
if self.num_gpus == 1:
self.use_single_gpu = True
self.use_dp = True
elif self.num_gpus > 1:
self.use_dp = True

# DDP, DDP-Spawn
elif self.distributed_backend in ("ddp", "ddp_spawn"):
if self.num_gpus == 0:
# DDP CPU
if self.num_nodes > 1 or self.num_processes > 1:
self.use_ddp = True

# DDP Single GPU
elif self.num_gpus == 1:
self.use_single_gpu = True
self.use_ddp = True

# DDP Multi GPU
elif self.num_gpus > 1:
self.use_ddp = True
self.num_processes = self.num_gpus

# DDP2
elif self.distributed_backend == "ddp2":
# do nothing if num_gpus == 0
if self.num_gpus >= 1:
self.use_ddp2 = True

# DDP CPU
elif self.distributed_backend == "ddp_cpu":
# special case with DDP on CPUs
if self.distributed_backend == "ddp_cpu":
self._distrib_type = DistributedType.DDP
self.data_parallel_device_ids = None
if self.num_gpus > 0:
rank_zero_warn(
"You requested one or more GPUs, but set the backend to `ddp_cpu`. Training will not use GPUs."
'You requested one or more GPUs, but set the backend to `ddp_cpu`. Training will not use GPUs.'
)
self.parallel_device_ids = None
self.use_ddp = True

# Sharded DDP
elif self.distributed_backend in ("ddp_sharded", "ddp_sharded_spawn"):
self.use_ddp = True

# HOROVOD
elif self.distributed_backend == "horovod":
if self.num_processes is None:
# define the max CPU available
self.num_processes = os.cpu_count()
# special case with TPUs
elif self.distributed_backend == 'tpu':
self._device_type = DeviceType.TPU
# set all other requested distrib. types adn if it was not set in the
elif self.distributed_backend and self._distrib_type is None:
self._distrib_type = DistributedType(self.distributed_backend)

# unless you request explicitly for CPU and some GPU are available use them
_on_cpu = self.distributed_backend and 'cpu' in self.distributed_backend
if (self.num_gpus > 0 and not _on_cpu):
self._device_type = DeviceType.GPU

_distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2)
# DP and DDP2 cannot run without GPU
if (self.num_gpus == 0 and self._distrib_type in _distrib_types):
rank_zero_warn(
'You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`.'
)
# todo: in some cases it yield in comarison None and int
if ((self.num_nodes and self.num_nodes > 1) or (self.num_processes and self.num_processes > 1)):
self._distrib_type = DistributedType.DDP
else:
rank_zero_warn('You are running on single node with no parallelization, so distributed has no effect.')
self._distrib_type = None

# for DDP overwrite nb processes by requested GPUs
if (
self._device_type == DeviceType.GPU
and self._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN)
):
self.num_processes = self.num_gpus

# Horovod si an extra case...
if self.distributed_backend == "horovod":
self._set_horovod_backend()

# throw error to force user ddp or ddp2 choice
if self.num_nodes > 1 and not (self.use_ddp2 or self.use_ddp):
_ddp = (DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2)
if (self.num_nodes > 1 and self._distrib_type not in _ddp):
raise MisconfigurationException(
"DataParallel does not support num_nodes > 1. Switching to DistributedDataParallel for you. "
"To silence this warning set distributed_backend=ddp or distributed_backend=ddp2"
'DataParallel does not support num_nodes > 1. Switching to DistributedDataParallel for you. '
'To silence this warning set `accelerator="ddp"` or `accelerator="ddp2"`'
)

rank_zero_info(f"GPU available: {torch.cuda.is_available()}, used: {self.on_gpu}")
rank_zero_info(f'GPU available: {torch.cuda.is_available()}, used: {self._device_type == DeviceType.GPU}')
num_cores = self.tpu_cores if self.tpu_cores is not None else 0
rank_zero_info(f"TPU available: {XLA_AVAILABLE}, using: {num_cores} TPU cores")
rank_zero_info(f'TPU available: {_TPU_AVAILABLE}, using: {num_cores} TPU cores')

if torch.cuda.is_available() and not self.on_gpu:
if torch.cuda.is_available() and self._device_type != DeviceType.GPU:
rank_zero_warn("GPU available but not used. Set the --gpus flag when calling the script.")

def _set_horovod_backend(self):
self.check_horovod()
self.use_horovod = True
self._distrib_type = DistributedType.HOROVOD

# Initialize Horovod to get rank / size info
hvd.init()
Expand Down

0 comments on commit 3019414

Please sign in to comment.