diff --git a/pytorch_lightning/accelerators/__init__.py b/pytorch_lightning/accelerators/__init__.py
index c0ff298b6c7c6..a3676f2aea2e3 100644
--- a/pytorch_lightning/accelerators/__init__.py
+++ b/pytorch_lightning/accelerators/__init__.py
@@ -11,3 +11,4 @@
from pytorch_lightning.accelerators.ddp_torchelastic_backend import DDPTorchElasticBackend
from pytorch_lightning.accelerators.ddp_cpu_torchelastic_backend import DDPCPUTorchElasticBackend
from pytorch_lightning.accelerators.ddp_cpu_slurm_backend import DDPCPUSLURMBackend
+from pytorch_lightning.accelerators.base_accelerator import Accelerator
diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py
index 4965134501597..02deb281ee5ce 100644
--- a/pytorch_lightning/accelerators/accelerator_connector.py
+++ b/pytorch_lightning/accelerators/accelerator_connector.py
@@ -9,6 +9,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.cluster_environments.slurm_environment import SLURMEnvironment
from pytorch_lightning.cluster_environments.torchelastic_environment import TorchElasticEnvironment
+from pytorch_lightning.accelerators.base_accelerator import Accelerator
try:
import torch_xla
@@ -29,11 +30,13 @@ class AcceleratorConnector:
def __init__(self, trainer):
self.trainer = trainer
+ self.accelerator = None
def on_trainer_init(
self,
num_processes,
tpu_cores,
+ accelerator,
distributed_backend,
auto_select_gpus,
gpus,
@@ -44,6 +47,15 @@ def on_trainer_init(
replace_sampler_ddp,
deterministic,
):
+ # temporary mapping until we remove all the distributed_backend references
+ if accelerator is not None:
+ self.accelerator = accelerator
+ if isinstance(accelerator, Accelerator):
+ self.accelerator.trainer = self
+ distributed_backend = self.accelerator.nickname
+ else:
+ distributed_backend = accelerator
+
self.trainer.deterministic = deterministic
torch.backends.cudnn.deterministic = self.trainer.deterministic
@@ -145,7 +157,18 @@ def select_accelerator(self):
if self.trainer.accelerator_backend is not None:
return self.trainer.accelerator_backend
- # SLURM ddp
+ # ----------------------------------
+ # Use the user provided accelerator
+ # ----------------------------------
+ # use the one the user passed in
+ if self.accelerator is not None and isinstance(self.accelerator, Accelerator):
+ self.accelerator.trainer = self.trainer
+ acc = self.accelerator
+ return acc
+
+ # ----------------------------------
+ # choose an accelerator for the user
+ # ----------------------------------
use_slurm_ddp = self.trainer.use_ddp and self.trainer.is_slurm_managing_tasks
# torchelastic or general non_slurm ddp
diff --git a/pytorch_lightning/accelerators/base_accelerator.py b/pytorch_lightning/accelerators/base_accelerator.py
index 7c5d4c1216543..59a441040d579 100644
--- a/pytorch_lightning/accelerators/base_accelerator.py
+++ b/pytorch_lightning/accelerators/base_accelerator.py
@@ -23,13 +23,16 @@
class Accelerator(object):
- def __init__(self, trainer, cluster_environment=None):
+ def __init__(self, trainer=None, cluster_environment=None):
self.trainer = trainer
+ self.nickname = None
self.cluster_environment = cluster_environment
self.dist = AttributeDict(rank=0, device=None)
- self.train_loop = self.trainer.train
- self.validation_loop = self.trainer.run_evaluation
- self.test_loop = self.trainer.run_evaluation
+
+ if trainer is not None:
+ self.train_loop = self.trainer.train
+ self.validation_loop = self.trainer.run_evaluation
+ self.test_loop = self.trainer.run_evaluation
def setup(self, model):
pass
diff --git a/pytorch_lightning/accelerators/cpu_backend.py b/pytorch_lightning/accelerators/cpu_backend.py
index b615cc9a47bee..77402fd893264 100644
--- a/pytorch_lightning/accelerators/cpu_backend.py
+++ b/pytorch_lightning/accelerators/cpu_backend.py
@@ -22,6 +22,7 @@ class CPUBackend(Accelerator):
def __init__(self, trainer, cluster_environment=None):
super().__init__(trainer, cluster_environment)
+ self.nickname = None
def setup(self, model):
# run through amp wrapper
diff --git a/pytorch_lightning/accelerators/ddp2_backend.py b/pytorch_lightning/accelerators/ddp2_backend.py
index ca64f64165ba1..62f01020817c2 100644
--- a/pytorch_lightning/accelerators/ddp2_backend.py
+++ b/pytorch_lightning/accelerators/ddp2_backend.py
@@ -42,6 +42,7 @@ def __init__(self, trainer, cluster_environment=None):
super().__init__(trainer, cluster_environment)
self.task_idx = None
self.dist = LightningDistributed()
+ self.nickname = 'ddp2'
def setup(self, model):
self._resolve_task_idx()
diff --git a/pytorch_lightning/accelerators/ddp_backend.py b/pytorch_lightning/accelerators/ddp_backend.py
index 83fabce1b975e..25007355b1b06 100644
--- a/pytorch_lightning/accelerators/ddp_backend.py
+++ b/pytorch_lightning/accelerators/ddp_backend.py
@@ -52,6 +52,7 @@ def __init__(self, trainer, cluster_environment=None):
self._has_spawned_children = False
self.interactive_ddp_procs = []
self.dist = LightningDistributed()
+ self.nickname = 'ddp'
def setup(self, model):
# first track model
diff --git a/pytorch_lightning/accelerators/ddp_cpu_slurm_backend.py b/pytorch_lightning/accelerators/ddp_cpu_slurm_backend.py
index c62e12298f247..a2f7b049005b8 100644
--- a/pytorch_lightning/accelerators/ddp_cpu_slurm_backend.py
+++ b/pytorch_lightning/accelerators/ddp_cpu_slurm_backend.py
@@ -48,6 +48,7 @@ def __init__(self, trainer, cluster_environment=None):
self.task_idx = None
self._has_spawned_children = False
self.dist = LightningDistributed()
+ self.nickname = 'ddp_cpu'
def setup(self, model):
self.trainer.model = model
diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py
index b0ab1f396984b..bd5ac410ad60c 100644
--- a/pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py
+++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py
@@ -46,6 +46,7 @@ def __init__(self, trainer, nprocs, cluster_environment=None):
self.mp_queue = None
self.nprocs = nprocs
self.dist = LightningDistributed()
+ self.nickname = 'ddp_cpu'
def setup(self, model):
os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', str(find_free_network_port()))
diff --git a/pytorch_lightning/accelerators/ddp_cpu_torchelastic_backend.py b/pytorch_lightning/accelerators/ddp_cpu_torchelastic_backend.py
index 472d79fd033d8..51247d7b5abe9 100644
--- a/pytorch_lightning/accelerators/ddp_cpu_torchelastic_backend.py
+++ b/pytorch_lightning/accelerators/ddp_cpu_torchelastic_backend.py
@@ -47,6 +47,7 @@ def __init__(self, trainer, cluster_environment=None):
self.task_idx = None
self._has_spawned_children = False
self.dist = LightningDistributed()
+ self.nickname = 'ddp_cpu'
def setup(self, model):
self.trainer.model = model
diff --git a/pytorch_lightning/accelerators/ddp_slurm_backend.py b/pytorch_lightning/accelerators/ddp_slurm_backend.py
index b45e56ca5c79d..7b2f6b9e3d7e8 100644
--- a/pytorch_lightning/accelerators/ddp_slurm_backend.py
+++ b/pytorch_lightning/accelerators/ddp_slurm_backend.py
@@ -47,6 +47,7 @@ def __init__(self, trainer, cluster_environment=None):
self.task_idx = None
self._has_spawned_children = False
self.dist = LightningDistributed()
+ self.nickname = 'ddp'
def setup(self, model):
self.trainer.model = model
diff --git a/pytorch_lightning/accelerators/ddp_spawn_backend.py b/pytorch_lightning/accelerators/ddp_spawn_backend.py
index 6d6c249fcae08..04202373b9df9 100644
--- a/pytorch_lightning/accelerators/ddp_spawn_backend.py
+++ b/pytorch_lightning/accelerators/ddp_spawn_backend.py
@@ -47,6 +47,7 @@ def __init__(self, trainer, nprocs, cluster_environment=None):
self.mp_queue = None
self.nprocs = nprocs
self.dist = LightningDistributed()
+ self.nickname = 'ddp'
def setup(self, model):
os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', str(find_free_network_port()))
diff --git a/pytorch_lightning/accelerators/ddp_torchelastic_backend.py b/pytorch_lightning/accelerators/ddp_torchelastic_backend.py
index 6fff9621c7f0f..e32d2e18c0824 100644
--- a/pytorch_lightning/accelerators/ddp_torchelastic_backend.py
+++ b/pytorch_lightning/accelerators/ddp_torchelastic_backend.py
@@ -48,6 +48,7 @@ def __init__(self, trainer, cluster_environment=None):
self.task_idx = None
self._has_spawned_children = False
self.dist = LightningDistributed()
+ self.nickname = 'ddp'
def setup(self, model):
self.trainer.model = model
diff --git a/pytorch_lightning/accelerators/dp_backend.py b/pytorch_lightning/accelerators/dp_backend.py
index 8fb1455fcec9d..670719d1c2d46 100644
--- a/pytorch_lightning/accelerators/dp_backend.py
+++ b/pytorch_lightning/accelerators/dp_backend.py
@@ -29,6 +29,7 @@ def __init__(self, trainer, cluster_environment=None):
super().__init__(trainer, cluster_environment)
self.model_autocast_original_forward = None
self.dist = LightningDistributed()
+ self.nickname = 'dp'
def setup(self, model):
# call setup after the ddp process has connected
diff --git a/pytorch_lightning/accelerators/gpu_backend.py b/pytorch_lightning/accelerators/gpu_backend.py
index 71d938413751f..c253b34823cbd 100644
--- a/pytorch_lightning/accelerators/gpu_backend.py
+++ b/pytorch_lightning/accelerators/gpu_backend.py
@@ -25,6 +25,7 @@ class GPUBackend(Accelerator):
def __init__(self, trainer, cluster_environment=None):
super().__init__(trainer, cluster_environment)
self.dist = LightningDistributed()
+ self.nickname = None
def setup(self, model):
diff --git a/pytorch_lightning/accelerators/horovod_backend.py b/pytorch_lightning/accelerators/horovod_backend.py
index bbd27cc547eba..be52453d894ff 100644
--- a/pytorch_lightning/accelerators/horovod_backend.py
+++ b/pytorch_lightning/accelerators/horovod_backend.py
@@ -33,6 +33,7 @@ class HorovodBackend(Accelerator):
def __init__(self, trainer, cluster_environment=None):
super().__init__(trainer, cluster_environment)
+ self.nickname = 'horovod'
def setup(self, model):
# call setup after the ddp process has connected
diff --git a/pytorch_lightning/accelerators/tpu_backend.py b/pytorch_lightning/accelerators/tpu_backend.py
index d48f2f6370d1c..95d0d78bb8660 100644
--- a/pytorch_lightning/accelerators/tpu_backend.py
+++ b/pytorch_lightning/accelerators/tpu_backend.py
@@ -42,6 +42,7 @@ def __init__(self, trainer, cluster_environment=None):
super().__init__(trainer, cluster_environment)
self.start_method = None
self.mp_queue = None
+ self.nickname = None
def setup(self, model):
rank_zero_info(f'training on {self.trainer.tpu_cores} TPU cores')
diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py
index ec081a4162032..caf2e85e034f8 100644
--- a/pytorch_lightning/trainer/__init__.py
+++ b/pytorch_lightning/trainer/__init__.py
@@ -165,6 +165,57 @@ def forward(self, x):
Trainer flags
-------------
+accelerator
+^^^^^^^^^^^
+
+.. raw:: html
+
+
+
+|
+
+The accelerator backend to use (previously known as distributed_backend).
+
+- (```dp```) is DataParallel (split batch among GPUs of same machine)
+- (```ddp```) is DistributedDataParallel (each gpu on each node trains, and syncs grads)
+- (```ddp_cpu```) is DistributedDataParallel on CPU (same as `ddp`, but does not use GPUs.
+ Useful for multi-node CPU training or single-node debugging. Note that this will **not** give
+ a speedup on a single node, since Torch already makes effient use of multiple CPUs on a single
+ machine.)
+- (```ddp2```) dp on node, ddp across nodes. Useful for things like increasing
+ the number of negative samples
+
+.. testcode::
+
+ # default used by the Trainer
+ trainer = Trainer(distributed_backend=None)
+
+Example::
+
+ # dp = DataParallel
+ trainer = Trainer(gpus=2, distributed_backend='dp')
+
+ # ddp = DistributedDataParallel
+ trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp')
+
+ # ddp2 = DistributedDataParallel + dp
+ trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp2')
+
+.. note:: this option does not apply to TPU. TPUs use ```ddp``` by default (over each core)
+
+You can also modify hardware behavior by subclassing an existing accelerator to adjust for your needs.
+
+Example::
+
+ class MyOwnDDP(DDPBackend):
+ ...
+
+ Trainer(accelerator=MyOwnDDP())
+
+.. warning:: Passing in custom accelerators is experimental but work is in progress to enable full compatibility.
+
accumulate_grad_batches
^^^^^^^^^^^^^^^^^^^^^^^
@@ -486,47 +537,7 @@ def on_train_end(self, trainer, pl_module):
distributed_backend
^^^^^^^^^^^^^^^^^^^
-
-.. raw:: html
-
-
-
-|
-
-The distributed backend to use.
-
-- (```dp```) is DataParallel (split batch among GPUs of same machine)
-- (```ddp```) is DistributedDataParallel (each gpu on each node trains, and syncs grads)
-- (```ddp_cpu```) is DistributedDataParallel on CPU (same as `ddp`, but does not use GPUs.
- Useful for multi-node CPU training or single-node debugging. Note that this will **not** give
- a speedup on a single node, since Torch already makes effient use of multiple CPUs on a single
- machine.)
-- (```ddp2```) dp on node, ddp across nodes. Useful for things like increasing
- the number of negative samples
-
-.. testcode::
-
- # default used by the Trainer
- trainer = Trainer(distributed_backend=None)
-
-Example::
-
- # dp = DataParallel
- trainer = Trainer(gpus=2, distributed_backend='dp')
-
- # ddp = DistributedDataParallel
- trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp')
-
- # ddp2 = DistributedDataParallel + dp
- trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp2')
-
-.. note:: this option does not apply to TPU. TPUs use ```ddp``` by default (over each core)
-
-See Also:
- - :ref:`Multi-GPU training guide `.
- - :ref:`Multi-node (SLURM) guide `.
+This has been renamed "accelerator".
fast_dev_run
^^^^^^^^^^^^
diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py
index 47e3286fe78fd..bac58e96ddccd 100644
--- a/pytorch_lightning/trainer/trainer.py
+++ b/pytorch_lightning/trainer/trainer.py
@@ -58,6 +58,8 @@
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.trainer.properties import TrainerProperties
from pytorch_lightning.plugins.plugin_connector import PluginConnector
+from pytorch_lightning.accelerators.base_accelerator import Accelerator
+from pytorch_lightning.accelerators.cpu_backend import CPUBackend
# warnings to ignore in trainer
warnings.filterwarnings(
@@ -111,7 +113,7 @@ def __init__(
val_check_interval: Union[int, float] = 1.0,
flush_logs_every_n_steps: int = 100,
log_every_n_steps: int = 50,
- distributed_backend: Optional[str] = None,
+ accelerator: Optional[Union[str, Accelerator]] = None,
sync_batchnorm: bool = False,
precision: int = 32,
weights_summary: Optional[str] = 'top',
@@ -131,12 +133,16 @@ def __init__(
plugins: list = None,
amp_backend: str = 'native',
amp_level: str = 'O2',
+ distributed_backend: Optional[str] = None,
):
r"""
Customize every aspect of training via flags
Args:
+ accelerator: Previously known as distributed_backend (dp, ddp, ddp2, etc...).
+ Can also take in an accelerator object for custom hardware.
+
accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict.
amp_backend: The mixed precision backend to use ("native" or "apex")
@@ -173,7 +179,7 @@ def __init__(
deterministic: If true enables cudnn.deterministic.
- distributed_backend: The distributed backend to use (dp, ddp, ddp2, ddp_spawn, ddp_cpu)
+ distributed_backend: deprecated. Please use 'accelerator'
fast_dev_run: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test).
@@ -318,6 +324,7 @@ def __init__(
self.accelerator_connector.on_trainer_init(
num_processes,
tpu_cores,
+ accelerator,
distributed_backend,
auto_select_gpus,
gpus,
diff --git a/pytorch_lightning/utilities/argparse_utils.py b/pytorch_lightning/utilities/argparse_utils.py
index f2626959bb073..79ba5507bb131 100644
--- a/pytorch_lightning/utilities/argparse_utils.py
+++ b/pytorch_lightning/utilities/argparse_utils.py
@@ -105,9 +105,11 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
>>> args = get_init_arguments_and_types(Trainer)
>>> import pprint
>>> pprint.pprint(sorted(args)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
- [('accumulate_grad_batches',
- (, typing.Dict[int, int], typing.List[list]),
- 1),
+ [('accelerator',
+ (,
+ ,
+ ),
+ None),
...
('callbacks',
(typing.List[pytorch_lightning.callbacks.base.Callback],
diff --git a/tests/backends/test_accelerator_connector.py b/tests/backends/test_accelerator_connector.py
index 99fd232aa369e..4ae711d4fdca7 100644
--- a/tests/backends/test_accelerator_connector.py
+++ b/tests/backends/test_accelerator_connector.py
@@ -18,6 +18,7 @@
from pytorch_lightning.callbacks import Callback
from pytorch_lightning import accelerators, Trainer
from pytorch_lightning.cluster_environments import SLURMEnvironment, TorchElasticEnvironment, ClusterEnvironment
+from pytorch_lightning.accelerators import Accelerator
from unittest import mock
@@ -297,3 +298,61 @@ def on_fit_start(self, trainer, pl_module):
with pytest.raises(SystemExit):
trainer.fit(model)
+
+
+@mock.patch.dict(os.environ, {
+ "SLURM_NTASKS": "1",
+ "SLURM_JOB_NAME": "SOME_NAME",
+ "SLURM_NODEID": "0",
+ "LOCAL_RANK": "0",
+ "SLURM_LOCALID": "0"
+})
+@mock.patch('torch.cuda.device_count', return_value=0)
+def test_custom_accelerator(tmpdir):
+ class Accel(Accelerator):
+ def init_ddp_connection(
+ self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True
+ ) -> None:
+ pass
+
+ class CB(Callback):
+ def on_fit_start(self, trainer, pl_module):
+ assert isinstance(trainer.accelerator_backend, Accel)
+ raise SystemExit()
+
+ model = BoringModel()
+ trainer = Trainer(
+ fast_dev_run=True,
+ accelerator=Accel(),
+ num_processes=1,
+ callbacks=[CB()]
+ )
+
+ with pytest.raises(SystemExit):
+ trainer.fit(model)
+
+
+@mock.patch.dict(os.environ, {
+ "SLURM_NTASKS": "1",
+ "SLURM_JOB_NAME": "SOME_NAME",
+ "SLURM_NODEID": "0",
+ "LOCAL_RANK": "0",
+ "SLURM_LOCALID": "0"
+})
+@mock.patch('torch.cuda.device_count', return_value=0)
+def test_dist_backend_accelerator_mapping(tmpdir):
+ class CB(Callback):
+ def on_fit_start(self, trainer, pl_module):
+ assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUSLURMBackend)
+ raise SystemExit()
+
+ model = BoringModel()
+ trainer = Trainer(
+ fast_dev_run=True,
+ accelerator='ddp_cpu',
+ num_processes=1,
+ callbacks=[CB()]
+ )
+
+ with pytest.raises(SystemExit):
+ trainer.fit(model)