Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to find suitable lrs with FastaiLRFinder when the optimizer has Multiple groups ? #2703

Closed
xiaoye-hhh opened this issue Sep 7, 2022 · 5 comments · Fixed by #2704
Closed

Comments

@xiaoye-hhh
Copy link

xiaoye-hhh commented Sep 7, 2022

optimizer = optim.SGD([
    {'params': model.conv.parameters(), 'lr': 1},
    {'params': model.linear.parameters(), 'lr': 0.1},
], lr=3e-4, momentum=0.9)

Such as this, the optimizer has two different groups. Can anyone give an example?

@xiaoye-hhh xiaoye-hhh changed the title How to find suitable lrs when the optimzer has Multiple groups? How to find suitable lrs with FastaiLRFinder when the optimzer has Multiple groups ? Sep 7, 2022
@xiaoye-hhh xiaoye-hhh changed the title How to find suitable lrs with FastaiLRFinder when the optimzer has Multiple groups ? How to find suitable lrs with FastaiLRFinder when the optimizer has Multiple groups ? Sep 7, 2022
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 7, 2022

cc @KickItLikeShika

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 7, 2022

@xiaoye-hhh thanks for question, you can check https://pytorch-ignite.ai/how-to-guides/04-fastai-lr-finder/#with-lr-finder and update it to multiple groups.

Our FastaiLRFinder can accept multiple groups but checks with a single lr range without respecting initial lr (e.g. 1.0, 0.1) in your case. So, it means that it will most probably suggest the same lr for both groups.

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.models import resnet18
from torchvision.transforms import Compose, Normalize, ToTensor

from ignite.engine import create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss
from ignite.handlers import FastaiLRFinder

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.model = resnet18(num_classes=20)
        self.model.conv1 = nn.Conv2d(
            1, 64, kernel_size=3, padding=1, bias=False
        )
        self.linear = nn.Linear(20, 10)

    def forward(self, x):
        return self.model(x)


model = Net().to(device)

data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])

train_loader = DataLoader(
    MNIST(download=True, root=".", transform=data_transform, train=True),
    batch_size=128,
    shuffle=True,
)

test_loader = DataLoader(
    MNIST(download=True, root=".", transform=data_transform, train=False),
    batch_size=256,
    shuffle=False,
)


model = Net().to(device)
# optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-06)
optimizer = torch.optim.SGD([
    {'params': model.model.parameters(), 'lr': 0.1},
    {'params': model.linear.parameters(), 'lr': 0.01},
], momentum=0.9)

criterion = nn.CrossEntropyLoss()

trainer = create_supervised_trainer(model, optimizer, criterion, device=device)
lr_finder = FastaiLRFinder()

# To restore the model's and optimizer's states after running the LR Finder
to_save = {"model": model, "optimizer": optimizer}

with lr_finder.attach(trainer, to_save, end_lr=1.0) as trainer_with_lr_finder:
    trainer_with_lr_finder.run(train_loader)

print("Suggested LR", lr_finder.lr_suggestion())
> Suggested LR [0.10451768106330113, 0.10451768106330113]

@xiaoye-hhh
Copy link
Author

Thanks for quick reply. As your code show, the lr_suggestion of the two groups are both 0.1045. However when I use pretrained model and do some change at last, I want the lrs of different groups are different, smaller for pretrained and bigger for new. How to do is?

@vfdev-5 vfdev-5 removed the question label Sep 8, 2022
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 8, 2022

@xiaoye-hhh right now there is no way to perform a search for 2 params groups. We can try to expand the behaviour of our lr finder to accept list of floats for start lr and end lr. It wont be a ND search (where N is the number of groups) as it may be too expensive.

Here is how you can patch current code to make it work with 2 groups:

from ignite.engine import Engine, Events
from typing import Union, List
from ignite.handlers.lr_finder import _LRScheduler
from ignite.handlers.param_scheduler import LRScheduler, PiecewiseLinear


class UpdatedLRFinder(FastaiLRFinder):

  def _run(
        self,
        trainer,
        optimizer,
        output_transform,
        num_iter: int,
        start_lr: Union[None, float, List[float]],
        end_lr: Union[float, List[float]],
        step_mode: str,
        smooth_f: float,
        diverge_th: float,
    ) -> None:

        self._history = {"lr": [], "loss": []}
        self._best_loss = None
        self._diverge_flag = False

        # attach LRScheduler to trainer.
        if num_iter is None:
            num_iter = trainer.state.epoch_length * trainer.state.max_epochs
        else:
            max_iter = trainer.state.epoch_length * trainer.state.max_epochs  # type: ignore[operator]
            if max_iter < num_iter:
                max_iter = num_iter
                trainer.state.max_epochs = ceil(num_iter / trainer.state.epoch_length)  # type: ignore[operator]

        if not trainer.has_event_handler(self._reached_num_iterations):
            trainer.add_event_handler(Events.ITERATION_COMPLETED, self._reached_num_iterations, num_iter)

        # attach loss and lr logging
        if not trainer.has_event_handler(self._log_lr_and_loss):
            trainer.add_event_handler(
                Events.ITERATION_COMPLETED, self._log_lr_and_loss, output_transform, smooth_f, diverge_th
            )

        self.logger.debug(f"Running LR finder for {num_iter} iterations")
        # Initialize the proper learning rate policy
        if step_mode.lower() == "exp":
            if start_lr is None:
                start_lr_list = [optimizer.param_groups[i]["lr"] for i in range(len(optimizer.param_groups))]
            elif isinstance(start_lr, float):
                start_lr_list = [start_lr] * len(optimizer.param_groups)
            else:
                assert len(start_lr) == len(optimizer.param_groups)
                start_lr_list = start_lr

            if isinstance(end_lr, float):
                end_lr_list = [end_lr] * len(optimizer.param_groups)
            else:
                assert len(end_lr) == len(optimizer.param_groups)
                end_lr_list = end_lr
            
            self._lr_schedule = LRScheduler(_ExponentialLR(optimizer, start_lr_list, end_lr_list, num_iter))
        else:
            if isinstance(start_lr, list) or isinstance(end_lr, list):
                assert False, "THIS WONT WORK"
            self._lr_schedule = PiecewiseLinear(
                optimizer, param_name="lr", milestones_values=[(0, start_lr), (num_iter, end_lr)]
            )
        if not trainer.has_event_handler(self._lr_schedule):
            trainer.add_event_handler(Events.ITERATION_COMPLETED, self._lr_schedule, num_iter)


class _ExponentialLR(_LRScheduler): 
    def __init__(self, optimizer, start_lrs: List[float], end_lrs: List[float], num_iter: int, last_epoch: int = -1):
        self.end_lrs = end_lrs
        self.num_iter = num_iter
        super(_ExponentialLR, self).__init__(optimizer, last_epoch)

        # override base_lrs
        self.base_lrs = start_lrs

    def get_lr(self) -> List[float]:  # type: ignore
        curr_iter = self.last_epoch + 1
        r = curr_iter / self.num_iter
        return [base_lr * (end_lr / base_lr) ** r for end_lr, base_lr in zip(self.end_lrs, self.base_lrs)]


lr_finder = UpdatedLRFinder()
with lr_finder.attach(trainer, to_save, start_lr=[0.001, 0.01], end_lr=[0.1, 2.0]) as trainer_with_lr_finder:
    trainer_with_lr_finder.run(train_loader)

print("Suggested LR", lr_finder.lr_suggestion())
> Suggested LR [0.0011473647220774028, 0.01171352105579639]

image

Maybe, @Jacob208M will help to integrate a similar feature to the library.

@xiaoye-hhh
Copy link
Author

Thanks a lot. It's helpful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants