Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer(precision=16) fails with optim.lr_scheduler.ReduceLROnPlateau #2078

Closed
naokishibuya opened this issue Jun 5, 2020 · 5 comments · Fixed by #2356
Closed

Trainer(precision=16) fails with optim.lr_scheduler.ReduceLROnPlateau #2078

naokishibuya opened this issue Jun 5, 2020 · 5 comments · Fixed by #2356
Labels
help wanted Open to be worked on

Comments

@naokishibuya
Copy link

🐛 Bug

To Reproduce

Steps to reproduce the behavior:

  1. Create a pl.LightningModule that returns your optimizer along with a optim.lr_scheduler.ReduceLROnPlateau scheduler from configure_optimizers
  2. Create a pl.Trainer wit precision=16
  3. Run your training (i.e., trainer.fit(model))
  4. See error
Traceback (most recent call last):                                                                                                  
  File "main.py", line 65, in <module>                                                                                              
    main()                                                                                                                          
  File "main.py", line 61, in main                                                                                                  
    trainer.fit(model)                                                                                                              
  File "/workspace/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 889, in fit                             
    self.dp_train(model)                                                                                                            
  File "/workspace/pytorch-lightning/pytorch_lightning/trainer/distrib_parts.py", line 223, in dp_train                  
    self.reinit_scheduler_properties(optimizers, self.lr_schedulers)                                                                
  File "/workspace/pytorch-lightning/pytorch_lightning/trainer/optimizers.py", line 122, in reinit_scheduler_properties  
    scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer)                                                                 
UnboundLocalError: local variable 'idx' referenced before assignment                                                                

The error occurs in pytorch-lightning/pytorch_lightning/trainer/optimizers.py", line 122.

def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
    # Reinitialize optimizer.step properties added by schedulers
    for scheduler in schedulers:
        for optimizer in optimizers:
            scheduler = scheduler['scheduler']
            # check that we dont mix users optimizers and schedulers
            if scheduler.optimizer == optimizer:
                # Find the mro belonging to the base lr scheduler class
                for i, mro in enumerate(scheduler.__class__.__mro__):
                    if mro == optim.lr_scheduler._LRScheduler:
                        idx = i
                scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer)

The idx local variable is unassigned because optim.lr_scheduler.ReduceLROnPlateau is not a subclass of optim.lr_scheduler._LRScheduler.

I could work around the error by adding a specific check for optim.lr_scheduler.ReduceLROnPlateau but I'm not sure if this is a good solution.

def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
    # Reinitialize optimizer.step properties added by schedulers
    for scheduler in schedulers:
        for optimizer in optimizers:
            scheduler = scheduler['scheduler']
            # check that we dont mix users optimizers and schedulers
            if scheduler.optimizer == optimizer:
                # Find the mro belonging to the base lr scheduler class
                for i, mro in enumerate(scheduler.__class__.__mro__):
                    if mro == optim.lr_scheduler._LRScheduler:
                        idx = i
                    elif mro == optim.lr_scheduler.ReduceLROnPlateau:
                        idx = i
                scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer)

Related issue in PyTorch:

ReduceLROnPlateau parent class is not _LRScheduler #21981
pytorch/pytorch#21981

@naokishibuya naokishibuya added the help wanted Open to be worked on label Jun 5, 2020
@github-actions
Copy link
Contributor

github-actions bot commented Jun 5, 2020

Hi! thanks for your contribution!, great first issue!

@SkafteNicki
Copy link
Member

@naokishibuya good catch. It seems like a problem that should be solved upstream in pytorch, but for now we can solve this locally. Would you be up for a PR?

@Anjum48
Copy link

Anjum48 commented Jun 9, 2020

When I tried this fix, it solved the error but unfortunately ReduceLROnPlateau stopped working for me (i.e. there was no indication of the LR decreasing with verbose=True or on TensorBoard). If I switched back to precision=32, it works normally again

@SkafteNicki
Copy link
Member

I think that the fix is actually working, however only calling __init__(scheduler, optimizer) will reset all other arguments (patience, mode, ect) to default values for the ReduceLrOnPlauteau scheduler. A solution to this is to copy over these properties:

__init__(scheduler, optimizer, patience=scheduler.patience,mode=scheduler.mode,...)

Again I think this is a bit hacky, and a proper solution upstream in pytorch is better.

@Anjum48
Copy link

Anjum48 commented Jun 11, 2020

I think this does the trick for me:

def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
    # Reinitialize optimizer.step properties added by schedulers
    for scheduler in schedulers:
        for optimizer in optimizers:
            scheduler = scheduler["scheduler"]
            # check that we dont mix users optimizers and schedulers
            if scheduler.optimizer == optimizer:
                # Find the mro belonging to the base lr scheduler class
                for i, mro in enumerate(scheduler.__class__.__mro__):
                    if (
                        mro == optim.lr_scheduler._LRScheduler
                        or mro == optim.lr_scheduler.ReduceLROnPlateau
                    ):
                        idx = i
                        state = scheduler.state_dict()
                    else:
                        state = None
                scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer)
                if state is not None:
                    scheduler.load_state_dict(state)

Happy to open a PR if it looks ok to you guys

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants