From 9195d120be822a46b27b67e95dd2a689171535fe Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Thu, 5 Dec 2024 21:42:23 +0200 Subject: [PATCH 1/2] Moved Batch Size Managers into a separate file Also renamed clip arguments as they were shadowing built-in functions --- bs_scheduler/__init__.py | 10 +-- bs_scheduler/batch_size_manager.py | 90 ++++++++++++++++++++ bs_scheduler/batch_size_schedulers.py | 117 +++----------------------- bs_scheduler/utils.py | 3 + 4 files changed, 111 insertions(+), 109 deletions(-) create mode 100644 bs_scheduler/batch_size_manager.py create mode 100644 bs_scheduler/utils.py diff --git a/bs_scheduler/__init__.py b/bs_scheduler/__init__.py index a9d92c0..ab37718 100644 --- a/bs_scheduler/__init__.py +++ b/bs_scheduler/__init__.py @@ -1,12 +1,10 @@ +from .batch_size_manager import BatchSizeManager, DefaultBatchSizeManager, CustomBatchSizeManager from .batch_size_schedulers import LambdaBS, MultiplicativeBS, StepBS, MultiStepBS, ConstantBS, LinearBS, ExponentialBS, \ SequentialBS, PolynomialBS, CosineAnnealingBS, ChainedBSScheduler, IncreaseBSOnPlateau, CyclicBS, \ - CosineAnnealingBSWithWarmRestarts, OneCycleBS, BSScheduler, BatchSizeManager + CosineAnnealingBSWithWarmRestarts, OneCycleBS, BSScheduler -# We do not export DefaultBatchSizeManager and CustomBatchSizeManager because they are not needed. Users with custom -# setups can create their own batch size managers. __all__ = ['LambdaBS', 'MultiplicativeBS', 'StepBS', 'MultiStepBS', 'ConstantBS', 'LinearBS', 'ExponentialBS', 'SequentialBS', 'PolynomialBS', 'CosineAnnealingBS', 'ChainedBSScheduler', 'IncreaseBSOnPlateau', 'CyclicBS', - 'CosineAnnealingBSWithWarmRestarts', 'OneCycleBS', 'BSScheduler', 'BatchSizeManager'] - -del batch_size_schedulers # noqa: F821 + 'CosineAnnealingBSWithWarmRestarts', 'OneCycleBS', 'BSScheduler', 'BatchSizeManager', + 'DefaultBatchSizeManager', 'CustomBatchSizeManager'] diff --git a/bs_scheduler/batch_size_manager.py b/bs_scheduler/batch_size_manager.py new file mode 100644 index 0000000..e86b76c --- /dev/null +++ b/bs_scheduler/batch_size_manager.py @@ -0,0 +1,90 @@ +from torch.utils.data import DataLoader, Dataset + +from .utils import check_isinstance + + +class BatchSizeManager: + """ Base class for all batch size managers, used for getting and setting the batch size. It is not mandatory to + inherit from this, but users must implement :meth:`get_current_batch_size` and :meth:`set_batch_size`. + """ + + def get_current_batch_size(self) -> int: + """ Returns the current batch size used by the dataloader as an :class:`int`. + """ + raise NotImplementedError + + def set_batch_size(self, new_bs: int): + """ Sets the new value of the batch size. + + Args: + new_bs (int): The new batch sizes that needs to be set. + """ + raise NotImplementedError + + +class DefaultBatchSizeManager(BatchSizeManager): + """ The default batch size manager used when the dataloader has a batch sampler. The batch sampler controls the + batch size used by the dataloader, and it can be queried and changed. Changes are reflected in the number of samples + given to the dataloader. See + https://github.com/pytorch/pytorch/blob/772e104dfdfd70c74cbc9600cfc946dc7c378f68/torch/utils/data/sampler.py#L241. + """ + + def __init__(self, dataloader: DataLoader): + check_isinstance(dataloader, DataLoader) + if dataloader.batch_sampler is None: + raise ValueError("Dataloader must have a batch sampler.") + self.dataloader: DataLoader = dataloader + + def get_current_batch_size(self) -> int: + """ Returns the current batch size used by the dataloader as an :class:`int`. The batch size member variable is + owned by the batch sampler. + """ + return self.dataloader.batch_sampler.batch_size + + def set_batch_size(self, new_bs: int): + """ Sets the new value of the batch size, which is owned by the batch sampler. + + Args: + new_bs (int): The new batch sizes that needs to be set. + """ + self.dataloader.batch_sampler.batch_size = new_bs + + +class CustomBatchSizeManager(BatchSizeManager): + """ Custom batch size manager, used when the dataloader does not use a batch sampler. In this case, the batch size + is controlled by the dataset wrapped by the dataloader, so this class expects the dataset to provide a getter and + a setter for the batch size, named :meth:`get_batch_size` and :meth:`change_batch_size` respectively. + """ + + def __init__(self, dataset: Dataset): + check_isinstance(dataset, Dataset) + if not hasattr(dataset, 'change_batch_size'): + raise KeyError("Because the dataloader does not have a batch sampler, the dataset owns and controls the " + "batch size. In order to change the batch size after dataloader creation we require our " + "users to implement a Callable[[int],None] method named `change_batch_size` in their " + "dataset which changes the batch size. Please see TODO for examples.") + if not hasattr(dataset, 'get_batch_size'): + raise KeyError("We require our users to implement a Callable[[], int] method named `get_batch_size` in " + "their dataset which returns the current batch size. Please see TODO for examples. ") + self.dataset = dataset + + def get_current_batch_size(self) -> int: + """ Returns the current batch size used by the dataset as an :class:`int`. + + In this case, the dataset controls the batch size, so we require our users to implement a + :class:`Callable[[], int]` method named :meth:`get_batch_size` in their dataset which returns the current value + of the batch size. + """ + return self.dataset.get_batch_size() + + def set_batch_size(self, new_bs: int): + """ Sets the new value of the batch size. + + In this case, the dataset controls the batch size, so we require our users to implement a + :class:`Callable[[int],None]` method named :meth:`change_batch_size` in their dataset which modifies the batch + size to the given value. + + Args: + new_bs (int): The new batch sizes that needs to be set. + """ + self.dataset.change_batch_size(new_bs) diff --git a/bs_scheduler/batch_size_schedulers.py b/bs_scheduler/batch_size_schedulers.py index 6f379a5..aaaa922 100644 --- a/bs_scheduler/batch_size_schedulers.py +++ b/bs_scheduler/batch_size_schedulers.py @@ -8,11 +8,14 @@ from typing import Callable, Union, Sequence, Tuple import torch -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader __all__ = ['LambdaBS', 'MultiplicativeBS', 'StepBS', 'MultiStepBS', 'ConstantBS', 'LinearBS', 'ExponentialBS', 'SequentialBS', 'PolynomialBS', 'CosineAnnealingBS', 'ChainedBSScheduler', 'IncreaseBSOnPlateau', 'CyclicBS', - 'CosineAnnealingBSWithWarmRestarts', 'OneCycleBS', 'BSScheduler', 'BatchSizeManager'] + 'CosineAnnealingBSWithWarmRestarts', 'OneCycleBS', 'BSScheduler'] + +from .batch_size_manager import BatchSizeManager, DefaultBatchSizeManager, CustomBatchSizeManager +from .utils import check_isinstance def rint(x: float) -> int: @@ -21,108 +24,16 @@ def rint(x: float) -> int: return int(round(x)) -def clip(x: int, min: int, max: int) -> int: +def clip(x: int, min_x: int, max_x: int) -> int: """ Clips x to [min, max] interval. """ - if x < min: - return min - if x > max: - return max + if x < min_x: + return min_x + if x > max_x: + return max_x return x -def check_isinstance(x, instance: type): - if not isinstance(x, instance): - raise TypeError(f"{type(x).__name__} is not a {instance.__name__}.") - - -class BatchSizeManager: - """ Base class for all batch size managers, used for getting and setting the batch size. It is not mandatory to - inherit from this, but users must implement :meth:`get_current_batch_size` and :meth:`set_batch_size`. - """ - - def get_current_batch_size(self) -> int: - """ Returns the current batch size used by the dataloader as an :class:`int`. - """ - raise NotImplementedError - - def set_batch_size(self, new_bs: int): - """ Sets the new value of the batch size. - - Args: - new_bs (int): The new batch sizes that needs to be set. - """ - raise NotImplementedError - - -class DefaultBatchSizeManager(BatchSizeManager): - """ The default batch size manager used when the dataloader has a batch sampler. The batch sampler controls the - batch size used by the dataloader, and it can be queried and changed. Changes are reflected in the number of samples - given to the dataloader. See - https://github.com/pytorch/pytorch/blob/772e104dfdfd70c74cbc9600cfc946dc7c378f68/torch/utils/data/sampler.py#L241. - """ - - def __init__(self, dataloader: DataLoader): - check_isinstance(dataloader, DataLoader) - if dataloader.batch_sampler is None: - raise ValueError("Dataloader must have a batch sampler.") - self.dataloader: DataLoader = dataloader - - def get_current_batch_size(self) -> int: - """ Returns the current batch size used by the dataloader as an :class:`int`. The batch size member variable is - owned by the batch sampler. - """ - return self.dataloader.batch_sampler.batch_size - - def set_batch_size(self, new_bs: int): - """ Sets the new value of the batch size, which is owned by the batch sampler. - - Args: - new_bs (int): The new batch sizes that needs to be set. - """ - self.dataloader.batch_sampler.batch_size = new_bs - - -class CustomBatchSizeManager(BatchSizeManager): - """ Custom batch size manager, used when the dataloader does not use a batch sampler. In this case, the batch size - is controlled by the dataset wrapped by the dataloader, so this class expects the dataset to provide a getter and - a setter for the batch size, named :meth:`get_batch_size` and :meth:`change_batch_size` respectively. - """ - - def __init__(self, dataset: Dataset): - check_isinstance(dataset, Dataset) - if not hasattr(dataset, 'change_batch_size'): - raise KeyError("Because the dataloader does not have a batch sampler, the dataset owns and controls the " - "batch size. In order to change the batch size after dataloader creation we require our " - "users to implement a Callable[[int],None] method named `change_batch_size` in their " - "dataset which changes the batch size. Please see TODO for examples.") - if not hasattr(dataset, 'get_batch_size'): - raise KeyError("We require our users to implement a Callable[[], int] method named `get_batch_size` in " - "their dataset which returns the current batch size. Please see TODO for examples. ") - self.dataset = dataset - - def get_current_batch_size(self) -> int: - """ Returns the current batch size used by the dataset as an :class:`int`. - - In this case, the dataset controls the batch size, so we require our users to implement a - :class:`Callable[[], int]` method named :meth:`get_batch_size` in their dataset which returns the current value - of the batch size. - """ - return self.dataset.get_batch_size() - - def set_batch_size(self, new_bs: int): - """ Sets the new value of the batch size. - - In this case, the dataset controls the batch size, so we require our users to implement a - :class:`Callable[[int],None]` method named :meth:`change_batch_size` in their dataset which modifies the batch - size to the given value. - - Args: - new_bs (int): The new batch sizes that needs to be set. - """ - self.dataset.change_batch_size(new_bs) - - class BSScheduler: def __init__(self, dataloader: DataLoader, batch_size_manager: Union[BatchSizeManager, None], max_batch_size: Union[int, None], min_batch_size: int, verbose: bool): @@ -261,7 +172,7 @@ def step(self, **kwargs): new_bs = self._internal_get_new_bs(**kwargs) if not self.min_batch_size <= new_bs <= self.max_batch_size: self._finished = True - new_bs = clip(new_bs, min=self.min_batch_size, max=self.max_batch_size) + new_bs = clip(new_bs, min_x=self.min_batch_size, max_x=self.max_batch_size) if new_bs != self.batch_size: self.set_batch_size(new_bs) self.print_bs(new_bs) @@ -943,7 +854,7 @@ def get_new_bs(self) -> int: self._float_batch_size - self.max_batch_size) + self.max_batch_size self._float_batch_size = new_bs - return clip(rint(new_bs), min=self.base_batch_size, max=self.max_batch_size) + return clip(rint(new_bs), min_x=self.base_batch_size, max_x=self.max_batch_size) class ChainedBSScheduler(BSScheduler): @@ -1436,7 +1347,7 @@ def get_new_bs(self) -> int: new_bs = self.base_batch_size + (self.max_batch_size - self.base_batch_size) * ( 1 + math.cos(math.pi + math.pi * self.t_cur / self.t_i)) / 2 - return clip(rint(new_bs), min=self.base_batch_size, max=self.max_batch_size) + return clip(rint(new_bs), min_x=self.base_batch_size, max_x=self.max_batch_size) class OneCycleBS(BSScheduler): @@ -1540,4 +1451,4 @@ def get_new_bs(self) -> int: if percentage == 1.0: self._finished = True - return clip(rint(new_bs), min=self.min_batch_size, max=self.max_batch_size) + return clip(rint(new_bs), min_x=self.min_batch_size, max_x=self.max_batch_size) diff --git a/bs_scheduler/utils.py b/bs_scheduler/utils.py new file mode 100644 index 0000000..9df4dfb --- /dev/null +++ b/bs_scheduler/utils.py @@ -0,0 +1,3 @@ +def check_isinstance(x, instance: type): + if not isinstance(x, instance): + raise TypeError(f"{type(x).__name__} is not a {instance.__name__}.") From a186a36ffa215c92ca9b49abcbe0174b6697dd2d Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Thu, 5 Dec 2024 21:45:46 +0200 Subject: [PATCH 2/2] Moved utils into separate file --- bs_scheduler/batch_size_schedulers.py | 18 +----------------- bs_scheduler/utils.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/bs_scheduler/batch_size_schedulers.py b/bs_scheduler/batch_size_schedulers.py index aaaa922..30eb4f1 100644 --- a/bs_scheduler/batch_size_schedulers.py +++ b/bs_scheduler/batch_size_schedulers.py @@ -15,23 +15,7 @@ 'CosineAnnealingBSWithWarmRestarts', 'OneCycleBS', 'BSScheduler'] from .batch_size_manager import BatchSizeManager, DefaultBatchSizeManager, CustomBatchSizeManager -from .utils import check_isinstance - - -def rint(x: float) -> int: - """ Rounds to the nearest int and returns the value as int. - """ - return int(round(x)) - - -def clip(x: int, min_x: int, max_x: int) -> int: - """ Clips x to [min, max] interval. - """ - if x < min_x: - return min_x - if x > max_x: - return max_x - return x +from .utils import check_isinstance, clip, rint class BSScheduler: diff --git a/bs_scheduler/utils.py b/bs_scheduler/utils.py index 9df4dfb..ccdf21d 100644 --- a/bs_scheduler/utils.py +++ b/bs_scheduler/utils.py @@ -1,3 +1,19 @@ def check_isinstance(x, instance: type): if not isinstance(x, instance): raise TypeError(f"{type(x).__name__} is not a {instance.__name__}.") + + +def rint(x: float) -> int: + """ Rounds to the nearest int and returns the value as int. + """ + return int(round(x)) + + +def clip(x: int, min_x: int, max_x: int) -> int: + """ Clips x to [min, max] interval. + """ + if x < min_x: + return min_x + if x > max_x: + return max_x + return x