Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moved Batch Size Managers into a separate file #17

Merged
merged 2 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions bs_scheduler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from .batch_size_manager import BatchSizeManager, DefaultBatchSizeManager, CustomBatchSizeManager
from .batch_size_schedulers import LambdaBS, MultiplicativeBS, StepBS, MultiStepBS, ConstantBS, LinearBS, ExponentialBS, \
SequentialBS, PolynomialBS, CosineAnnealingBS, ChainedBSScheduler, IncreaseBSOnPlateau, CyclicBS, \
CosineAnnealingBSWithWarmRestarts, OneCycleBS, BSScheduler, BatchSizeManager
CosineAnnealingBSWithWarmRestarts, OneCycleBS, BSScheduler

# We do not export DefaultBatchSizeManager and CustomBatchSizeManager because they are not needed. Users with custom
# setups can create their own batch size managers.

__all__ = ['LambdaBS', 'MultiplicativeBS', 'StepBS', 'MultiStepBS', 'ConstantBS', 'LinearBS', 'ExponentialBS',
'SequentialBS', 'PolynomialBS', 'CosineAnnealingBS', 'ChainedBSScheduler', 'IncreaseBSOnPlateau', 'CyclicBS',
'CosineAnnealingBSWithWarmRestarts', 'OneCycleBS', 'BSScheduler', 'BatchSizeManager']

del batch_size_schedulers # noqa: F821
'CosineAnnealingBSWithWarmRestarts', 'OneCycleBS', 'BSScheduler', 'BatchSizeManager',
'DefaultBatchSizeManager', 'CustomBatchSizeManager']
90 changes: 90 additions & 0 deletions bs_scheduler/batch_size_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from torch.utils.data import DataLoader, Dataset

from .utils import check_isinstance


class BatchSizeManager:
""" Base class for all batch size managers, used for getting and setting the batch size. It is not mandatory to
inherit from this, but users must implement :meth:`get_current_batch_size` and :meth:`set_batch_size`.
"""

def get_current_batch_size(self) -> int:
""" Returns the current batch size used by the dataloader as an :class:`int`.
"""
raise NotImplementedError

def set_batch_size(self, new_bs: int):
""" Sets the new value of the batch size.

Args:
new_bs (int): The new batch sizes that needs to be set.
"""
raise NotImplementedError


class DefaultBatchSizeManager(BatchSizeManager):
""" The default batch size manager used when the dataloader has a batch sampler. The batch sampler controls the
batch size used by the dataloader, and it can be queried and changed. Changes are reflected in the number of samples
given to the dataloader. See
https://github.com/pytorch/pytorch/blob/772e104dfdfd70c74cbc9600cfc946dc7c378f68/torch/utils/data/sampler.py#L241.
"""

def __init__(self, dataloader: DataLoader):
check_isinstance(dataloader, DataLoader)
if dataloader.batch_sampler is None:
raise ValueError("Dataloader must have a batch sampler.")
self.dataloader: DataLoader = dataloader

def get_current_batch_size(self) -> int:
""" Returns the current batch size used by the dataloader as an :class:`int`. The batch size member variable is
owned by the batch sampler.
"""
return self.dataloader.batch_sampler.batch_size

def set_batch_size(self, new_bs: int):
""" Sets the new value of the batch size, which is owned by the batch sampler.

Args:
new_bs (int): The new batch sizes that needs to be set.
"""
self.dataloader.batch_sampler.batch_size = new_bs


class CustomBatchSizeManager(BatchSizeManager):
""" Custom batch size manager, used when the dataloader does not use a batch sampler. In this case, the batch size
is controlled by the dataset wrapped by the dataloader, so this class expects the dataset to provide a getter and
a setter for the batch size, named :meth:`get_batch_size` and :meth:`change_batch_size` respectively.
"""

