Skip to content

Commit

Permalink
Merge branch 'master' into feat_artifacts
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Feb 27, 2021
2 parents 3fb55f6 + ee5032a commit fb7041e
Show file tree
Hide file tree
Showing 27 changed files with 245 additions and 115 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed support for passing a bool value to `profiler` argument of Trainer ([#6164](https://github.com/PyTorchLightning/pytorch-lightning/pull/6164))


- Removed passing a `ModelCheckpoint` instance to `Trainer(checkpoint_callback)` ([#6166](https://github.com/PyTorchLightning/pytorch-lightning/pull/6166))


- Removed deprecated Trainer argument `enable_pl_optimizer` and `automatic_optimization` ([#6163](https://github.com/PyTorchLightning/pytorch-lightning/pull/6163))


Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def amp_backend(self) -> Optional[LightningEnum]:
return None

@property
def precision(self) -> int:
def precision(self) -> Union[str, int]:
return self.precision_plugin.precision

@property
Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/loggers/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ class CometLogger(LightningLoggerBase):
prefix: A string to put at the beginning of metric keys.
\**kwargs: Additional arguments like `workspace`, `log_code`, etc. used by
:class:`CometExperiment` can be passed as keyword arguments in this logger.
Raises:
ImportError:
If required Comet package is not installed on the device.
MisconfigurationException:
If neither ``api_key`` nor ``save_dir`` are passed as arguments.
"""

LOGGER_JOIN_CHAR = '-'
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def any_lightning_module_function_or_hook(self):
Has no effect if `tracking_uri` is provided.
prefix: A string to put at the beginning of metric keys.
Raises:
ImportError:
If required MLFlow package is not installed on the device.
"""

LOGGER_JOIN_CHAR = '-'
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/loggers/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,10 @@ def any_lightning_module_function_or_hook(self):
prefix: A string to put at the beginning of metric keys.
\**kwargs: Additional arguments like `params`, `tags`, `properties`, etc. used by
:func:`neptune.Session.create_experiment` can be passed as keyword arguments in this logger.
Raises:
ImportError:
If required Neptune package is not installed on the device.
"""

LOGGER_JOIN_CHAR = '-'
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/loggers/test_tube.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def any_lightning_module_function_or_hook(self):
the user has defined the `self.example_input_array` attribute in their
model.
prefix: A string to put at the beginning of metric keys.
Raises:
ImportError:
If required TestTube package is not installed on the device.
"""

__test__ = False
Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ class WandbLogger(LightningLoggerBase):
\**kwargs: Additional arguments like `entity`, `group`, `tags`, etc. used by
:func:`wandb.init` can be passed as keyword arguments in this logger.
Raises:
ImportError:
If required WandB package is not installed on the device.
MisconfigurationException:
If both ``log_model`` and ``offline``is set to ``True``.
Example::
from pytorch_lightning.loggers import WandbLogger
Expand Down
17 changes: 2 additions & 15 deletions pytorch_lightning/plugins/base_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from abc import ABC, abstractmethod
from typing import Generator, Optional, Sequence, Tuple

from torch.nn import Module
from abc import ABC
from typing import Generator


class Plugin(ABC):
"""Basic Plugin class to derive precision and training type plugins from."""

@abstractmethod
def connect(
self,
model: Module,
*args: Sequence,
**kwargs: Sequence,
) -> Optional[Tuple[Module, Sequence, Sequence]]:
"""Connects the plugin with the accelerator (and thereby with trainer and model).
Will be called by the accelerator.
"""

def pre_dispatch(self) -> None:
"""Hook to do something before the training/evaluation/prediction starts."""

Expand Down
40 changes: 24 additions & 16 deletions pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, List, Tuple
from typing import Any, Callable, Generator, List, Sequence, Tuple, Type, TYPE_CHECKING

import torch
from torch.optim import Optimizer

from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
Expand All @@ -23,37 +22,41 @@
if _APEX_AVAILABLE:
from apex import amp

if TYPE_CHECKING:
from torch.optim import Optimizer


class ApexMixedPrecisionPlugin(MixedPrecisionPlugin):
"""Mixed Precision Plugin based on Nvidia/Apex (https://github.com/NVIDIA/apex)"""

def __init__(self, amp_level: str):
def __init__(self, amp_level: str) -> None:
self.backend = AMPType.APEX
self.amp_level = amp_level

def master_params(self, optimizer: torch.optim.Optimizer):
def master_params(self, optimizer: 'Optimizer') -> Generator[torch.Tensor, None, None]:
return amp.master_params(optimizer)

def connect(self, model: torch.nn.Module, optimizers, lr_schedulers):
def connect(self, model: torch.nn.Module, optimizers: Sequence['Optimizer'],
lr_schedulers: Sequence[Any]) -> Tuple[torch.nn.Module, Sequence['Optimizer'], Sequence[Any]]:
"""Connects the precision plugin to the training process,
configures apex and reinits the schedulers
"""
if model.device.type != "cuda":
return model, optimizers, lr_schedulers
model, optimizers = self.configure_apex(amp, model, optimizers, self.amp_level)
model, optimizers = self.configure_apex(amp, model, list(optimizers), self.amp_level)
self.reinit_scheduler_properties(optimizers, lr_schedulers)
return model, optimizers, lr_schedulers

def backward(
self,
model: LightningModule,
closure_loss: torch.Tensor,
optimizer: torch.optim.Optimizer,
optimizer: 'Optimizer',
opt_idx: int,
should_accumulate: bool,
*args,
**kwargs,
):
*args: Any,
**kwargs: Any,
) -> torch.Tensor:
"""performs the actual backpropagation
Args:
Expand Down Expand Up @@ -94,11 +97,11 @@ def backward(

def configure_apex(
self,
amp: object,
amp: Type,
model: LightningModule,
optimizers: List[Optimizer],
optimizers: List['Optimizer'],
amp_level: str,
) -> Tuple[LightningModule, List[Optimizer]]:
) -> Tuple[LightningModule, List['Optimizer']]:
r"""
Override to init AMP your own way.
Must return a model and list of optimizers.
Expand Down Expand Up @@ -127,7 +130,7 @@ def configure_apex(self, amp, model, optimizers, amp_level):
return model, optimizers

@staticmethod
def reinit_scheduler_properties(optimizers: list, schedulers: list):
def reinit_scheduler_properties(optimizers: Sequence['Optimizer'], schedulers: Sequence[Any]) -> None:
"""Reinitializes schedulers with correct properties"""
# Reinitialize optimizer.step properties added by schedulers
for scheduler in schedulers:
Expand All @@ -149,7 +152,12 @@ def reinit_scheduler_properties(optimizers: list, schedulers: list):
break

def pre_optimizer_step(
self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs
self,
pl_module: LightningModule,
optimizer: 'Optimizer',
optimizer_idx: int,
lambda_closure: Callable,
**kwargs: Any,
) -> bool:
"""
always called before the optimizer step.
Expand All @@ -160,6 +168,6 @@ def pre_optimizer_step(
if not pl_module.automatic_optimization:
pl_module.trainer.call_hook("on_after_backward")

optimizer.step()
optimizer.step(**kwargs)

return False
51 changes: 37 additions & 14 deletions pytorch_lightning/plugins/precision/deepspeed_precision.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,45 @@
from typing import Callable, Union
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, TYPE_CHECKING, Union

import torch
from torch.optim import Optimizer

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.warnings import WarningCache

if TYPE_CHECKING:
from torch.optim import Optimizer

from pytorch_lightning.core.lightning import LightningModule

warning_cache = WarningCache()


class DeepSpeedPrecisionPlugin(PrecisionPlugin):

def __init__(self, precision):
def __init__(self, precision: int) -> None:
super().__init__()
self.precision = precision

def pre_optimizer_step(
self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs
self,
pl_module: 'LightningModule',
optimizer: 'Optimizer',
optimizer_idx: int,
lambda_closure: Callable,
**kwargs: Any,
) -> bool:
deepspeed_engine = pl_module.trainer.model
# DeepSpeed not support closures.
Expand All @@ -33,28 +54,30 @@ def pre_optimizer_step(

def backward(
self,
lightning_module: LightningModule,
model: 'LightningModule',
closure_loss: torch.Tensor,
optimizer: torch.optim.Optimizer,
optimizer: 'Optimizer',
opt_idx: int,
should_accumulate: bool,
*args,
**kwargs,
):
if is_overridden('backward', lightning_module):
*args: Any,
**kwargs: Any,
) -> torch.Tensor:
if is_overridden('backward', model):
warning_cache.warn(
"Overridden backward hook in the LightningModule will be ignored since DeepSpeed handles"
"backward logic outside of the LightningModule"
)
# todo: hack around for deepspeed engine to call backward
deepspeed_engine = lightning_module.trainer.model
deepspeed_engine.backward(closure_loss, **kwargs)
deepspeed_engine = model.trainer.model
deepspeed_engine.backward(closure_loss, *args, **kwargs)
# once backward has been applied, release graph
closure_loss = closure_loss.detach()

return closure_loss

def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)):
def clip_gradients(
self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0
) -> None:
"""
DeepSpeed handles clipping gradients via the training type plugin.
"""
Expand Down
12 changes: 8 additions & 4 deletions pytorch_lightning/plugins/precision/mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Union

from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import AMPType

if TYPE_CHECKING:
from pytorch_lightning.utilities import AMPType


class MixedPrecisionPlugin(PrecisionPlugin):
"""Base Class for mixed precision"""

EPSILON = 1e-5
backend: AMPType
precision = "mixed"
EPSILON: float = 1e-5
backend: 'AMPType'
precision: Union[str, int] = "mixed"
Loading

0 comments on commit fb7041e

Please sign in to comment.