From 586b68dc1c568ded0ffb19cf6bd7b1b34b93ace3 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 27 Sep 2023 19:45:30 +0800 Subject: [PATCH 1/8] feat: add learning rate schedulers; --- pypots/optim/lr_scheduler/__init__.py | 29 ++++ pypots/optim/lr_scheduler/base.py | 159 ++++++++++++++++++ pypots/optim/lr_scheduler/constant_lrs.py | 84 +++++++++ pypots/optim/lr_scheduler/exponential_lrs.py | 55 ++++++ pypots/optim/lr_scheduler/lambda_lrs.py | 79 +++++++++ pypots/optim/lr_scheduler/linear_lrs.py | 115 +++++++++++++ .../optim/lr_scheduler/multiplicative_lrs.py | 77 +++++++++ pypots/optim/lr_scheduler/multistep_lrs.py | 75 +++++++++ pypots/optim/lr_scheduler/step_lrs.py | 70 ++++++++ 9 files changed, 743 insertions(+) create mode 100644 pypots/optim/lr_scheduler/__init__.py create mode 100644 pypots/optim/lr_scheduler/base.py create mode 100644 pypots/optim/lr_scheduler/constant_lrs.py create mode 100644 pypots/optim/lr_scheduler/exponential_lrs.py create mode 100644 pypots/optim/lr_scheduler/lambda_lrs.py create mode 100644 pypots/optim/lr_scheduler/linear_lrs.py create mode 100644 pypots/optim/lr_scheduler/multiplicative_lrs.py create mode 100644 pypots/optim/lr_scheduler/multistep_lrs.py create mode 100644 pypots/optim/lr_scheduler/step_lrs.py diff --git a/pypots/optim/lr_scheduler/__init__.py b/pypots/optim/lr_scheduler/__init__.py new file mode 100644 index 00000000..c0688a51 --- /dev/null +++ b/pypots/optim/lr_scheduler/__init__.py @@ -0,0 +1,29 @@ +""" +Learning rate schedulers available in PyPOTS. Their functionalities are the same with those in PyTorch, +the only difference that is also why we implement them is that you don't have to pass according optimizers +into them immediately while initializing them. Instead, you can pass them into pypots.optim.Optimizer +after initialization and call their `init_scheduler()` method in Optimizer.init_optimizer() to initialize +schedulers together with optimizers. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from .lambda_lrs import LambdaLR +from .multiplicative_lrs import MultiplicativeLR +from .step_lrs import StepLR +from .multistep_lrs import MultiStepLR +from .constant_lrs import ConstantLR +from .exponential_lrs import ExponentialLR +from .linear_lrs import LinearLR + + +__all__ = [ + "LambdaLR", + "MultiplicativeLR", + "StepLR", + "MultiStepLR", + "ConstantLR", + "ExponentialLR", + "LinearLR", +] diff --git a/pypots/optim/lr_scheduler/base.py b/pypots/optim/lr_scheduler/base.py new file mode 100644 index 00000000..3c5af3b7 --- /dev/null +++ b/pypots/optim/lr_scheduler/base.py @@ -0,0 +1,159 @@ +""" +The base class for learning rate schedulers. This class is adapted from PyTorch, +please refer to torch.optim.lr_scheduler for more details. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +import weakref +from abc import ABC, abstractmethod +from functools import wraps + +from torch.optim import Optimizer + +from ...utils.logging import logger + + +class LRScheduler(ABC): + """Base class for PyPOTS learning rate schedulers. + + Parameters + ---------- + last_epoch: int + The index of last epoch. Default: -1. + + verbose: If ``True``, prints a message to stdout for + each update. Default: ``False``. + + """ + + def __init__(self, last_epoch=-1, verbose=False): + self.last_epoch = last_epoch + self.verbose = verbose + self.optimizer = None + self.base_lrs = None + self._last_lr = None + self._step_count = 0 + + def init_scheduler(self, optimizer): + """Initialize the scheduler. This method should be called in pypots.optim.Optimizer.init_optimizer() + to initialize the scheduler together with the optimizer. + + Parameters + ---------- + optimizer: torch.optim.Optimizer, + The optimizer to be scheduled. + + """ + + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) + self.optimizer = optimizer + + # Initialize epoch and base learning rates + if self.last_epoch == -1: + for group in optimizer.param_groups: + group.setdefault("initial_lr", group["lr"]) + else: + for i, group in enumerate(optimizer.param_groups): + if "initial_lr" not in group: + raise KeyError( + "param 'initial_lr' is not specified " + "in param_groups[{}] when resuming an optimizer".format(i) + ) + self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups] + + # Following https://github.com/pytorch/pytorch/issues/20124 + # We would like to ensure that `lr_scheduler.step()` is called after + # `optimizer.step()` + def with_counter(method): + if getattr(method, "_with_counter", False): + # `optimizer.step()` has already been replaced, return. + return method + + # Keep a weak reference to the optimizer instance to prevent + # cyclic references. + instance_ref = weakref.ref(method.__self__) + # Get the unbound method for the same purpose. + func = method.__func__ + cls = instance_ref().__class__ + del method + + @wraps(func) + def wrapper(*args, **kwargs): + instance = instance_ref() + instance._step_count += 1 + wrapped = func.__get__(instance, cls) + return wrapped(*args, **kwargs) + + # Note that the returned function here is no longer a bound method, + # so attributes like `__func__` and `__self__` no longer exist. + wrapper._with_counter = True + return wrapper + + self.optimizer.step = with_counter(self.optimizer.step) + self.optimizer._step_count = 0 + + @abstractmethod + def get_lr(self): + """Compute learning rate.""" + # Compute learning rate using chainable form of the scheduler + raise NotImplementedError + + def get_last_lr(self): + """Return last computed learning rate by current scheduler.""" + return self._last_lr + + @staticmethod + def print_lr(is_verbose, group, lr): + """Display the current learning rate.""" + if is_verbose: + logger.info(f"Adjusting learning rate of group {group} to {lr:.4e}.") + + def step(self): + # Raise a warning if old pattern is detected + # https://github.com/pytorch/pytorch/issues/20124 + if self._step_count == 1: + if not hasattr(self.optimizer.step, "_with_counter"): + logger.warning( + "Seems like `optimizer.step()` has been overridden after learning rate scheduler " + "initialization. Please, make sure to call `optimizer.step()` before " + "`lr_scheduler.step()`. See more details at " + "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", + ) + + # Just check if there were two first lr_scheduler.step() calls before optimizer.step() + elif self.optimizer._step_count < 1: + logger.warning.warn( + "Detected call of `lr_scheduler.step()` before `optimizer.step()`. " + "In PyTorch 1.1.0 and later, you should call them in the opposite order: " + "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this " + "will result in PyTorch skipping the first value of the learning rate schedule. " + "See more details at " + "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", + ) + self._step_count += 1 + + class _enable_get_lr_call: + def __init__(self, o): + self.o = o + + def __enter__(self): + self.o._get_lr_called_within_step = True + return self + + def __exit__(self, type, value, traceback): + self.o._get_lr_called_within_step = False + + with _enable_get_lr_call(self): + self.last_epoch += 1 + values = self.get_lr() + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + param_group["lr"] = lr + self.print_lr(self.verbose, i, lr) + + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] diff --git a/pypots/optim/lr_scheduler/constant_lrs.py b/pypots/optim/lr_scheduler/constant_lrs.py new file mode 100644 index 00000000..3f6ae1a3 --- /dev/null +++ b/pypots/optim/lr_scheduler/constant_lrs.py @@ -0,0 +1,84 @@ +""" +Constant learning rate scheduler. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from .base import LRScheduler, logger + + +class ConstantLR(LRScheduler): + """Decays the learning rate of each parameter group by a small constant factor until the number of epoch reaches + a pre-defined milestone: total_iters. Notice that such decay can happen simultaneously with other changes + to the learning rate from outside this scheduler. When last_epoch=-1, sets initial lr as lr. + + Parameters + ---------- + factor: float, default=1./3. + The number we multiply learning rate until the milestone. + + total_iters: int, default=5, + The number of steps that the scheduler decays the learning rate. + + last_epoch: int + The index of last epoch. Default: -1. + + verbose: bool + If ``True``, prints a message to stdout for each update. Default: ``False``. + + Notes + ----- + This class works the same with ``torch.optim.lr_scheduler.ConstantLR``. + The only difference that is also why we implement them is that you don't have to pass according optimizers + into them immediately while initializing them. + + Example + ------- + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.025 if epoch == 0 + >>> # lr = 0.025 if epoch == 1 + >>> # lr = 0.025 if epoch == 2 + >>> # lr = 0.025 if epoch == 3 + >>> # lr = 0.05 if epoch >= 4 + >>> # xdoctest: +SKIP + >>> scheduler = ConstantLR(factor=0.5, total_iters=4) + >>> adam = pypots.optim.Adam(lr=1e-3, lr_scheduler=scheduler) + + """ + + def __init__(self, factor=1.0 / 3, total_iters=5, last_epoch=-1, verbose=False): + super().__init__(last_epoch, verbose) + if factor > 1.0 or factor < 0: + raise ValueError( + "Constant multiplicative factor expected to be between 0 and 1." + ) + + self.factor = factor + self.total_iters = total_iters + + def get_lr(self): + if not self._get_lr_called_within_step: + logger.warning( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + ) + + if self.last_epoch == 0: + return [group["lr"] * self.factor for group in self.optimizer.param_groups] + + if self.last_epoch > self.total_iters or (self.last_epoch != self.total_iters): + return [group["lr"] for group in self.optimizer.param_groups] + + if self.last_epoch == self.total_iters: + return [ + group["lr"] * (1.0 / self.factor) + for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self): + return [ + base_lr + * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor)) + for base_lr in self.base_lrs + ] diff --git a/pypots/optim/lr_scheduler/exponential_lrs.py b/pypots/optim/lr_scheduler/exponential_lrs.py new file mode 100644 index 00000000..ed7e960f --- /dev/null +++ b/pypots/optim/lr_scheduler/exponential_lrs.py @@ -0,0 +1,55 @@ +""" +Exponential learning rate scheduler. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from .base import LRScheduler, logger + + +class ExponentialLR(LRScheduler): + """Decays the learning rate of each parameter group by gamma every epoch. When last_epoch=-1, sets initial lr as lr. + + Parameters + ---------- + gamma: float, + Multiplicative factor of learning rate decay. + + last_epoch: int + The index of last epoch. Default: -1. + + verbose: bool + If ``True``, prints a message to stdout for each update. Default: ``False``. + + Notes + ----- + This class works the same with ``torch.optim.lr_scheduler.ExponentialLR``. + The only difference that is also why we implement them is that you don't have to pass according optimizers + into them immediately while initializing them. + + Example + ------- + >>> scheduler = ExponentialLR(gamma=0.1) + >>> adam = pypots.optim.Adam(lr=1e-3, lr_scheduler=scheduler) + + """ + + def __init__(self, gamma, last_epoch=-1, verbose=False): + super().__init__(last_epoch, verbose) + + self.gamma = gamma + + def get_lr(self): + if not self._get_lr_called_within_step: + logger.warning( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + ) + + if self.last_epoch == 0: + return [group["lr"] for group in self.optimizer.param_groups] + return [group["lr"] * self.gamma for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [base_lr * self.gamma**self.last_epoch for base_lr in self.base_lrs] diff --git a/pypots/optim/lr_scheduler/lambda_lrs.py b/pypots/optim/lr_scheduler/lambda_lrs.py new file mode 100644 index 00000000..5471cee6 --- /dev/null +++ b/pypots/optim/lr_scheduler/lambda_lrs.py @@ -0,0 +1,79 @@ +""" +Lambda learning rate scheduler. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from typing import Callable, Union + +from .base import LRScheduler, logger + + +class LambdaLR(LRScheduler): + """Sets the learning rate of each parameter group to the initial lr times a given function. + When last_epoch=-1, sets initial lr as lr. + + Parameters + ---------- + lr_lambda: Callable or list, + A function which computes a multiplicative factor given an integer parameter epoch, or a list of such + functions, one for each group in optimizer.param_groups. + + last_epoch: int, + The index of last epoch. Default: -1. + + verbose: bool, + If ``True``, prints a message to stdout for each update. Default: ``False``. + + Notes + ----- + This class works the same with ``torch.optim.lr_scheduler.LambdaLR``. + The only difference that is also why we implement them is that you don't have to pass according optimizers + into them immediately while initializing them. + + Example + ------- + >>> lambda1 = lambda epoch: epoch // 30 + >>> scheduler = LambdaLR(lr_lambda=lambda1) + >>> adam = pypots.optim.Adam(lr=1e-3, lr_scheduler=scheduler) + + """ + + def __init__( + self, + lr_lambda: Union[Callable, list], + last_epoch: int = -1, + verbose: bool = False, + ): + super().__init__(last_epoch, verbose) + self.lr_lambda = lr_lambda + self.lr_lambdas = None + + def init_scheduler(self, optimizer): + if not isinstance(self.lr_lambda, list) and not isinstance( + self.lr_lambda, tuple + ): + self.lr_lambdas = [self.lr_lambda] * len(optimizer.param_groups) + else: + if len(self.lr_lambda) != len(optimizer.param_groups): + raise ValueError( + "Expected {} lr_lambdas, but got {}".format( + len(optimizer.param_groups), len(self.lr_lambda) + ) + ) + self.lr_lambdas = list(self.lr_lambda) + + super().init_scheduler(optimizer) + + def get_lr(self): + if not self._get_lr_called_within_step: + logger.warning( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`." + ) + + return [ + base_lr * lmbda(self.last_epoch) + for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs) + ] diff --git a/pypots/optim/lr_scheduler/linear_lrs.py b/pypots/optim/lr_scheduler/linear_lrs.py new file mode 100644 index 00000000..a1e8e1e6 --- /dev/null +++ b/pypots/optim/lr_scheduler/linear_lrs.py @@ -0,0 +1,115 @@ +""" +Linear learning rate scheduler. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from .base import LRScheduler, logger + + +class LinearLR(LRScheduler): + """Decays the learning rate of each parameter group by linearly changing small multiplicative factor until + the number of epoch reaches a pre-defined milestone: total_iters. Notice that such decay can happen simultaneously + with other changes to the learning rate from outside this scheduler. When last_epoch=-1, sets initial lr as lr. + + Parameters + ---------- + start_factor: float, default=1.0 / 3, + The number we multiply learning rate in the first epoch. The multiplication factor changes towards + end_factor in the following epochs. + + end_factor: float, default=1.0, + The number we multiply learning rate at the end of linear changing process. + + total_iters: int, default=5, + The number of iterations that multiplicative factor reaches to 1. + + last_epoch: int + The index of last epoch. Default: -1. + + verbose: bool + If ``True``, prints a message to stdout for each update. Default: ``False``. + + Notes + ----- + This class works the same with ``torch.optim.lr_scheduler.LinearLR``. + The only difference that is also why we implement them is that you don't have to pass according optimizers + into them immediately while initializing them. + + Example + ------- + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.025 if epoch == 0 + >>> # lr = 0.03125 if epoch == 1 + >>> # lr = 0.0375 if epoch == 2 + >>> # lr = 0.04375 if epoch == 3 + >>> # lr = 0.05 if epoch >= 4 + >>> # xdoctest: +SKIP + >>> scheduler = LinearLR(start_factor=0.5, total_iters=4) + >>> adam = pypots.optim.Adam(lr=1e-3, lr_scheduler=scheduler) + + """ + + def __init__( + self, + start_factor=1.0 / 3, + end_factor=1.0, + total_iters=5, + last_epoch=-1, + verbose=False, + ): + super().__init__(last_epoch, verbose) + if start_factor > 1.0 or start_factor < 0: + raise ValueError( + "Starting multiplicative factor expected to be between 0 and 1." + ) + + if end_factor > 1.0 or end_factor < 0: + raise ValueError( + "Ending multiplicative factor expected to be between 0 and 1." + ) + + self.start_factor = start_factor + self.end_factor = end_factor + self.total_iters = total_iters + + def get_lr(self): + if not self._get_lr_called_within_step: + logger.warning( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + ) + + if self.last_epoch == 0: + return [ + group["lr"] * self.start_factor for group in self.optimizer.param_groups + ] + + if self.last_epoch > self.total_iters: + return [group["lr"] for group in self.optimizer.param_groups] + + return [ + group["lr"] + * ( + 1.0 + + (self.end_factor - self.start_factor) + / ( + self.total_iters * self.start_factor + + (self.last_epoch - 1) * (self.end_factor - self.start_factor) + ) + ) + for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self): + return [ + base_lr + * ( + self.start_factor + + (self.end_factor - self.start_factor) + * min(self.total_iters, self.last_epoch) + / self.total_iters + ) + for base_lr in self.base_lrs + ] diff --git a/pypots/optim/lr_scheduler/multiplicative_lrs.py b/pypots/optim/lr_scheduler/multiplicative_lrs.py new file mode 100644 index 00000000..5dbc18ea --- /dev/null +++ b/pypots/optim/lr_scheduler/multiplicative_lrs.py @@ -0,0 +1,77 @@ +""" +Multiplicative learning rate scheduler. +""" + +# Created by Wenjie Du +# License: GLP-v3 + + +from .base import LRScheduler, logger + + +class MultiplicativeLR(LRScheduler): + """Multiply the learning rate of each parameter group by the factor given in the specified function. + When last_epoch=-1, sets initial lr as lr. + + Parameters + ---------- + lr_lambda: Callable or list, + A function which computes a multiplicative factor given an integer parameter epoch, or a list of such + functions, one for each group in optimizer.param_groups. + + last_epoch: int, + The index of last epoch. Default: -1. + + verbose: bool, + If ``True``, prints a message to stdout for each update. Default: ``False``. + + Notes + ----- + This class works the same with ``torch.optim.lr_scheduler.MultiplicativeLR``. + The only difference that is also why we implement them is that you don't have to pass according optimizers + into them immediately while initializing them. + + Example + ------- + >>> lmbda = lambda epoch: 0.95 + >>> # xdoctest: +SKIP + >>> scheduler = MultiplicativeLR(lr_lambda=lmbda) + >>> adam = pypots.optim.Adam(lr=1e-3, lr_scheduler=scheduler) + + """ + + def __init__(self, lr_lambda, last_epoch=-1, verbose=False): + super().__init__(last_epoch, verbose) + self.lr_lambda = lr_lambda + self.lr_lambdas = None + + def init_scheduler(self, optimizer): + if not isinstance(self.lr_lambda, list) and not isinstance( + self.lr_lambda, tuple + ): + self.lr_lambdas = [self.lr_lambda] * len(optimizer.param_groups) + else: + if len(self.lr_lambda) != len(optimizer.param_groups): + raise ValueError( + "Expected {} lr_lambdas, but got {}".format( + len(optimizer.param_groups), len(self.lr_lambda) + ) + ) + self.lr_lambdas = list(self.lr_lambda) + + super().init_scheduler(optimizer) + + def get_lr(self): + if not self._get_lr_called_within_step: + logger.warning( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + ) + + if self.last_epoch > 0: + return [ + group["lr"] * lmbda(self.last_epoch) + for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups) + ] + else: + return [group["lr"] for group in self.optimizer.param_groups] diff --git a/pypots/optim/lr_scheduler/multistep_lrs.py b/pypots/optim/lr_scheduler/multistep_lrs.py new file mode 100644 index 00000000..567570e9 --- /dev/null +++ b/pypots/optim/lr_scheduler/multistep_lrs.py @@ -0,0 +1,75 @@ +""" +Multistep learning rate scheduler. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from bisect import bisect_right +from collections import Counter + +from .base import LRScheduler, logger + + +class MultiStepLR(LRScheduler): + """Decays the learning rate of each parameter group by gamma once the number of epoch reaches one of the milestones. + Notice that such decay can happen simultaneously with other changes to the learning rate from outside this + scheduler. When last_epoch=-1, sets initial lr as lr. + + Parameters + ---------- + milestones: list, + List of epoch indices. Must be increasing. + + gamma: float, default=0.1, + Multiplicative factor of learning rate decay. + + last_epoch: int + The index of last epoch. Default: -1. + + verbose: bool + If ``True``, prints a message to stdout for each update. Default: ``False``. + + Notes + ----- + This class works the same with ``torch.optim.lr_scheduler.MultiStepLR``. + The only difference that is also why we implement them is that you don't have to pass according optimizers + into them immediately while initializing them. + + Example + ------- + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.05 if epoch < 30 + >>> # lr = 0.005 if 30 <= epoch < 80 + >>> # lr = 0.0005 if epoch >= 80 + >>> # xdoctest: +SKIP + >>> scheduler = MultiStepLR(milestones=[30,80], gamma=0.1) + >>> adam = pypots.optim.Adam(lr=1e-3, lr_scheduler=scheduler) + + """ + + def __init__(self, milestones, gamma=0.1, last_epoch=-1, verbose=False): + super().__init__(last_epoch, verbose) + self.milestones = Counter(milestones) + self.gamma = gamma + + def get_lr(self): + if not self._get_lr_called_within_step: + logger.warning( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + ) + + if self.last_epoch not in self.milestones: + return [group["lr"] for group in self.optimizer.param_groups] + return [ + group["lr"] * self.gamma ** self.milestones[self.last_epoch] + for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self): + milestones = list(sorted(self.milestones.elements())) + return [ + base_lr * self.gamma ** bisect_right(milestones, self.last_epoch) + for base_lr in self.base_lrs + ] diff --git a/pypots/optim/lr_scheduler/step_lrs.py b/pypots/optim/lr_scheduler/step_lrs.py new file mode 100644 index 00000000..29f72bb8 --- /dev/null +++ b/pypots/optim/lr_scheduler/step_lrs.py @@ -0,0 +1,70 @@ +""" +Step learning rate scheduler. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from .base import LRScheduler, logger + + +class StepLR(LRScheduler): + """Decays the learning rate of each parameter group by gamma every step_size epochs. Notice that such decay can + happen simultaneously with other changes to the learning rate from outside this scheduler. + When last_epoch=-1, sets initial lr as lr. + + Parameters + ---------- + step_size: int, + Period of learning rate decay. + + gamma: float, default=0.1, + Multiplicative factor of learning rate decay. + + last_epoch: int + The index of last epoch. Default: -1. + + verbose: bool + If ``True``, prints a message to stdout for each update. Default: ``False``. + + Notes + ----- + This class works the same with ``torch.optim.lr_scheduler.StepLR``. + The only difference that is also why we implement them is that you don't have to pass according optimizers + into them immediately while initializing them. + + Example + ------- + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.05 if epoch < 30 + >>> # lr = 0.005 if 30 <= epoch < 60 + >>> # lr = 0.0005 if 60 <= epoch < 90 + >>> # ... + >>> # xdoctest: +SKIP + >>> scheduler = StepLR(step_size=30, gamma=0.1) + >>> adam = pypots.optim.Adam(lr=1e-3, lr_scheduler=scheduler) + + """ + + def __init__(self, step_size, gamma=0.1, last_epoch=-1, verbose=False): + super().__init__(last_epoch, verbose) + + self.step_size = step_size + self.gamma = gamma + + def get_lr(self): + if not self._get_lr_called_within_step: + logger.warning( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + ) + + if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0): + return [group["lr"] for group in self.optimizer.param_groups] + return [group["lr"] * self.gamma for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [ + base_lr * self.gamma ** (self.last_epoch // self.step_size) + for base_lr in self.base_lrs + ] From dd626b312bad9fae2f00171b6e10bb1124de3b28 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 27 Sep 2023 19:49:25 +0800 Subject: [PATCH 2/8] feat: integrate schedulers into optimizers; --- pypots/optim/adadelta.py | 9 +++++++-- pypots/optim/adagrad.py | 9 +++++++-- pypots/optim/adam.py | 9 +++++++-- pypots/optim/adamw.py | 9 +++++++-- pypots/optim/base.py | 8 +++++++- pypots/optim/rmsprop.py | 9 +++++++-- pypots/optim/sgd.py | 9 +++++++-- 7 files changed, 49 insertions(+), 13 deletions(-) diff --git a/pypots/optim/adadelta.py b/pypots/optim/adadelta.py index ac4726d3..59e98f2a 100644 --- a/pypots/optim/adadelta.py +++ b/pypots/optim/adadelta.py @@ -6,11 +6,12 @@ # Created by Wenjie Du # License: GLP-v3 -from typing import Iterable +from typing import Iterable, Optional from torch.optim import Adadelta as torch_Adadelta from .base import Optimizer +from .lr_scheduler.base import LRScheduler class Adadelta(Optimizer): @@ -39,8 +40,9 @@ def __init__( rho: float = 0.9, eps: float = 1e-08, weight_decay: float = 0.01, + lr_scheduler: Optional[LRScheduler] = None, ): - super().__init__(lr) + super().__init__(lr, lr_scheduler) self.rho = rho self.eps = eps self.weight_decay = weight_decay @@ -61,3 +63,6 @@ def init_optimizer(self, params: Iterable) -> None: eps=self.eps, weight_decay=self.weight_decay, ) + + if self.lr_scheduler is not None: + self.lr_scheduler.init_scheduler(self.torch_optimizer) diff --git a/pypots/optim/adagrad.py b/pypots/optim/adagrad.py index e4374244..8a10f06c 100644 --- a/pypots/optim/adagrad.py +++ b/pypots/optim/adagrad.py @@ -6,11 +6,12 @@ # Created by Wenjie Du # License: GLP-v3 -from typing import Iterable +from typing import Iterable, Optional from torch.optim import Adagrad as torch_Adagrad from .base import Optimizer +from .lr_scheduler.base import LRScheduler class Adagrad(Optimizer): @@ -43,8 +44,9 @@ def __init__( weight_decay: float = 0.01, initial_accumulator_value: float = 0.01, # it is set as 0 in the torch implementation, but delta shouldn't be 0 eps: float = 1e-08, + lr_scheduler: Optional[LRScheduler] = None, ): - super().__init__(lr) + super().__init__(lr, lr_scheduler) self.lr_decay = lr_decay self.weight_decay = weight_decay self.initial_accumulator_value = initial_accumulator_value @@ -67,3 +69,6 @@ def init_optimizer(self, params: Iterable) -> None: initial_accumulator_value=self.initial_accumulator_value, eps=self.eps, ) + + if self.lr_scheduler is not None: + self.lr_scheduler.init_scheduler(self.torch_optimizer) diff --git a/pypots/optim/adam.py b/pypots/optim/adam.py index d308b27e..c5e0e1af 100644 --- a/pypots/optim/adam.py +++ b/pypots/optim/adam.py @@ -6,11 +6,12 @@ # Created by Wenjie Du # License: GLP-v3 -from typing import Iterable, Tuple +from typing import Iterable, Tuple, Optional from torch.optim import Adam as torch_Adam from .base import Optimizer +from .lr_scheduler.base import LRScheduler class Adam(Optimizer): @@ -42,8 +43,9 @@ def __init__( eps: float = 1e-08, weight_decay: float = 0, amsgrad: bool = False, + lr_scheduler: Optional[LRScheduler] = None, ): - super().__init__(lr) + super().__init__(lr, lr_scheduler) self.betas = betas self.eps = eps self.weight_decay = weight_decay @@ -66,3 +68,6 @@ def init_optimizer(self, params: Iterable) -> None: weight_decay=self.weight_decay, amsgrad=self.amsgrad, ) + + if self.lr_scheduler is not None: + self.lr_scheduler.init_scheduler(self.torch_optimizer) diff --git a/pypots/optim/adamw.py b/pypots/optim/adamw.py index b93d8a74..6a5191e4 100644 --- a/pypots/optim/adamw.py +++ b/pypots/optim/adamw.py @@ -6,11 +6,12 @@ # Created by Wenjie Du # License: GLP-v3 -from typing import Iterable, Tuple +from typing import Iterable, Tuple, Optional from torch.optim import AdamW as torch_AdamW from .base import Optimizer +from .lr_scheduler.base import LRScheduler class AdamW(Optimizer): @@ -42,8 +43,9 @@ def __init__( eps: float = 1e-08, weight_decay: float = 0.01, amsgrad: bool = False, + lr_scheduler: Optional[LRScheduler] = None, ): - super().__init__(lr) + super().__init__(lr, lr_scheduler) self.betas = betas self.eps = eps self.weight_decay = weight_decay @@ -66,3 +68,6 @@ def init_optimizer(self, params: Iterable) -> None: weight_decay=self.weight_decay, amsgrad=self.amsgrad, ) + + if self.lr_scheduler is not None: + self.lr_scheduler.init_scheduler(self.torch_optimizer) diff --git a/pypots/optim/base.py b/pypots/optim/base.py index f1bb9637..db09fb3a 100644 --- a/pypots/optim/base.py +++ b/pypots/optim/base.py @@ -19,6 +19,8 @@ from abc import ABC, abstractmethod from typing import Callable, Iterable, Optional +from .lr_scheduler.base import LRScheduler + class Optimizer(ABC): """The base wrapper for PyTorch optimizers, also is the base class for all optimizers in pypots.optim. @@ -35,9 +37,10 @@ class Optimizer(ABC): """ - def __init__(self, lr): + def __init__(self, lr, lr_scheduler: Optional[LRScheduler] = None): self.lr = lr self.torch_optimizer = None + self.lr_scheduler = lr_scheduler @abstractmethod def init_optimizer(self, params: Iterable) -> None: @@ -97,6 +100,9 @@ def step(self, closure: Optional[Callable] = None) -> None: """ self.torch_optimizer.step(closure) + if self.lr_scheduler is not None: + self.lr_scheduler.step() + def zero_grad(self, set_to_none: bool = True) -> None: """Sets the gradients of all optimized ``torch.Tensor`` to zero. diff --git a/pypots/optim/rmsprop.py b/pypots/optim/rmsprop.py index 65a817ca..f00da68d 100644 --- a/pypots/optim/rmsprop.py +++ b/pypots/optim/rmsprop.py @@ -6,11 +6,12 @@ # Created by Wenjie Du # License: GLP-v3 -from typing import Iterable +from typing import Iterable, Optional from torch.optim import RMSprop as torch_RMSprop from .base import Optimizer +from .lr_scheduler.base import LRScheduler class RMSprop(Optimizer): @@ -47,8 +48,9 @@ def __init__( eps: float = 1e-08, centered: bool = False, weight_decay: float = 0, + lr_scheduler: Optional[LRScheduler] = None, ): - super().__init__(lr) + super().__init__(lr, lr_scheduler) self.momentum = momentum self.alpha = alpha self.eps = eps @@ -73,3 +75,6 @@ def init_optimizer(self, params: Iterable) -> None: centered=self.centered, weight_decay=self.weight_decay, ) + + if self.lr_scheduler is not None: + self.lr_scheduler.init_scheduler(self.torch_optimizer) diff --git a/pypots/optim/sgd.py b/pypots/optim/sgd.py index 4696db91..34cd07f0 100644 --- a/pypots/optim/sgd.py +++ b/pypots/optim/sgd.py @@ -6,11 +6,12 @@ # Created by Wenjie Du # License: GLP-v3 -from typing import Iterable +from typing import Iterable, Optional from torch.optim import SGD as torch_SGD from .base import Optimizer +from .lr_scheduler.base import LRScheduler class SGD(Optimizer): @@ -43,8 +44,9 @@ def __init__( weight_decay: float = 0, dampening: float = 0, nesterov: bool = False, + lr_scheduler: Optional[LRScheduler] = None, ): - super().__init__(lr) + super().__init__(lr, lr_scheduler) self.momentum = momentum self.weight_decay = weight_decay self.dampening = dampening @@ -67,3 +69,6 @@ def init_optimizer(self, params: Iterable) -> None: dampening=self.dampening, nesterov=self.nesterov, ) + + if self.lr_scheduler is not None: + self.lr_scheduler.init_scheduler(self.torch_optimizer) From f7fc32c949f89a1d91985dd7187eb4056395bf38 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 27 Sep 2023 20:17:12 +0800 Subject: [PATCH 3/8] feat: add testing cases of learn rate schedulers; --- tests/optim/lr_schedulers.py | 249 +++++++++++++++++++++++++++++++++++ 1 file changed, 249 insertions(+) create mode 100644 tests/optim/lr_schedulers.py diff --git a/tests/optim/lr_schedulers.py b/tests/optim/lr_schedulers.py new file mode 100644 index 00000000..e7748f91 --- /dev/null +++ b/tests/optim/lr_schedulers.py @@ -0,0 +1,249 @@ +""" +Test cases for the learning rate schedulers. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +import unittest + +import numpy as np +import pytest + +from pypots.imputation import SAITS +from pypots.optim import Adam, AdamW, Adadelta, Adagrad, RMSprop, SGD +from pypots.optim.lr_scheduler import ( + LambdaLR, + ConstantLR, + ExponentialLR, + LinearLR, + StepLR, + MultiStepLR, + MultiplicativeLR, +) +from pypots.utils.logging import logger +from pypots.utils.metrics import cal_mae +from tests.global_test_config import DATA +from tests.optim.config import EPOCHS, TEST_SET, TRAIN_SET, VAL_SET + + +class TestLRSchedulers(unittest.TestCase): + logger.info("Running tests for learning rate schedulers...") + + # init lambda_lrs + lambda_lrs = LambdaLR(lr_lambda=lambda epoch: epoch // 30, verbose=True) + + # init multiplicative_lrs + multiplicative_lrs = MultiplicativeLR(lr_lambda=lambda epoch: 0.95, verbose=True) + + # init step_lrs + step_lrs = StepLR(step_size=30, gamma=0.1, verbose=True) + + # init multistep_lrs + multistep_lrs = MultiStepLR(milestones=[30, 80], gamma=0.1, verbose=True) + + # init constant_lrs + constant_lrs = ConstantLR(factor=0.5, total_iters=4, verbose=True) + + # init linear_lrs + linear_lrs = LinearLR(start_factor=0.5, total_iters=4, verbose=True) + + # init exponential_lrs + exponential_lrs = ExponentialLR(gamma=0.9, verbose=True) + + @pytest.mark.xdist_group(name="lrs-lambda") + def test_0_lambda_lrs(self): + logger.info("Running tests for Adam + LambdaLRS...") + + adam = Adam(lr=0.001, weight_decay=1e-5, lr_scheduler=self.lambda_lrs) + saits = SAITS( + DATA["n_steps"], + DATA["n_features"], + n_layers=1, + d_model=128, + d_inner=64, + n_heads=2, + d_k=64, + d_v=64, + dropout=0.1, + optimizer=adam, + epochs=EPOCHS, + ) + saits.fit(TRAIN_SET, VAL_SET) + imputed_X = saits.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"SAITS test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="lrs-multiplicative") + def test_1_multiplicative_lrs(self): + logger.info("Running tests for Adamw + MultiplicativeLRS...") + + adamw = AdamW(lr=0.001, weight_decay=1e-5, lr_scheduler=self.multiplicative_lrs) + saits = SAITS( + DATA["n_steps"], + DATA["n_features"], + n_layers=1, + d_model=128, + d_inner=64, + n_heads=2, + d_k=64, + d_v=64, + dropout=0.1, + optimizer=adamw, + epochs=EPOCHS, + ) + saits.fit(TRAIN_SET, VAL_SET) + imputed_X = saits.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"SAITS test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="lrs-step") + def test_2_step_lrs(self): + logger.info("Running tests for Adadelta + StepLRS...") + + adamw = Adadelta(lr=0.001, lr_scheduler=self.step_lrs) + saits = SAITS( + DATA["n_steps"], + DATA["n_features"], + n_layers=1, + d_model=128, + d_inner=64, + n_heads=2, + d_k=64, + d_v=64, + dropout=0.1, + optimizer=adamw, + epochs=EPOCHS, + ) + saits.fit(TRAIN_SET, VAL_SET) + imputed_X = saits.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"SAITS test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="lrs-multistep") + def test_3_multistep_lrs(self): + logger.info("Running tests for Adadelta + MultiStepLRS...") + + adagrad = Adagrad(lr=0.001, lr_scheduler=self.multistep_lrs) + saits = SAITS( + DATA["n_steps"], + DATA["n_features"], + n_layers=1, + d_model=128, + d_inner=64, + n_heads=2, + d_k=64, + d_v=64, + dropout=0.1, + optimizer=adagrad, + epochs=EPOCHS, + ) + saits.fit(TRAIN_SET, VAL_SET) + imputed_X = saits.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"SAITS test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="lrs-constant") + def test_4_constant_lrs(self): + logger.info("Running tests for RMSprop + ConstantLRS...") + + # initialize a SAITS model for testing DatasetForMIT and BaseDataset + rmsprop = RMSprop(lr=0.001, lr_scheduler=self.constant_lrs) + saits = SAITS( + DATA["n_steps"], + DATA["n_features"], + n_layers=1, + d_model=128, + d_inner=64, + n_heads=2, + d_k=64, + d_v=64, + dropout=0.1, + optimizer=rmsprop, + epochs=EPOCHS, + ) + saits.fit(TRAIN_SET, VAL_SET) + imputed_X = saits.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"SAITS test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="lrs-linear") + def test_5_linear_lrs(self): + logger.info("Running tests for SGD + MultiStepLRS...") + + sgd = SGD(lr=0.001, lr_scheduler=self.linear_lrs) + saits = SAITS( + DATA["n_steps"], + DATA["n_features"], + n_layers=1, + d_model=128, + d_inner=64, + n_heads=2, + d_k=64, + d_v=64, + dropout=0.1, + optimizer=sgd, + epochs=EPOCHS, + ) + saits.fit(TRAIN_SET, VAL_SET) + imputed_X = saits.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"SAITS test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="lrs-exponential") + def test_6_exponential_lrs(self): + logger.info("Running tests for SGD + ExponentialLRS...") + + sgd = SGD(lr=0.001, lr_scheduler=self.exponential_lrs) + saits = SAITS( + DATA["n_steps"], + DATA["n_features"], + n_layers=1, + d_model=128, + d_inner=64, + n_heads=2, + d_k=64, + d_v=64, + dropout=0.1, + optimizer=sgd, + epochs=EPOCHS, + ) + saits.fit(TRAIN_SET, VAL_SET) + imputed_X = saits.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"SAITS test_MAE: {test_MAE}") From 1798ecf31eb056112dad3f2728e2126adf7eff16 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 27 Sep 2023 20:18:10 +0800 Subject: [PATCH 4/8] docs: update docs for learning rate schedulers; --- docs/pypots.optim.rst | 9 +++++++++ pypots/optim/lr_scheduler/__init__.py | 2 +- pypots/optim/lr_scheduler/base.py | 3 +++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/docs/pypots.optim.rst b/docs/pypots.optim.rst index 8badeb1c..2bcc93f2 100644 --- a/docs/pypots.optim.rst +++ b/docs/pypots.optim.rst @@ -54,3 +54,12 @@ pypots.optim.base module :undoc-members: :show-inheritance: :inherited-members: + +pypots.optim.lr_scheduler module +------------------------------ + +.. automodule:: pypots.optim.lr_scheduler + :members: + :undoc-members: + :show-inheritance: + :inherited-members: diff --git a/pypots/optim/lr_scheduler/__init__.py b/pypots/optim/lr_scheduler/__init__.py index c0688a51..ddb14350 100644 --- a/pypots/optim/lr_scheduler/__init__.py +++ b/pypots/optim/lr_scheduler/__init__.py @@ -2,7 +2,7 @@ Learning rate schedulers available in PyPOTS. Their functionalities are the same with those in PyTorch, the only difference that is also why we implement them is that you don't have to pass according optimizers into them immediately while initializing them. Instead, you can pass them into pypots.optim.Optimizer -after initialization and call their `init_scheduler()` method in Optimizer.init_optimizer() to initialize +after initialization and call their `init_scheduler()` method in pypots.optim.Optimizer.init_optimizer() to initialize schedulers together with optimizers. """ diff --git a/pypots/optim/lr_scheduler/base.py b/pypots/optim/lr_scheduler/base.py index 3c5af3b7..0aeffd8b 100644 --- a/pypots/optim/lr_scheduler/base.py +++ b/pypots/optim/lr_scheduler/base.py @@ -113,6 +113,9 @@ def print_lr(is_verbose, group, lr): logger.info(f"Adjusting learning rate of group {group} to {lr:.4e}.") def step(self): + """Step could be called after every batch update. This should be called in ``pypots.optim.Optimizer.step()`` + after ``pypots.optim.Optimizer.torch_optimizer.step()``. + """ # Raise a warning if old pattern is detected # https://github.com/pytorch/pytorch/issues/20124 if self._step_count == 1: From 248dfeb27c49aa63d9866bd1654e03a412c04e28 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 28 Sep 2023 14:44:45 +0800 Subject: [PATCH 5/8] feat: add testing case for CRLI with LSTM cells; --- tests/clustering/crli.py | 91 +++++++++++++++++++++++++++++++++------- 1 file changed, 76 insertions(+), 15 deletions(-) diff --git a/tests/clustering/crli.py b/tests/clustering/crli.py index 191f58c8..2385d1e5 100644 --- a/tests/clustering/crli.py +++ b/tests/clustering/crli.py @@ -43,12 +43,27 @@ class TestCRLI(unittest.TestCase): D_optimizer = Adam(lr=0.001, weight_decay=1e-5) # initialize a CRLI model - crli = CRLI( + crli_gru = CRLI( n_steps=DATA["n_steps"], n_features=DATA["n_features"], n_clusters=DATA["n_classes"], n_generator_layers=2, rnn_hidden_size=128, + rnn_cell_type="GRU", + epochs=EPOCHS, + saving_path=saving_path, + G_optimizer=G_optimizer, + D_optimizer=D_optimizer, + device=DEVICE, + ) + + crli_lstm = CRLI( + n_steps=DATA["n_steps"], + n_features=DATA["n_features"], + n_clusters=DATA["n_classes"], + n_generator_layers=2, + rnn_hidden_size=128, + rnn_cell_type="LSTM", epochs=EPOCHS, saving_path=saving_path, G_optimizer=G_optimizer, @@ -58,34 +73,80 @@ class TestCRLI(unittest.TestCase): @pytest.mark.xdist_group(name="clustering-crli") def test_0_fit(self): - self.crli.fit(TRAIN_SET) + logger.info("Training CRLI-GRU...") + self.crli_gru.fit(TRAIN_SET) + logger.info("Training CRLI-LSTM...") + self.crli_lstm.fit(TRAIN_SET) @pytest.mark.xdist_group(name="clustering-crli") def test_1_parameters(self): - assert hasattr(self.crli, "model") and self.crli.model is not None + # GRU cell + assert hasattr(self.crli_gru, "model") and self.crli_gru.model is not None - assert hasattr(self.crli, "G_optimizer") and self.crli.G_optimizer is not None - assert hasattr(self.crli, "D_optimizer") and self.crli.D_optimizer is not None + assert ( + hasattr(self.crli_gru, "G_optimizer") + and self.crli_gru.G_optimizer is not None + ) + assert ( + hasattr(self.crli_gru, "D_optimizer") + and self.crli_gru.D_optimizer is not None + ) - assert hasattr(self.crli, "best_loss") - self.assertNotEqual(self.crli.best_loss, float("inf")) + assert hasattr(self.crli_gru, "best_loss") + self.assertNotEqual(self.crli_gru.best_loss, float("inf")) assert ( - hasattr(self.crli, "best_model_dict") - and self.crli.best_model_dict is not None + hasattr(self.crli_gru, "best_model_dict") + and self.crli_gru.best_model_dict is not None + ) + + # LSTM cell + assert hasattr(self.crli_lstm, "model") and self.crli_lstm.model is not None + + assert ( + hasattr(self.crli_lstm, "G_optimizer") + and self.crli_lstm.G_optimizer is not None + ) + assert ( + hasattr(self.crli_lstm, "D_optimizer") + and self.crli_lstm.D_optimizer is not None + ) + + assert hasattr(self.crli_lstm, "best_loss") + self.assertNotEqual(self.crli_lstm.best_loss, float("inf")) + + assert ( + hasattr(self.crli_lstm, "best_model_dict") + and self.crli_lstm.best_model_dict is not None ) @pytest.mark.xdist_group(name="clustering-crli") def test_2_cluster(self): - clustering, latent_collector = self.crli.cluster(TEST_SET, return_latent=True) + # GRU cell + clustering, latent_collector = self.crli_gru.cluster( + TEST_SET, return_latent=True + ) + external_metrics = cal_external_cluster_validation_metrics( + clustering, DATA["test_y"] + ) + internal_metrics = cal_internal_cluster_validation_metrics( + latent_collector["clustering_latent"], DATA["test_y"] + ) + logger.info(f"CRLI-GRU: {external_metrics}") + logger.info(f"CRLI-GRU:{internal_metrics}") + + # LSTM cell + clustering, latent_collector = self.crli_lstm.cluster( + TEST_SET, return_latent=True + ) external_metrics = cal_external_cluster_validation_metrics( clustering, DATA["test_y"] ) internal_metrics = cal_internal_cluster_validation_metrics( latent_collector["clustering_latent"], DATA["test_y"] ) - logger.info(f"{external_metrics}") - logger.info(f"{internal_metrics}") + logger.info(f"CRLI-LSTM: {external_metrics}") + logger.info(f"CRLI-LSTM: {internal_metrics}") @pytest.mark.xdist_group(name="clustering-crli") def test_3_saving_path(self): @@ -95,16 +156,16 @@ def test_3_saving_path(self): ), f"file {self.saving_path} does not exist" # check if the tensorboard file and model checkpoints exist - check_tb_and_model_checkpoints_existence(self.crli) + check_tb_and_model_checkpoints_existence(self.crli_gru) # save the trained model into file, and check if the path exists - self.crli.save_model( + self.crli_gru.save_model( saving_dir=self.saving_path, file_name=self.model_save_name ) # test loading the saved model, not necessary, but need to test saved_model_path = os.path.join(self.saving_path, self.model_save_name) - self.crli.load_model(saved_model_path) + self.crli_gru.load_model(saved_model_path) if __name__ == "__main__": From 0fb3a83f4ee49fc102241d605cd0d49c153cecb9 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 28 Sep 2023 14:52:29 +0800 Subject: [PATCH 6/8] fix: make CRLI work with LSTM cells; --- pypots/clustering/crli/modules.py | 34 +++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/pypots/clustering/crli/modules.py b/pypots/clustering/crli/modules.py index d5413e37..f6837647 100644 --- a/pypots/clustering/crli/modules.py +++ b/pypots/clustering/crli/modules.py @@ -65,8 +65,11 @@ def forward(self, inputs: dict) -> Tuple[torch.Tensor, torch.Tensor]: ) output_collector = torch.empty((bz, n_steps, self.d_input), device=self.device) if self.cell_type == "LSTM": - # TODO: cell states should have different shapes - cell_states = torch.zeros((self.d_input, self.d_hidden), device=self.device) + cell_states = [ + torch.zeros((bz, self.d_hidden), device=self.device) + for i in range(self.n_layer) + ] + for step in range(n_steps): x = X[:, step, :] estimation = self.output_layer(hidden_state) @@ -76,13 +79,14 @@ def forward(self, inputs: dict) -> Tuple[torch.Tensor, torch.Tensor]: ) for i in range(self.n_layer): if i == 0: - hidden_state, cell_states = self.model[i]( - imputed_x, (hidden_state, cell_states) + hidden_state, cell_state = self.model[i]( + imputed_x, (hidden_state, cell_states[i]) ) else: - hidden_state, cell_states = self.model[i]( - hidden_state, (hidden_state, cell_states) + hidden_state, cell_state = self.model[i]( + hidden_state, (hidden_state, cell_states[i]) ) + hidden_state_collector[:, step, :] = hidden_state elif self.cell_type == "GRU": @@ -168,19 +172,27 @@ def forward(self, inputs: dict) -> torch.Tensor: ] hidden_state_collector = torch.empty((bz, n_steps, 32), device=self.device) if self.cell_type == "LSTM": - cell_states = torch.zeros((self.d_input, self.d_hidden), device=self.device) + cell_states = [ + torch.zeros((bz, 32), device=self.device), + torch.zeros((bz, 16), device=self.device), + torch.zeros((bz, 8), device=self.device), + torch.zeros((bz, 16), device=self.device), + torch.zeros((bz, 32), device=self.device), + ] for step in range(n_steps): x = imputed_X[:, step, :] for i, rnn_cell in enumerate(self.rnn_cell_module_list): if i == 0: - hidden_state, cell_states = rnn_cell( - x, (hidden_states[i], cell_states) + hidden_state, cell_state = rnn_cell( + x, (hidden_states[i], cell_states[i]) ) else: - hidden_state, cell_states = rnn_cell( - hidden_states[i - 1], (hidden_states[i], cell_states) + hidden_state, cell_state = rnn_cell( + hidden_states[i - 1], (hidden_states[i], cell_states[i]) ) + cell_states[i] = cell_state hidden_states[i] = hidden_state + hidden_state_collector[:, step, :] = hidden_state elif self.cell_type == "GRU": From ffebf371289d6adc37fac79e32430a74eff4f8aa Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 28 Sep 2023 16:45:40 +0800 Subject: [PATCH 7/8] docs: update the documentation; --- docs/pypots.forecasting.rst | 25 +------- pypots/classification/brits/model.py | 11 ++-- pypots/classification/grud/model.py | 10 +-- pypots/classification/raindrop/model.py | 12 ++-- pypots/clustering/crli/model.py | 12 ++-- pypots/clustering/vader/model.py | 12 ++-- pypots/forecasting/bttf/model.py | 7 +++ pypots/imputation/brits/model.py | 11 ++-- pypots/imputation/gpvae/model.py | 35 +++++------ pypots/imputation/mrnn/model.py | 12 ++-- pypots/imputation/saits/model.py | 11 ++-- pypots/imputation/transformer/model.py | 17 +++-- pypots/imputation/usgan/model.py | 83 +++++++++---------------- pypots/optim/adadelta.py | 14 +++-- pypots/optim/adagrad.py | 16 ++--- pypots/optim/adam.py | 16 ++--- pypots/optim/adamw.py | 17 ++--- pypots/optim/base.py | 9 ++- pypots/optim/lr_scheduler/__init__.py | 6 +- pypots/optim/lr_scheduler/base.py | 11 ++-- pypots/optim/rmsprop.py | 18 +++--- pypots/optim/sgd.py | 18 +++--- pypots/utils/metrics.py | 53 ++++++++++++++-- 23 files changed, 229 insertions(+), 207 deletions(-) diff --git a/docs/pypots.forecasting.rst b/docs/pypots.forecasting.rst index c4ac76b7..5cd6eaa1 100644 --- a/docs/pypots.forecasting.rst +++ b/docs/pypots.forecasting.rst @@ -1,31 +1,10 @@ pypots.forecasting package ========================== -Subpackages ------------ - -.. toctree:: - :maxdepth: 4 - - pypots.forecasting.bttf - pypots.forecasting.template - -Submodules ----------- - -pypots.forecasting.base module +pypots.forecasting.bttf module ------------------------------ -.. automodule:: pypots.forecasting.base - :members: - :undoc-members: - :show-inheritance: - :inherited-members: - -Module contents ---------------- - -.. automodule:: pypots.forecasting +.. automodule:: pypots.forecasting.bttf :members: :undoc-members: :show-inheritance: diff --git a/pypots/classification/brits/model.py b/pypots/classification/brits/model.py index bbccb7ce..c9419b4a 100644 --- a/pypots/classification/brits/model.py +++ b/pypots/classification/brits/model.py @@ -173,13 +173,12 @@ class BRITS(BaseNNClassifier): The "better" strategy will automatically save the model during training whenever the model performs better than in previous epochs. - Attributes + References ---------- - model : :class:`torch.nn.Module` - The underlying BRITS model. - - optimizer : :class:`pypots.optim.Optimizer` - The optimizer for model training. + .. [1] `Cao, Wei, Dong Wang, Jian Li, Hao Zhou, Lei Li, and Yitan Li. + "Brits: Bidirectional recurrent imputation for time series." + Advances in neural information processing systems 31 (2018). + `_ """ diff --git a/pypots/classification/grud/model.py b/pypots/classification/grud/model.py index a9e4f6e6..7d122d8f 100644 --- a/pypots/classification/grud/model.py +++ b/pypots/classification/grud/model.py @@ -169,13 +169,13 @@ class GRUD(BaseNNClassifier): The "better" strategy will automatically save the model during training whenever the model performs better than in previous epochs. - Attributes + References ---------- - model : :class:`torch.nn.Module` - The underlying GRU-D model. + .. [1] `Che, Zhengping, Sanjay Purushotham, Kyunghyun Cho, David Sontag, and Yan Liu. + "Recurrent neural networks for multivariate time series with missing values." + Scientific reports 8, no. 1 (2018): 6085. + `_ - optimizer : :class:`pypots.optim.Optimizer` - The optimizer for model training. """ def __init__( diff --git a/pypots/classification/raindrop/model.py b/pypots/classification/raindrop/model.py index 75bd1470..06d68e8b 100644 --- a/pypots/classification/raindrop/model.py +++ b/pypots/classification/raindrop/model.py @@ -367,14 +367,12 @@ class Raindrop(BaseNNClassifier): The "better" strategy will automatically save the model during training whenever the model performs better than in previous epochs. - Attributes + References ---------- - model : :class:`torch.nn.Module` - The underlying Raindrop model. - - optimizer : :class:`pypots.optim.Optimizer` - The optimizer for model training. - + .. [1] `Zhang, Xiang, Marko Zeman, Theodoros Tsiligkaridis, and Marinka Zitnik. + "Graph-guided network for irregularly sampled multivariate time series." + International Conference on Learning Representations (ICLR). 2022. + `_ """ def __init__( diff --git a/pypots/clustering/crli/model.py b/pypots/clustering/crli/model.py index 26f0a769..35be0034 100644 --- a/pypots/clustering/crli/model.py +++ b/pypots/clustering/crli/model.py @@ -196,13 +196,13 @@ class CRLI(BaseNNClusterer): The "better" strategy will automatically save the model during training whenever the model performs better than in previous epochs. - Attributes + References ---------- - model : :class:`torch.nn.Module` - The underlying CRLI model. - - optimizer : :class:`pypots.optim.Optimizer` - The optimizer for model training. + .. [1] `Ma, Qianli, Chuxin Chen, Sen Li, and Garrison W. Cottrell. 2021. + "Learning Representations for Incomplete Time Series Clustering". + Proceedings of the AAAI Conference on Artificial Intelligence 35 (10):8837-46. + https://doi.org/10.1609/aaai.v35i10.17070. + `_ """ diff --git a/pypots/clustering/vader/model.py b/pypots/clustering/vader/model.py index 7c85ad13..9e9198f5 100644 --- a/pypots/clustering/vader/model.py +++ b/pypots/clustering/vader/model.py @@ -328,13 +328,15 @@ class VaDER(BaseNNClusterer): The "better" strategy will automatically save the model during training whenever the model performs better than in previous epochs. - Attributes + References ---------- - model : :class:`torch.nn.Module` - The underlying VaDER model. + .. [1] `de Jong, Johann, Mohammad Asif Emon, Ping Wu, Reagon Karki, Meemansa Sood, Patrice Godard, + Ashar Ahmad, Henri Vrooman, Martin Hofmann-Apitius, and Holger Fröhlich. + "Deep learning for clustering of multivariate clinical patient trajectories with missing values." + GigaScience 8, no. 11 (2019): giz134. + `_ + - optimizer : :class:`pypots.optim.Optimizer` - The optimizer for model training. """ diff --git a/pypots/forecasting/bttf/model.py b/pypots/forecasting/bttf/model.py index 500412a9..57712ee1 100644 --- a/pypots/forecasting/bttf/model.py +++ b/pypots/forecasting/bttf/model.py @@ -311,6 +311,13 @@ class BTTF(BaseForecaster): 2). ``n_steps - pred_step`` must be larger than ``max(time_lags)``; + References + ---------- + .. [1] `Chen, Xinyu, and Lijun Sun. + "Bayesian temporal factorization for multidimensional time series prediction." + IEEE Transactions on Pattern Analysis and Machine Intelligence 44, no. 9 (2021): 4659-4673. + `_ + """ def __init__( diff --git a/pypots/imputation/brits/model.py b/pypots/imputation/brits/model.py index 0ce03f97..e53826f3 100644 --- a/pypots/imputation/brits/model.py +++ b/pypots/imputation/brits/model.py @@ -394,13 +394,12 @@ class BRITS(BaseNNImputer): The "better" strategy will automatically save the model during training whenever the model performs better than in previous epochs. - Attributes + References ---------- - model : :class:`torch.nn.Module` - The underlying BRITS model. - - optimizer : :class:`pypots.optim.Optimizer` - The optimizer for model training. + .. [1] `Cao, Wei, Dong Wang, Jian Li, Hao Zhou, Lei Li, and Yitan Li. + "Brits: Bidirectional recurrent imputation for time series." + Advances in neural information processing systems 31 (2018). + `_ """ diff --git a/pypots/imputation/gpvae/model.py b/pypots/imputation/gpvae/model.py index 350ff14e..b38332cc 100644 --- a/pypots/imputation/gpvae/model.py +++ b/pypots/imputation/gpvae/model.py @@ -173,7 +173,6 @@ def forward(self, inputs, training=True): @staticmethod def kl_divergence(a, b): - # TODO: different from the author's implementation return torch.distributions.kl.kl_divergence(a, b) def _init_prior(self): @@ -222,36 +221,36 @@ def _init_prior(self): class GPVAE(BaseNNImputer): - """The PyTorch implementation of the GPVAE model :cite:``. + """The PyTorch implementation of the GPVAE model :cite:`fortuin2020GPVAEDeep`. Parameters ---------- - beta: + beta: float The weight of KL divergence in EBLO. - kernel: + kernel: str The type of kernel function chosen in the Gaussain Process Proir. ["cauchy", "diffusion", "rbf", "matern"] - batch_size : + batch_size : int The batch size for training and evaluating the model. - epochs : + epochs : int The number of epochs for training the model. - patience : + patience : int The patience for the early-stopping mechanism. Given a positive integer, the training process will be stopped when the model does not perform better after that number of epochs. Leaving it default as None will disable the early-stopping. - optimizer : + optimizer : pypots.optim.base.Optimizer The optimizer for model training. If not given, will use a default Adam optimizer. - num_workers : + num_workers : int The number of subprocesses to use for data loading. `0` means data loading will be in the main process, i.e. there won't be subprocesses. - device : + device : :class:`torch.device` or list The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. @@ -259,24 +258,24 @@ class GPVAE(BaseNNImputer): model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. - saving_path : + saving_path : str The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during training into a tensorboard file). Will not save if not given. - model_saving_strategy : + model_saving_strategy : str The strategy to save model checkpoints. It has to be one of [None, "best", "better"]. No model will be saved when it is set as None. The "best" strategy will only automatically save the best model after the training finished. The "better" strategy will automatically save the model during training whenever the model performs better than in previous epochs. - Attributes + References ---------- - model : :class:`torch.nn.Module` - The underlying GPVAE model. - - optimizer : :class:`pypots.optim.Optimizer` - The optimizer for model training. + .. [1] `Fortuin, V., Baranchuk, D., Raetsch, G. & Mandt, S.. (2020). + "GP-VAE: Deep Probabilistic Time Series Imputation". + Proceedings of the Twenty Third International Conference on Artificial Intelligence and Statistics, + in Proceedings of Machine Learning Research 108:1651-1661 + `_ """ diff --git a/pypots/imputation/mrnn/model.py b/pypots/imputation/mrnn/model.py index 5d50cc32..536a14d2 100644 --- a/pypots/imputation/mrnn/model.py +++ b/pypots/imputation/mrnn/model.py @@ -152,13 +152,13 @@ class MRNN(BaseNNImputer): The "better" strategy will automatically save the model during training whenever the model performs better than in previous epochs. - Attributes + References ---------- - model : :class:`torch.nn.Module` - The underlying BRITS model. - - optimizer : :class:`pypots.optim.Optimizer` - The optimizer for model training. + .. [1] `J. Yoon, W. R. Zame and M. van der Schaar, + "Estimating Missing Data in Temporal Data Streams Using Multi-Directional Recurrent Neural Networks," + in IEEE Transactions on Biomedical Engineering, + vol. 66, no. 5, pp. 1477-1490, May 2019, doi: 10.1109/TBME.2018.2874712. + `_ """ diff --git a/pypots/imputation/saits/model.py b/pypots/imputation/saits/model.py index 85731df7..d336678d 100644 --- a/pypots/imputation/saits/model.py +++ b/pypots/imputation/saits/model.py @@ -280,13 +280,12 @@ class SAITS(BaseNNImputer): The "better" strategy will automatically save the model during training whenever the model performs better than in previous epochs. - Attributes + References ---------- - model : :class:`torch.nn.Module` - The underlying SAITS model. - - optimizer : :class:`pypots.optim.Optimizer` - The optimizer for model training. + .. [1] `Du, Wenjie, David Côté, and Yan Liu. + "Saits: Self-attention-based imputation for time series". + Expert Systems with Applications 219 (2023): 119619. + `_ """ diff --git a/pypots/imputation/transformer/model.py b/pypots/imputation/transformer/model.py index dfc925ad..72b13b41 100644 --- a/pypots/imputation/transformer/model.py +++ b/pypots/imputation/transformer/model.py @@ -201,13 +201,18 @@ class Transformer(BaseNNImputer): The "better" strategy will automatically save the model during training whenever the model performs better than in previous epochs. - Attributes + References ---------- - model : :class:`torch.nn.Module` - The underlying Transformer model. - - optimizer : :class:`pypots.optim.Optimizer` - The optimizer for model training. + .. [1] `Vaswani, Ashish, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Łukasz Kaiser, + and Illia Polosukhin. + "Attention is all you need." + Advances in neural information processing systems 30 (2017). + `_ + + .. [2] `Du, Wenjie, David Côté, and Yan Liu. + "Saits: Self-attention-based imputation for time series". + Expert Systems with Applications 219 (2023): 119619. + `_ """ diff --git a/pypots/imputation/usgan/model.py b/pypots/imputation/usgan/model.py index c171d810..ff7fd785 100644 --- a/pypots/imputation/usgan/model.py +++ b/pypots/imputation/usgan/model.py @@ -101,33 +101,7 @@ def forward( class _USGAN(nn.Module): - """model USGAN: - USGAN consists of a generator, a discriminator, which are all built on bidirectional recurrent neural networks. - - Attributes - ---------- - n_steps : - sequence length (number of time steps) - - n_features : - number of features (input dimensions) - - rnn_hidden_size : - the hidden size of the RNN cell - - lambda_mse : - the weigth of the reconstruction loss - - hint_rate : - the hint rate for the discriminator - - dropout_rate : - the dropout rate for the last layer in Discriminator - - device : - specify running the model on which device, CPU/GPU - - """ + """USGAN model""" def __init__( self, @@ -192,58 +166,58 @@ def forward( class USGAN(BaseNNImputer): - """The PyTorch implementation of the CRLI model :cite:`ma2021CRLI`. + """The PyTorch implementation of the USGAN model. Refer to :cite:`miao2021SSGAN`. Parameters ---------- - n_steps : + n_steps : int The number of time steps in the time-series data sample. - n_features : + n_features : int The number of features in the time-series data sample. - rnn_hidden_size : - the hidden size of the RNN cell + rnn_hidden_size : int + The hidden size of the RNN cell - lambda_mse : - the weight of the reconstruction loss + lambda_mse : float + The weight of the reconstruction loss - hint_rate : - the hint rate for the discriminator + hint_rate : float + The hint rate for the discriminator - dropout_rate : - the dropout rate for the last layer in Discriminator + dropout_rate : float + The dropout rate for the last layer in Discriminator - G_steps : + G_steps : int The number of steps to train the generator in each iteration. - D_steps : + D_steps : int The number of steps to train the discriminator in each iteration. - batch_size : + batch_size : int The batch size for training and evaluating the model. - epochs : + epochs : int The number of epochs for training the model. - patience : + patience : int The patience for the early-stopping mechanism. Given a positive integer, the training process will be stopped when the model does not perform better after that number of epochs. Leaving it default as None will disable the early-stopping. - G_optimizer : + G_optimizer : :class:`pypots.optim.Optimizer` The optimizer for the generator training. If not given, will use a default Adam optimizer. - D_optimizer : + D_optimizer : :class:`pypots.optim.Optimizer` The optimizer for the discriminator training. If not given, will use a default Adam optimizer. - num_workers : + num_workers : int The number of subprocesses to use for data loading. `0` means data loading will be in the main process, i.e. there won't be subprocesses. - device : + device : Union[str, torch.device, list] The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. @@ -251,24 +225,23 @@ class USGAN(BaseNNImputer): model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. - saving_path : + saving_path : str The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during training into a tensorboard file). Will not save if not given. - model_saving_strategy : + model_saving_strategy : str The strategy to save model checkpoints. It has to be one of [None, "best", "better"]. No model will be saved when it is set as None. The "best" strategy will only automatically save the best model after the training finished. The "better" strategy will automatically save the model during training whenever the model performs better than in previous epochs. - Attributes + References ---------- - model : :class:`torch.nn.Module` - The underlying CRLI model. - - optimizer : :class:`pypots.optim.Optimizer` - The optimizer for model training. + .. [1] `Miao, Xiaoye, Yangyang Wu, Jun Wang, Yunjun Gao, Xudong Mao, and Jianwei Yin. 2021. + "Generative Semi-Supervised Learning for Multivariate Time Series Imputation". + Proceedings of the AAAI Conference on Artificial Intelligence 35 (10):8983-91. + `_ """ diff --git a/pypots/optim/adadelta.py b/pypots/optim/adadelta.py index 59e98f2a..4ff0037d 100644 --- a/pypots/optim/adadelta.py +++ b/pypots/optim/adadelta.py @@ -15,23 +15,25 @@ class Adadelta(Optimizer): - """The optimizer wrapper for PyTorch Adadelta. - https://pytorch.org/docs/stable/generated/torch.optim.Adadelta.html#torch.optim.Adadelta + """The optimizer wrapper for PyTorch Adadelta :class:`torch.optim.Adadelta`. Parameters ---------- - lr : + lr : float The learning rate of the optimizer. - rho : + rho : float Coefficient used for computing a running average of squared gradients. - eps : + eps : float Term added to the denominator to improve numerical stability. - weight_decay : + weight_decay : float Weight decay (L2 penalty). + lr_scheduler : pypots.optim.lr_scheduler.base.LRScheduler + The learning rate scheduler of the optimizer. + """ def __init__( diff --git a/pypots/optim/adagrad.py b/pypots/optim/adagrad.py index 8a10f06c..b25efbc1 100644 --- a/pypots/optim/adagrad.py +++ b/pypots/optim/adagrad.py @@ -15,26 +15,28 @@ class Adagrad(Optimizer): - """The optimizer wrapper for PyTorch Adagrad. - https://pytorch.org/docs/stable/generated/torch.optim.Adagrad.html#torch.optim.Adagrad + """The optimizer wrapper for PyTorch Adagrad :class:`torch.optim.Adagrad`. Parameters ---------- - lr : + lr : float The learning rate of the optimizer. - lr_decay : + lr_decay : float Learning rate decay. - weight_decay : + weight_decay : float Weight decay (L2 penalty). - eps : + eps : float Term added to the denominator to improve numerical stability. - initial_accumulator_value : + initial_accumulator_value : float A floating point value. Starting value for the accumulators, must be positive. + lr_scheduler : pypots.optim.lr_scheduler.base.LRScheduler + The learning rate scheduler of the optimizer. + """ def __init__( diff --git a/pypots/optim/adam.py b/pypots/optim/adam.py index c5e0e1af..9817b50e 100644 --- a/pypots/optim/adam.py +++ b/pypots/optim/adam.py @@ -15,25 +15,27 @@ class Adam(Optimizer): - """The optimizer wrapper for PyTorch Adam. - https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam + """The optimizer wrapper for PyTorch Adam :class:`torch.optim.Adam`. Parameters ---------- - lr : + lr : float The learning rate of the optimizer. - betas : + betas : Tuple[float, float] Coefficients used for computing running averages of gradient and its square. - eps : + eps : float Term added to the denominator to improve numerical stability. - weight_decay : + weight_decay : float Weight decay (L2 penalty). - amsgrad : + amsgrad : bool Whether to use the AMSGrad variant of this algorithm from the paper :cite:`reddi2018OnTheConvergence`. + + lr_scheduler : pypots.optim.lr_scheduler.base.LRScheduler + The learning rate scheduler of the optimizer. """ def __init__( diff --git a/pypots/optim/adamw.py b/pypots/optim/adamw.py index 6a5191e4..26887b2c 100644 --- a/pypots/optim/adamw.py +++ b/pypots/optim/adamw.py @@ -15,25 +15,28 @@ class AdamW(Optimizer): - """The optimizer wrapper for PyTorch AdamW. - https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html#torch.optim.AdamW + """The optimizer wrapper for PyTorch AdamW :class:`torch.optim.AdamW`. Parameters ---------- - lr : + lr : float The learning rate of the optimizer. - betas : + betas : Tuple[float, float] Coefficients used for computing running averages of gradient and its square. - eps : + eps : float Term added to the denominator to improve numerical stability. - weight_decay : + weight_decay : float Weight decay (L2 penalty). - amsgrad : + amsgrad : bool Whether to use the AMSGrad variant of this algorithm from the paper :cite:`reddi2018OnTheConvergence`. + + lr_scheduler : pypots.optim.lr_scheduler.base.LRScheduler + The learning rate scheduler of the optimizer. + """ def __init__( diff --git a/pypots/optim/base.py b/pypots/optim/base.py index db09fb3a..6a57ab7e 100644 --- a/pypots/optim/base.py +++ b/pypots/optim/base.py @@ -23,13 +23,16 @@ class Optimizer(ABC): - """The base wrapper for PyTorch optimizers, also is the base class for all optimizers in pypots.optim. + """The base wrapper for PyTorch optimizers, also is the base class for all optimizers in PyPOTS. Parameters ---------- - lr : + lr : float The learning rate of the optimizer. + lr_scheduler : pypots.optim.lr_scheduler.base.LRScheduler + The learning rate scheduler of the optimizer. + Attributes ---------- torch_optimizer : @@ -95,7 +98,7 @@ def step(self, closure: Optional[Callable] = None) -> None: ---------- closure : A closure that reevaluates the model and returns the loss. Optional for most optimizers. - Refer to the torch.optim.Optimizer.step() docstring for more details. + Refer to the :class:`torch.optim.Optimizer.step()` docstring for more details. """ self.torch_optimizer.step(closure) diff --git a/pypots/optim/lr_scheduler/__init__.py b/pypots/optim/lr_scheduler/__init__.py index ddb14350..89015847 100644 --- a/pypots/optim/lr_scheduler/__init__.py +++ b/pypots/optim/lr_scheduler/__init__.py @@ -1,9 +1,9 @@ """ Learning rate schedulers available in PyPOTS. Their functionalities are the same with those in PyTorch, the only difference that is also why we implement them is that you don't have to pass according optimizers -into them immediately while initializing them. Instead, you can pass them into pypots.optim.Optimizer -after initialization and call their `init_scheduler()` method in pypots.optim.Optimizer.init_optimizer() to initialize -schedulers together with optimizers. +into them immediately while initializing them. Instead, you can pass them into :class:`pypots.optim.base.Optimizer` +after initialization and call their `init_scheduler()` method in :class:`pypots.optim.base.Optimizer.init_optimizer()` +to initialize schedulers together with optimizers. """ # Created by Wenjie Du diff --git a/pypots/optim/lr_scheduler/base.py b/pypots/optim/lr_scheduler/base.py index 0aeffd8b..9c787ae7 100644 --- a/pypots/optim/lr_scheduler/base.py +++ b/pypots/optim/lr_scheduler/base.py @@ -37,12 +37,12 @@ def __init__(self, last_epoch=-1, verbose=False): self._step_count = 0 def init_scheduler(self, optimizer): - """Initialize the scheduler. This method should be called in pypots.optim.Optimizer.init_optimizer() - to initialize the scheduler together with the optimizer. + """Initialize the scheduler. This method should be called in + :class:`pypots.optim.base.Optimizer.init_optimizer()` to initialize the scheduler together with the optimizer. Parameters ---------- - optimizer: torch.optim.Optimizer, + optimizer: torch.optim.Optimizer The optimizer to be scheduled. """ @@ -113,8 +113,9 @@ def print_lr(is_verbose, group, lr): logger.info(f"Adjusting learning rate of group {group} to {lr:.4e}.") def step(self): - """Step could be called after every batch update. This should be called in ``pypots.optim.Optimizer.step()`` - after ``pypots.optim.Optimizer.torch_optimizer.step()``. + """Step could be called after every batch update. + This should be called in :class:`pypots.optim.base.Optimizer.step()` after + :class:`pypots.optim.base.Optimizer.torch_optimizer.step()`. """ # Raise a warning if old pattern is detected # https://github.com/pytorch/pytorch/issues/20124 diff --git a/pypots/optim/rmsprop.py b/pypots/optim/rmsprop.py index f00da68d..9451c0a4 100644 --- a/pypots/optim/rmsprop.py +++ b/pypots/optim/rmsprop.py @@ -15,29 +15,31 @@ class RMSprop(Optimizer): - """The optimizer wrapper for PyTorch RMSprop. - https://pytorch.org/docs/stable/generated/torch.optim.RMSprop.html#torch.optim.RMSprop + """The optimizer wrapper for PyTorch RMSprop :class:`torch.optim.RMSprop`. Parameters ---------- - lr : + lr : float The learning rate of the optimizer. - momentum : + momentum : float Momentum factor. - alpha : + alpha : float Smoothing constant. - eps : + eps : float Term added to the denominator to improve numerical stability. - centered : + centered : bool If True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance - weight_decay : + weight_decay : float Weight decay (L2 penalty). + lr_scheduler : pypots.optim.lr_scheduler.base.LRScheduler + The learning rate scheduler of the optimizer. + """ def __init__( diff --git a/pypots/optim/sgd.py b/pypots/optim/sgd.py index 34cd07f0..b31baf5f 100644 --- a/pypots/optim/sgd.py +++ b/pypots/optim/sgd.py @@ -1,5 +1,5 @@ """ -The optimizer wrapper for PyTorch SGD. +The optimizer wrapper for PyTorch SGD :class:`torch.optim.SGD`. """ @@ -15,26 +15,28 @@ class SGD(Optimizer): - """The optimizer wrapper for PyTorch SGD. - https://pytorch.org/docs/stable/generated/torch.optim.SGD.html#torch.optim.SGD + """The optimizer wrapper for PyTorch SGD :class:`torch.optim.SGD`. Parameters ---------- - lr : + lr : float The learning rate of the optimizer. - momentum : + momentum : float Momentum factor. - weight_decay : + weight_decay : float Weight decay (L2 penalty). - dampening : + dampening : float Dampening for momentum. - nesterov : + nesterov : bool Whether to enable Nesterov momentum. + lr_scheduler : pypots.optim.lr_scheduler.base.LRScheduler + The learning rate scheduler of the optimizer. + """ def __init__( diff --git a/pypots/utils/metrics.py b/pypots/utils/metrics.py index ac239648..1327ee7a 100644 --- a/pypots/utils/metrics.py +++ b/pypots/utils/metrics.py @@ -497,6 +497,15 @@ def cal_rand_index( RI : Rand index. + References + ---------- + .. L. Hubert and P. Arabie, Comparing Partitions, Journal of + Classification 1985 + https://link.springer.com/article/10.1007%2FBF01908075 + + .. https://en.wikipedia.org/wiki/Simple_matching_coefficient + + .. https://en.wikipedia.org/wiki/Rand_index """ # # detailed implementation # n = len(targets) @@ -523,7 +532,7 @@ def cal_adjusted_rand_index( class_predictions: np.ndarray, targets: np.ndarray, ) -> float: - """Calculate adjusted Rand Index. Refer to :cite:`hubert1985AdjustedRI`. + """Calculate adjusted Rand Index. Parameters ---------- @@ -538,6 +547,17 @@ def cal_adjusted_rand_index( aRI : Adjusted Rand index. + References + ---------- + .. [Hubert1985] L. Hubert and P. Arabie, Comparing Partitions, + Journal of Classification 1985 + https://link.springer.com/article/10.1007%2FBF01908075 + + .. [Steinley2004] D. Steinley, Properties of the Hubert-Arabie + adjusted Rand index, Psychological Methods 2004 + + .. [wk] https://en.wikipedia.org/wiki/Rand_index#Adjusted_Rand_index + """ aRI = metrics.adjusted_rand_score(targets, class_predictions) return aRI @@ -644,7 +664,17 @@ def cal_silhouette(X: np.ndarray, predicted_labels: np.ndarray) -> float: Returns ------- silhouette_score : float - Mean Silhouette Coefficient for all samples. + Mean Silhouette Coefficient for all samples. In short, the higher, the better. + + References + ---------- + .. [1] `Peter J. Rousseeuw (1987). "Silhouettes: a Graphical Aid to the + Interpretation and Validation of Cluster Analysis". Computational + and Applied Mathematics 20: 53-65. + `_ + + .. [2] `Wikipedia entry on the Silhouette Coefficient + `_ """ silhouette_score = metrics.silhouette_score(X, predicted_labels) @@ -659,10 +689,17 @@ def cal_chs(X: np.ndarray, predicted_labels: np.ndarray) -> float: predicted_labels : array-like of shape (n_samples) Predicted labels for each sample. + Returns ------- calinski_harabasz_score : float - The resulting Calinski-Harabasz score. + The resulting Calinski-Harabasz score. In short, the higher, the better. + + References + ---------- + .. [1] `T. Calinski and J. Harabasz, 1974. "A dendrite method for cluster + analysis". Communications in Statistics + `_ """ calinski_harabasz_score = metrics.calinski_harabasz_score(X, predicted_labels) @@ -683,7 +720,15 @@ def cal_dbs(X: np.ndarray, predicted_labels: np.ndarray) -> float: Returns ------- davies_bouldin_score : float - The resulting Davies-Bouldin score. + The resulting Davies-Bouldin score. In short, the lower, the better. + + References + ---------- + .. [1] Davies, David L.; Bouldin, Donald W. (1979). + `"A Cluster Separation Measure" + `__. + IEEE Transactions on Pattern Analysis and Machine Intelligence. + PAMI-1 (2): 224-227 """ davies_bouldin_score = metrics.davies_bouldin_score(X, predicted_labels) From c426cb23974157fe64e04990b2450541d2e534ee Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 28 Sep 2023 19:14:43 +0800 Subject: [PATCH 8/8] refactor: rename gene_incomplete_random_walk_dataset() to gene_random_walk(); --- pypots/data/__init__.py | 10 +-- pypots/data/generating.py | 130 ++++++++++++++++++++---------------- tests/global_test_config.py | 10 ++- 3 files changed, 86 insertions(+), 64 deletions(-) diff --git a/pypots/data/__init__.py b/pypots/data/__init__.py index a3a68be9..dc1bfbf8 100644 --- a/pypots/data/__init__.py +++ b/pypots/data/__init__.py @@ -8,8 +8,9 @@ from .base import BaseDataset from .generating import ( gene_complete_random_walk, - gene_random_walk_for_classification, - gene_incomplete_random_walk_dataset, + gene_complete_random_walk_for_anomaly_detection, + gene_complete_random_walk_for_classification, + gene_random_walk, gene_physionet2012, ) from .load_specific_datasets import ( @@ -29,8 +30,9 @@ "BaseDataset", # data generation "gene_complete_random_walk", - "gene_random_walk_for_classification", - "gene_incomplete_random_walk_dataset", + "gene_complete_random_walk_for_anomaly_detection", + "gene_complete_random_walk_for_classification", + "gene_random_walk", "gene_physionet2012", # list and load datasets "list_supported_datasets", diff --git a/pypots/data/generating.py b/pypots/data/generating.py index e80efe49..f0a20473 100644 --- a/pypots/data/generating.py +++ b/pypots/data/generating.py @@ -26,26 +26,26 @@ def gene_complete_random_walk( std: float = 1.0, random_state: Optional[int] = None, ) -> np.ndarray: - """Generate complete random walk time-series data. + """Generate complete random walk time-series data, i.e. having no missing values. Parameters ---------- - n_samples : + n_samples : int, default=1000 The number of training time-series samples to generate. n_steps: int, default=24 The number of time steps (length) of generated time-series samples. - n_features : + n_features : int, default=10 The number of features (dimensions) of generated time-series samples. - mu : + mu : float, default=0.0 Mean of the normal distribution, which random walk steps are sampled from. - std : + std : float, default=1.0 Standard deviation of the normal distribution, which random walk steps are sampled from. - random_state : + random_state : int, default=None Random seed for data generation. Returns @@ -63,7 +63,7 @@ def gene_complete_random_walk( return ts_samples -def gene_random_walk_for_classification( +def gene_complete_random_walk_for_classification( n_classes: int = 2, n_samples_each_class: int = 500, n_steps: int = 24, @@ -75,37 +75,39 @@ def gene_random_walk_for_classification( Parameters ---------- - n_classes : + n_classes : int, must >=1, default=2 Number of classes (types) of the generated data. - n_samples_each_class : + n_samples_each_class : int, default=500 Number of samples for each class to generate. - n_steps : + n_steps : int, default=24 Number of time steps in each sample. - n_features : + n_features : int, default=10 Number of features. - shuffle : + shuffle : bool, default=True Whether to shuffle generated samples. If not, you can separate samples of each class according to `n_samples_each_class`. For example, X_class0=X[:n_samples_each_class], X_class1=X[n_samples_each_class:n_samples_each_class*2] - random_state : + random_state : int, default=None Random seed for data generation. Returns ------- - X : + X : array, shape of [n_samples, n_steps, n_features] Generated time-series data. - y : + y : array, shape of [n_samples] Labels indicating classes of time-series samples. """ + assert n_classes > 1, f"n_classes should be >1, but got {n_classes}" + ts_collector = [] label_collector = [] @@ -149,39 +151,39 @@ def gene_complete_random_walk_for_anomaly_detection( Parameters ---------- - n_samples : + n_samples : int, default=1000 The number of training time-series samples to generate. - n_features : + n_features : int, default=10 The number of features (dimensions) of generated time-series samples. n_steps: int, default=24 The number of time steps (length) of generated time-series samples. - mu : + mu : float, default=0.0 Mean of the normal distribution, which random walk steps are sampled from. - std : + std : float, default=1.0 Standard deviation of the normal distribution, which random walk steps are sampled from. - anomaly_proportion : + anomaly_proportion : float, default=0.1 Proportion of anomaly samples in all samples. - anomaly_fraction : + anomaly_fraction : float, default=0.02 Fraction of anomaly points in each anomaly sample. - anomaly_scale_factor : + anomaly_scale_factor : float, default=2.0 Scale factor for value scaling to create anomaly points in time series samples. - random_state : + random_state : int, default=None Random seed for data generation. Returns ------- - X : + X : array, shape of [n_samples, n_steps, n_features] Generated time-series data. - y : + y : array, shape of [n_samples] Labels indicating if time-series samples are anomalies. """ assert ( @@ -225,35 +227,41 @@ def gene_complete_random_walk_for_anomaly_detection( return X, y -def gene_incomplete_random_walk_dataset( - n_steps=24, n_features=10, n_classes=2, n_samples_each_class=1000, missing_rate=0.1 +def gene_random_walk( + n_steps=24, + n_features=10, + n_classes=2, + n_samples_each_class=1000, + missing_rate=0.1, ) -> dict: """Generate a random-walk data. Parameters ---------- - n_steps : + n_steps : int, default=24 Number of time steps in each sample. - n_features : + n_features : int, default=10 Number of features. - n_classes : + n_classes : int, default=2 Number of classes (types) of the generated data. - n_samples_each_class : + n_samples_each_class : int, default=1000 Number of samples for each class to generate. - missing_rate : - The rate of randomly missing values to generate. + missing_rate : float, default=0.1 + The rate of randomly missing values to generate, should be in [0,1). Returns ------- data: dict, A dictionary containing the generated data. """ + assert 0 <= missing_rate < 1, "missing_rate must be in [0,1)" + # generate samples - X, y = gene_random_walk_for_classification( + X, y = gene_complete_random_walk_for_classification( n_classes=n_classes, n_samples_each_class=n_samples_each_class, n_steps=n_steps, @@ -262,12 +270,14 @@ def gene_incomplete_random_walk_dataset( # split into train/val/test sets train_X, test_X, train_y, test_y = train_test_split(X, y, test_size=0.2) train_X, val_X, train_y, val_y = train_test_split(train_X, train_y, test_size=0.2) - # create random missing values - _, train_X, missing_mask, _ = mcar(train_X, missing_rate) - train_X = masked_fill(train_X, 1 - missing_mask, torch.nan) - _, val_X, missing_mask, _ = mcar(val_X, missing_rate) - val_X = masked_fill(val_X, 1 - missing_mask, torch.nan) - # test set is left to mask after normalization + + if missing_rate > 0: + # create random missing values + _, train_X, missing_mask, _ = mcar(train_X, missing_rate) + train_X = masked_fill(train_X, 1 - missing_mask, torch.nan) + _, val_X, missing_mask, _ = mcar(val_X, missing_rate) + val_X = masked_fill(val_X, 1 - missing_mask, torch.nan) + # test set is left to mask after normalization train_X = train_X.reshape(-1, n_features) val_X = val_X.reshape(-1, n_features) @@ -281,19 +291,6 @@ def gene_incomplete_random_walk_dataset( train_X = train_X.reshape(-1, n_steps, n_features) val_X = val_X.reshape(-1, n_steps, n_features) test_X = test_X.reshape(-1, n_steps, n_features) - - # mask values in the validation set as ground truth - val_X_intact, val_X, val_X_missing_mask, val_X_indicating_mask = mcar( - val_X, missing_rate - ) - val_X = masked_fill(val_X, 1 - val_X_missing_mask, torch.nan) - - # mask values in the test set as ground truth - test_X_intact, test_X, test_X_missing_mask, test_X_indicating_mask = mcar( - test_X, 0.3 - ) - test_X = masked_fill(test_X, 1 - test_X_missing_mask, torch.nan) - data = { "n_classes": n_classes, "n_steps": n_steps, @@ -302,13 +299,30 @@ def gene_incomplete_random_walk_dataset( "train_y": train_y, "val_X": val_X, "val_y": val_y, - "val_X_intact": val_X_intact, - "val_X_indicating_mask": val_X_indicating_mask, "test_X": test_X, "test_y": test_y, - "test_X_intact": test_X_intact, - "test_X_indicating_mask": test_X_indicating_mask, + "scaler": scaler, } + + if missing_rate > 0: + # mask values in the validation set as ground truth + val_X_intact, val_X, val_X_missing_mask, val_X_indicating_mask = mcar( + val_X, missing_rate + ) + val_X = masked_fill(val_X, 1 - val_X_missing_mask, torch.nan) + + # mask values in the test set as ground truth + test_X_intact, test_X, test_X_missing_mask, test_X_indicating_mask = mcar( + test_X, 0.3 + ) + test_X = masked_fill(test_X, 1 - test_X_missing_mask, torch.nan) + + data["val_X"] = val_X + data["val_X_intact"] = val_X_intact + data["val_X_indicating_mask"] = val_X_indicating_mask + data["test_X"] = test_X + data["test_X_intact"] = test_X_intact + data["test_X_indicating_mask"] = test_X_indicating_mask return data @@ -317,7 +331,7 @@ def gene_physionet2012(artificially_missing_rate: float = 0.1): Parameters ---------- - artificially_missing_rate : + artificially_missing_rate : float, default=0.1 The rate of artificially missing values to generate for model evaluation. This ratio is calculated based on the number of observed values, i.e. if artificially_missing_rate = 0.1, then 10% of the observed values will be randomly masked as missing data and hold out for model evaluation. diff --git a/tests/global_test_config.py b/tests/global_test_config.py index 5e152734..62ad73bb 100644 --- a/tests/global_test_config.py +++ b/tests/global_test_config.py @@ -9,12 +9,18 @@ import torch -from pypots.data.generating import gene_incomplete_random_walk_dataset +from pypots.data.generating import gene_random_walk from pypots.utils.logging import logger # Generate the unified data for testing and cache it first, DATA here is a singleton # Otherwise, file lock will cause bug if running test parallely with pytest-xdist. -DATA = gene_incomplete_random_walk_dataset() +DATA = gene_random_walk( + n_steps=24, + n_features=10, + n_classes=2, + n_samples_each_class=1000, + missing_rate=0.1, +) # The directory for saving the dataset into files for testing DATA_SAVING_DIR = "h5data_for_tests"