def __init__(self, dataset: Dataset):
check_isinstance(dataset, Dataset)
if not hasattr(dataset, 'change_batch_size'):
raise KeyError("Because the dataloader does not have a batch sampler, the dataset owns and controls the "
"batch size. In order to change the batch size after dataloader creation we require our "
"users to implement a Callable[[int],None] method named `change_batch_size` in their "
"dataset which changes the batch size. Please see TODO for examples.")
if not hasattr(dataset, 'get_batch_size'):
raise KeyError("We require our users to implement a Callable[[], int] method named `get_batch_size` in "
"their dataset which returns the current batch size. Please see TODO for examples. ")
self.dataset = dataset

def get_current_batch_size(self) -> int:
""" Returns the current batch size used by the dataset as an :class:`int`.

In this case, the dataset controls the batch size, so we require our users to implement a
:class:`Callable[[], int]` method named :meth:`get_batch_size` in their dataset which returns the current value
of the batch size.
"""
return self.dataset.get_batch_size()

def set_batch_size(self, new_bs: int):
""" Sets the new value of the batch size.

In this case, the dataset controls the batch size, so we require our users to implement a
:class:`Callable[[int],None]` method named :meth:`change_batch_size` in their dataset which modifies the batch
size to the given value.

Args:
new_bs (int): The new batch sizes that needs to be set.
"""
self.dataset.change_batch_size(new_bs)
121 changes: 8 additions & 113 deletions bs_scheduler/batch_size_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,119 +8,14 @@
from typing import Callable, Union, Sequence, Tuple

import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import DataLoader

__all__ = ['LambdaBS', 'MultiplicativeBS', 'StepBS', 'MultiStepBS', 'ConstantBS', 'LinearBS', 'ExponentialBS',
'SequentialBS', 'PolynomialBS', 'CosineAnnealingBS', 'ChainedBSScheduler', 'IncreaseBSOnPlateau', 'CyclicBS',
'CosineAnnealingBSWithWarmRestarts', 'OneCycleBS', 'BSScheduler', 'BatchSizeManager']
'CosineAnnealingBSWithWarmRestarts', 'OneCycleBS', 'BSScheduler']


def rint(x: float) -> int:
""" Rounds to the nearest int and returns the value as int.
"""
return int(round(x))


def clip(x: int, min: int, max: int) -> int:
""" Clips x to [min, max] interval.
"""
if x < min:
return min
if x > max:
return max
return x


def check_isinstance(x, instance: type):
if not isinstance(x, instance):
raise TypeError(f"{type(x).__name__} is not a {instance.__name__}.")


class BatchSizeManager:
""" Base class for all batch size managers, used for getting and setting the batch size. It is not mandatory to
inherit from this, but users must implement :meth:`get_current_batch_size` and :meth:`set_batch_size`.
"""

def get_current_batch_size(self) -> int:
""" Returns the current batch size used by the dataloader as an :class:`int`.
"""
raise NotImplementedError

def set_batch_size(self, new_bs: int):
""" Sets the new value of the batch size.

Args:
new_bs (int): The new batch sizes that needs to be set.
"""
raise NotImplementedError


class DefaultBatchSizeManager(BatchSizeManager):
""" The default batch size manager used when the dataloader has a batch sampler. The batch sampler controls the
batch size used by the dataloader, and it can be queried and changed. Changes are reflected in the number of samples
given to the dataloader. See
https://github.com/pytorch/pytorch/blob/772e104dfdfd70c74cbc9600cfc946dc7c378f68/torch/utils/data/sampler.py#L241.
"""

def __init__(self, dataloader: DataLoader):
check_isinstance(dataloader, DataLoader)
if dataloader.batch_sampler is None:
raise ValueError("Dataloader must have a batch sampler.")
self.dataloader: DataLoader = dataloader

def get_current_batch_size(self) -> int:
""" Returns the current batch size used by the dataloader as an :class:`int`. The batch size member variable is
owned by the batch sampler.
"""
return self.dataloader.batch_sampler.batch_size

def set_batch_size(self, new_bs: int):
""" Sets the new value of the batch size, which is owned by the batch sampler.

Args:
new_bs (int): The new batch sizes that needs to be set.
"""
self.dataloader.batch_sampler.batch_size = new_bs


class CustomBatchSizeManager(BatchSizeManager):
""" Custom batch size manager, used when the dataloader does not use a batch sampler. In this case, the batch size
is controlled by the dataset wrapped by the dataloader, so this class expects the dataset to provide a getter and
a setter for the batch size, named :meth:`get_batch_size` and :meth:`change_batch_size` respectively.
"""

def __init__(self, dataset: Dataset):
check_isinstance(dataset, Dataset)
if not hasattr(dataset, 'change_batch_size'):
raise KeyError("Because the dataloader does not have a batch sampler, the dataset owns and controls the "
"batch size. In order to change the batch size after dataloader creation we require our "
"users to implement a Callable[[int],None] method named `change_batch_size` in their "
"dataset which changes the batch size. Please see TODO for examples.")
if not hasattr(dataset, 'get_batch_size'):
raise KeyError("We require our users to implement a Callable[[], int] method named `get_batch_size` in "
"their dataset which returns the current batch size. Please see TODO for examples. ")
self.dataset = dataset

def get_current_batch_size(self) -> int:
""" Returns the current batch size used by the dataset as an :class:`int`.

In this case, the dataset controls the batch size, so we require our users to implement a
:class:`Callable[[], int]` method named :meth:`get_batch_size` in their dataset which returns the current value
of the batch size.
"""
return self.dataset.get_batch_size()

def set_batch_size(self, new_bs: int):
""" Sets the new value of the batch size.

In this case, the dataset controls the batch size, so we require our users to implement a
:class:`Callable[[int],None]` method named :meth:`change_batch_size` in their dataset which modifies the batch
size to the given value.

Args:
new_bs (int): The new batch sizes that needs to be set.
"""
self.dataset.change_batch_size(new_bs)
from .batch_size_manager import BatchSizeManager, DefaultBatchSizeManager, CustomBatchSizeManager
from .utils import check_isinstance, clip, rint


class BSScheduler:
Expand Down Expand Up @@ -261,7 +156,7 @@ def step(self, **kwargs):
new_bs = self._internal_get_new_bs(**kwargs)
if not self.min_batch_size <= new_bs <= self.max_batch_size:
self._finished = True
new_bs = clip(new_bs, min=self.min_batch_size, max=self.max_batch_size)
new_bs = clip(new_bs, min_x=self.min_batch_size, max_x=self.max_batch_size)
if new_bs != self.batch_size:
self.set_batch_size(new_bs)
self.print_bs(new_bs)
Expand Down Expand Up @@ -943,7 +838,7 @@ def get_new_bs(self) -> int:
self._float_batch_size - self.max_batch_size) + self.max_batch_size

self._float_batch_size = new_bs
return clip(rint(new_bs), min=self.base_batch_size, max=self.max_batch_size)
return clip(rint(new_bs), min_x=self.base_batch_size, max_x=self.max_batch_size)


class ChainedBSScheduler(BSScheduler):
Expand Down Expand Up @@ -1436,7 +1331,7 @@ def get_new_bs(self) -> int:

new_bs = self.base_batch_size + (self.max_batch_size - self.base_batch_size) * (
1 + math.cos(math.pi + math.pi * self.t_cur / self.t_i)) / 2
return clip(rint(new_bs), min=self.base_batch_size, max=self.max_batch_size)
return clip(rint(new_bs), min_x=self.base_batch_size, max_x=self.max_batch_size)


class OneCycleBS(BSScheduler):
Expand Down Expand Up @@ -1540,4 +1435,4 @@ def get_new_bs(self) -> int:
if percentage == 1.0:
self._finished = True

return clip(rint(new_bs), min=self.min_batch_size, max=self.max_batch_size)
return clip(rint(new_bs), min_x=self.min_batch_size, max_x=self.max_batch_size)
19 changes: 19 additions & 0 deletions bs_scheduler/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
def check_isinstance(x, instance: type):
if not isinstance(x, instance):
raise TypeError(f"{type(x).__name__} is not a {instance.__name__}.")


def rint(x: float) -> int:
""" Rounds to the nearest int and returns the value as int.
"""
return int(round(x))


def clip(x: int, min_x: int, max_x: int) -> int:
""" Clips x to [min, max] interval.
"""
if x < min_x:
return min_x
if x > max_x:
return max_x
return x
Loading