From 393ee5a353728fde5bc91d379159eee93159b7d6 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Thu, 5 Dec 2024 22:23:47 +0200 Subject: [PATCH 1/2] Breaking change for IncreaseBSOnPlateau Changed it to receive 'metrics' kwargs instead of 'metric', to be identical to ReduceLROnPlateau --- bs_scheduler/batch_size_schedulers.py | 7 +++---- tests/test_IncreaseBSOnPlateau.py | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/bs_scheduler/batch_size_schedulers.py b/bs_scheduler/batch_size_schedulers.py index 0e55923..0306e08 100644 --- a/bs_scheduler/batch_size_schedulers.py +++ b/bs_scheduler/batch_size_schedulers.py @@ -952,8 +952,7 @@ class IncreaseBSOnPlateau(BSScheduler): Increases the batch size when a metric has stopped improving. Models often benefit from increasing the batch size by a factor once the learning stagnates. This scheduler receives a metric value and if no improvement is seen for a given number of epochs, the batch size is increased. - Unfortunately, this class is not compatible with the other batch size schedulers as its step() function needs to - receive the metric value. + The step() function needs to receive the metric value using the `metrics` keyword argument. Args: dataloader (DataLoader): Wrapped dataloader. @@ -1068,9 +1067,9 @@ def get_new_bs(self, **kwargs) -> int: if self.last_epoch == 0: # Don't do anything at initialization. return self.batch_size - metric = kwargs.pop('metric', None) + metric = kwargs.pop('metrics', None) if metric is None: - raise TypeError("IncreaseBSOnPlateau requires passing a 'metric' keyword argument in the step() function.") + raise TypeError("IncreaseBSOnPlateau requires passing a 'metrics' keyword argument in the step() function.") current = float(metric) if self.is_better(current, self.best, self.threshold): diff --git a/tests/test_IncreaseBSOnPlateau.py b/tests/test_IncreaseBSOnPlateau.py index 2ddcee5..fe475cb 100644 --- a/tests/test_IncreaseBSOnPlateau.py +++ b/tests/test_IncreaseBSOnPlateau.py @@ -16,7 +16,7 @@ def test_constant_metric(self): max_batch_size = 100 n_epochs = 100 - metrics = [{"metric": 0.1}] * n_epochs + metrics = [{"metrics": 0.1}] * n_epochs dataloader = create_dataloader(self.dataset, batch_size=base_batch_size) scheduler1 = IncreaseBSOnPlateau(dataloader, mode='min', threshold_mode='rel', max_batch_size=max_batch_size) @@ -69,7 +69,7 @@ def test_graphic(self): max_batch_size = 100 n_epochs = 100 - metrics = [{"metric": 0.1}] * n_epochs + metrics = [{"metrics": 0.1}] * n_epochs dataloader = create_dataloader(self.dataset, batch_size=base_batch_size) scheduler = IncreaseBSOnPlateau(dataloader, mode='min', threshold_mode='rel', max_batch_size=max_batch_size) From 0745c5ab750f2f326cbc942c30d56b25b0892902 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Thu, 5 Dec 2024 22:24:57 +0200 Subject: [PATCH 2/2] Updated tests --- tests/test_IncreaseBSOnPlateau.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_IncreaseBSOnPlateau.py b/tests/test_IncreaseBSOnPlateau.py index fe475cb..54c8dc4 100644 --- a/tests/test_IncreaseBSOnPlateau.py +++ b/tests/test_IncreaseBSOnPlateau.py @@ -55,7 +55,7 @@ def test_loading_and_unloading(self): self.reloading_scheduler(scheduler) self.torch_save_and_load(scheduler) - scheduler.step(metric=10) + scheduler.step(metrics=10) self.assertEqual(scheduler.mode, mode) self.assertEqual(scheduler.threshold_mode, threshold_mode)