Skip to content

Commit

Permalink
Changing initialization of max_batch_size
Browse files Browse the repository at this point in the history
Previous behavior: if None or greater than the dataset length, it is set to the dataset length
Current behavior: If None, it is set to the dataset length + 1, if the dataset has a length. Otherwise, it is set to 1
  • Loading branch information
ancestor-mithril committed Dec 6, 2024
1 parent 58028b4 commit c8aec06
Showing 1 changed file with 30 additions and 28 deletions.
58 changes: 30 additions & 28 deletions bs_scheduler/batch_size_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ def __init__(self, dataloader: DataLoader, batch_size_manager: Optional[BatchSiz
assert max_batch_size is None or isinstance(max_batch_size, int)
assert isinstance(min_batch_size, int)
if max_batch_size is None:
max_batch_size = len(self.dataloader.dataset)
if dataloader is not None and hasattr(dataloader, 'dataset') and hasattr(dataloader.dataset, '__len__'):
max_batch_size = len(dataloader.dataset) + 1
else:
max_batch_size = 1
else:
if max_batch_size < 0:
raise ValueError(f"Maximum batch size must be greater than 0, but is {max_batch_size}.")
max_batch_size = min(len(self.dataloader.dataset), max_batch_size)
self.max_batch_size: int = max_batch_size

if min_batch_size < 0:
Expand Down Expand Up @@ -173,8 +175,8 @@ class LambdaBS(BSScheduler):
batch_size_manager (Optional[BatchSizeManager]): If not None, a custom class which manages the batch size,
which provides a getter and setter for the batch size. Default: None.
max_batch_size (Optional[int]): Upper limit for the batch size so that a batch of size max_batch_size fits
in the memory. If None or greater than the lenght of the dataset wrapped by the dataloader, max_batch_size
is set to `len(self.dataloader.dataset)`. Default: None.
in the memory. If None, max_batch_size is set to `len(self.dataloader.dataset) if available else 0 + 1`.
Default: None.
min_batch_size (int): Lower limit for the batch size which must be greater than 0. Default: 1.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False``.
Expand Down Expand Up @@ -237,8 +239,8 @@ class MultiplicativeBS(BSScheduler):
batch_size_manager (Optional[BatchSizeManager]): If not None, a custom class which manages the batch size,
which provides a getter and setter for the batch size. Default: None.
max_batch_size (Optional[int]): Upper limit for the batch size so that a batch of size max_batch_size fits
in the memory. If None or greater than the lenght of the dataset wrapped by the dataloader, max_batch_size
is set to `len(self.dataloader.dataset)`. Default: None.
in the memory. If None, max_batch_size is set to `len(self.dataloader.dataset) if available else 0 + 1`.
Default: None.
min_batch_size (int): Lower limit for the batch size which must be greater than 0. Default: 1.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False``.
Expand Down Expand Up @@ -299,8 +301,8 @@ class StepBS(BSScheduler):
batch_size_manager (Optional[BatchSizeManager]): If not None, a custom class which manages the batch size,
which provides a getter and setter for the batch size. Default: None.
max_batch_size (Optional[int]): Upper limit for the batch size so that a batch of size max_batch_size fits
in the memory. If None or greater than the lenght of the dataset wrapped by the dataloader, max_batch_size
is set to `len(self.dataloader.dataset)`. Default: None.
in the memory. If None, max_batch_size is set to `len(self.dataloader.dataset) if available else 0 + 1`.
Default: None.
min_batch_size (int): Lower limit for the batch size which must be greater than 0. Default: 1.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False``.
Expand Down Expand Up @@ -346,8 +348,8 @@ class MultiStepBS(BSScheduler):
batch_size_manager (Optional[BatchSizeManager]): If not None, a custom class which manages the batch size,
which provides a getter and setter for the batch size. Default: None.
max_batch_size (Optional[int]): Upper limit for the batch size so that a batch of size max_batch_size fits
in the memory. If None or greater than the lenght of the dataset wrapped by the dataloader, max_batch_size
is set to `len(self.dataloader.dataset)`. Default: None.
in the memory. If None, max_batch_size is set to `len(self.dataloader.dataset) if available else 0 + 1`.
Default: None.
min_batch_size (int): Lower limit for the batch size which must be greater than 0. Default: 1.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False``.
Expand Down Expand Up @@ -410,8 +412,8 @@ class ConstantBS(BSScheduler):
batch_size_manager (Optional[BatchSizeManager]): If not None, a custom class which manages the batch size,
which provides a getter and setter for the batch size. Default: None.
max_batch_size (Optional[int]): Upper limit for the batch size so that a batch of size max_batch_size fits
in the memory. If None or greater than the lenght of the dataset wrapped by the dataloader, max_batch_size
is set to `len(self.dataloader.dataset)`. Default: None.
in the memory. If None, max_batch_size is set to `len(self.dataloader.dataset) if available else 0 + 1`.
Default: None.
min_batch_size (int): Lower limit for the batch size which must be greater than 0. Default: 1.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False``.
Expand Down Expand Up @@ -475,8 +477,8 @@ class LinearBS(BSScheduler):
batch_size_manager (Optional[BatchSizeManager]): If not None, a custom class which manages the batch size,
which provides a getter and setter for the batch size. Default: None.
max_batch_size (Optional[int]): Upper limit for the batch size so that a batch of size max_batch_size fits
in the memory. If None or greater than the lenght of the dataset wrapped by the dataloader, max_batch_size
is set to `len(self.dataloader.dataset)`. Default: None.
in the memory. If None, max_batch_size is set to `len(self.dataloader.dataset) if available else 0 + 1`.
Default: None.
min_batch_size (int): Lower limit for the batch size which must be greater than 0. Default: 1.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False``.
Expand Down Expand Up @@ -534,8 +536,8 @@ class ExponentialBS(BSScheduler):
batch_size_manager (Optional[BatchSizeManager]): If not None, a custom class which manages the batch size,
which provides a getter and setter for the batch size. Default: None.
max_batch_size (Optional[int]): Upper limit for the batch size so that a batch of size max_batch_size fits
in the memory. If None or greater than the lenght of the dataset wrapped by the dataloader, max_batch_size
is set to `len(self.dataloader.dataset)`. Default: None.
in the memory. If None, max_batch_size is set to `len(self.dataloader.dataset) if available else 0 + 1`.
Default: None.
min_batch_size (int): Lower limit for the batch size which must be greater than 0. Default: 1.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False``.
Expand Down Expand Up @@ -720,8 +722,8 @@ class PolynomialBS(BSScheduler):
batch_size_manager (Optional[BatchSizeManager]): If not None, a custom class which manages the batch size,
which provides a getter and setter for the batch size. Default: None.
max_batch_size (Optional[int]): Upper limit for the batch size so that a batch of size max_batch_size fits
in the memory. If None or greater than the lenght of the dataset wrapped by the dataloader, max_batch_size
is set to `len(self.dataloader.dataset)`. Default: None.
in the memory. If None, max_batch_size is set to `len(self.dataloader.dataset) if available else 0 + 1`.
Default: None.
min_batch_size (int): Lower limit for the batch size which must be greater than 0. Default: 1.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False``.
Expand Down Expand Up @@ -777,8 +779,8 @@ class CosineAnnealingBS(BSScheduler):
batch_size_manager (Optional[BatchSizeManager]): If not None, a custom class which manages the batch size,
which provides a getter and setter for the batch size. Default: None.
max_batch_size (Optional[int]): Upper limit for the batch size so that a batch of size max_batch_size fits
in the memory. If None or greater than the lenght of the dataset wrapped by the dataloader, max_batch_size
is set to `len(self.dataloader.dataset)`. Default: None.
in the memory. If None, max_batch_size is set to `len(self.dataloader.dataset) if available else 0 + 1`.
Default: None.
min_batch_size (int): Lower limit for the batch size which must be greater than 0. Default: 1.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False``.
Expand Down Expand Up @@ -971,8 +973,8 @@ class IncreaseBSOnPlateau(BSScheduler):
batch_size_manager (Optional[BatchSizeManager]): If not None, a custom class which manages the batch size,
which provides a getter and setter for the batch size. Default: None.
max_batch_size (Optional[int]): Upper limit for the batch size so that a batch of size max_batch_size fits
in the memory. If None or greater than the lenght of the dataset wrapped by the dataloader, max_batch_size
is set to `len(self.dataloader.dataset)`. Default: None.
in the memory. If None, max_batch_size is set to `len(self.dataloader.dataset) if available else 0 + 1`.
Default: None.
min_batch_size (int): Lower limit for the batch size which must be greater than 0. Default: 1.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False``.
Expand Down Expand Up @@ -1138,8 +1140,8 @@ class CyclicBS(BSScheduler):
max_batch_size (Optional[int]): Upper batch size boundary in the cycle. Functionally, it defines the cycle
amplitude (upper_batch_size_bound - base_batch_size). The batch size at any cycle is the sum of
base_batch_size and some scaling of the amplitude; therefore, upper_batch_size_bound may not actually be
reached depending on scaling function. If None or greater than the lenght of the dataset wrapped by the
dataloader, max_batch_size is set to `len(self.dataloader.dataset)`. Default: None.
reached depending on scaling function. If None, max_batch_size is set to
`len(self.dataloader.dataset) if available else 0 + 1`. Default: None.
min_batch_size (int): Lower limit for the batch size which must be greater than 0. Default: 1.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False``.
Expand Down Expand Up @@ -1274,8 +1276,8 @@ class CosineAnnealingBSWithWarmRestarts(BSScheduler):
batch_size_manager (Optional[BatchSizeManager]): If not None, a custom class which manages the batch size,
which provides a getter and setter for the batch size. Default: None.
max_batch_size (Optional[int]): Upper limit for the batch size so that a batch of size max_batch_size fits
in the memory. If None or greater than the lenght of the dataset wrapped by the dataloader, max_batch_size
is set to `len(self.dataloader.dataset)`. Default: None.
in the memory. If None, max_batch_size is set to `len(self.dataloader.dataset) if available else 0 + 1`.
Default: None.
min_batch_size (int): Lower limit for the batch size which must be greater than 0. Default: 1.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False``.
Expand Down Expand Up @@ -1359,8 +1361,8 @@ class OneCycleBS(BSScheduler):
batch_size_manager (Optional[BatchSizeManager]): If not None, a custom class which manages the batch size,
which provides a getter and setter for the batch size. Default: None.
max_batch_size (Optional[int]): Upper limit for the batch size so that a batch of size max_batch_size fits
in the memory. If None or greater than the lenght of the dataset wrapped by the dataloader, max_batch_size
is set to `len(self.dataloader.dataset)`. Default: None.
in the memory. If None, max_batch_size is set to `len(self.dataloader.dataset) if available else 0 + 1`.
Default: None.
min_batch_size (int): Lower limit for the batch size which must be greater than 0. Default: 1.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False``.
Expand Down

0 comments on commit c8aec06

Please sign in to comment.