Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable passing in custom accelerators #4050

Merged
merged 4 commits into from
Oct 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
from pytorch_lightning.accelerators.ddp_torchelastic_backend import DDPTorchElasticBackend
from pytorch_lightning.accelerators.ddp_cpu_torchelastic_backend import DDPCPUTorchElasticBackend
from pytorch_lightning.accelerators.ddp_cpu_slurm_backend import DDPCPUSLURMBackend
from pytorch_lightning.accelerators.base_accelerator import Accelerator
25 changes: 24 additions & 1 deletion pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.cluster_environments.slurm_environment import SLURMEnvironment
from pytorch_lightning.cluster_environments.torchelastic_environment import TorchElasticEnvironment
from pytorch_lightning.accelerators.base_accelerator import Accelerator

try:
import torch_xla
Expand All @@ -29,11 +30,13 @@ class AcceleratorConnector:

def __init__(self, trainer):
self.trainer = trainer
self.accelerator = None

def on_trainer_init(
self,
num_processes,
tpu_cores,
accelerator,
distributed_backend,
auto_select_gpus,
gpus,
Expand All @@ -44,6 +47,15 @@ def on_trainer_init(
replace_sampler_ddp,
deterministic,
):
# temporary mapping until we remove all the distributed_backend references
if accelerator is not None:
self.accelerator = accelerator
if isinstance(accelerator, Accelerator):
self.accelerator.trainer = self
distributed_backend = self.accelerator.nickname
else:
distributed_backend = accelerator

self.trainer.deterministic = deterministic

torch.backends.cudnn.deterministic = self.trainer.deterministic
Expand Down Expand Up @@ -145,7 +157,18 @@ def select_accelerator(self):
if self.trainer.accelerator_backend is not None:
return self.trainer.accelerator_backend

# SLURM ddp
# ----------------------------------
# Use the user provided accelerator
# ----------------------------------
# use the one the user passed in
if self.accelerator is not None and isinstance(self.accelerator, Accelerator):
self.accelerator.trainer = self.trainer
acc = self.accelerator
return acc

# ----------------------------------
# choose an accelerator for the user
# ----------------------------------
use_slurm_ddp = self.trainer.use_ddp and self.trainer.is_slurm_managing_tasks

# torchelastic or general non_slurm ddp
Expand Down
11 changes: 7 additions & 4 deletions pytorch_lightning/accelerators/base_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@

class Accelerator(object):

def __init__(self, trainer, cluster_environment=None):
def __init__(self, trainer=None, cluster_environment=None):
self.trainer = trainer
self.nickname = None
self.cluster_environment = cluster_environment
self.dist = AttributeDict(rank=0, device=None)
self.train_loop = self.trainer.train
self.validation_loop = self.trainer.run_evaluation
self.test_loop = self.trainer.run_evaluation

if trainer is not None:
self.train_loop = self.trainer.train
self.validation_loop = self.trainer.run_evaluation
self.test_loop = self.trainer.run_evaluation

def setup(self, model):
pass
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/cpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class CPUBackend(Accelerator):

def __init__(self, trainer, cluster_environment=None):
super().__init__(trainer, cluster_environment)
self.nickname = None

def setup(self, model):
# run through amp wrapper
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/ddp2_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(self, trainer, cluster_environment=None):
super().__init__(trainer, cluster_environment)
self.task_idx = None
self.dist = LightningDistributed()
self.nickname = 'ddp2'

def setup(self, model):
self._resolve_task_idx()
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/ddp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(self, trainer, cluster_environment=None):
self._has_spawned_children = False
self.interactive_ddp_procs = []
self.dist = LightningDistributed()
self.nickname = 'ddp'

def setup(self, model):
# first track model
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/ddp_cpu_slurm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(self, trainer, cluster_environment=None):
self.task_idx = None
self._has_spawned_children = False
self.dist = LightningDistributed()
self.nickname = 'ddp_cpu'

def setup(self, model):
self.trainer.model = model
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self, trainer, nprocs, cluster_environment=None):
self.mp_queue = None
self.nprocs = nprocs
self.dist = LightningDistributed()
self.nickname = 'ddp_cpu'

