Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to utilize timm's scheduler? #5555

Closed
sooperset opened this issue Jan 18, 2021 · 9 comments · Fixed by #10249
Closed

How to utilize timm's scheduler? #5555

sooperset opened this issue Jan 18, 2021 · 9 comments · Fixed by #10249
Labels
feature Is an improvement or enhancement help wanted Open to be worked on priority: 2 Low priority task won't fix This will not be worked on

Comments

@sooperset
Copy link

sooperset commented Jan 18, 2021

🚀 Feature

Hi, I want to reproduce a result of image classification network by using timm library.
But I couldn't use timm.scheduler.create_scheduler because pytorch_lightning doesn't accept custom class for a scheduler.
(timm.scheduler is not the torch.optim.lr_scheduler class)

from timm.scheduler import create_scheduler
from timm.optim import create_optimizer

def configure_optimizers(self):
    optimizer = create_optimizer(self.args, self)
    scheduler, _ = create_scheduler(self.args, optimizer)

    return {
        'optimizer': optimizer,
        'lr_scheduler': scheduler,
    }

Then results this error

File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/optimizers.py", line 141, in configure_schedulers
raise ValueError(f'The provided lr scheduler "{scheduler}" is invalid')

Is there a plan for utilizing timm's scheduler?

Motivation

Pitch

Alternatives

Additional context

@sooperset sooperset added feature Is an improvement or enhancement help wanted Open to be worked on labels Jan 18, 2021
@DuinoDu
Copy link

DuinoDu commented Jan 18, 2021

One solution is to wrap it into a callback. I met similiar problem, and I created a MMCVLrCallback. Hope help.

@tchaton
Copy link
Contributor

tchaton commented Jan 18, 2021

Dear @soomiles,

It shouldn't be hard to extend the conditions to support it.
Feel free to make a PR.

Best,
T.C

@tchaton tchaton added the priority: 2 Low priority task label Jan 18, 2021
@potipot
Copy link

potipot commented Jan 25, 2021

Hello, has there been any update on this? @soomiles, have you perhaps worked it out?

@DuinoDu
Copy link

DuinoDu commented Jan 26, 2021

Dear @soomiles,

It shouldn't be hard to extend the conditions to support it.
Feel free to make a PR.

Best,
T.C

I plan to make a PR in these days.

@stale
Copy link

stale bot commented Feb 27, 2021

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Feb 27, 2021
@stale stale bot closed this as completed Mar 7, 2021
@jinwon-samsung
Copy link

Is the timm scheduler still not supported?

@dgmp88
Copy link

dgmp88 commented Oct 4, 2021

This wont work with multiple parameter groups, but a quick fix if anyone else want to use it:

from timm.scheduler import StepLRScheduler
class TimmStepLRScheduler(torch.optim.lr_scheduler.LambdaLR):
    def __init__(self, optim, **kwargs):
        self.init_lr = optim.param_groups[0]["lr"]
        self.timmsteplr = StepLRScheduler(optim, **kwargs)
        super().__init__(optim, self)

    def __call__(self, epoch):
        desired_lr = self.timmsteplr.get_epoch_values(epoch)[0]
        mult = desired_lr / self.init_lr
        return mult

@timothylimyl
Copy link

Issue still exists for me (updated to latest lightning version) upon running trainer.fit(model, data):

line 505, in _update_learning_rates
    lr_scheduler["scheduler"].step()
TypeError: step() missing 1 required positional argument: 'epoch'

The fix according to the PR that I followed:

    def configure_optimizers(self):

        return [self.optimizer], [{"scheduler": self.scheduler, "interval": "epoch"}]

    def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
        print(self.current_epoch)
        scheduler.step(
            epoch=self.current_epoch
        )  # timm's scheduler need the epoch value

The scheduler object was: <timm.scheduler.cosine_lr.CosineLRScheduler object at 0x7f6d1dc29be0>. When debugging, it does not seem to go to the function lr_scheduler_step. It will run successfully for a single training epoch and then the error above was raised (which makes sense since LR is scheduled every epoch), this prevents going to the validation set.

@timothylimyl
Copy link

Upon further debugging, I realised that the latest lightning (1.5.10) is not updated yet with this merge. Silly me, I was assuming that the latest tag will follow the master branch as the practice is to only merge when stable and necessary. I saw that it was merged to master and immediately assume that it will be released upon the new updates, but further reading in details (looking at dates after this merge), I see that it was not brought up in any updates. Later on, I found out that it is actually going to only be tagged in the upcoming 1.6 version.

Sharing this for learning purposes in case someone was inexperienced as I am. Remember to check the PR milestone!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on priority: 2 Low priority task won't fix This will not be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants