Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[accelerator][BugFix] Resolve some test for 1 gpu #5863

Merged
merged 47 commits into from
Feb 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
02f4818
update
Feb 7, 2021
18b4b25
revert init
tchaton Feb 7, 2021
6cdf71d
resolve a bug
Feb 7, 2021
95870b9
Merge branch 'resolve_tests' of https://github.com/PyTorchLightning/p…
Feb 7, 2021
ffdddb9
update
Feb 8, 2021
6f9830a
resolve flake8
tchaton Feb 8, 2021
b02b7b0
update
Feb 8, 2021
67a8cb3
:Merge branch 'resolve_tests' of https://github.com/PyTorchLightning/…
Feb 8, 2021
701539f
update
Feb 8, 2021
b8a8d81
update
Feb 7, 2021
eea223f
revert init
tchaton Feb 7, 2021
e85e213
resolve a bug
Feb 7, 2021
337f723
update
Feb 8, 2021
b41fc9f
resolve flake8
tchaton Feb 8, 2021
951cc4d
update
Feb 8, 2021
e8cc904
update
Feb 8, 2021
3e79a6d
update
Feb 8, 2021
f9666f1
update
Feb 8, 2021
6ac21c5
Merge branch 'resolve_tests' of https://github.com/PyTorchLightning/p…
Feb 8, 2021
5890da3
update
Feb 7, 2021
83ff23f
revert init
tchaton Feb 7, 2021
cde3781
resolve a bug
Feb 7, 2021
0f6eeb4
update
Feb 8, 2021
47ef8e0
resolve flake8
tchaton Feb 8, 2021
35a6f53
update
Feb 8, 2021
f7689b4
update
Feb 8, 2021
e411983
update
Feb 7, 2021
60082d7
revert init
tchaton Feb 7, 2021
8153efd
update
Feb 8, 2021
f53aa29
resolve flake8
tchaton Feb 8, 2021
4bfc621
update
Feb 8, 2021
77b5e87
update
Feb 8, 2021
9f7e41f
update
Feb 8, 2021
d96b249
Merge branch 'resolve_tests' of https://github.com/PyTorchLightning/p…
Feb 8, 2021
3b1e784
update
Feb 8, 2021
f2214ef
update
Feb 8, 2021
c5029f7
all_gather
justusschock Feb 8, 2021
af791a7
update
Feb 8, 2021
7378e2e
make plugins work, add misconfig for RPC
justusschock Feb 8, 2021
b2812c2
Merge branch 'resolve_tests' of github.com:PytorchLightning/pytorch-l…
justusschock Feb 8, 2021
28c8005
update
Feb 8, 2021
1f96f00
Merge branch 'resolve_tests' of https://github.com/PyTorchLightning/p…
Feb 8, 2021
13972e7
update
Feb 8, 2021
b77003e
remove breaking test
Feb 8, 2021
0c7e10d
resolve some tests
Feb 8, 2021
1c247dc
resolve flake8
tchaton Feb 8, 2021
c3594b0
revert to ddp_spawn
tchaton Feb 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .drone.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ steps:
- unzip -o legacy/checkpoints.zip -d legacy/
- ls -l legacy/checkpoints/
# testing...
- python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=25 # --flake8
- python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests --ignore tests/plugins/test_sharded_plugin.py --ignore tests/trainer/test_dataloaders.py -v --durations=25 # --flake8
# Todo: Find why those tests are failing when run in the main pytest.
- python -m coverage run -a --source pytorch_lightning -m pytest tests/plugins/test_sharded_plugin.py tests/trainer/test_dataloaders.py -v --durations=25 # --flake8
# Running special tests
- sh tests/special_tests.sh
- coverage report
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,4 @@ wandb

# dataset generated from bolts in examples.
cifar-10-batches-py
*.pt
13 changes: 13 additions & 0 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
from typing import Any, Callable, Iterable, Optional, TYPE_CHECKING, Union

import torch
Expand Down Expand Up @@ -374,3 +375,15 @@ def on_save(self, checkpoint):

def barrier(self, name: Optional[str] = None) -> None:
self.training_type_plugin.barrier(name=name)

def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
"""
Function to gather a tensor from several distributed processes
Args:
tensor: tensor of shape (batch, ...)
group: the process group to gather results from. Defaults to all processes (world)
sync_grads: flag that allows users to synchronize gradients for all_gather op
Return:
A tensor of shape (world_size, batch, ...)
"""
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)
31 changes: 22 additions & 9 deletions pytorch_lightning/accelerators/accelerator_connector.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
HorovodPlugin,
NativeMixedPrecisionPlugin,
PrecisionPlugin,
RPCPlugin,
ShardedNativeMixedPrecisionPlugin,
SingleDevicePlugin,
SingleTPUPlugin,
Expand Down Expand Up @@ -116,11 +115,11 @@ def __init__(
self.parallel_device_ids = device_parser.parse_gpu_ids(self.gpus)
self.root_gpu = device_parser.determine_root_gpu_device(self.parallel_device_ids)

self.handle_given_plugins(plugins)

self.set_distributed_mode()
self.configure_slurm_ddp()

self.handle_given_plugins(plugins)

self.accelerator = self.select_accelerator()

# override dist backend when using tpus
Expand All @@ -147,8 +146,10 @@ def __init__(
self.replace_sampler_ddp = replace_sampler_ddp

def handle_given_plugins(self, plugins: Optional[Sequence]):
if plugins is None:
return
plugins = plugins if plugins is not None else []

if isinstance(plugins, str):
plugins = [plugins]

if not isinstance(plugins, Sequence):
plugins = [plugins]
Expand All @@ -158,7 +159,10 @@ def handle_given_plugins(self, plugins: Optional[Sequence]):
cluster_environment = None

for plug in plugins:
if isinstance(plug, TrainingTypePlugin):
if isinstance(plug, str):
self.set_distributed_mode(plug)

elif isinstance(plug, TrainingTypePlugin):
if training_type is None:
training_type = plug

Expand Down Expand Up @@ -191,6 +195,7 @@ def handle_given_plugins(self, plugins: Optional[Sequence]):
)

self._training_type_plugin = training_type
self._training_type_plugin = self.training_type_plugin
self._precision_plugin = precision
self._cluster_environment = cluster_environment or self.select_cluster_environment()

Expand All @@ -206,6 +211,7 @@ def training_type_plugin(self) -> TrainingTypePlugin:
self._training_type_plugin = self.select_training_type_plugin()
else:
self._training_type_plugin = self.resolve_training_type_plugin(self._training_type_plugin)

return self._training_type_plugin

@property
Expand Down Expand Up @@ -327,7 +333,7 @@ def select_precision_plugin(self):

def select_training_type_plugin(self):
if self.use_ddp2:
plugin = DDP2Plugin(parallel_devices=self.parallel_devices, cluster_environment=self._cluster_environment)
plugin = DDP2Plugin(parallel_devices=self.parallel_devices, cluster_environment=self.cluster_environment)
elif self.use_ddp:
use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks
use_torchelastic_ddp = self.use_ddp and self.is_using_torchelastic
Expand Down Expand Up @@ -359,7 +365,7 @@ def select_training_type_plugin(self):
plugin = ddp_plugin_cls(
parallel_devices=self.parallel_devices,
num_nodes=self.num_nodes,
cluster_environment=self.select_cluster_environment(),
cluster_environment=self.cluster_environment,
sync_batchnorm=self.sync_batchnorm,
)
elif self.use_dp:
Expand Down Expand Up @@ -425,7 +431,11 @@ def select_cluster_environment(self):
env = TorchElasticEnvironment()
return env

def set_distributed_mode(self):
def set_distributed_mode(self, distributed_backend: Optional[str] = None):

if distributed_backend is not None:
self.distributed_backend = distributed_backend

if isinstance(self.distributed_backend, Accelerator):
return

Expand Down Expand Up @@ -484,6 +494,9 @@ def set_distributed_mode(self):
):
self.num_processes = self.num_gpus

if (self._device_type == DeviceType.GPU and self._distrib_type == DistributedType.DDP2):
self.num_processes = self.num_nodes

# Horovod is an extra case...
if self.distributed_backend == "horovod":
self._set_horovod_backend()
Expand Down
15 changes: 14 additions & 1 deletion pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Callable
from typing import Any, Callable, Optional, Union
import torch

from torch.optim import Optimizer

Expand Down Expand Up @@ -28,3 +29,15 @@ def setup(self, trainer, model):

def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs):
xm.optimizer_step(optimizer, optimizer_args={'closure': lambda_closure, **kwargs})

