Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate TrainerOptimizersMixin and move functionality to core/optimizer.py #11155

Merged
merged 38 commits into from
Dec 23, 2021
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
68d60fa
refactor init_optimizers
daniellepintz Dec 18, 2021
e7e3b59
fix tests
daniellepintz Dec 18, 2021
628e4a0
fix tests
daniellepintz Dec 18, 2021
68caa57
fix another test
daniellepintz Dec 19, 2021
0396959
attempt to fix deepspeed test
daniellepintz Dec 19, 2021
b645134
addr comments
daniellepintz Dec 19, 2021
47bdae1
deprecate optimizer mixin
daniellepintz Dec 19, 2021
aff9de3
addr comments
daniellepintz Dec 20, 2021
74a3033
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Dec 20, 2021
3d1faaa
move _convert_to_lightning_optimizer to core/optimizer.py
daniellepintz Dec 20, 2021
8c1267f
fix typing and docstring
daniellepintz Dec 21, 2021
4a1ee1a
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Dec 21, 2021
108c4a2
remove strategy refactor
daniellepintz Dec 21, 2021
a768823
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 21, 2021
b27e2e1
small fix
daniellepintz Dec 21, 2021
e8d6cf6
Merge branch 'optimizers_mixin' of github.com:daniellepintz/pytorch-l…
daniellepintz Dec 21, 2021
7d72f56
fix mypy
daniellepintz Dec 21, 2021
4fda88f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 21, 2021
be56fcd
addr comment
daniellepintz Dec 21, 2021
0da03c8
Merge branch 'optimizers_mixin' of github.com:daniellepintz/pytorch-l…
daniellepintz Dec 21, 2021
779c0a4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 21, 2021
362a1b7
fix test
daniellepintz Dec 21, 2021
81c6788
Merge branch 'optimizers_mixin' of github.com:daniellepintz/pytorch-l…
daniellepintz Dec 21, 2021
493a582
addr comments
daniellepintz Dec 22, 2021
b20d6f4
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Dec 22, 2021
1c30527
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Dec 22, 2021
8b80304
fix merge conflicts from strategy
daniellepintz Dec 22, 2021
5397d2f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 22, 2021
5da0649
addr comments
daniellepintz Dec 22, 2021
f320402
Merge branch 'optimizers_mixin' of github.com:daniellepintz/pytorch-l…
daniellepintz Dec 22, 2021
119bc58
fix whitespace
daniellepintz Dec 22, 2021
ecb3c60
fix weakref test
daniellepintz Dec 22, 2021
fd1b14d
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Dec 22, 2021
8ba6f37
add comment
daniellepintz Dec 22, 2021
d2c6b6e
addr comments
daniellepintz Dec 22, 2021
bb6ba78
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 22, 2021
41bbfb8
fix mypy
daniellepintz Dec 22, 2021
126533d
Merge branch 'optimizers_mixin' of github.com:daniellepintz/pytorch-l…
daniellepintz Dec 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `Trainer.should_rank_save_checkpoint` Trainer property ([#11068](https://github.com/PyTorchLightning/pytorch-lightning/pull/11068))


- Deprecated `TrainerOptimizersMixin` and moved functionality to `core/optimizer.py`([#11155](https://github.com/PyTorchLightning/pytorch-lightning/pull/11155))


- Deprecated `TrainerCallbackHookMixin` ([#11148](https://github.com/PyTorchLightning/pytorch-lightning/pull/11148))

### Removed
Expand Down Expand Up @@ -348,6 +351,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with the `TPUSpawnPlugin` handling the `XLA_USE_BF16` environment variable incorrectly ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/pull/10990))


- Fixed wrong typehint for `Trainer.lightning_optimizers` ([#11155](https://github.com/PyTorchLightning/pytorch-lightning/pull/11155))


## [1.5.7] - 2021-12-21

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.core.optimizer import _get_default_scheduler_config
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down
237 changes: 234 additions & 3 deletions pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import weakref
from contextlib import contextmanager
from typing import Any, Callable, Generator, Optional
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
from weakref import proxy

import torch
from torch import optim
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities import AMPType, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException


Expand Down Expand Up @@ -54,7 +57,11 @@ def optimizer(self) -> Optimizer:
return self._optimizer

def _on_trainer_init(self, trainer: "pl.Trainer") -> None:
self._trainer = proxy(trainer)
# check if trainer is already of type weakproxy since we can't call proxy on a weakproxy
if isinstance(trainer, weakref.ProxyType):
daniellepintz marked this conversation as resolved.
Show resolved Hide resolved
self._trainer = trainer
else:
self._trainer = proxy(trainer)
daniellepintz marked this conversation as resolved.
Show resolved Hide resolved
for opt_idx, opt in enumerate(trainer.optimizers):
if opt == self._optimizer:
self._optimizer_idx = opt_idx
Expand Down Expand Up @@ -162,3 +169,227 @@ def closure_dis():
assert trainer is not None
with trainer.profiler.profile(profiler_action):
trainer.strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)


def _init_optimizers_and_lr_schedulers(model: "pl.LightningModule") -> Tuple[List, List, List]:
ananthsub marked this conversation as resolved.
Show resolved Hide resolved
"""Calls `LightningModule.configure_optimizers` and parses and validates the output."""
model.trainer._lightning_optimizers = None
optim_conf = model.trainer._call_lightning_module_hook("configure_optimizers", pl_module=model)

if optim_conf is None:
rank_zero_warn(
"`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer",
)
optim_conf = _MockOptimizer()

optimizers, lr_schedulers, optimizer_frequencies, monitor = _configure_optimizers(optim_conf)
lr_schedulers = _configure_schedulers(lr_schedulers, monitor, not model.automatic_optimization)
_validate_scheduler_optimizer(optimizers, lr_schedulers)
return optimizers, lr_schedulers, optimizer_frequencies


def _configure_optimizers(
optim_conf: Union[Dict[str, Any], List, Optimizer, Tuple]
) -> Tuple[List, List, List, Optional[str]]:
optimizers, lr_schedulers, optimizer_frequencies = [], [], []
monitor = None

# single output, single optimizer
if isinstance(optim_conf, Optimizer):
optimizers = [optim_conf]
# two lists, optimizer + lr schedulers
elif (
isinstance(optim_conf, (list, tuple))
and len(optim_conf) == 2
and isinstance(optim_conf[0], list)
and all(isinstance(opt, Optimizer) for opt in optim_conf[0])
):
opt, sch = optim_conf
optimizers = opt
lr_schedulers = sch if isinstance(sch, list) else [sch]
daniellepintz marked this conversation as resolved.
Show resolved Hide resolved
# single dictionary
elif isinstance(optim_conf, dict):
_validate_optim_conf(optim_conf)
optimizers = [optim_conf["optimizer"]]
monitor = optim_conf.get("monitor", None)
lr_schedulers = [optim_conf["lr_scheduler"]] if "lr_scheduler" in optim_conf else []
# multiple dictionaries
elif isinstance(optim_conf, (list, tuple)) and all(isinstance(d, dict) for d in optim_conf):
for opt_dict in optim_conf:
_validate_optim_conf(opt_dict)
optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
scheduler_dict = (
lambda scheduler, opt_idx: dict(scheduler, opt_idx=opt_idx)
if isinstance(scheduler, dict)
else {"scheduler": scheduler, "opt_idx": opt_idx}
)

lr_schedulers = [
scheduler_dict(opt_dict["lr_scheduler"], opt_idx)
for opt_idx, opt_dict in enumerate(optim_conf)
if "lr_scheduler" in opt_dict
]
optimizer_frequencies = [
opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency", None) is not None
]
# assert that if frequencies are present, they are given for all optimizers
if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers):
raise ValueError("A frequency must be given to each optimizer.")
# single list or tuple, multiple optimizer
elif isinstance(optim_conf, (list, tuple)) and all(isinstance(opt, Optimizer) for opt in optim_conf):
optimizers = list(optim_conf)
# unknown configuration
else:
raise MisconfigurationException(
"Unknown configuration for model optimizers."
" Output from `model.configure_optimizers()` should either be:\n"
daniellepintz marked this conversation as resolved.
Show resolved Hide resolved
" * `torch.optim.Optimizer`\n"
" * [`torch.optim.Optimizer`]\n"
" * ([`torch.optim.Optimizer`], [`torch.optim.lr_scheduler`])\n"
daniellepintz marked this conversation as resolved.
Show resolved Hide resolved
' * {"optimizer": `torch.optim.Optimizer`, (optional) "lr_scheduler": `torch.optim.lr_scheduler`}\n'
' * A list of the previously described dict format, with an optional "frequency" key (int)'
daniellepintz marked this conversation as resolved.
Show resolved Hide resolved
)
return optimizers, lr_schedulers, optimizer_frequencies, monitor


def _configure_schedulers(
schedulers: list, monitor: Optional[str], is_manual_optimization: bool
) -> List[Dict[str, Any]]:
"""Convert each scheduler into dict structure with relevant information."""
lr_schedulers = []
default_config = _get_default_scheduler_config()
# TODO: move is_manual_optimization check out of for loop
for scheduler in schedulers:
if is_manual_optimization:
daniellepintz marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(scheduler, dict):
invalid_keys = {"interval", "frequency", "reduce_on_plateau", "monitor", "strict"}
keys_to_warn = [k for k in scheduler.keys() if k in invalid_keys]

if keys_to_warn:
rank_zero_warn(
f"The lr scheduler dict contains the key(s) {keys_to_warn}, but the keys will be ignored."
" You need to call `lr_scheduler.step()` manually in manual optimization.",
category=RuntimeWarning,
)

scheduler = {key: scheduler[key] for key in scheduler if key not in invalid_keys}
lr_schedulers.append({**default_config, **scheduler})
else:
lr_schedulers.append({**default_config, "scheduler": scheduler})
else:
if isinstance(scheduler, dict):
# check provided keys
extra_keys = [k for k in scheduler.keys() if k not in default_config.keys()]
if extra_keys:
rank_zero_warn(
f"Found unsupported keys in the lr scheduler dict: {extra_keys}", category=RuntimeWarning
)
if "scheduler" not in scheduler:
raise MisconfigurationException(
'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler'
)
if "interval" in scheduler and scheduler["interval"] not in ("step", "epoch"):
raise MisconfigurationException(
'The "interval" key in lr scheduler dict must be "step" or "epoch"'
f' but is "{scheduler["interval"]}"'
)
scheduler["reduce_on_plateau"] = isinstance(
scheduler["scheduler"], optim.lr_scheduler.ReduceLROnPlateau
)
if scheduler["reduce_on_plateau"] and scheduler.get("monitor", None) is None:
raise MisconfigurationException(
"The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used."
' For example: {"optimizer": optimizer, "lr_scheduler":'
' {"scheduler": scheduler, "monitor": "your_loss"}}'
)
is_one_cycle = isinstance(scheduler["scheduler"], optim.lr_scheduler.OneCycleLR)
if is_one_cycle and scheduler.get("interval", "epoch") == "epoch":
rank_zero_warn(
"A `OneCycleLR` scheduler is using 'interval': 'epoch'."
" Are you sure you didn't mean 'interval': 'step'?",
category=RuntimeWarning,
)
lr_schedulers.append({**default_config, **scheduler})
elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
if monitor is None:
raise MisconfigurationException(
"`configure_optimizers` must include a monitor when a `ReduceLROnPlateau`"
" scheduler is used. For example:"
' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}'
)
lr_schedulers.append(
{**default_config, "scheduler": scheduler, "reduce_on_plateau": True, "monitor": monitor}
)
elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
lr_schedulers.append({**default_config, "scheduler": scheduler})
else:
raise ValueError(f'The provided lr scheduler "{scheduler}" is invalid')
return lr_schedulers


def _get_default_scheduler_config() -> Dict[str, Any]:
return {
"scheduler": None,
"name": None, # no custom name
"interval": "epoch", # after epoch is over
"frequency": 1, # every epoch/batch
"reduce_on_plateau": False, # most often not ReduceLROnPlateau scheduler
"monitor": None, # value to monitor for ReduceLROnPlateau
"strict": True, # enforce that the monitor exists for ReduceLROnPlateau
"opt_idx": None, # necessary to store opt_idx when optimizer frequencies are specified
}


def _validate_scheduler_optimizer(optimizers: List[Any], lr_schedulers: List[Any]) -> None:
if any(sch["scheduler"].optimizer not in optimizers for sch in lr_schedulers):
raise MisconfigurationException(
"Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
)


def _validate_optim_conf(optim_conf: Dict[str, Any]) -> None:
valid_keys = {"optimizer", "lr_scheduler", "frequency", "monitor"}
extra_keys = optim_conf.keys() - valid_keys
if extra_keys:
rank_zero_warn(
f"Found unsupported keys in the optimizer configuration: {set(extra_keys)}", category=RuntimeWarning
)


def _convert_to_lightning_optimizers(trainer: "pl.Trainer") -> None:
def _convert_to_lightning_optimizer(optimizer: Optimizer) -> LightningOptimizer:
if not isinstance(optimizer, LightningOptimizer):
optimizer = LightningOptimizer(optimizer) # type: ignore [assignment]
optimizer._on_trainer_init(trainer)
return optimizer # type: ignore [return-value]

trainer._lightning_optimizers = { # type: ignore [assignment]
opt_idx: _convert_to_lightning_optimizer(opt) for opt_idx, opt in enumerate(trainer.optimizers)
}


class _MockOptimizer(Optimizer):
"""The `_MockOptimizer` will be used inplace of an optimizer in the event that `None` is returned from
`configure_optimizers`."""
daniellepintz marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self) -> None:
super().__init__([torch.zeros(1)], {})

def add_param_group(self, param_group: Dict[Any, Any]) -> None:
pass # Do Nothing

def load_state_dict(self, state_dict: Dict[Any, Any]) -> None:
pass # Do Nothing

def state_dict(self) -> Dict[Any, Any]:
daniellepintz marked this conversation as resolved.
Show resolved Hide resolved
return {} # Return Empty

def step(self, closure: Callable = None) -> None:
if closure is not None:
closure()

def zero_grad(self, set_to_none: Optional[bool] = False) -> None:
pass # Do Nothing

def __repr__(self) -> str:
return "No Optimizer"
4 changes: 2 additions & 2 deletions pytorch_lightning/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from torch.nn.parallel.distributed import DistributedDataParallel

import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.optimizer import _convert_to_lightning_optimizers, LightningOptimizer
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
Expand Down Expand Up @@ -347,7 +347,7 @@ def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int):
del optimizer
trainer = self.lightning_module.trainer
trainer.optimizers = optimizers
trainer.convert_to_lightning_optimizers()
_convert_to_lightning_optimizers(trainer)

def configure_ddp(self) -> None:
self.pre_configure_ddp()
Expand Down
6 changes: 2 additions & 4 deletions pytorch_lightning/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@
from torch.optim.lr_scheduler import _LRScheduler

import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import _get_default_scheduler_config, _init_optimizers_and_lr_schedulers
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.apply_func import apply_to_collection
Expand Down Expand Up @@ -446,9 +446,7 @@ def init_deepspeed(self):
self._initialize_deepspeed_inference(model)

def _init_optimizers(self) -> Tuple[Optimizer, Optional[Union[LRSchedulerTypeTuple]], Optional[int]]:
optimizers, schedulers, optimizer_frequencies = self.lightning_module.trainer.init_optimizers(
self.lightning_module
)
optimizers, schedulers, optimizer_frequencies = _init_optimizers_and_lr_schedulers(self.lightning_module)
if len(optimizers) > 1 or len(schedulers) > 1:
raise MisconfigurationException(
"DeepSpeed currently only supports single optimizer, single optional scheduler."
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/strategies/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.optimizer import _convert_to_lightning_optimizers, LightningOptimizer
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, rank_zero_only
Expand Down Expand Up @@ -50,7 +50,7 @@ def configure_ddp(self) -> None:
optimizers=trainer.optimizers,
)
trainer.optimizers = optimizers
trainer.convert_to_lightning_optimizers()
_convert_to_lightning_optimizers(trainer)

def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
"""Wraps the model and optimizers with fairscale components.
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/strategies/training_type_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers
from pytorch_lightning.overrides.base import unwrap_lightning_module
from pytorch_lightning.plugins import TorchCheckpointIO
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
Expand Down Expand Up @@ -377,7 +378,7 @@ def process_dataloader(self, dataloader: DataLoader) -> DataLoader:
return dataloader

def init_optimizers(self, trainer: "pl.Trainer", model: "pl.LightningModule"):
return trainer.init_optimizers(model)
return _init_optimizers_and_lr_schedulers(model)

@property
def restore_checkpoint_after_setup(self) -> bool:
Expand Down
Loading