def setup(self, model):
os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', str(find_free_network_port()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(self, trainer, cluster_environment=None):
self.task_idx = None
self._has_spawned_children = False
self.dist = LightningDistributed()
self.nickname = 'ddp_cpu'

def setup(self, model):
self.trainer.model = model
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/ddp_slurm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(self, trainer, cluster_environment=None):
self.task_idx = None
self._has_spawned_children = False
self.dist = LightningDistributed()
self.nickname = 'ddp'

def setup(self, model):
self.trainer.model = model
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/ddp_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(self, trainer, nprocs, cluster_environment=None):
self.mp_queue = None
self.nprocs = nprocs
self.dist = LightningDistributed()
self.nickname = 'ddp'

def setup(self, model):
os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', str(find_free_network_port()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(self, trainer, cluster_environment=None):
self.task_idx = None
self._has_spawned_children = False
self.dist = LightningDistributed()
self.nickname = 'ddp'

def setup(self, model):
self.trainer.model = model
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/dp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(self, trainer, cluster_environment=None):
super().__init__(trainer, cluster_environment)
self.model_autocast_original_forward = None
self.dist = LightningDistributed()
self.nickname = 'dp'

def setup(self, model):
# call setup after the ddp process has connected
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/gpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class GPUBackend(Accelerator):
def __init__(self, trainer, cluster_environment=None):
super().__init__(trainer, cluster_environment)
self.dist = LightningDistributed()
self.nickname = None

def setup(self, model):

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/horovod_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class HorovodBackend(Accelerator):

def __init__(self, trainer, cluster_environment=None):
super().__init__(trainer, cluster_environment)
self.nickname = 'horovod'

def setup(self, model):
# call setup after the ddp process has connected
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(self, trainer, cluster_environment=None):
super().__init__(trainer, cluster_environment)
self.start_method = None
self.mp_queue = None
self.nickname = None

def setup(self, model):
rank_zero_info(f'training on {self.trainer.tpu_cores} TPU cores')
Expand Down
93 changes: 52 additions & 41 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,57 @@ def forward(self, x):
Trainer flags
-------------

accelerator
^^^^^^^^^^^

.. raw:: html

<video width="50%" max-width="400px" controls
poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/distributed_backend.jpg"
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/distributed_backend.mp4"></video>

|

The accelerator backend to use (previously known as distributed_backend).

- (```dp```) is DataParallel (split batch among GPUs of same machine)
- (```ddp```) is DistributedDataParallel (each gpu on each node trains, and syncs grads)
- (```ddp_cpu```) is DistributedDataParallel on CPU (same as `ddp`, but does not use GPUs.
Useful for multi-node CPU training or single-node debugging. Note that this will **not** give
a speedup on a single node, since Torch already makes effient use of multiple CPUs on a single
machine.)
- (```ddp2```) dp on node, ddp across nodes. Useful for things like increasing
the number of negative samples

.. testcode::

# default used by the Trainer
trainer = Trainer(distributed_backend=None)

Example::

# dp = DataParallel
trainer = Trainer(gpus=2, distributed_backend='dp')

# ddp = DistributedDataParallel
trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp')

# ddp2 = DistributedDataParallel + dp
trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp2')

.. note:: this option does not apply to TPU. TPUs use ```ddp``` by default (over each core)

You can also modify hardware behavior by subclassing an existing accelerator to adjust for your needs.

Example::

class MyOwnDDP(DDPBackend):
...

Trainer(accelerator=MyOwnDDP())

.. warning:: Passing in custom accelerators is experimental but work is in progress to enable full compatibility.

accumulate_grad_batches
^^^^^^^^^^^^^^^^^^^^^^^

Expand Down Expand Up @@ -486,47 +537,7 @@ def on_train_end(self, trainer, pl_module):

