Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deprecate enable_pl_optimizer as it is not restored properly #5244

Merged
merged 115 commits into from
Jan 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
115 commits
Select commit Hold shift + click to select a range
fbebccb
update
tchaton Dec 17, 2020
f84085c
clean test
tchaton Dec 17, 2020
a309878
still in progress
tchaton Dec 17, 2020
ae08761
udpdate test
tchaton Dec 17, 2020
3ef910f
Merge branch 'master' into bugfix/5165_enable_pl_optimizer
tchaton Dec 17, 2020
f5a5d1e
update
tchaton Dec 17, 2020
7edec88
Merge branch 'bugfix/5165_enable_pl_optimizer' of https://github.com/…
tchaton Dec 17, 2020
b4181ea
update
tchaton Dec 17, 2020
be48064
resolve flake
tchaton Dec 17, 2020
379d2be
add test for zero_grad
tchaton Dec 17, 2020
fd51f32
update
tchaton Dec 17, 2020
05a838e
works without accumulated_grad
tchaton Dec 17, 2020
82c2602
update
tchaton Dec 17, 2020
386f6d4
update
tchaton Dec 17, 2020
b007c9d
resolve amp
tchaton Dec 17, 2020
5007e68
Merge branch 'master' into bugfix/5165_enable_pl_optimizer
tchaton Dec 17, 2020
88c5c63
revert back to True
tchaton Dec 17, 2020
3accce3
Merge branch 'bugfix/5165_enable_pl_optimizer' of https://github.com/…
tchaton Dec 17, 2020
7fc56ee
update
tchaton Dec 18, 2020
8d13893
clean tests
tchaton Dec 18, 2020
e7abee6
cleaned out
tchaton Dec 18, 2020
14475e7
typo
tchaton Dec 18, 2020
b47db5e
update test
tchaton Dec 18, 2020
6a79921
git repare bug
tchaton Dec 18, 2020
c106828
remove print
tchaton Dec 18, 2020
85e4e96
udpate
tchaton Dec 18, 2020
40f7c54
Fix formatting/optimizer imports
Dec 18, 2020
e6f9945
Refactor the test for cleanliness
Dec 18, 2020
9d4fd68
Add vanilla model to the test, better var names
Dec 18, 2020
f71ce5d
Fixed var names, let's clean up these mock tests
Dec 18, 2020
5c98b0f
repare test
Dec 19, 2020
cfd63ea
update test
Dec 19, 2020
ca6c184
resolve flake8
tchaton Dec 19, 2020
a8c0c20
Merge branch 'master' into bugfix/5165_enable_pl_optimizer_refactor
tchaton Dec 19, 2020
6b5af8b
add manual_optimization
tchaton Dec 19, 2020
feaa861
Merge branch 'bugfix/5165_enable_pl_optimizer_refactor' of https://gi…
tchaton Dec 19, 2020
c1e9d14
update tests
Dec 19, 2020
1352a49
resolve flake8
tchaton Dec 19, 2020
c0afb3b
add random accumulate_grad_batches
Dec 19, 2020
2d8b9bb
Merge branch 'bugfix/5165_enable_pl_optimizer_refactor' of https://gi…
tchaton Dec 19, 2020
9a43d8e
improve test
tchaton Dec 19, 2020
12b3554
Update tests/trainer/optimization/test_parity_automatic_optimization.py
tchaton Dec 19, 2020
a126e56
Update tests/trainer/optimization/test_parity_automatic_optimization.py
tchaton Dec 19, 2020
9d083d5
update
tchaton Dec 19, 2020
7126b2d
clean tests
tchaton Dec 19, 2020
b6c7ad0
correct bug
Dec 19, 2020
f5ec5f5
Apply suggestions from code review
Borda Dec 19, 2020
a9c1f7e
format
Borda Dec 19, 2020
151790d
Merge branch 'master' into bugfix/5165_enable_pl_optimizer_refactor
tchaton Dec 20, 2020
b33ee49
adress comments
tchaton Dec 20, 2020
196d8b4
Merge branch 'master' into bugfix/5165_enable_pl_optimizer_refactor
tchaton Dec 20, 2020
1677b6c
Merge branch 'bugfix/5165_enable_pl_optimizer_refactor' of https://gi…
tchaton Dec 20, 2020
02ded96
update on comments
tchaton Dec 21, 2020
94d3b4b
Merge branch 'master' into bugfix/5165_enable_pl_optimizer_refactor
tchaton Dec 21, 2020
1e8a11e
Merge branch 'master' into bugfix/5165_enable_pl_optimizer_refactor
tchaton Dec 21, 2020
6e68e31
Merge branch 'master' into bugfix/5165_enable_pl_optimizer_refactor
tchaton Dec 21, 2020
47d047c
Merge branch 'master' into bugfix/5165_enable_pl_optimizer_refactor
tchaton Dec 22, 2020
05678f5
Merge branch 'master' into bugfix/5165_enable_pl_optimizer_refactor
tchaton Dec 23, 2020
c764bee
wip
tchaton Dec 23, 2020
55536ab
typo
tchaton Dec 23, 2020
1dfe521
depreceate enable_pl_optimizer
tchaton Dec 23, 2020
134cf0e
resolve latest bugs
Dec 23, 2020
09ea317
update
tchaton Dec 23, 2020
a42cb3a
Merge branch 'master' into bugfix/5224_not_restored_properly
tchaton Dec 23, 2020
0f81944
resolve merge
tchaton Dec 23, 2020
a61297b
add comment
tchaton Dec 23, 2020
dcc4897
Update pytorch_lightning/core/lightning.py
tchaton Dec 28, 2020
e62e4fb
Update tests/deprecated_api/test_remove_1-3.py
tchaton Dec 28, 2020
de5b0cd
Update pytorch_lightning/trainer/connectors/optimizer_connector.py
tchaton Dec 28, 2020
0899e56
Update pytorch_lightning/trainer/trainer.py
tchaton Dec 28, 2020
704b47a
Update pytorch_lightning/trainer/trainer.py
tchaton Dec 28, 2020
bd3a5f4
Update tests/trainer/optimization/test_parity_automatic_optimization.py
tchaton Dec 28, 2020
36a5cfb
update on comments
tchaton Dec 28, 2020
624b560
Merge branch 'bugfix/5224_not_restored_properly' of https://github.co…
tchaton Dec 28, 2020
4f18365
Merge branch 'master' into bugfix/5224_not_restored_properly
tchaton Dec 28, 2020
bf6e529
update restore
tchaton Dec 28, 2020
b3ac147
add a property
tchaton Dec 28, 2020
65da7b7
remove setstate as not needed anymore
tchaton Dec 28, 2020
a7c3d44
update test
tchaton Dec 28, 2020
c9e4ffc
provide optimizer to on_before_zero_grad
tchaton Dec 28, 2020
4c10879
Merge branch 'master' into bugfix/5224_not_restored_properly
tchaton Dec 28, 2020
9ef42bb
Merge branch 'master' into bugfix/5224_not_restored_properly
tchaton Dec 29, 2020
eecde3d
Merge branch 'master' into bugfix/5224_not_restored_properly
tchaton Jan 4, 2021
7f356d4
update on comments
tchaton Jan 4, 2021
b971445
Merge branch 'bugfix/5224_not_restored_properly' of https://github.co…
tchaton Jan 4, 2021
5a97d80
Merge branch 'master' into bugfix/5224_not_restored_properly
tchaton Jan 4, 2021
c5fdbea
update on comments
tchaton Jan 4, 2021
07cb66e
Merge branch 'bugfix/5224_not_restored_properly' of https://github.co…
tchaton Jan 4, 2021
1f4f255
Update pytorch_lightning/trainer/trainer.py
tchaton Jan 4, 2021
302b005
Update tests/trainer/optimization/test_parity_automatic_optimization.py
tchaton Jan 4, 2021
5922dc6
Update tests/trainer/optimization/test_parity_automatic_optimization.py
tchaton Jan 4, 2021
7835edc
Update tests/trainer/optimization/test_parity_automatic_optimization.py
tchaton Jan 4, 2021
1201c31
Merge branch 'master' into bugfix/5224_not_restored_properly
tchaton Jan 4, 2021
3f75ae0
Merge branch 'master' into bugfix/5224_not_restored_properly
tchaton Jan 5, 2021
13d4119
Merge branch 'master' into bugfix/5224_not_restored_properly
tchaton Jan 5, 2021
d45b501
Merge branch 'master' into bugfix/5224_not_restored_properly
tchaton Jan 5, 2021
70aef31
Merge branch 'master' into bugfix/5224_not_restored_properly
tchaton Jan 5, 2021
5eb6b1f
mofidy import
tchaton Jan 5, 2021
ab660fe
Merge branch 'master' into bugfix/5224_not_restored_properly
SeanNaren Jan 5, 2021
8342314
Merge branch 'master' into bugfix/5224_not_restored_properly
tchaton Jan 5, 2021
9102217
Merge branch 'master' into bugfix/5224_not_restored_properly
tchaton Jan 6, 2021
3d38b60
Merge branch 'master' into bugfix/5224_not_restored_properly
tchaton Jan 6, 2021
637e423
update changelog
tchaton Jan 6, 2021
1122cec
Merge branch 'master' into bugfix/5224_not_restored_properly
tchaton Jan 6, 2021
6a8a65b
resolve flake8
tchaton Jan 7, 2021
538019f
Merge branch 'bugfix/5224_not_restored_properly' of https://github.co…
tchaton Jan 7, 2021
36e09d5
Merge branch 'master' into bugfix/5224_not_restored_properly
tchaton Jan 7, 2021
ca41a18
update
tchaton Jan 8, 2021
db5e102
update
tchaton Jan 8, 2021
f6b272c
Merge branch 'master' into bugfix/5224_not_restored_properly
tchaton Jan 8, 2021
0542c5f
clean doc
tchaton Jan 8, 2021
57a06c1
Merge branch 'bugfix/5224_not_restored_properly' of https://github.co…
tchaton Jan 8, 2021
3374578
Merge branch 'master' into bugfix/5224_not_restored_properly
tchaton Jan 8, 2021
2dcfd2a
update
tchaton Jan 8, 2021
b1e847b
Merge branch 'bugfix/5224_not_restored_properly' of https://github.co…
tchaton Jan 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Changed depreceated `enable_pl_optimizer=True` ([#5244](https://github.com/PyTorchLightning/pytorch-lightning/pull/5244))


### Deprecated

Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmpfile:
```python
class LitAutoEncoder(pl.LightningModule):
def training_step(self, batch, batch_idx, opt_idx):
(opt_a, opt_b) = self.optimizers()
# access your optimizers with use_pl_optimizer=False. Default is True
(opt_a, opt_b) = self.optimizers(use_pl_optimizer=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need it if its the default? i think its more confusing than helpful

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So people are aware they can opt out. Do you think it should be removed ?


loss_a = ...
self.manual_backward(loss_a, opt_a)
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/test_sharded_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ def train_dataloader(self):
class SeedTrainLoaderManualModel(SeedTrainLoaderModel):
def training_step(self, batch, batch_idx, optimizer_idx):
# manual
(opt_a, opt_b) = self.optimizers()
# access your optimizers with use_pl_optimizer=False. Default is True
(opt_a, opt_b) = self.optimizers(use_pl_optimizer=True)
loss_1 = self.step(batch)

self.manual_backward(loss_1, opt_a)
Expand Down
3 changes: 2 additions & 1 deletion docs/source/new-project.rst
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,8 @@ Now you own the train loop!
.. code-block:: python
def training_step(self, batch, batch_idx, opt_idx):
(opt_a, opt_b, opt_c) = self.optimizers()
# access your optimizers with use_pl_optimizer=False. Default is True
(opt_a, opt_b, opt_c) = self.optimizers(use_pl_optimizer=True)
loss_a = self.generator(batch[0])
Expand Down
31 changes: 24 additions & 7 deletions docs/source/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,15 @@ to manually manage the optimization process. To do so, do the following:
.. code-block:: python
def training_step(self, batch, batch_idx, optimizer_idx):
# ignore optimizer_idx
(opt_g, opt_d) = self.optimizers()
# 1. ignore optimizer_idx
# 2. `use_pl_optimizer=True` means `opt_g` and `opt_d` will be of type `LightingOptimizer`
# `LightingOptimizer` simply wrapped your optimizer and behave the same way !
# When calling `optimizer.step`, `LightingOptimizer` will just handle TPU, AMP, accumulate_grad_batches, etc ... for you.
# access your optimizers with `use_pl_optimizer=False` or `optimizer.optimizer` when using use_pl_optimizer=True
# use_pl_optimizer=True is the default
(opt_g, opt_d) = self.optimizers(use_pl_optimizer=True)
# do anything you want
loss_a = ...
Expand Down Expand Up @@ -242,19 +249,29 @@ Here we add a learning-rate warm up
# update params
optimizer.step(closure=closure)

The default ``optimizer_step`` is relying on the internal ``LightningOptimizer`` to properly perform a step.
.. note:: The default ``optimizer_step`` is relying on the internal ``LightningOptimizer`` to properly perform a step. It handles TPUs, AMP, accumulate_grad_batches, zero_grad, and much more ...

.. testcode::

from pytorch_lightning.core.optimizer import LightningOptimizer
# function hook in LightningModule
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False):
optimizer.step(closure=closure)

.. note:: To access your wrapped Optimizer from ``LightningOptimizer``, do as follow.

.. testcode::

# function hook in LightningModule
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False):
if not isinstance(optimizer, LightningOptimizer):
# wraps into LightingOptimizer only for running step
optimizer = LightningOptimizer.to_lightning_optimizer(optimizer, self.trainer)

# `optimizer is a ``LightningOptimizer`` wrapping the optimizer.
# To access it, do as follow:
optimizer = optimizer.optimizer

# run step. However, it won't work on TPU, AMP, etc...
optimizer.step(closure=closure)


----------

Using the closure functions for optimization
Expand Down
6 changes: 4 additions & 2 deletions docs/source/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,8 @@ optimizer behavior
Example::

def training_step(self, batch, batch_idx):
opt = self.optimizers()
# access your optimizers with use_pl_optimizer=False. Default is True
opt = self.optimizers(use_pl_optimizer=True)

loss = ...
self.manual_backward(loss, opt)
Expand All @@ -350,7 +351,8 @@ In the multi-optimizer case, ignore the optimizer_idx flag and use the optimizer
Example::

def training_step(self, batch, batch_idx, optimizer_idx):
(opt_a, opt_b) = self.optimizers()
# access your optimizers with use_pl_optimizer=False. Default is True
(opt_a, opt_b) = self.optimizers(use_pl_optimizer=True)

gen_loss = ...
self.manual_backward(gen_loss, opt_a)
Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/accelerators/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Union, Callable
from typing import Any, Callable, Optional, Union

import torch

Expand Down Expand Up @@ -48,8 +48,6 @@ def setup(self, model):
# allow for lr schedulers as well
self.setup_optimizers(model)

self.trainer.convert_to_lightning_optimizers()

self.trainer.model = model

def train(self):
Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/accelerators/ddp2_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,6 @@ def ddp_train(self, process_idx, mp_queue, model):
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

# device ids change depending on the DDP setup
device_ids = self.get_device_ids()

Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/accelerators/ddp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License
import os
from os.path import abspath
import subprocess
import sys
from os.path import abspath
from time import sleep
from typing import Any, List, Optional, Union

Expand Down Expand Up @@ -291,8 +291,6 @@ def ddp_train(self, process_idx, model):
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

# device ids change depending on the DDP setup
device_ids = self.get_device_ids()

Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,6 @@ def ddp_train(self, process_idx, mp_queue, model):
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

# DDP spawn already spawned off each process... no need to do anything
device_ids = self.get_device_ids()

Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/accelerators/ddp_hpc_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from typing import Any, List, Optional, Union

import torch
import torch.distributed as torch_distrib
import torch.distributed as dist
import torch.distributed as torch_distrib
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning import _logger as log
Expand Down Expand Up @@ -183,8 +183,6 @@ def ddp_train(self, process_idx, model):
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

# device ids change depending on the DDP setup
device_ids = self.get_device_ids()

Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/accelerators/ddp_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,6 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not related now, but as it is in each then it shall be in the base not in all children...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't do as the self.trainer.precision wasn't called as the same place depending of the accelerator.
It would have required a cleaning of accelerators.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have some cool ideas for accelerators refactoring in #4510, more on that soon :)


# device ids change depending on the DDP setup
device_ids = self.get_device_ids()

Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/accelerators/dp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ def setup(self, model):
if self.trainer.amp_backend:
model = self.__init_half_precision(model)

self.trainer.convert_to_lightning_optimizers()

self.trainer.model = model

def __init_torch_data_parallel(self, model):
Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/accelerators/gpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ def setup(self, model):
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

self.trainer.model = model

def train(self):
Expand Down
6 changes: 2 additions & 4 deletions pytorch_lightning/accelerators/horovod_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import ExitStack
from typing import Any, Optional, Union, Callable
from typing import Any, Callable, Optional, Union

import torch
from torch.optim.lr_scheduler import _LRScheduler

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.utilities import HOROVOD_AVAILABLE, AMPType
from pytorch_lightning.utilities import AMPType, HOROVOD_AVAILABLE
from pytorch_lightning.utilities.distributed import rank_zero_only

if HOROVOD_AVAILABLE:
Expand Down Expand Up @@ -91,8 +91,6 @@ def _filter_named_parameters(model, optimizer):
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

# Update logger rank info from Horovod to avoid race conditions from different ranks
# creating directories / writing files in the same locations.
self.trainer.global_rank = hvd.rank()
Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/accelerators/tpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
from pytorch_lightning.core import LightningModule
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities import (
TPU_AVAILABLE,
move_data_to_device,
rank_zero_info,
rank_zero_only,
rank_zero_warn,
TPU_AVAILABLE,
)
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -230,8 +230,6 @@ def __setup_tpu_training(self, model: LightningModule, trainer):
f' global rank: {trainer.tpu_global_core_rank}'
f' with XLA_USE_BF16={os.environ.get("XLA_USE_BF16")}')

self.trainer.convert_to_lightning_optimizers()

def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
# do backward pass
if self.trainer.train_loop.automatic_optimization:
Expand Down
7 changes: 5 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,11 @@ def __init__(self, *args, **kwargs):
self._current_hook_fx_name = None
self._current_dataloader_idx = None

def optimizers(self):
opts = self.trainer.optimizers
def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]:
if use_pl_optimizer:
opts = list(self.trainer.lightning_optimizers.values())
else:
opts = self.trainer.optimizers

