Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PoC: Accelerator refactor [wip] [skip ci] #5616

Closed
wants to merge 168 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
168 commits
Select commit Hold shift + click to select a range
fddeee3
move to old package
justusschock Nov 9, 2020
f9c1e8d
add initial draft of new accelerators
justusschock Nov 9, 2020
28ae403
add initial data parallel draft
justusschock Nov 9, 2020
fe7573f
add initial precision draft
justusschock Nov 9, 2020
9fd48a1
scheduler helper functions
justusschock Nov 9, 2020
b961aaf
define base plugin api
justusschock Nov 11, 2020
532ad5d
base plugin integration
justusschock Nov 11, 2020
f52ad64
continue ddp plugin
justusschock Nov 11, 2020
bcfb4e7
minor changes precision plugin
justusschock Nov 11, 2020
bf8a87a
start ddp plugin
justusschock Nov 11, 2020
8482c0b
initail version ddp spawn
justusschock Nov 12, 2020
12d2c59
remove deprecated implementation
justusschock Nov 12, 2020
8d83db8
add comment on whats missing
justusschock Nov 12, 2020
22e1e31
latest state
justusschock Nov 20, 2020
eac87c3
update accelerator for model to live in traintype plugin
justusschock Nov 30, 2020
d111471
add general plugin interface
justusschock Nov 30, 2020
3d6c4b8
add model properties
justusschock Nov 30, 2020
51740e9
Trainer integration part 1 for CPU accelerator
awaelchli Dec 4, 2020
9e48568
test single gpu trainer integration
awaelchli Dec 6, 2020
5da773a
make device changes a bit less hardcoded
justusschock Dec 7, 2020
42e53be
properly resolve attributes
justusschock Dec 7, 2020
4c8d24f
add properties for accelerator forwarding
justusschock Dec 7, 2020
6faebfa
correct optimizer_step calls
justusschock Dec 7, 2020
29568e1
call train or test
awaelchli Dec 7, 2020
33561d7
make calls to trainstep (ad fix bugs)
justusschock Dec 7, 2020
ef94755
remove gradient_clip_val from accelerator
awaelchli Dec 7, 2020
c5e9892
add back the step end methods
awaelchli Dec 7, 2020
c02baad
add precision todo comment
awaelchli Dec 7, 2020
ce4eafa
ddp
awaelchli Dec 8, 2020
e6ba009
clean up
awaelchli Dec 8, 2020
fa4d844
connect
awaelchli Dec 8, 2020
8be82a4
clean up
awaelchli Dec 8, 2020
08ce7d3
post
awaelchli Dec 8, 2020
ffbcd4f
disable progress bar on rank > 0
awaelchli Dec 9, 2020
4be76bf
precision test
justusschock Dec 10, 2020
098f665
fix native amp
justusschock Dec 10, 2020
ea85633
a
awaelchli Dec 12, 2020
846dc92
ddp spawn
awaelchli Dec 12, 2020
0d0c3d7
spawn
awaelchli Dec 12, 2020
3fb8b4d
finish ddp plugin integration
awaelchli Dec 13, 2020
0f5298e
remove logger from plugins
awaelchli Dec 13, 2020
434e30e
setup
awaelchli Dec 13, 2020
3fb31c8
remove logger arg
awaelchli Dec 13, 2020
e7a7a87
module
awaelchli Dec 13, 2020
1e8aa44
clean up
awaelchli Dec 13, 2020
628fdc3
ddp_cpu integration
awaelchli Dec 14, 2020
9f369cc
cuda context manager for emptying cache
awaelchli Dec 14, 2020
a8e8306
args
awaelchli Dec 14, 2020
71cbd33
move "log_gpu_memory" to logger connector
awaelchli Dec 14, 2020
1a9ad4f
fix imports
justusschock Dec 14, 2020
7b874cc
typo
justusschock Dec 14, 2020
bc2460a
remove todo
justusschock Dec 14, 2020
506c446
add rpc_enabled flag
justusschock Dec 14, 2020
19d19d5
remove unused self arg
justusschock Dec 14, 2020
dd4d148
comment out unnexessary amp part
justusschock Dec 14, 2020
f2fffc6
fix model connector
justusschock Dec 14, 2020
c6b3aeb
fix import
justusschock Dec 14, 2020
55fc952
copy properties only once
justusschock Dec 14, 2020
177a634
add cluster env
awaelchli Dec 22, 2020
7290e99
move slurm configuration
awaelchli Dec 22, 2020
1b9c095
resolve importerrors
awaelchli Dec 22, 2020
e50aea9
handle distributed_sampler_kwargs
awaelchli Dec 22, 2020
2e8f944
move emptying cache to accelertor
awaelchli Dec 22, 2020
bcc7a72
fix a few tests
awaelchli Dec 22, 2020
259c7f7
restoring the result from subprocess
awaelchli Dec 22, 2020
dfab52a
fix queue.get() order for results
awaelchli Dec 22, 2020
6742488
add missing "block_backward_sync" context manager
awaelchli Dec 22, 2020
8c89932
add missing "block_backward_sync" context manager
awaelchli Dec 22, 2020
0186a0f
fix sync_batchnorm
awaelchli Dec 22, 2020
b2ac1f4
fix supported gpu-ids for tuple
awaelchli Dec 22, 2020
07a41ce
fix clip gradients and inf recursion
awaelchli Dec 22, 2020
63b7eaf
accelerator selection: added cluster_environment plugin
awaelchli Dec 23, 2020
f8344c5
fix torchelastic test
awaelchli Dec 23, 2020
34e3c15
fix reduce early stopping decision for DDP
awaelchli Dec 24, 2020
27a4cff
fix tests: callbacks, conversion to lightning optimizer
awaelchli Dec 24, 2020
df5ac30
fix lightning optimizer does not pickle
awaelchli Dec 24, 2020
dcf917a
fix setting benchmark and deterministic option
awaelchli Dec 24, 2020
272f088
fix slurm amp test
awaelchli Dec 24, 2020
4529476
fix prepare_data test and determine node_rank
awaelchli Dec 27, 2020
5319b0f
fix retrieving last path when testing
awaelchli Dec 27, 2020
3b54cfb
remove obsolete plugin argument
awaelchli Dec 27, 2020
6540b87
fix test: test_trainer_config
awaelchli Dec 27, 2020
6b450e1
fix torchscript tests
awaelchli Dec 27, 2020
4ef539f
fix trainer.model access
awaelchli Dec 27, 2020
1001ccf
move properties
awaelchli Dec 27, 2020
38a1d0f
fix test_transfer_batch_hook
awaelchli Dec 27, 2020
46cf7ef
fix auto_select_gpus
awaelchli Dec 27, 2020
258f50e
fix omegaconf test
awaelchli Dec 27, 2020
a5d69b9
fix test that needs to simulate slurm ddp
awaelchli Dec 27, 2020
88a7ed5
add horovod plugin
awaelchli Dec 29, 2020
40daa41
fix test with named arguments
awaelchli Dec 29, 2020
96fc074
clean up whitespace
awaelchli Dec 29, 2020
210831a
fix datamodules test
awaelchli Dec 29, 2020
98b6dd4
remove old accelerators
justusschock Jan 6, 2021
dfcbba6
fix naming
justusschock Jan 6, 2021
348a1b0
move old plugins
justusschock Jan 6, 2021
14f2f6e
move to plugins
justusschock Jan 6, 2021
2f779c6
create precision subpackage
justusschock Jan 6, 2021
58536f6
create training_type subpackage
justusschock Jan 6, 2021
ee53c90
fix all new import errors
awaelchli Jan 7, 2021
894e604
fix wrong arguments order passed to test
awaelchli Jan 7, 2021
2bdc836
fix LR finder
awaelchli Jan 10, 2021
48b9882
Added sharded training type and amp plugin
Jan 11, 2021
38452b6
Move clip grad to precision plugin
Jan 11, 2021
173b22c
Added sharded spawn, select accelerators based on distributed_backend…
Jan 12, 2021
79803f6
Fix import issue, attempting to fix tests
Jan 12, 2021
a7c0d8f
Fix initial test
Jan 12, 2021
02df0ad
Reflect hook logic from master, should wrap model after move to device
Jan 14, 2021
d0ebcba
Optional state consolidation, since master has optimizers not wrapped
justusschock Jan 22, 2021
319c3e8
change attribute for instance test
justusschock Jan 22, 2021
a34cd15
reset optimizers
justusschock Jan 22, 2021
c95b06a
legacy
Borda Jan 22, 2021
9ff0c64
imports in accel
Borda Jan 22, 2021
67d4e47
legacy2
Borda Jan 22, 2021
577b00d
trainer imports
Borda Jan 22, 2021
aa4858b
fix import errors after rebase
awaelchli Jan 25, 2021
f81a44f
move hook to new setup location
awaelchli Jan 25, 2021
a285665
provide unwrapping logic
awaelchli Jan 25, 2021
bf78d70
fix trainer callback system
awaelchli Jan 25, 2021
34947cf
added ddp2 implementation
awaelchli Jan 25, 2021
49bec53
fix imports .legacy
Borda Jan 25, 2021
ba1c986
move plugins
Borda Jan 25, 2021
45dfbb7
restore legacy
Borda Jan 25, 2021
9b7326a
drop test.py from root
Borda Jan 25, 2021
96bc05d
add tpu accelerator and plugins
justusschock Jan 26, 2021
c5994e5
Merge branch 'release/1.2-dev' into accelerator-refactor-sharted-4
awaelchli Jan 30, 2021
9e46624
fixes
awaelchli Jan 30, 2021
22d2ae8
Merge branch 'release/1.2-dev' into accelerator-refactor-sharted-4
awaelchli Jan 30, 2021
901d392
Merge branch 'release/1.2-dev' into accelerator-refactor-sharted-4
awaelchli Jan 31, 2021
e174b8d
fix lightning optimizer merge
awaelchli Jan 31, 2021
98660de
reset bugreportmodel
awaelchli Jan 31, 2021
4d95b6c
unwrapping
awaelchli Jan 31, 2021
b69d013
step routing forward
awaelchli Jan 31, 2021
cb6676d
model access
awaelchli Jan 31, 2021
a33d27f
unwrap
awaelchli Jan 31, 2021
f7486e2
opt
awaelchli Jan 31, 2021
117f16d
Merge branch 'release/1.2-dev' into accelerator-refactor-sharted-4
awaelchli Jan 31, 2021
3792b72
integrate distrib_type
awaelchli Jan 31, 2021
ef85b81
sync changes
awaelchli Jan 31, 2021
9d9a940
sync
awaelchli Feb 1, 2021
f017a39
Merge branch 'release/1.2-dev' into accelerator-refactor-sharted-4
awaelchli Feb 1, 2021
a190a56
fixes
awaelchli Feb 1, 2021
73bb607
add forgotten generators
awaelchli Feb 1, 2021
c8c74f3
Merge branch 'release/1.2-dev' into accelerator-refactor-sharted-4
awaelchli Feb 1, 2021
ae71997
add missing logic
awaelchli Feb 1, 2021
d89847b
Merge branch 'release/1.2-dev' into accelerator-refactor-sharted-4
awaelchli Feb 1, 2021
0e686c3
update
awaelchli Feb 1, 2021
d6a43ea
import
awaelchli Feb 1, 2021
ceb8f75
missed imports
awaelchli Feb 1, 2021
fbb7c20
import fixes
awaelchli Feb 1, 2021
b610999
isort
awaelchli Feb 1, 2021
9b79924
mv f
awaelchli Feb 1, 2021
9afe54d
changelog
awaelchli Feb 1, 2021
3b63e82
Merge branch 'release/1.2-dev' into ref/update-plugins
awaelchli Feb 1, 2021
ca8cb68
format
awaelchli Feb 1, 2021
0633745
move helper to parallel plugin
awaelchli Feb 1, 2021
a622e0b
d
awaelchli Feb 1, 2021
18c682f
Merge branch 'ref/update-plugins' into accelerator-refactor-sharted-4
awaelchli Feb 1, 2021
f275803
add world size
awaelchli Feb 1, 2021
4ae008b
clean up
awaelchli Feb 1, 2021
3b3918b
Merge branch 'release/1.2-dev' into accelerator-refactor-sharted-4
awaelchli Feb 1, 2021
d4c6308
duplicate
awaelchli Feb 1, 2021
7eef4a0
Merge branch 'release/1.2-dev' into accelerator-refactor-sharted-4
awaelchli Feb 2, 2021
9949164
activate ddp_sharded and tpu
awaelchli Feb 2, 2021
6d47357
set nvidia flags
awaelchli Feb 2, 2021
a6864ec
remove unused colab var
awaelchli Feb 2, 2021
b4b9724
use_tpu <-> on_tpu attrs
awaelchli Feb 2, 2021
81001e3
make some ddp_cpu and clusterplugin tests pass
awaelchli Feb 2, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 12 additions & 33 deletions benchmarks/test_sharded_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
import platform
import time
from typing import Type, Union
from typing import Type