distributed_backend
^^^^^^^^^^^^^^^^^^^

.. raw:: html

<video width="50%" max-width="400px" controls
poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/distributed_backend.jpg"
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/distributed_backend.mp4"></video>

|

The distributed backend to use.

- (```dp```) is DataParallel (split batch among GPUs of same machine)
- (```ddp```) is DistributedDataParallel (each gpu on each node trains, and syncs grads)
- (```ddp_cpu```) is DistributedDataParallel on CPU (same as `ddp`, but does not use GPUs.
Useful for multi-node CPU training or single-node debugging. Note that this will **not** give
a speedup on a single node, since Torch already makes effient use of multiple CPUs on a single
machine.)
- (```ddp2```) dp on node, ddp across nodes. Useful for things like increasing
the number of negative samples

.. testcode::

# default used by the Trainer
trainer = Trainer(distributed_backend=None)

Example::

# dp = DataParallel
trainer = Trainer(gpus=2, distributed_backend='dp')

# ddp = DistributedDataParallel
trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp')

# ddp2 = DistributedDataParallel + dp
trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp2')

.. note:: this option does not apply to TPU. TPUs use ```ddp``` by default (over each core)

See Also:
- :ref:`Multi-GPU training guide <multi_gpu>`.
- :ref:`Multi-node (SLURM) guide <slurm>`.
This has been renamed "accelerator".

fast_dev_run
^^^^^^^^^^^^
Expand Down
11 changes: 9 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.trainer.properties import TrainerProperties
from pytorch_lightning.plugins.plugin_connector import PluginConnector
from pytorch_lightning.accelerators.base_accelerator import Accelerator
from pytorch_lightning.accelerators.cpu_backend import CPUBackend

# warnings to ignore in trainer
warnings.filterwarnings(
Expand Down Expand Up @@ -111,7 +113,7 @@ def __init__(
val_check_interval: Union[int, float] = 1.0,
flush_logs_every_n_steps: int = 100,
log_every_n_steps: int = 50,
distributed_backend: Optional[str] = None,
accelerator: Optional[Union[str, Accelerator]] = None,
sync_batchnorm: bool = False,
precision: int = 32,
weights_summary: Optional[str] = 'top',
Expand All @@ -131,12 +133,16 @@ def __init__(
plugins: list = None,
amp_backend: str = 'native',
amp_level: str = 'O2',
distributed_backend: Optional[str] = None,
):
r"""
Customize every aspect of training via flags

Args:

accelerator: Previously known as distributed_backend (dp, ddp, ddp2, etc...).
Can also take in an accelerator object for custom hardware.

accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict.

amp_backend: The mixed precision backend to use ("native" or "apex")
Expand Down Expand Up @@ -173,7 +179,7 @@ def __init__(

deterministic: If true enables cudnn.deterministic.

distributed_backend: The distributed backend to use (dp, ddp, ddp2, ddp_spawn, ddp_cpu)
distributed_backend: deprecated. Please use 'accelerator'

fast_dev_run: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test).

Expand Down Expand Up @@ -318,6 +324,7 @@ def __init__(
self.accelerator_connector.on_trainer_init(
num_processes,
tpu_cores,
accelerator,
distributed_backend,
auto_select_gpus,
gpus,
Expand Down
8 changes: 5 additions & 3 deletions pytorch_lightning/utilities/argparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,11 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
>>> args = get_init_arguments_and_types(Trainer)
>>> import pprint
>>> pprint.pprint(sorted(args)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
[('accumulate_grad_batches',
(<class 'int'>, typing.Dict[int, int], typing.List[list]),
1),
[('accelerator',
(<class 'str'>,
<class 'pytorch_lightning.accelerators.base_accelerator.Accelerator'>,
<class 'NoneType'>),
None),
...
('callbacks',
(typing.List[pytorch_lightning.callbacks.base.Callback],
Expand Down
Loading