# single optimizer
if isinstance(opts, list) and len(opts) == 1 and isinstance(opts[0], Optimizer):
Expand Down
20 changes: 13 additions & 7 deletions pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from torch.optim.optimizer import Optimizer

from pytorch_lightning.utilities import TPU_AVAILABLE
from pytorch_lightning.utilities import AMPType, TPU_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if TPU_AVAILABLE:
Expand Down Expand Up @@ -62,6 +62,10 @@ def __init__(self,
self._accumulate_grad_batches = accumulate_grad_batches
self._optimizer_idx = None

@property
def optimizer(self):
return self._optimizer

@property
def defaults(self):
return self._optimizer.defaults
Expand Down Expand Up @@ -102,11 +106,13 @@ def _on_trainer_init(self, trainer):
break

@classmethod
def to_lightning_optimizer(cls, optimizer, trainer):
if isinstance(optimizer, LightningOptimizer):
return optimizer
optimizer = cls(optimizer)
optimizer._on_trainer_init(trainer)
def _to_lightning_optimizer(cls, optimizer, trainer, opt_idx):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a breaking API change, need to add meta methods with deprecation warning...

  • name chnage
  • added required arg

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_to_lightning_optimizer is only for internal use. It shouldn't break any API.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was a public api so you never know who was using it... so we shall be always careful what we make as public/protected...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

either way is fine imo. lightning optimizer is anyway experimental feature and user should not even know about it at this point :)