def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
"""
Function to gather a tensor from several distributed processes
Args:
tensor: tensor of shape (batch, ...)
group: the process group to gather results from. Defaults to all processes (world)
sync_grads: flag that allows users to synchronize gradients for all_gather op
Return:
A tensor of shape (world_size, batch, ...)
"""
return xm.all_gather(tensor, group=group, sync_grads=sync_grads)
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,9 +540,9 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):

accelerator_backend = trainer.accelerator_backend

if accelerator_backend is not None and accelerator_backend.rpc_enabled:
if accelerator_backend.training_type_plugin.rpc_enabled:
# RPCPlugin manages saving all model states
accelerator_backend.ddp_plugin.rpc_save_model(self._save_model, last_filepath, trainer, pl_module)
accelerator_backend.training_type_plugin.rpc_save_model(self._save_model, last_filepath, trainer, pl_module)
else:
self._save_model(last_filepath, trainer, pl_module)
if (
Expand Down
16 changes: 14 additions & 2 deletions pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, List, Tuple
from typing import List, Tuple, Callable

import torch
from torch.optim import Optimizer
Expand All @@ -38,6 +38,8 @@ def connect(self, model: torch.nn.Module, optimizers, lr_schedulers):
"""Connects the precision plugin to the training process,
configures apex and reinits the schedulers
"""
if model.device.type != "cuda":
return model, optimizers, lr_schedulers
tchaton marked this conversation as resolved.
Show resolved Hide resolved
model, optimizers = self.configure_apex(amp, model, optimizers, self.amp_level)
self.reinit_scheduler_properties(optimizers, lr_schedulers)
return model, optimizers, lr_schedulers
Expand Down Expand Up @@ -71,7 +73,7 @@ def backward(
# do backward pass
# TODO: not entirely sure, why we need this
if model is not None and isinstance(model, LightningModule):
model.backward(closure_loss, optimizer, opt_idx)
model.backward(closure_loss, optimizer, opt_idx, **kwargs)

# TODO: avoid dev_debugger and track these calls with mock
model.trainer.dev_debugger.track_event('AMP', str(AMPType.APEX))
Expand All @@ -90,6 +92,16 @@ def backward(
closure_loss = closure_loss.detach()
return closure_loss

def pre_optimizer_step(
self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, closure: Callable, **kwargs
) -> bool:
"""Hook to do something before each optimizer step."""
# Apex: Amp does not support closure use with optimizers
closure()
optimizer.step()
return False


def configure_apex(
self,
amp: object,
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.utilities import _HYDRA_AVAILABLE, _PYTORCH_GREATER_EQUAL_THAN_1_7_0, rank_zero_warn
from pytorch_lightning.utilities import _HYDRA_AVAILABLE, _PYTORCH_GREATER_EQUAL_1_7_0, rank_zero_warn
from pytorch_lightning.utilities.distributed import (
find_free_network_port,
rank_zero_only,
Expand Down Expand Up @@ -181,7 +181,7 @@ def set_world_ranks(self):

def pre_configure_ddp(self):
# todo: PyTorch 1.7.0 DDP introduces ``self.reducer._rebuild_buckets()``` breaking manual_optimization
if _PYTORCH_GREATER_EQUAL_THAN_1_7_0 and not self.lightning_module.automatic_optimization:
if _PYTORCH_GREATER_EQUAL_1_7_0 and not self.lightning_module.automatic_optimization:
rank_zero_warn(
"From PyTorch 1.7.0, Lightning ``manual_optimization`` needs to set ``find_unused_parameters=True`` "
"to properly work with DDP."
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.utilities import _PYTORCH_GREATER_EQUAL_THAN_1_7_0
from pytorch_lightning.utilities import _PYTORCH_GREATER_EQUAL_1_7_0
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.distributed import (
Expand Down Expand Up @@ -91,6 +91,7 @@ def setup(self, model):
def set_world_ranks(self, process_idx):
self.local_rank = process_idx
self.node_rank = self.cluster_environment.node_rank()
self.task_idx = self.cluster_local_rank
self.global_rank = self.node_rank * self.num_processes + self.local_rank
self.world_size = self.num_nodes * self.num_processes

Expand Down Expand Up @@ -164,7 +165,7 @@ def post_training(self):

def pre_configure_ddp(self):
# todo: PyTorch 1.7.0 DDP introduces ``self.reducer._rebuild_buckets()``` breaking manual_optimization
if _PYTORCH_GREATER_EQUAL_THAN_1_7_0 and not self.lightning_module.automatic_optimization:
if _PYTORCH_GREATER_EQUAL_1_7_0 and not self.lightning_module.automatic_optimization:
rank_zero_warn(
"From PyTorch 1.7.0, Lightning ``manual_optimization`` needs to set ``find_unused_parameters=True`` "
"to properly work with DDP."
Expand Down
9 changes: 8 additions & 1 deletion pytorch_lightning/plugins/training_type/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,17 @@ def __init__(
):
super().__init__()
self.parallel_devices = parallel_devices
self.local_rank = 0
self.world_size = 1
self.local_rank = 0
self.cluster_environment = cluster_environment

@property
def cluster_local_rank(self):
try:
return self.cluster_environment.local_rank()
except KeyError:
return 0

@property
@abstractmethod
def root_device(self):
Expand Down
16 changes: 11 additions & 5 deletions pytorch_lightning/plugins/training_type/rpc_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ def _find_and_init_pipe_module(self, model):
model.sequential_module.module.model.trainer = model.trainer
model.sequential_module.module.model.configure_optimizers = model.configure_optimizers

self.model = model

else:
raise MisconfigurationException(
'Could not find a PipeLightningModule within the model. '
Expand Down Expand Up @@ -261,11 +263,14 @@ def _check_arguments(self, trainer):
'DDPSequentialPlugin is currently not supported in Automatic Mixed Precision'
)

def configure_ddp(self, model: LightningModule, device_ids: List[int]) -> DistributedDataParallel:
ddp_plugin = RPCPlugin(process_group=mpu.get_data_parallel_group()).configure_ddp(model, device_ids)
def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):
"""Run before precision plugin executes backward"""

def configure_ddp(self) -> None:
# process_group=mpu.get_data_parallel_group()
super().configure_ddp()
# Plugin handle backwards across processes. Currently not supported for DDP + pipe parallel
ddp_plugin.PREPARE_FOR_BACKWARDS = False
return ddp_plugin
self._model.require_backward_grad_sync = False

@rank_zero_only
def rpc_save_model(self, save_model_fn, last_filepath, trainer, pl_module) -> None:
Expand All @@ -289,7 +294,8 @@ def worker_optimizer_step(self, model: LightningModule, opt_idx: int, *args, **k
}, include_self=False
)

def distributed_sampler_kwargs(self, distributed_sampler_kwargs):
@property
def distributed_sampler_kwargs(self):
return dict(
num_replicas=mpu.get_data_parallel_world_size(),
rank=mpu.get_data_parallel_rank(),
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ def fit(
# ----------------------------
# SET UP TRAINING
# ----------------------------
self.call_hook("on_before_accelerator_backend_setup", model)
self.accelerator_backend.setup(self, model)
self.setup_trainer(model)

Expand All @@ -469,7 +470,6 @@ def fit(

# plugin will setup training (e.g. ddp will launch child processes)
# TODO: the old setup is now called "pre_training", where should this hook be called now?
self.call_hook("on_before_accelerator_backend_setup", model)
self.training_type_plugin.pre_training()
self.precision_plugin.pre_training()

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
_module_available,
_NATIVE_AMP_AVAILABLE,
_OMEGACONF_AVAILABLE,
_PYTORCH_GREATER_EQUAL_THAN_1_7_0,
_PYTORCH_GREATER_EQUAL_1_7_0,
_PYTORCH_PRUNE_AVAILABLE,
_RPC_AVAILABLE,
_TORCHTEXT_AVAILABLE,
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,5 @@ def _module_available(module_path: str) -> bool:
) <= LooseVersion("0.1.3")
_BOLTS_AVAILABLE = _module_available('pl_bolts')
_PYTORCH_PRUNE_AVAILABLE = _module_available('torch.nn.utils.prune')
_PYTORCH_GREATER_EQUAL_THAN_1_7_0 = LooseVersion(pkg_resources.get_distribution('torch').version) >= LooseVersion("1.7.0")
_PYTORCH_GREATER_EQUAL_1_7_0 = LooseVersion(pkg_resources.get_distribution('torch').version) >= LooseVersion("1.7.0")
_TORCHVISION_AVAILABLE = _module_available('torchvision')
Loading