import pytest
import torch
Expand All @@ -32,10 +32,8 @@
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_one_gpu():
plugin_parity_test(
sharded_parity_test(
gpus=1,
accelerator='ddp_spawn',
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel,
)

Expand All @@ -45,11 +43,9 @@ def test_ddp_sharded_plugin_correctness_one_gpu():
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_amp_one_gpu():
plugin_parity_test(
sharded_parity_test(
gpus=1,
precision=16,
accelerator='ddp_spawn',
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel,
)

Expand All @@ -59,10 +55,8 @@ def test_ddp_sharded_plugin_correctness_amp_one_gpu():
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_multi_gpu():
plugin_parity_test(
sharded_parity_test(
gpus=2,
accelerator='ddp_spawn',
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel,
max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers
)
Expand All @@ -73,11 +67,9 @@ def test_ddp_sharded_plugin_correctness_multi_gpu():
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_amp_multi_gpu():
plugin_parity_test(
sharded_parity_test(
gpus=2,
precision=16,
accelerator='ddp_spawn',
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel,
max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers
)
Expand All @@ -88,11 +80,9 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu():
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu():
plugin_parity_test(
sharded_parity_test(
gpus=2,
precision=16,
accelerator='ddp_spawn',
plugin='ddp_sharded',
model_cls=SeedTrainLoaderModel,
max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers
)
Expand All @@ -105,11 +95,9 @@ def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu():
)
@DDPLauncher.run("--accelerator ddp --gpus 2 --precision 32")
def test_ddp_sharded_plugin_correctness_multi_gpu_ddp(tmpdir, args=None):
plugin_parity_test(
sharded_parity_test(
gpus=args.gpus,
precision=args.precision,
accelerator=args.accelerator,
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel,
)

Expand All @@ -121,11 +109,9 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_ddp(tmpdir, args=None):
)
@DDPLauncher.run("--accelerator ddp --gpus 2 --precision 16")
def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None):
plugin_parity_test(
sharded_parity_test(
gpus=args.gpus,
precision=args.precision,
accelerator=args.accelerator,
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel,
)

Expand All @@ -138,10 +124,8 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim():
"""
Ensures same results using multiple optimizers across multiple GPUs
"""
plugin_parity_test(
plugin=DDPShardedPlugin(),
sharded_parity_test(
gpus=2,
accelerator='ddp_spawn',
model_cls=SeedTrainLoaderMultipleOptimizersModel,
max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers
)
Expand All @@ -155,10 +139,8 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir):
"""
Ensures using multiple optimizers across multiple GPUs with manual optimization
"""
plugin_parity_test(
plugin=DDPShardedPlugin(),
sharded_parity_test(
gpus=2,
accelerator='ddp_spawn',
model_cls=SeedTrainLoaderManualModel,
max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers
)
Expand Down Expand Up @@ -273,9 +255,7 @@ def plugin_parity_test(

Args:
model_cls: Model class to use for test.
plugin: Plugin to parity test.
seed: Seed for generators. Note that this does not handle the seed for data-loading on multi-process.
accelerator: Accelerator type for test.
gpus: Number of GPUS to enable.
precision: Whether to use AMP or normal FP32 training.
max_percent_speed_diff: The maximum speed difference compared to normal DDP training.
Expand All @@ -293,7 +273,7 @@ def plugin_parity_test(
max_epochs=1,
gpus=gpus,
precision=precision,
accelerator=accelerator,
accelerator='ddp_spawn',
)

max_memory_ddp, ddp_time = record_ddp_fit_model_stats(trainer=trainer, model=ddp_model, use_cuda=use_cuda)
Expand All @@ -307,8 +287,7 @@ def plugin_parity_test(
max_epochs=1,
gpus=gpus,
precision=precision,
accelerator=accelerator,
plugins=[plugin],
accelerator='ddp_sharded_spawn',
)

max_memory_custom, custom_model_time = record_ddp_fit_model_stats(
Expand Down
29 changes: 4 additions & 25 deletions pytorch_lightning/accelerators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,4 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.accelerators.legacy.accelerator import Accelerator # noqa: F401
from pytorch_lightning.accelerators.legacy.cpu_accelerator import CPUAccelerator # noqa: F401
from pytorch_lightning.accelerators.legacy.ddp2_accelerator import DDP2Accelerator # noqa: F401
from pytorch_lightning.accelerators.legacy.ddp_accelerator import DDPAccelerator # noqa: F401
from pytorch_lightning.accelerators.legacy.ddp_cpu_hpc_accelerator import DDPCPUHPCAccelerator # noqa: F401
from pytorch_lightning.accelerators.legacy.ddp_cpu_spawn_accelerator import DDPCPUSpawnAccelerator # noqa: F401
from pytorch_lightning.accelerators.legacy.ddp_hpc_accelerator import DDPHPCAccelerator # noqa: F401
from pytorch_lightning.accelerators.legacy.ddp_spawn_accelerator import DDPSpawnAccelerator # noqa: F401
from pytorch_lightning.accelerators.legacy.dp_accelerator import DataParallelAccelerator # noqa: F401
from pytorch_lightning.accelerators.legacy.gpu_accelerator import GPUAccelerator # noqa: F401
from pytorch_lightning.accelerators.legacy.horovod_accelerator import HorovodAccelerator # noqa: F401
from pytorch_lightning.accelerators.legacy.tpu_accelerator import TPUAccelerator # noqa: F401
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.cpu import CPUAccelerator
from pytorch_lightning.accelerators.gpu import GPUAccelerator
from pytorch_lightning.accelerators.tpu import TPUAccelerator
59 changes: 26 additions & 33 deletions pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
SingleDevicePlugin,
SingleTPUPlugin,
TPUHalfPrecisionPlugin,
TPUSpawnPlugin,
TPUSpawnPlugin, DDPShardedPlugin, DDPSpawnShardedPlugin,
)
from pytorch_lightning.plugins.environments import SLURMEnvironment, TorchElasticEnvironment
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus
Expand Down Expand Up @@ -116,16 +116,12 @@ def __init__(
# override dist backend when using tpus
if self.on_tpu:
self.distributed_backend = "tpu"
self.use_tpu = True

# init flags for SLURM+DDP to work
self.world_size = 1
self.interactive_ddp_procs = []
self.global_rank = 0

# NVIDIA setup
# self.set_nvidia_flags(self.trainer.is_slurm_managing_tasks, self.trainer.data_parallel_device_ids)

# benchmarking
# TODO: should this be moved to GPU accelerator?
torch.backends.cudnn.benchmark = self.benchmark
Expand All @@ -138,9 +134,6 @@ def __init__(
# https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383
os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0)

# TODO: move this to TPU accelerator/plugin
self.on_colab_kaggle = os.getenv("COLAB_GPU") or os.getenv("KAGGLE_URL_BASE")

self.replace_sampler_ddp = replace_sampler_ddp

@property
Expand Down Expand Up @@ -256,23 +249,21 @@ def select_training_type_plugin(self):
use_ddp_cpu_spawn = self.use_ddp and self.on_cpu
use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self.is_using_torchelastic
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks
# use_ddp_sharded = self.distributed_backend == "ddp_sharded"
# use_ddp_sharded_spawn = self.distributed_backend == "ddp_sharded_spawn"
use_ddp_sharded = self.distributed_backend == "ddp_sharded"
use_ddp_sharded_spawn = self.distributed_backend == "ddp_sharded_spawn"

if self.on_tpu:
ddp_plugin_cls = TPUSpawnPlugin

# ddp script mode uses the same flags as TE
# TODO: decouple from TE
# ddp script mode uses the same flags as TE
if os.environ.get("PL_IN_DDP_SUBPROCESS", False):
use_torchelastic_ddp = False

# fixme
# if use_ddp_sharded:
# ddp_plugin_cls = DDPShardedPlugin
# elif use_ddp_sharded_spawn:
# ddp_plugin_cls = DDPSpawnShardedPlugin
if use_ddp_cpu_slurm or use_slurm_ddp or use_ddp_cpu_torch_elastic or use_torchelastic_ddp:
if self.on_tpu:
ddp_plugin_cls = TPUSpawnPlugin
elif use_ddp_sharded:
ddp_plugin_cls = DDPShardedPlugin
elif use_ddp_sharded_spawn:
ddp_plugin_cls = DDPSpawnShardedPlugin
elif use_ddp_cpu_slurm or use_slurm_ddp or use_ddp_cpu_torch_elastic or use_torchelastic_ddp:
ddp_plugin_cls = DDPPlugin
elif use_ddp_spawn or use_ddp_cpu_spawn:
ddp_plugin_cls = DDPSpawnPlugin
Expand Down Expand Up @@ -328,6 +319,8 @@ def select_cluster_environment(self):
return env

def set_distributed_mode(self):
if isinstance(self.distributed_backend, Accelerator):
return

if self.distributed_backend is None:
if self.has_horovodrun():
Expand Down Expand Up @@ -355,27 +348,27 @@ def set_distributed_mode(self):
# special case with TPUs
elif self.distributed_backend == 'tpu':
self._device_type = DeviceType.TPU
# set all other requested distrib. types adn if it was not set in the
# set all other requested distrib. types and if it was not set in the
elif self.distributed_backend and self._distrib_type is None:
self._distrib_type = DistributedType(self.distributed_backend)

# unless you request explicitly for CPU and some GPU are available use them
_on_cpu = self.distributed_backend and 'cpu' in self.distributed_backend
if (self.num_gpus > 0 and not _on_cpu):
if self.num_gpus > 0 and not _on_cpu:
self._device_type = DeviceType.GPU

_distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2)
# _distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2)
# DP and DDP2 cannot run without GPU
if (self.num_gpus == 0 and self._distrib_type in _distrib_types):
rank_zero_warn(
'You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`.'
)
# todo: in some cases it yield in comarison None and int
if ((self.num_nodes and self.num_nodes > 1) or (self.num_processes and self.num_processes > 1)):
self._distrib_type = DistributedType.DDP
else:
rank_zero_warn('You are running on single node with no parallelization, so distributed has no effect.')
self._distrib_type = None
# if (self.num_gpus == 0 and self._distrib_type in _distrib_types):
# rank_zero_warn(
# 'You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`.'
# )
# # todo: in some cases it yield in comarison None and int
# if ((self.num_nodes and self.num_nodes > 1) or (self.num_processes and self.num_processes > 1)):
# self._distrib_type = DistributedType.DDP
# else:
# rank_zero_warn('You are running on single node with no parallelization, so distributed has no effect.')
# self._distrib_type = None

# for DDP overwrite nb processes by requested GPUs
if (
Expand Down
15 changes: 14 additions & 1 deletion pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
import logging
import os

import torch

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities.exceptions import MisconfigurationException

log = logging.getLogger(__name__)


class GPUAccelerator(Accelerator):

def setup(self, trainer, model):
if "cuda" not in str(self.root_device):
raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead")
self.set_nvidia_flags()
torch.cuda.set_device(self.root_device)
model.to(self.root_device)

return super().setup(trainer, model)

def on_train_start(self):
Expand All @@ -25,3 +30,11 @@ def on_train_end(self):
# clean up memory
with torch.cuda.device(self.root_device):
torch.cuda.empty_cache()

@staticmethod
def set_nvidia_flags():
# set the correct cuda visible devices (using pci order)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
all_gpu_ids = ",".join([str(x) for x in range(torch.cuda.device_count())])
devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids)
log.info(f"LOCAL_RANK: {os.getenv('LOCAL_RANK', 0)} - CUDA_VISIBLE_DEVICES: [{devices}]")
3 changes: 2 additions & 1 deletion pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def _run_early_stopping_check(self, trainer, pl_module):
return # short circuit if metric not present

current = logs.get(self.monitor)
should_stop = False

# when in dev debugging
trainer.dev_debugger.track_early_stopping_history(self, current)
Expand All @@ -204,5 +205,5 @@ def _run_early_stopping_check(self, trainer, pl_module):
trainer.should_stop = True

# stop every ddp process if any world process decides to stop
should_stop = trainer.accelerator_backend.early_stopping_should_stop(pl_module)
should_stop = trainer.training_type_plugin.reduce_early_stopping_decision(should_stop)
trainer.should_stop = should_stop
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def __resolve_ckpt_dir(self, trainer, pl_module):
else f"version_{trainer.logger.version}"
)

version, name = trainer.accelerator_backend.broadcast((version, trainer.logger.name))
version, name = trainer.training_type_plugin.broadcast((version, trainer.logger.name))

ckpt_path = os.path.join(
save_dir, str(name), version, "checkpoints"
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def log(
f"Logged key: {name} should not contain information about dataloader_idx.")

accelerator = self.trainer.accelerator_backend
training_type_plugin = self.trainer.training_type_plugin

self._results.log(
name,
Expand All @@ -291,7 +292,7 @@ def log(
sync_dist,
sync_dist_op,
sync_dist_group,
accelerator.sync_tensor,
training_type_plugin.reduce,
self._current_dataloader_idx,
self.device,
)
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,9 @@ def __optimizer_step(self, *args, closure: Optional[Callable] = None, profiler_n
with trainer.profiler.profile(profiler_name):
xm.optimizer_step(optimizer, optimizer_args={'closure': closure, **kwargs})

elif trainer.amp_backend is not None:
trainer.precision_connector.backend.optimizer_step(trainer, optimizer, closure)
# elif trainer.amp_backend is not None:
# # TODO: Adapt for new optimizer structure
# trainer.precision_connector.backend.optimizer_step(trainer, optimizer, closure)

else:
with trainer.profiler.profile(profiler_name):
Expand Down
Loading