# apex overrides .step function and need to be wrapped on each step
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if trainer.amp_backend == AMPType.APEX:
optimizer = cls(optimizer)
optimizer._on_trainer_init(trainer)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
else:
optimizer = trainer.lightning_optimizers[opt_idx]
return optimizer

def _accumulated_batches_reached(self):
Expand Down Expand Up @@ -148,7 +154,7 @@ def __optimizer_step(self, *args, closure: Optional[Callable] = None, profiler_n
**kwargs
)

trainer.train_loop.on_before_zero_grad(self)
trainer.train_loop.on_before_zero_grad(optimizer)

model.optimizer_zero_grad(
trainer.current_epoch,
Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/plugins/ddp_sequential_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from typing import Any, List, Optional

import torch
import torch.distributed as torch_distrib
from torch import nn
import torch.distributed as torch_distrib
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning import _logger as log
Expand All @@ -27,8 +27,8 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if FAIRSCALE_PIPE_AVAILABLE:
import fairscale.nn.model_parallel as mpu
from fairscale.nn import PipeRPCWrapper
import fairscale.nn.model_parallel as mpu
from fairscale.nn.pipe import balance as pipe_balance
from fairscale.nn.pipe import rpc as rpc_pipe
from fairscale.nn.pipe.pipeline import PipelineStyle
Expand Down Expand Up @@ -380,7 +380,6 @@ def register_optimizers(ctx, model):
model.trainer.optimizers = optimizers
model.trainer.lr_schedulers = lr_schedulers
model.trainer.optimizer_frequencies = optimizer_frequencies
model.trainer.convert_to_lightning_optimizers()


def run_optimizer(ctx, model):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
# unscale gradient to allow analyze within `on_after_backward`
if not self.trainer.train_loop.should_accumulate() and automatic_optimization:
if isinstance(optimizer, LightningOptimizer):
self.trainer.scaler.unscale_(optimizer._optimizer)
self.trainer.scaler.unscale_(optimizer.optimizer)
else:
self.trainer.scaler.unscale_(optimizer)

Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/sharded_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _reinit_with_fairscale_oss(self, trainer):
optimizers = trainer.optimizers
for x, optimizer in enumerate(optimizers):
if is_lightning_optimizer(optimizer):
optimizer = optimizer._optimizer
optimizer = optimizer.optimizer
if not isinstance(optimizer, OSS):
optim_class = type(optimizer)
zero_optimizer = OSS(
Expand All @@ -73,7 +73,6 @@ def _reinit_with_fairscale_oss(self, trainer):
)
optimizers[x] = zero_optimizer
del optimizer
trainer.convert_to_lightning_optimizers()

def get_model_from_plugin(
self,
Expand Down
Loading