Skip to content

Commit

Permalink
deprecate enable_pl_optimizer as it is not restored properly (#5244)
Browse files Browse the repository at this point in the history
* update

* clean test

* still in progress

* udpdate test

* update

* update

* resolve flake

* add test for zero_grad

* update

* works without accumulated_grad

* update

* update

* resolve amp

* revert back to True

* update

* clean tests

* cleaned out

* typo

* update test

* git repare bug

* remove print

* udpate

* Fix formatting/optimizer imports

* Refactor the test for cleanliness

* Add vanilla model to the test, better var names

* Fixed var names, let's clean up these mock tests

* repare test

* update test

* resolve flake8

* add manual_optimization

* update tests

* resolve flake8

* add random accumulate_grad_batches

* improve test

* Update tests/trainer/optimization/test_parity_automatic_optimization.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update tests/trainer/optimization/test_parity_automatic_optimization.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* update

* clean tests

* correct bug

* Apply suggestions from code review

* format

* adress comments

* update on comments

* wip

* typo

* depreceate enable_pl_optimizer

* resolve latest bugs

* update

* resolve merge

* add comment

* Update pytorch_lightning/core/lightning.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update tests/deprecated_api/test_remove_1-3.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/trainer/connectors/optimizer_connector.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/trainer/trainer.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/trainer/trainer.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update tests/trainer/optimization/test_parity_automatic_optimization.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* update on comments

* update restore

* add a property

* remove setstate as not needed anymore

* update test

* provide optimizer to on_before_zero_grad

* update on comments

* update on comments

* Update pytorch_lightning/trainer/trainer.py

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Update tests/trainer/optimization/test_parity_automatic_optimization.py

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Update tests/trainer/optimization/test_parity_automatic_optimization.py

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Update tests/trainer/optimization/test_parity_automatic_optimization.py

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* mofidy import

* update changelog

* resolve flake8

* update

* update

* clean doc

Co-authored-by: SeanNaren <sean@grid.ai>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-62-109.ec2.internal>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>

(cherry picked from commit f2e99d6)
  • Loading branch information
tchaton authored and SeanNaren committed Jan 13, 2021
1 parent 822cceb commit f28f0aa
Show file tree
Hide file tree
Showing 41 changed files with 156 additions and 213 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Changed depreceated `enable_pl_optimizer=True` ([#5244](https://github.com/PyTorchLightning/pytorch-lightning/pull/5244))


### Deprecated

Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmpfile:
```python
class LitAutoEncoder(pl.LightningModule):
def training_step(self, batch, batch_idx, opt_idx):
(opt_a, opt_b) = self.optimizers()
# access your optimizers with use_pl_optimizer=False. Default is True
(opt_a, opt_b) = self.optimizers(use_pl_optimizer=True)

loss_a = ...
self.manual_backward(loss_a, opt_a)
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/test_sharded_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ def train_dataloader(self):
class SeedTrainLoaderManualModel(SeedTrainLoaderModel):
def training_step(self, batch, batch_idx, optimizer_idx):
# manual
(opt_a, opt_b) = self.optimizers()
# access your optimizers with use_pl_optimizer=False. Default is True
(opt_a, opt_b) = self.optimizers(use_pl_optimizer=True)
loss_1 = self.step(batch)

self.manual_backward(loss_1, opt_a)
Expand Down
3 changes: 2 additions & 1 deletion docs/source/new-project.rst
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,8 @@ Now you own the train loop!
.. code-block:: python
def training_step(self, batch, batch_idx, opt_idx):
(opt_a, opt_b, opt_c) = self.optimizers()
# access your optimizers with use_pl_optimizer=False. Default is True
(opt_a, opt_b, opt_c) = self.optimizers(use_pl_optimizer=True)
loss_a = self.generator(batch[0])
Expand Down
31 changes: 24 additions & 7 deletions docs/source/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,15 @@ to manually manage the optimization process. To do so, do the following:
.. code-block:: python
def training_step(self, batch, batch_idx, optimizer_idx):
# ignore optimizer_idx
(opt_g, opt_d) = self.optimizers()
# 1. ignore optimizer_idx
# 2. `use_pl_optimizer=True` means `opt_g` and `opt_d` will be of type `LightingOptimizer`
# `LightingOptimizer` simply wrapped your optimizer and behave the same way !
# When calling `optimizer.step`, `LightingOptimizer` will just handle TPU, AMP, accumulate_grad_batches, etc ... for you.
# access your optimizers with `use_pl_optimizer=False` or `optimizer.optimizer` when using use_pl_optimizer=True
# use_pl_optimizer=True is the default
(opt_g, opt_d) = self.optimizers(use_pl_optimizer=True)
# do anything you want
loss_a = ...
Expand Down Expand Up @@ -242,19 +249,29 @@ Here we add a learning-rate warm up
# update params
optimizer.step(closure=closure)

The default ``optimizer_step`` is relying on the internal ``LightningOptimizer`` to properly perform a step.
.. note:: The default ``optimizer_step`` is relying on the internal ``LightningOptimizer`` to properly perform a step. It handles TPUs, AMP, accumulate_grad_batches, zero_grad, and much more ...

.. testcode::

from pytorch_lightning.core.optimizer import LightningOptimizer
# function hook in LightningModule
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False):
optimizer.step(closure=closure)

.. note:: To access your wrapped Optimizer from ``LightningOptimizer``, do as follow.

.. testcode::

# function hook in LightningModule
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False):
if not isinstance(optimizer, LightningOptimizer):
# wraps into LightingOptimizer only for running step
optimizer = LightningOptimizer.to_lightning_optimizer(optimizer, self.trainer)

# `optimizer is a ``LightningOptimizer`` wrapping the optimizer.
# To access it, do as follow:
optimizer = optimizer.optimizer

# run step. However, it won't work on TPU, AMP, etc...
optimizer.step(closure=closure)


----------

Using the closure functions for optimization
Expand Down
6 changes: 4 additions & 2 deletions docs/source/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,8 @@ optimizer behavior
Example::
def training_step(self, batch, batch_idx):
opt = self.optimizers()
# access your optimizers with use_pl_optimizer=False. Default is True
opt = self.optimizers(use_pl_optimizer=True)
loss = ...
self.manual_backward(loss, opt)
Expand All @@ -350,7 +351,8 @@ In the multi-optimizer case, ignore the optimizer_idx flag and use the optimizer
Example::
def training_step(self, batch, batch_idx, optimizer_idx):
(opt_a, opt_b) = self.optimizers()
# access your optimizers with use_pl_optimizer=False. Default is True
(opt_a, opt_b) = self.optimizers(use_pl_optimizer=True)
gen_loss = ...
self.manual_backward(gen_loss, opt_a)
Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/accelerators/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Union, Callable
from typing import Any, Callable, Optional, Union

import torch

Expand Down Expand Up @@ -48,8 +48,6 @@ def setup(self, model):
# allow for lr schedulers as well
self.setup_optimizers(model)

self.trainer.convert_to_lightning_optimizers()

self.trainer.model = model

def train(self):
Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/accelerators/ddp2_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,6 @@ def ddp_train(self, process_idx, mp_queue, model):
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

# device ids change depending on the DDP setup
device_ids = self.get_device_ids()

Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/accelerators/ddp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License
import os
from os.path import abspath
import subprocess
import sys
from os.path import abspath
from time import sleep
from typing import Any, List, Optional, Union

Expand Down Expand Up @@ -292,8 +292,6 @@ def ddp_train(self, process_idx, model):
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

# device ids change depending on the DDP setup
device_ids = self.get_device_ids()

Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,6 @@ def ddp_train(self, process_idx, mp_queue, model):
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

# DDP spawn already spawned off each process... no need to do anything
device_ids = self.get_device_ids()

Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/accelerators/ddp_hpc_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from typing import Any, List, Optional, Union

import torch
import torch.distributed as torch_distrib
import torch.distributed as dist
import torch.distributed as torch_distrib
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning import _logger as log
Expand Down Expand Up @@ -177,8 +177,6 @@ def ddp_train(self, process_idx, model):
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

# device ids change depending on the DDP setup
device_ids = self.get_device_ids()

Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/accelerators/ddp_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,6 @@ def ddp_train(self, process_idx, mp_queue, model, is_master: bool = False, proc_
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

# device ids change depending on the DDP setup
device_ids = self.get_device_ids()

Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/accelerators/dp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ def setup(self, model):
if self.trainer.amp_backend:
model = self.__init_half_precision(model)

self.trainer.convert_to_lightning_optimizers()

self.trainer.model = model

def __init_torch_data_parallel(self, model):
Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/accelerators/gpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ def setup(self, model):
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

self.trainer.model = model

def train(self):
Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/accelerators/horovod_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import ExitStack
from typing import Any, Optional, Union, Callable
from typing import Any, Callable, Optional, Union

import torch
from torch.optim.lr_scheduler import _LRScheduler
Expand Down Expand Up @@ -90,8 +90,6 @@ def _filter_named_parameters(model, optimizer):
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

# Update logger rank info from Horovod to avoid race conditions from different ranks
# creating directories / writing files in the same locations.
self.trainer.global_rank = hvd.rank()
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/accelerators/tpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
rank_zero_info,
rank_zero_only,
rank_zero_warn,
TPU_AVAILABLE,
)
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -229,8 +230,6 @@ def __setup_tpu_training(self, model: LightningModule, trainer):
f' global rank: {trainer.tpu_global_core_rank}'
f' with XLA_USE_BF16={os.environ.get("XLA_USE_BF16")}')

self.trainer.convert_to_lightning_optimizers()

def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
# do backward pass
if self.trainer.train_loop.automatic_optimization:
Expand Down
7 changes: 5 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,11 @@ def __init__(self, *args, **kwargs):
self._current_dataloader_idx = None
self._automatic_optimization: bool = True

def optimizers(self):
opts = self.trainer.optimizers
def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]:
if use_pl_optimizer:
opts = list(self.trainer.lightning_optimizers.values())
else:
opts = self.trainer.optimizers

# single optimizer
if isinstance(opts, list) and len(opts) == 1 and isinstance(opts[0], Optimizer):
Expand Down
16 changes: 12 additions & 4 deletions pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ def __init__(self,
self._accumulate_grad_batches = accumulate_grad_batches
self._optimizer_idx = None

@property
def optimizer(self):
return self._optimizer

@property
def defaults(self):
return self._optimizer.defaults
Expand Down Expand Up @@ -103,9 +107,13 @@ def _on_trainer_init(self, trainer):
break

@classmethod
def to_lightning_optimizer(cls, optimizer, trainer):
optimizer = cls(optimizer)
optimizer._on_trainer_init(trainer)
def _to_lightning_optimizer(cls, optimizer, trainer, opt_idx):
# apex overrides .step function and need to be wrapped on each step
if trainer.amp_backend == AMPType.APEX:
optimizer = cls(optimizer)
optimizer._on_trainer_init(trainer)
else:
optimizer = trainer.lightning_optimizers[opt_idx]
return optimizer

def _accumulated_batches_reached(self):
Expand Down Expand Up @@ -147,7 +155,7 @@ def __optimizer_step(self, *args, closure: Optional[Callable] = None, profiler_n
**kwargs
)

trainer.train_loop.on_before_zero_grad(self)
trainer.train_loop.on_before_zero_grad(optimizer)

model.optimizer_zero_grad(
trainer.current_epoch,
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/ddp_sequential_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from typing import Any, List, Optional

import torch
import torch.distributed as torch_distrib
from torch import nn
import torch.distributed as torch_distrib
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning import _logger as log
Expand All @@ -29,6 +29,7 @@
if _FAIRSCALE_PIPE_AVAILABLE:
import fairscale.nn.model_parallel as mpu
from fairscale.nn import PipeRPCWrapper
import fairscale.nn.model_parallel as mpu
from fairscale.nn.pipe import balance as pipe_balance
from fairscale.nn.pipe import rpc as rpc_pipe
from fairscale.nn.pipe.pipeline import PipelineStyle
Expand Down Expand Up @@ -380,7 +381,6 @@ def register_optimizers(ctx, model):
model.trainer.optimizers = optimizers
model.trainer.lr_schedulers = lr_schedulers
model.trainer.optimizer_frequencies = optimizer_frequencies
model.trainer.convert_to_lightning_optimizers()


def run_optimizer(ctx, model):
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/plugins/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):

# unscale gradient to allow analyze within `on_after_backward`
if not self.trainer.train_loop.should_accumulate() and automatic_optimization:
self.trainer.scaler.unscale_(optimizer)
if isinstance(optimizer, LightningOptimizer):
self.trainer.scaler.unscale_(optimizer.optimizer)
else:
self.trainer.scaler.unscale_(optimizer)

return closure_loss

Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/sharded_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _reinit_with_fairscale_oss(self, trainer):
optimizers = trainer.optimizers
for x, optimizer in enumerate(optimizers):
if is_lightning_optimizer(optimizer):
optimizer = optimizer._optimizer
optimizer = optimizer.optimizer
if not isinstance(optimizer, OSS):
optim_class = type(optimizer)
zero_optimizer = OSS(
Expand All @@ -73,7 +73,6 @@ def _reinit_with_fairscale_oss(self, trainer):
)
optimizers[x] = zero_optimizer
del optimizer
trainer.convert_to_lightning_optimizers()

def get_model_from_plugin(
self,
Expand Down
17 changes: 0 additions & 17 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,7 @@ def __verify_train_loop_configuration(self, model):

trainer.overriden_optimizer_step = is_overridden('optimizer_step', model)
trainer.overriden_optimizer_zero_grad = is_overridden('optimizer_zero_grad', model)

enable_pl_optimizer = trainer._enable_pl_optimizer
automatic_optimization = trainer.train_loop.automatic_optimization
if trainer.overriden_optimizer_step and not enable_pl_optimizer and automatic_optimization:
rank_zero_warn(
"When overriding `LightningModule` optimizer_step with"
" `Trainer(..., enable_pl_optimizer=False, ...)`,"
" we won't be calling `.zero_grad` we can't assume when you call your `optimizer.step()`."
" For Lightning to take care of it, please use `Trainer(enable_pl_optimizer=True)`."
)

going_to_accumulate_grad_batches = trainer.accumulation_scheduler.going_to_accumulate_grad_batches()

has_overriden_optimization_functions = trainer.overriden_optimizer_step or trainer.overriden_optimizer_zero_grad
Expand All @@ -93,13 +83,6 @@ def __verify_train_loop_configuration(self, model):
' It ensures optimizer_step or optimizer_zero_grad are called on every batch.'
)

if (enable_pl_optimizer) and trainer.overriden_optimizer_zero_grad and not automatic_optimization:
raise MisconfigurationException(
'When overriding `LightningModule` optimizer_zero_grad'
' and preserving model property `automatic_optimization` as True with'
' `Trainer(enable_pl_optimizer=True, ...) is not supported'
)

def __verify_eval_loop_configuration(self, model, eval_loop_name):
step_name = f'{eval_loop_name}_step'

Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/trainer/connectors/optimizer_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ def __init__(self, trainer):
self.trainer = trainer

def on_trainer_init(self, enable_pl_optimizer):
self.trainer._enable_pl_optimizer = enable_pl_optimizer
if enable_pl_optimizer is not None:
rank_zero_warn(
"Trainer argument `enable_pl_optimizer` is deprecated in v1.1.3. It will be removed in v1.3.0",
DeprecationWarning
)
self.trainer.lr_schedulers = []
self.trainer.optimizers = []
self.trainer.optimizer_frequencies = []
Expand Down
Loading

0 comments on commit f28f0aa

Please sign in to comment.