Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make reduce_on_plateau more flexible #5918

Closed
alexander-soare opened this issue Feb 11, 2021 · 5 comments
Closed

Make reduce_on_plateau more flexible #5918

alexander-soare opened this issue Feb 11, 2021 · 5 comments
Labels
feature Is an improvement or enhancement help wanted Open to be worked on won't fix This will not be worked on

Comments

@alexander-soare
Copy link

🚀 Feature

I want to have freedom to use a custom reduce on plateau type scheduler.

Motivation

Right now, only torch.optim.lr_scheduler.ReduceLROnPlateau is supported out of the box. See here. This means even if I specify {'reduce_on_plateau': True} in:

def configure_optimizers(self):
  optimizer = TrainConf.optimizer(self.parameters(), **TrainConf.optimizer_params)
  scheduler = TrainConf.scheduler(optimizer, **TrainConf.scheduler_params)
  return {
      'optimizer': optimizer, 
      'lr_scheduler': {
          'scheduler': scheduler, 
          'interval': TrainConf.scheduler_interval,
          'monitor': 'val_lwlrap',
          'reduce_on_plateau': isinstance(scheduler,
                                          (WarmupReduceLROnPlateau, # this is my custom class
                                           ReduceLROnPlateau)),
      }
  }

this will be overridden downstream.

Pitch

At the very least TrainerOptimizersMixin.configure_schedulers should check to see if the key was provided, and if so, don't overwrite it. So I propose changing

scheduler['reduce_on_plateau'] = isinstance(
    scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau
)

to

if 'reduce_on_plateau' not in scheduler:
  scheduler['reduce_on_plateau'] = isinstance(
      scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau
  )

But in the long run I'd suggest we allow the user to override a method of Trainer like this:

def step_scheduler(self):
  ...

That way we could access the variables needed and provide our own signature to the step method of the scheduler class.

Additional context

My particular custom scheduler looks like

class WarmupReduceLROnPlateau(ReduceLROnPlateau):
    """
    Subclassing torch.optim.lr_scheduler.ReduceLROnPlateau
    added warmup parameters
    """
    def __init__(self, optimizer, mode='min', factor=0.1, patience=10,
                 threshold=1e-4, threshold_mode='rel', cooldown=0,
                 min_lr=0, eps=1e-8, warmup_itrs=0, warmup_type='lin',
                 start_lr=1e-16, verbose=False):
        super().__init__(optimizer, mode=mode, factor=factor, patience=patience,
                 threshold=threshold, threshold_mode=threshold_mode,
                 cooldown=cooldown, min_lr=min_lr, eps=eps, verbose=verbose)
        self.warmup_itrs = warmup_itrs
        self.warmup_type = warmup_type
        self.start_lr = start_lr
        self.default_lrs = []
        self.itr = 0
        for param_group in optimizer.param_groups:
            self.default_lrs.append(param_group['lr'])

    def step(self, metrics):
        if self.itr < self.warmup_itrs:
            for i, param_group in enumerate(self.optimizer.param_groups):
                if self.warmup_type == 'exp':
                    new_lr = self.start_lr * \
                        (self.default_lrs[i] /
                         self.start_lr)**(self.itr/self.warmup_itrs)
                if self.warmup_type == 'lin':
                    new_lr = self.start_lr + \
                        (self.default_lrs[i] - self.start_lr) * \
                        (self.itr/self.warmup_itrs)
                param_group['lr'] = new_lr
        elif self.itr == self.warmup_itrs:
            param_group['lr'] = self.default_lrs[i]
        else:
            super.step(metrics)
        self.itr += 1

If this gets any traction I'd love to contribute a PR (would be my first to a widely used repo!)

@alexander-soare alexander-soare added feature Is an improvement or enhancement help wanted Open to be worked on labels Feb 11, 2021
@SkafteNicki
Copy link
Member

@alexander-soare I am not completly sure what the problem is as the lightning internals

scheduler['reduce_on_plateau'] = isinstance(
    scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau
)

should set scheduler['reduce_on_plateau'] = True for your custom class. Atleast, I get:

w = WarmupReduceLROnPlateau(torch.optim.Adam(torch.nn.Parameter([1])))
isinstance(w, ReduceLROnPlateau) # True

@alexander-soare
Copy link
Author

@SkafteNicki ah yes you're right, because here I subclass ReduceLROnPlateau in my particular example. I found another way around my specific problem anyway. But more generally if I don't subclass ReduceLROnPlateau, this won't work

@SkafteNicki
Copy link
Member

So historically, we have only supported pytorch schedulers (meaning schedulers subclassing torch.optim.lr_scheduler._LRScheduler) and the ReduceLROnPlateau because we did not want to maintain code that support every possible scheduler the user could come up with. However, I have seen more and more users request to use third-party schedulers (like this issue #5555). So maybe it is time to allow for more general class.

@stale
Copy link

stale bot commented Mar 15, 2021

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Mar 15, 2021
@stale stale bot closed this as completed Mar 22, 2021
@xichenpan
Copy link

Hi, I have a ame request for the feature. and it seems have not been supported in 1.4.1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on won't fix This will not be worked on
Projects
None yet
Development

No branches or pull requests

3 participants