From 586b68dc1c568ded0ffb19cf6bd7b1b34b93ace3 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 27 Sep 2023 19:45:30 +0800 Subject: [PATCH 1/4] feat: add learning rate schedulers; --- pypots/optim/lr_scheduler/__init__.py | 29 ++++ pypots/optim/lr_scheduler/base.py | 159 ++++++++++++++++++ pypots/optim/lr_scheduler/constant_lrs.py | 84 +++++++++ pypots/optim/lr_scheduler/exponential_lrs.py | 55 ++++++ pypots/optim/lr_scheduler/lambda_lrs.py | 79 +++++++++ pypots/optim/lr_scheduler/linear_lrs.py | 115 +++++++++++++ .../optim/lr_scheduler/multiplicative_lrs.py | 77 +++++++++ pypots/optim/lr_scheduler/multistep_lrs.py | 75 +++++++++ pypots/optim/lr_scheduler/step_lrs.py | 70 ++++++++ 9 files changed, 743 insertions(+) create mode 100644 pypots/optim/lr_scheduler/__init__.py create mode 100644 pypots/optim/lr_scheduler/base.py create mode 100644 pypots/optim/lr_scheduler/constant_lrs.py create mode 100644 pypots/optim/lr_scheduler/exponential_lrs.py create mode 100644 pypots/optim/lr_scheduler/lambda_lrs.py create mode 100644 pypots/optim/lr_scheduler/linear_lrs.py create mode 100644 pypots/optim/lr_scheduler/multiplicative_lrs.py create mode 100644 pypots/optim/lr_scheduler/multistep_lrs.py create mode 100644 pypots/optim/lr_scheduler/step_lrs.py diff --git a/pypots/optim/lr_scheduler/__init__.py b/pypots/optim/lr_scheduler/__init__.py new file mode 100644 index 00000000..c0688a51 --- /dev/null +++ b/pypots/optim/lr_scheduler/__init__.py @@ -0,0 +1,29 @@ +""" +Learning rate schedulers available in PyPOTS. Their functionalities are the same with those in PyTorch, +the only difference that is also why we implement them is that you don't have to pass according optimizers +into them immediately while initializing them. Instead, you can pass them into pypots.optim.Optimizer +after initialization and call their `init_scheduler()` method in Optimizer.init_optimizer() to initialize +schedulers together with optimizers. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from .lambda_lrs import LambdaLR +from .multiplicative_lrs import MultiplicativeLR +from .step_lrs import StepLR +from .multistep_lrs import MultiStepLR +from .constant_lrs import ConstantLR +from .exponential_lrs import ExponentialLR +from .linear_lrs import LinearLR + + +__all__ = [ + "LambdaLR", + "MultiplicativeLR", + "StepLR", + "MultiStepLR", + "ConstantLR", + "ExponentialLR", + "LinearLR", +] diff --git a/pypots/optim/lr_scheduler/base.py b/pypots/optim/lr_scheduler/base.py new file mode 100644 index 00000000..3c5af3b7 --- /dev/null +++ b/pypots/optim/lr_scheduler/base.py @@ -0,0 +1,159 @@ +""" +The base class for learning rate schedulers. This class is adapted from PyTorch, +please refer to torch.optim.lr_scheduler for more details. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +import weakref +from abc import ABC, abstractmethod +from functools import wraps + +from torch.optim import Optimizer + +from ...utils.logging import logger + + +class LRScheduler(ABC): + """Base class for PyPOTS learning rate schedulers. + + Parameters + ---------- + last_epoch: int + The index of last epoch. Default: -1. + + verbose: If ``True``, prints a message to stdout for + each update. Default: ``False``. + + """ + + def __init__(self, last_epoch=-1, verbose=False): + self.last_epoch = last_epoch + self.verbose = verbose + self.optimizer = None + self.base_lrs = None + self._last_lr = None + self._step_count = 0 + + def init_scheduler(self, optimizer): + """Initialize the scheduler. This method should be called in pypots.optim.Optimizer.init_optimizer() + to initialize the scheduler together with the optimizer. + + Parameters + ---------- + optimizer: torch.optim.Optimizer, + The optimizer to be scheduled. + + """ + + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) + self.optimizer = optimizer + + # Initialize epoch and base learning rates + if self.last_epoch == -1: + for group in optimizer.param_groups: + group.setdefault("initial_lr", group["lr"]) + else: + for i, group in enumerate(optimizer.param_groups): + if "initial_lr" not in group: + raise KeyError( + "param 'initial_lr' is not specified " + "in param_groups[{}] when resuming an optimizer".format(i) + ) + self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups] + + # Following https://github.com/pytorch/pytorch/issues/20124 + # We would like to ensure that `lr_scheduler.step()` is called after + # `optimizer.step()` + def with_counter(method): + if getattr(method, "_with_counter", False): + # `optimizer.step()` has already been replaced, return. + return method + + # Keep a weak reference to the optimizer instance to prevent + # cyclic references. + instance_ref = weakref.ref(method.__self__) + # Get the unbound method for the same purpose. + func = method.__func__ + cls = instance_ref().__class__ + del method + + @wraps(func) + def wrapper(*args, **kwargs): + instance = instance_ref() + instance._step_count += 1 + wrapped = func.__get__(instance, cls) + return wrapped(*args, **kwargs) + + # Note that the returned function here is no longer a bound method, + # so attributes like `__func__` and `__self__` no longer exist. + wrapper._with_counter = True + return wrapper + + self.optimizer.step = with_counter(self.optimizer.step) + self.optimizer._step_count = 0 + + @abstractmethod + def get_lr(self): + """Compute learning rate.""" + # Compute learning rate using chainable form of the scheduler + raise NotImplementedError + + def get_last_lr(self): + """Return last computed learning rate by current scheduler.""" + return self._last_lr + + @staticmethod + def print_lr(is_verbose, group, lr): + """Display the current learning rate.""" + if is_verbose: + logger.info(f"Adjusting learning rate of group {group} to {lr:.4e}.") + + def step(self): + # Raise a warning if old pattern is detected + # https://github.com/pytorch/pytorch/issues/20124 + if self._step_count == 1: + if not hasattr(self.optimizer.step, "_with_counter"): + logger.warning( + "Seems like `optimizer.step()` has been overridden after learning rate scheduler " + "initialization. Please, make sure to call `optimizer.step()` before " + "`lr_scheduler.step()`. See more details at " + "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", + ) + + # Just check if there were two first lr_scheduler.step() calls before optimizer.step() + elif self.optimizer._step_count < 1: + logger.warning.warn( + "Detected call of `lr_scheduler.step()` before `optimizer.step()`. " + "In PyTorch 1.1.0 and later, you should call them in the opposite order: " + "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this " + "will result in PyTorch skipping the first value of the learning rate schedule. " + "See more details at " + "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", + ) + self._step_count += 1 + + class _enable_get_lr_call: + def __init__(self, o): + self.o = o + + def __enter__(self): + self.o._get_lr_called_within_step = True + return self + + def __exit__(self, type, value, traceback): + self.o._get_lr_called_within_step = False + + with _enable_get_lr_call(self): + self.last_epoch += 1 + values = self.get_lr() + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + param_group["lr"] = lr + self.print_lr(self.verbose, i, lr) + + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] diff --git a/pypots/optim/lr_scheduler/constant_lrs.py b/pypots/optim/lr_scheduler/constant_lrs.py new file mode 100644 index 00000000..3f6ae1a3 --- /dev/null +++ b/pypots/optim/lr_scheduler/constant_lrs.py @@ -0,0 +1,84 @@ +""" +Constant learning rate scheduler. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from .base import LRScheduler, logger + + +class ConstantLR(LRScheduler): + """Decays the learning rate of each parameter group by a small constant factor until the number of epoch reaches + a pre-defined milestone: total_iters. Notice that such decay can happen simultaneously with other changes + to the learning rate from outside this scheduler. When last_epoch=-1, sets initial lr as lr. + + Parameters + ---------- + factor: float, default=1./3. + The number we multiply learning rate until the milestone. + + total_iters: int, default=5, + The number of steps that the scheduler decays the learning rate. + + last_epoch: int + The index of last epoch. Default: -1. + + verbose: bool + If ``True``, prints a message to stdout for each update. Default: ``False``. + + Notes + ----- + This class works the same with ``torch.optim.lr_scheduler.ConstantLR``. + The only difference that is also why we implement them is that you don't have to pass according optimizers + into them immediately while initializing them. + + Example + ------- + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.025 if epoch == 0 + >>> # lr = 0.025 if epoch == 1 + >>> # lr = 0.025 if epoch == 2 + >>> # lr = 0.025 if epoch == 3 + >>> # lr = 0.05 if epoch >= 4 + >>> # xdoctest: +SKIP + >>> scheduler = ConstantLR(factor=0.5, total_iters=4) + >>> adam = pypots.optim.Adam(lr=1e-3, lr_scheduler=scheduler) + + """ + + def __init__(self, factor=1.0 / 3, total_iters=5, last_epoch=-1, verbose=False): + super().__init__(last_epoch, verbose) + if factor > 1.0 or factor < 0: + raise ValueError( + "Constant multiplicative factor expected to be between 0 and 1." + ) + + self.factor = factor + self.total_iters = total_iters + + def get_lr(self): + if not self._get_lr_called_within_step: + logger.warning( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + ) + + if self.last_epoch == 0: + return [group["lr"] * self.factor for group in self.optimizer.param_groups] + + if self.last_epoch > self.total_iters or (self.last_epoch != self.total_iters): + return [group["lr"] for group in self.optimizer.param_groups] + + if self.last_epoch == self.total_iters: + return [ + group["lr"] * (1.0 / self.factor) + for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self): + return [ + base_lr + * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor)) + for base_lr in self.base_lrs + ] diff --git a/pypots/optim/lr_scheduler/exponential_lrs.py b/pypots/optim/lr_scheduler/exponential_lrs.py new file mode 100644 index 00000000..ed7e960f --- /dev/null +++ b/pypots/optim/lr_scheduler/exponential_lrs.py @@ -0,0 +1,55 @@ +""" +Exponential learning rate scheduler. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from .base import LRScheduler, logger + + +class ExponentialLR(LRScheduler): + """Decays the learning rate of each parameter group by gamma every epoch. When last_epoch=-1, sets initial lr as lr. + + Parameters + ---------- + gamma: float, + Multiplicative factor of learning rate decay. + + last_epoch: int + The index of last epoch. Default: -1. + + verbose: bool + If ``True``, prints a message to stdout for each update. Default: ``False``. + + Notes + ----- + This class works the same with ``torch.optim.lr_scheduler.ExponentialLR``. + The only difference that is also why we implement them is that you don't have to pass according optimizers + into them immediately while initializing them. + + Example + ------- + >>> scheduler = ExponentialLR(gamma=0.1) + >>> adam = pypots.optim.Adam(lr=1e-3, lr_scheduler=scheduler) + + """ + + def __init__(self, gamma, last_epoch=-1, verbose=False): + super().__init__(last_epoch, verbose) + + self.gamma = gamma + + def get_lr(self): + if not self._get_lr_called_within_step: + logger.warning( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + ) + + if self.last_epoch == 0: + return [group["lr"] for group in self.optimizer.param_groups] + return [group["lr"] * self.gamma for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [base_lr * self.gamma**self.last_epoch for base_lr in self.base_lrs] diff --git a/pypots/optim/lr_scheduler/lambda_lrs.py b/pypots/optim/lr_scheduler/lambda_lrs.py new file mode 100644 index 00000000..5471cee6 --- /dev/null +++ b/pypots/optim/lr_scheduler/lambda_lrs.py @@ -0,0 +1,79 @@ +""" +Lambda learning rate scheduler. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from typing import Callable, Union + +from .base import LRScheduler, logger + + +class LambdaLR(LRScheduler): + """Sets the learning rate of each parameter group to the initial lr times a given function. + When last_epoch=-1, sets initial lr as lr. + + Parameters + ---------- + lr_lambda: Callable or list, + A function which computes a multiplicative factor given an integer parameter epoch, or a list of such + functions, one for each group in optimizer.param_groups. + + last_epoch: int, + The index of last epoch. Default: -1. + + verbose: bool, + If ``True``, prints a message to stdout for each update. Default: ``False``. + + Notes + ----- + This class works the same with ``torch.optim.lr_scheduler.LambdaLR``. + The only difference that is also why we implement them is that you don't have to pass according optimizers + into them immediately while initializing them. + + Example + ------- + >>> lambda1 = lambda epoch: epoch // 30 + >>> scheduler = LambdaLR(lr_lambda=lambda1) + >>> adam = pypots.optim.Adam(lr=1e-3, lr_scheduler=scheduler) + + """ + + def __init__( + self, + lr_lambda: Union[Callable, list], + last_epoch: int = -1, + verbose: bool = False, + ): + super().__init__(last_epoch, verbose) + self.lr_lambda = lr_lambda + self.lr_lambdas = None + + def init_scheduler(self, optimizer): + if not isinstance(self.lr_lambda, list) and not isinstance( + self.lr_lambda, tuple + ): + self.lr_lambdas = [self.lr_lambda] * len(optimizer.param_groups) + else: + if len(self.lr_lambda) != len(optimizer.param_groups): + raise ValueError( + "Expected {} lr_lambdas, but got {}".format( + len(optimizer.param_groups), len(self.lr_lambda) + ) + ) + self.lr_lambdas = list(self.lr_lambda) + + super().init_scheduler(optimizer) + + def get_lr(self): + if not self._get_lr_called_within_step: + logger.warning( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`." + ) + + return [ + base_lr * lmbda(self.last_epoch) + for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs) + ] diff --git a/pypots/optim/lr_scheduler/linear_lrs.py b/pypots/optim/lr_scheduler/linear_lrs.py new file mode 100644 index 00000000..a1e8e1e6 --- /dev/null +++ b/pypots/optim/lr_scheduler/linear_lrs.py @@ -0,0 +1,115 @@ +""" +Linear learning rate scheduler. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from .base import LRScheduler, logger + + +class LinearLR(LRScheduler): + """Decays the learning rate of each parameter group by linearly changing small multiplicative factor until + the number of epoch reaches a pre-defined milestone: total_iters. Notice that such decay can happen simultaneously + with other changes to the learning rate from outside this scheduler. When last_epoch=-1, sets initial lr as lr. + + Parameters + ---------- + start_factor: float, default=1.0 / 3, + The number we multiply learning rate in the first epoch. The multiplication factor changes towards + end_factor in the following epochs. + + end_factor: float, default=1.0, + The number we multiply learning rate at the end of linear changing process. + + total_iters: int, default=5, + The number of iterations that multiplicative factor reaches to 1. + + last_epoch: int + The index of last epoch. Default: -1. + + verbose: bool + If ``True``, prints a message to stdout for each update. Default: ``False``. + + Notes + ----- + This class works the same with ``torch.optim.lr_scheduler.LinearLR``. + The only difference that is also why we implement them is that you don't have to pass according optimizers + into them immediately while initializing them. + + Example + ------- + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.025 if epoch == 0 + >>> # lr = 0.03125 if epoch == 1 + >>> # lr = 0.0375 if epoch == 2 + >>> # lr = 0.04375 if epoch == 3 + >>> # lr = 0.05 if epoch >= 4 + >>> # xdoctest: +SKIP + >>> scheduler = LinearLR(start_factor=0.5, total_iters=4) + >>> adam = pypots.optim.Adam(lr=1e-3, lr_scheduler=scheduler) + + """ + + def __init__( + self, + start_factor=1.0 / 3, + end_factor=1.0, + total_iters=5, + last_epoch=-1, + verbose=False, + ): + super().__init__(last_epoch, verbose) + if start_factor > 1.0 or start_factor < 0: + raise ValueError( + "Starting multiplicative factor expected to be between 0 and 1." + ) + + if end_factor > 1.0 or end_factor < 0: + raise ValueError( + "Ending multiplicative factor expected to be between 0 and 1." + ) + + self.start_factor = start_factor + self.end_factor = end_factor + self.total_iters = total_iters + + def get_lr(self): + if not self._get_lr_called_within_step: + logger.warning( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + ) + + if self.last_epoch == 0: + return [ + group["lr"] * self.start_factor for group in self.optimizer.param_groups + ] + + if self.last_epoch > self.total_iters: + return [group["lr"] for group in self.optimizer.param_groups] + + return [ + group["lr"] + * ( + 1.0 + + (self.end_factor - self.start_factor) + / ( + self.total_iters * self.start_factor + + (self.last_epoch - 1) * (self.end_factor - self.start_factor) + ) + ) + for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self): + return [ + base_lr + * ( + self.start_factor + + (self.end_factor - self.start_factor) + * min(self.total_iters, self.last_epoch) + / self.total_iters + ) + for base_lr in self.base_lrs + ] diff --git a/pypots/optim/lr_scheduler/multiplicative_lrs.py b/pypots/optim/lr_scheduler/multiplicative_lrs.py new file mode 100644 index 00000000..5dbc18ea --- /dev/null +++ b/pypots/optim/lr_scheduler/multiplicative_lrs.py @@ -0,0 +1,77 @@ +""" +Multiplicative learning rate scheduler. +""" + +# Created by Wenjie Du +# License: GLP-v3 + + +from .base import LRScheduler, logger + + +class MultiplicativeLR(LRScheduler): + """Multiply the learning rate of each parameter group by the factor given in the specified function. + When last_epoch=-1, sets initial lr as lr. + + Parameters + ---------- + lr_lambda: Callable or list, + A function which computes a multiplicative factor given an integer parameter epoch, or a list of such + functions, one for each group in optimizer.param_groups. + + last_epoch: int, + The index of last epoch. Default: -1. + + verbose: bool, + If ``True``, prints a message to stdout for each update. Default: ``False``. + + Notes + ----- + This class works the same with ``torch.optim.lr_scheduler.MultiplicativeLR``. + The only difference that is also why we implement them is that you don't have to pass according optimizers + into them immediately while initializing them. + + Example + ------- + >>> lmbda = lambda epoch: 0.95 + >>> # xdoctest: +SKIP + >>> scheduler = MultiplicativeLR(lr_lambda=lmbda) + >>> adam = pypots.optim.Adam(lr=1e-3, lr_scheduler=scheduler) + + """ + + def __init__(self, lr_lambda, last_epoch=-1, verbose=False): + super().__init__(last_epoch, verbose) + self.lr_lambda = lr_lambda + self.lr_lambdas = None + + def init_scheduler(self, optimizer): + if not isinstance(self.lr_lambda, list) and not isinstance( + self.lr_lambda, tuple + ): + self.lr_lambdas = [self.lr_lambda] * len(optimizer.param_groups) + else: + if len(self.lr_lambda) != len(optimizer.param_groups): + raise ValueError( + "Expected {} lr_lambdas, but got {}".format( + len(optimizer.param_groups), len(self.lr_lambda) + ) + ) + self.lr_lambdas = list(self.lr_lambda) + + super().init_scheduler(optimizer) + + def get_lr(self): + if not self._get_lr_called_within_step: + logger.warning( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + ) + + if self.last_epoch > 0: + return [ + group["lr"] * lmbda(self.last_epoch) + for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups) + ] + else: + return [group["lr"] for group in self.optimizer.param_groups] diff --git a/pypots/optim/lr_scheduler/multistep_lrs.py b/pypots/optim/lr_scheduler/multistep_lrs.py new file mode 100644 index 00000000..567570e9 --- /dev/null +++ b/pypots/optim/lr_scheduler/multistep_lrs.py @@ -0,0 +1,75 @@ +""" +Multistep learning rate scheduler. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from bisect import bisect_right +from collections import Counter + +from .base import LRScheduler, logger + + +class MultiStepLR(LRScheduler): + """Decays the learning rate of each parameter group by gamma once the number of epoch reaches one of the milestones. + Notice that such decay can happen simultaneously with other changes to the learning rate from outside this + scheduler. When last_epoch=-1, sets initial lr as lr. + + Parameters + ---------- + milestones: list, + List of epoch indices. Must be increasing. + + gamma: float, default=0.1, + Multiplicative factor of learning rate decay. + + last_epoch: int + The index of last epoch. Default: -1. + + verbose: bool + If ``True``, prints a message to stdout for each update. Default: ``False``. + + Notes + ----- + This class works the same with ``torch.optim.lr_scheduler.MultiStepLR``. + The only difference that is also why we implement them is that you don't have to pass according optimizers + into them immediately while initializing them. + + Example + ------- + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.05 if epoch < 30 + >>> # lr = 0.005 if 30 <= epoch < 80 + >>> # lr = 0.0005 if epoch >= 80 + >>> # xdoctest: +SKIP + >>> scheduler = MultiStepLR(milestones=[30,80], gamma=0.1) + >>> adam = pypots.optim.Adam(lr=1e-3, lr_scheduler=scheduler) + + """ + + def __init__(self, milestones, gamma=0.1, last_epoch=-1, verbose=False): + super().__init__(last_epoch, verbose) + self.milestones = Counter(milestones) + self.gamma = gamma + + def get_lr(self): + if not self._get_lr_called_within_step: + logger.warning( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + ) + + if self.last_epoch not in self.milestones: + return [group["lr"] for group in self.optimizer.param_groups] + return [ + group["lr"] * self.gamma ** self.milestones[self.last_epoch] + for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self): + milestones = list(sorted(self.milestones.elements())) + return [ + base_lr * self.gamma ** bisect_right(milestones, self.last_epoch) + for base_lr in self.base_lrs + ] diff --git a/pypots/optim/lr_scheduler/step_lrs.py b/pypots/optim/lr_scheduler/step_lrs.py new file mode 100644 index 00000000..29f72bb8 --- /dev/null +++ b/pypots/optim/lr_scheduler/step_lrs.py @@ -0,0 +1,70 @@ +""" +Step learning rate scheduler. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from .base import LRScheduler, logger + + +class StepLR(LRScheduler): + """Decays the learning rate of each parameter group by gamma every step_size epochs. Notice that such decay can + happen simultaneously with other changes to the learning rate from outside this scheduler. + When last_epoch=-1, sets initial lr as lr. + + Parameters + ---------- + step_size: int, + Period of learning rate decay. + + gamma: float, default=0.1, + Multiplicative factor of learning rate decay. + + last_epoch: int + The index of last epoch. Default: -1. + + verbose: bool + If ``True``, prints a message to stdout for each update. Default: ``False``. + + Notes + ----- + This class works the same with ``torch.optim.lr_scheduler.StepLR``. + The only difference that is also why we implement them is that you don't have to pass according optimizers + into them immediately while initializing them. + + Example + ------- + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.05 if epoch < 30 + >>> # lr = 0.005 if 30 <= epoch < 60 + >>> # lr = 0.0005 if 60 <= epoch < 90 + >>> # ... + >>> # xdoctest: +SKIP + >>> scheduler = StepLR(step_size=30, gamma=0.1) + >>> adam = pypots.optim.Adam(lr=1e-3, lr_scheduler=scheduler) + + """ + + def __init__(self, step_size, gamma=0.1, last_epoch=-1, verbose=False): + super().__init__(last_epoch, verbose) + + self.step_size = step_size + self.gamma = gamma + + def get_lr(self): + if not self._get_lr_called_within_step: + logger.warning( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + ) + + if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0): + return [group["lr"] for group in self.optimizer.param_groups] + return [group["lr"] * self.gamma for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [ + base_lr * self.gamma ** (self.last_epoch // self.step_size) + for base_lr in self.base_lrs + ] From dd626b312bad9fae2f00171b6e10bb1124de3b28 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 27 Sep 2023 19:49:25 +0800 Subject: [PATCH 2/4] feat: integrate schedulers into optimizers; --- pypots/optim/adadelta.py | 9 +++++++-- pypots/optim/adagrad.py | 9 +++++++-- pypots/optim/adam.py | 9 +++++++-- pypots/optim/adamw.py | 9 +++++++-- pypots/optim/base.py | 8 +++++++- pypots/optim/rmsprop.py | 9 +++++++-- pypots/optim/sgd.py | 9 +++++++-- 7 files changed, 49 insertions(+), 13 deletions(-) diff --git a/pypots/optim/adadelta.py b/pypots/optim/adadelta.py index ac4726d3..59e98f2a 100644 --- a/pypots/optim/adadelta.py +++ b/pypots/optim/adadelta.py @@ -6,11 +6,12 @@ # Created by Wenjie Du # License: GLP-v3 -from typing import Iterable +from typing import Iterable, Optional from torch.optim import Adadelta as torch_Adadelta from .base import Optimizer +from .lr_scheduler.base import LRScheduler class Adadelta(Optimizer): @@ -39,8 +40,9 @@ def __init__( rho: float = 0.9, eps: float = 1e-08, weight_decay: float = 0.01, + lr_scheduler: Optional[LRScheduler] = None, ): - super().__init__(lr) + super().__init__(lr, lr_scheduler) self.rho = rho self.eps = eps self.weight_decay = weight_decay @@ -61,3 +63,6 @@ def init_optimizer(self, params: Iterable) -> None: eps=self.eps, weight_decay=self.weight_decay, ) + + if self.lr_scheduler is not None: + self.lr_scheduler.init_scheduler(self.torch_optimizer) diff --git a/pypots/optim/adagrad.py b/pypots/optim/adagrad.py index e4374244..8a10f06c 100644 --- a/pypots/optim/adagrad.py +++ b/pypots/optim/adagrad.py @@ -6,11 +6,12 @@ # Created by Wenjie Du # License: GLP-v3 -from typing import Iterable +from typing import Iterable, Optional from torch.optim import Adagrad as torch_Adagrad from .base import Optimizer +from .lr_scheduler.base import LRScheduler class Adagrad(Optimizer): @@ -43,8 +44,9 @@ def __init__( weight_decay: float = 0.01, initial_accumulator_value: float = 0.01, # it is set as 0 in the torch implementation, but delta shouldn't be 0 eps: float = 1e-08, + lr_scheduler: Optional[LRScheduler] = None, ): - super().__init__(lr) + super().__init__(lr, lr_scheduler) self.lr_decay = lr_decay self.weight_decay = weight_decay self.initial_accumulator_value = initial_accumulator_value @@ -67,3 +69,6 @@ def init_optimizer(self, params: Iterable) -> None: initial_accumulator_value=self.initial_accumulator_value, eps=self.eps, ) + + if self.lr_scheduler is not None: + self.lr_scheduler.init_scheduler(self.torch_optimizer) diff --git a/pypots/optim/adam.py b/pypots/optim/adam.py index d308b27e..c5e0e1af 100644 --- a/pypots/optim/adam.py +++ b/pypots/optim/adam.py @@ -6,11 +6,12 @@ # Created by Wenjie Du # License: GLP-v3 -from typing import Iterable, Tuple +from typing import Iterable, Tuple, Optional from torch.optim import Adam as torch_Adam from .base import Optimizer +from .lr_scheduler.base import LRScheduler class Adam(Optimizer): @@ -42,8 +43,9 @@ def __init__( eps: float = 1e-08, weight_decay: float = 0, amsgrad: bool = False, + lr_scheduler: Optional[LRScheduler] = None, ): - super().__init__(lr) + super().__init__(lr, lr_scheduler) self.betas = betas self.eps = eps self.weight_decay = weight_decay @@ -66,3 +68,6 @@ def init_optimizer(self, params: Iterable) -> None: weight_decay=self.weight_decay, amsgrad=self.amsgrad, ) + + if self.lr_scheduler is not None: + self.lr_scheduler.init_scheduler(self.torch_optimizer) diff --git a/pypots/optim/adamw.py b/pypots/optim/adamw.py index b93d8a74..6a5191e4 100644 --- a/pypots/optim/adamw.py +++ b/pypots/optim/adamw.py @@ -6,11 +6,12 @@ # Created by Wenjie Du # License: GLP-v3 -from typing import Iterable, Tuple +from typing import Iterable, Tuple, Optional from torch.optim import AdamW as torch_AdamW from .base import Optimizer +from .lr_scheduler.base import LRScheduler class AdamW(Optimizer): @@ -42,8 +43,9 @@ def __init__( eps: float = 1e-08, weight_decay: float = 0.01, amsgrad: bool = False, + lr_scheduler: Optional[LRScheduler] = None, ): - super().__init__(lr) + super().__init__(lr, lr_scheduler) self.betas = betas self.eps = eps self.weight_decay = weight_decay @@ -66,3 +68,6 @@ def init_optimizer(self, params: Iterable) -> None: weight_decay=self.weight_decay, amsgrad=self.amsgrad, ) + + if self.lr_scheduler is not None: + self.lr_scheduler.init_scheduler(self.torch_optimizer) diff --git a/pypots/optim/base.py b/pypots/optim/base.py index f1bb9637..db09fb3a 100644 --- a/pypots/optim/base.py +++ b/pypots/optim/base.py @@ -19,6 +19,8 @@ from abc import ABC, abstractmethod from typing import Callable, Iterable, Optional +from .lr_scheduler.base import LRScheduler + class Optimizer(ABC): """The base wrapper for PyTorch optimizers, also is the base class for all optimizers in pypots.optim. @@ -35,9 +37,10 @@ class Optimizer(ABC): """ - def __init__(self, lr): + def __init__(self, lr, lr_scheduler: Optional[LRScheduler] = None): self.lr = lr self.torch_optimizer = None + self.lr_scheduler = lr_scheduler @abstractmethod def init_optimizer(self, params: Iterable) -> None: @@ -97,6 +100,9 @@ def step(self, closure: Optional[Callable] = None) -> None: """ self.torch_optimizer.step(closure) + if self.lr_scheduler is not None: + self.lr_scheduler.step() + def zero_grad(self, set_to_none: bool = True) -> None: """Sets the gradients of all optimized ``torch.Tensor`` to zero. diff --git a/pypots/optim/rmsprop.py b/pypots/optim/rmsprop.py index 65a817ca..f00da68d 100644 --- a/pypots/optim/rmsprop.py +++ b/pypots/optim/rmsprop.py @@ -6,11 +6,12 @@ # Created by Wenjie Du # License: GLP-v3 -from typing import Iterable +from typing import Iterable, Optional from torch.optim import RMSprop as torch_RMSprop from .base import Optimizer +from .lr_scheduler.base import LRScheduler class RMSprop(Optimizer): @@ -47,8 +48,9 @@ def __init__( eps: float = 1e-08, centered: bool = False, weight_decay: float = 0, + lr_scheduler: Optional[LRScheduler] = None, ): - super().__init__(lr) + super().__init__(lr, lr_scheduler) self.momentum = momentum self.alpha = alpha self.eps = eps @@ -73,3 +75,6 @@ def init_optimizer(self, params: Iterable) -> None: centered=self.centered, weight_decay=self.weight_decay, ) + + if self.lr_scheduler is not None: + self.lr_scheduler.init_scheduler(self.torch_optimizer) diff --git a/pypots/optim/sgd.py b/pypots/optim/sgd.py index 4696db91..34cd07f0 100644 --- a/pypots/optim/sgd.py +++ b/pypots/optim/sgd.py @@ -6,11 +6,12 @@ # Created by Wenjie Du # License: GLP-v3 -from typing import Iterable +from typing import Iterable, Optional from torch.optim import SGD as torch_SGD from .base import Optimizer +from .lr_scheduler.base import LRScheduler class SGD(Optimizer): @@ -43,8 +44,9 @@ def __init__( weight_decay: float = 0, dampening: float = 0, nesterov: bool = False, + lr_scheduler: Optional[LRScheduler] = None, ): - super().__init__(lr) + super().__init__(lr, lr_scheduler) self.momentum = momentum self.weight_decay = weight_decay self.dampening = dampening @@ -67,3 +69,6 @@ def init_optimizer(self, params: Iterable) -> None: dampening=self.dampening, nesterov=self.nesterov, ) + + if self.lr_scheduler is not None: + self.lr_scheduler.init_scheduler(self.torch_optimizer) From f7fc32c949f89a1d91985dd7187eb4056395bf38 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 27 Sep 2023 20:17:12 +0800 Subject: [PATCH 3/4] feat: add testing cases of learn rate schedulers; --- tests/optim/lr_schedulers.py | 249 +++++++++++++++++++++++++++++++++++ 1 file changed, 249 insertions(+) create mode 100644 tests/optim/lr_schedulers.py diff --git a/tests/optim/lr_schedulers.py b/tests/optim/lr_schedulers.py new file mode 100644 index 00000000..e7748f91 --- /dev/null +++ b/tests/optim/lr_schedulers.py @@ -0,0 +1,249 @@ +""" +Test cases for the learning rate schedulers. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +import unittest + +import numpy as np +import pytest + +from pypots.imputation import SAITS +from pypots.optim import Adam, AdamW, Adadelta, Adagrad, RMSprop, SGD +from pypots.optim.lr_scheduler import ( + LambdaLR, + ConstantLR, + ExponentialLR, + LinearLR, + StepLR, + MultiStepLR, + MultiplicativeLR, +) +from pypots.utils.logging import logger +from pypots.utils.metrics import cal_mae +from tests.global_test_config import DATA +from tests.optim.config import EPOCHS, TEST_SET, TRAIN_SET, VAL_SET + + +class TestLRSchedulers(unittest.TestCase): + logger.info("Running tests for learning rate schedulers...") + + # init lambda_lrs + lambda_lrs = LambdaLR(lr_lambda=lambda epoch: epoch // 30, verbose=True) + + # init multiplicative_lrs + multiplicative_lrs = MultiplicativeLR(lr_lambda=lambda epoch: 0.95, verbose=True) + + # init step_lrs + step_lrs = StepLR(step_size=30, gamma=0.1, verbose=True) + + # init multistep_lrs + multistep_lrs = MultiStepLR(milestones=[30, 80], gamma=0.1, verbose=True) + + # init constant_lrs + constant_lrs = ConstantLR(factor=0.5, total_iters=4, verbose=True) + + # init linear_lrs + linear_lrs = LinearLR(start_factor=0.5, total_iters=4, verbose=True) + + # init exponential_lrs + exponential_lrs = ExponentialLR(gamma=0.9, verbose=True) + + @pytest.mark.xdist_group(name="lrs-lambda") + def test_0_lambda_lrs(self): + logger.info("Running tests for Adam + LambdaLRS...") + + adam = Adam(lr=0.001, weight_decay=1e-5, lr_scheduler=self.lambda_lrs) + saits = SAITS( + DATA["n_steps"], + DATA["n_features"], + n_layers=1, + d_model=128, + d_inner=64, + n_heads=2, + d_k=64, + d_v=64, + dropout=0.1, + optimizer=adam, + epochs=EPOCHS, + ) + saits.fit(TRAIN_SET, VAL_SET) + imputed_X = saits.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"SAITS test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="lrs-multiplicative") + def test_1_multiplicative_lrs(self): + logger.info("Running tests for Adamw + MultiplicativeLRS...") + + adamw = AdamW(lr=0.001, weight_decay=1e-5, lr_scheduler=self.multiplicative_lrs) + saits = SAITS( + DATA["n_steps"], + DATA["n_features"], + n_layers=1, + d_model=128, + d_inner=64, + n_heads=2, + d_k=64, + d_v=64, + dropout=0.1, + optimizer=adamw, + epochs=EPOCHS, + ) + saits.fit(TRAIN_SET, VAL_SET) + imputed_X = saits.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"SAITS test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="lrs-step") + def test_2_step_lrs(self): + logger.info("Running tests for Adadelta + StepLRS...") + + adamw = Adadelta(lr=0.001, lr_scheduler=self.step_lrs) + saits = SAITS( + DATA["n_steps"], + DATA["n_features"], + n_layers=1, + d_model=128, + d_inner=64, + n_heads=2, + d_k=64, + d_v=64, + dropout=0.1, + optimizer=adamw, + epochs=EPOCHS, + ) + saits.fit(TRAIN_SET, VAL_SET) + imputed_X = saits.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"SAITS test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="lrs-multistep") + def test_3_multistep_lrs(self): + logger.info("Running tests for Adadelta + MultiStepLRS...") + + adagrad = Adagrad(lr=0.001, lr_scheduler=self.multistep_lrs) + saits = SAITS( + DATA["n_steps"], + DATA["n_features"], + n_layers=1, + d_model=128, + d_inner=64, + n_heads=2, + d_k=64, + d_v=64, + dropout=0.1, + optimizer=adagrad, + epochs=EPOCHS, + ) + saits.fit(TRAIN_SET, VAL_SET) + imputed_X = saits.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"SAITS test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="lrs-constant") + def test_4_constant_lrs(self): + logger.info("Running tests for RMSprop + ConstantLRS...") + + # initialize a SAITS model for testing DatasetForMIT and BaseDataset + rmsprop = RMSprop(lr=0.001, lr_scheduler=self.constant_lrs) + saits = SAITS( + DATA["n_steps"], + DATA["n_features"], + n_layers=1, + d_model=128, + d_inner=64, + n_heads=2, + d_k=64, + d_v=64, + dropout=0.1, + optimizer=rmsprop, + epochs=EPOCHS, + ) + saits.fit(TRAIN_SET, VAL_SET) + imputed_X = saits.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"SAITS test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="lrs-linear") + def test_5_linear_lrs(self): + logger.info("Running tests for SGD + MultiStepLRS...") + + sgd = SGD(lr=0.001, lr_scheduler=self.linear_lrs) + saits = SAITS( + DATA["n_steps"], + DATA["n_features"], + n_layers=1, + d_model=128, + d_inner=64, + n_heads=2, + d_k=64, + d_v=64, + dropout=0.1, + optimizer=sgd, + epochs=EPOCHS, + ) + saits.fit(TRAIN_SET, VAL_SET) + imputed_X = saits.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"SAITS test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="lrs-exponential") + def test_6_exponential_lrs(self): + logger.info("Running tests for SGD + ExponentialLRS...") + + sgd = SGD(lr=0.001, lr_scheduler=self.exponential_lrs) + saits = SAITS( + DATA["n_steps"], + DATA["n_features"], + n_layers=1, + d_model=128, + d_inner=64, + n_heads=2, + d_k=64, + d_v=64, + dropout=0.1, + optimizer=sgd, + epochs=EPOCHS, + ) + saits.fit(TRAIN_SET, VAL_SET) + imputed_X = saits.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"SAITS test_MAE: {test_MAE}") From 1798ecf31eb056112dad3f2728e2126adf7eff16 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 27 Sep 2023 20:18:10 +0800 Subject: [PATCH 4/4] docs: update docs for learning rate schedulers; --- docs/pypots.optim.rst | 9 +++++++++ pypots/optim/lr_scheduler/__init__.py | 2 +- pypots/optim/lr_scheduler/base.py | 3 +++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/docs/pypots.optim.rst b/docs/pypots.optim.rst index 8badeb1c..2bcc93f2 100644 --- a/docs/pypots.optim.rst +++ b/docs/pypots.optim.rst @@ -54,3 +54,12 @@ pypots.optim.base module :undoc-members: :show-inheritance: :inherited-members: + +pypots.optim.lr_scheduler module +------------------------------ + +.. automodule:: pypots.optim.lr_scheduler + :members: + :undoc-members: + :show-inheritance: + :inherited-members: diff --git a/pypots/optim/lr_scheduler/__init__.py b/pypots/optim/lr_scheduler/__init__.py index c0688a51..ddb14350 100644 --- a/pypots/optim/lr_scheduler/__init__.py +++ b/pypots/optim/lr_scheduler/__init__.py @@ -2,7 +2,7 @@ Learning rate schedulers available in PyPOTS. Their functionalities are the same with those in PyTorch, the only difference that is also why we implement them is that you don't have to pass according optimizers into them immediately while initializing them. Instead, you can pass them into pypots.optim.Optimizer -after initialization and call their `init_scheduler()` method in Optimizer.init_optimizer() to initialize +after initialization and call their `init_scheduler()` method in pypots.optim.Optimizer.init_optimizer() to initialize schedulers together with optimizers. """ diff --git a/pypots/optim/lr_scheduler/base.py b/pypots/optim/lr_scheduler/base.py index 3c5af3b7..0aeffd8b 100644 --- a/pypots/optim/lr_scheduler/base.py +++ b/pypots/optim/lr_scheduler/base.py @@ -113,6 +113,9 @@ def print_lr(is_verbose, group, lr): logger.info(f"Adjusting learning rate of group {group} to {lr:.4e}.") def step(self): + """Step could be called after every batch update. This should be called in ``pypots.optim.Optimizer.step()`` + after ``pypots.optim.Optimizer.torch_optimizer.step()``. + """ # Raise a warning if old pattern is detected # https://github.com/pytorch/pytorch/issues/20124 if self._step_count == 1: