-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
optimizers.py
194 lines (165 loc) · 8.57 KB
/
optimizers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch import optim
from torch.optim.optimizer import Optimizer
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
class TrainerOptimizersMixin(ABC):
_lightning_optimizers: Optional[List[LightningOptimizer]]
def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]:
self._lightning_optimizers = None
optim_conf = model.configure_optimizers()
if optim_conf is None:
rank_zero_warn(
'`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer',
UserWarning,
)
optim_conf = _MockOptimizer()
optimizers, lr_schedulers, optimizer_frequencies = [], [], []
monitor = None
# single output, single optimizer
if isinstance(optim_conf, Optimizer):
optimizers = [optim_conf]
# two lists, optimizer + lr schedulers
elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 and isinstance(optim_conf[0], list):
opt, sch = optim_conf
optimizers = opt
lr_schedulers = sch if isinstance(sch, list) else [sch]
# single dictionary
elif isinstance(optim_conf, dict):
optimizers = [optim_conf["optimizer"]]
monitor = optim_conf.get('monitor', None)
lr_schedulers = [optim_conf["lr_scheduler"]] if "lr_scheduler" in optim_conf else []
# multiple dictionaries
elif isinstance(optim_conf, (list, tuple)) and all(isinstance(d, dict) for d in optim_conf):
optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
lr_schedulers = [opt_dict["lr_scheduler"] for opt_dict in optim_conf if "lr_scheduler" in opt_dict]
optimizer_frequencies = [
opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency", None) is not None
]
# assert that if frequencies are present, they are given for all optimizers
if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers):
raise ValueError("A frequency must be given to each optimizer.")
# single list or tuple, multiple optimizer
elif isinstance(optim_conf, (list, tuple)):
optimizers = list(optim_conf)
# unknown configuration
else:
raise MisconfigurationException(
'Unknown configuration for model optimizers.'
' Output from `model.configure_optimizers()` should either be:\n'
' * `torch.optim.Optimizer`\n'
' * [`torch.optim.Optimizer`]\n'
' * ([`torch.optim.Optimizer`], [`torch.optim.lr_scheduler`])\n'
' * {"optimizer": `torch.optim.Optimizer`, (optional) "lr_scheduler": `torch.optim.lr_scheduler`}\n'
' * A list of the previously described dict format, with an optional "frequency" key (int)'
)
lr_schedulers = self.configure_schedulers(lr_schedulers, monitor=monitor)
_validate_scheduler_optimizer(optimizers, lr_schedulers)
return optimizers, lr_schedulers, optimizer_frequencies
def convert_to_lightning_optimizers(self):
def _convert_to_lightning_optimizer(trainer, optimizer):
if not isinstance(optimizer, LightningOptimizer):
optimizer = LightningOptimizer(optimizer)
optimizer._on_trainer_init(trainer)
return optimizer
self._lightning_optimizers = {
opt_idx: _convert_to_lightning_optimizer(self, opt)
for opt_idx, opt in enumerate(self.optimizers)
}
def configure_schedulers(self, schedulers: list, monitor: Optional[str] = None):
# Convert each scheduler into dict structure with relevant information
lr_schedulers = []
default_config = _get_default_scheduler_config()
for scheduler in schedulers:
if isinstance(scheduler, dict):
# check provided keys
extra_keys = [k for k in scheduler.keys() if k not in default_config.keys()]
if extra_keys:
rank_zero_warn(f'Found unsupported keys in the lr scheduler dict: {extra_keys}', RuntimeWarning)
if 'scheduler' not in scheduler:
raise MisconfigurationException(
'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler'
)
if 'interval' in scheduler and scheduler['interval'] not in ('step', 'epoch'):
raise MisconfigurationException(
f'The "interval" key in lr scheduler dict must be "step" or "epoch"'
f' but is "{scheduler["interval"]}"'
)
scheduler['reduce_on_plateau'] = isinstance(
scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau
)
if scheduler['reduce_on_plateau'] and scheduler.get('monitor', None) is None:
raise MisconfigurationException(
'The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used.'
' For example: {"optimizer": optimizer, "lr_scheduler":'
' {"scheduler": scheduler, "monitor": "your_loss"}}'
)
lr_schedulers.append({**default_config, **scheduler})
elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
if monitor is None:
raise MisconfigurationException(
'`configure_optimizers` must include a monitor when a `ReduceLROnPlateau` scheduler is used.'
' For example:'
' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}'
)
lr_schedulers.append({
**default_config, 'scheduler': scheduler,
'reduce_on_plateau': True,
'monitor': monitor
})
elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
lr_schedulers.append({**default_config, 'scheduler': scheduler})
else:
raise ValueError(f'The provided lr scheduler "{scheduler}" is invalid')
return lr_schedulers
class _MockOptimizer(Optimizer):
"""The `_MockOptimizer` will be used inplace of an optimizer in the event that `None`
is returned from `configure_optimizers`.
"""
def __init__(self):
super().__init__([torch.zeros(1)], {})
def add_param_group(self, param_group):
pass # Do Nothing
def load_state_dict(self, state_dict):
pass # Do Nothing
def state_dict(self):
return {} # Return Empty
def step(self, closure=None):
if closure is not None:
closure()
def zero_grad(self):
pass # Do Nothing
def __repr__(self):
return 'No Optimizer'
def _validate_scheduler_optimizer(optimizers, lr_schedulers):
if any(sch['scheduler'].optimizer not in optimizers for sch in lr_schedulers):
raise MisconfigurationException(
"Some schedulers are attatched with an optimizer that wasn't returned from `configure_optimizers`."
)
def _get_default_scheduler_config() -> Dict[str, Any]:
return {
'scheduler': None,
'name': None, # no custom name
'interval': 'epoch', # after epoch is over
'frequency': 1, # every epoch/batch
'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler
'monitor': None, # value to monitor for ReduceLROnPlateau
'strict': True, # enforce that the monitor exists for ReduceLROnPlateau
}