Skip to content

Commit

Permalink
Fix user warning produced by apex + scheduler combination (#1873)
Browse files Browse the repository at this point in the history
* fix user error produced by apex + scheduler combination

* add changelog

* added reinit to every configure_apex call

* fix styling

Co-authored-by: Nicki Skafte <nugginea@gmail.com>
  • Loading branch information
SkafteNicki and Nicki Skafte committed May 22, 2020
1 parent d610f3b commit 8f6b7a2
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed user warning when apex was used together with learning rate schedulers ([#1873](https://github.com/PyTorchLightning/pytorch-lightning/pull/1873))

## [0.7.6] - 2020-05-16

### Added
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ def ddp_train(self, process_idx, model):
if self.use_amp and not self.use_native_amp:
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
self.optimizers = optimizers
self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers)

# DDP2 uses all GPUs on the machine
if self.distributed_backend == 'ddp':
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ def single_gpu_train(self, model):
# An example
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
self.optimizers = optimizers
self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers)

self.run_pretrain_routine(model)

Expand Down Expand Up @@ -559,6 +560,7 @@ def dp_train(self, model):
f' We recommend you switch to ddp if you want to use amp')
else:
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
self.reinit_scheduler_properties(optimizers, self.lr_schedulers)

# create list of device ids
device_ids = self.data_parallel_device_ids
Expand Down Expand Up @@ -599,6 +601,7 @@ def horovod_train(self, model):
# An example
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
self.optimizers = optimizers
self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers)

# Horovod: broadcast parameters & optimizer state to ensure consistent initialization
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
Expand Down
13 changes: 13 additions & 0 deletions pytorch_lightning/trainer/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,19 @@ def configure_schedulers(self, schedulers: list):
'is a invalid input.')
return lr_schedulers

def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
# Reinitialize optimizer.step properties added by schedulers
for scheduler in schedulers:
for optimizer in optimizers:
scheduler = scheduler['scheduler']
# check that we dont mix users optimizers and schedulers
if scheduler.optimizer == optimizer:
# Find the mro belonging to the base lr scheduler class
for i, mro in enumerate(scheduler.__class__.__mro__):
if mro == optim.lr_scheduler._LRScheduler:
idx = i
scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer)


class _MockOptimizer(Optimizer):
"""The `_MockOptimizer` will be used inplace of an optimizer in the event that `None`
Expand Down

0 comments on commit 8f6b7a2

Please sign in to comment.