From 89792d26c5b45db49ce2e4ebd8a1714b18f74468 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 30 Aug 2020 09:33:45 +0200 Subject: [PATCH 01/15] fix datamodule hasattr --- pytorch_lightning/utilities/parsing.py | 14 ++++++++++---- tests/trainer/test_trainer_tricks.py | 4 ++++ 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index 6f57e63e48fc9..6da272761c161 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -186,6 +186,8 @@ def lightning_hasattr(model, attribute): attr = attribute in model.hparams else: attr = hasattr(model.hparams, attribute) + elif hasattr(model.datamodule, attribute): + attr = True else: attr = False @@ -204,9 +206,11 @@ def lightning_getattr(model, attribute): attr = model.hparams[attribute] else: attr = getattr(model.hparams, attribute) + elif hasattr(model.datamodule, attribute): + attr = getattr(model.datamodule, attribute) else: - raise ValueError(f'{attribute} is not stored in the model namespace' - ' or the `hparams` namespace/dict.') + raise ValueError(f'{attribute} is neither stored in the model namespace' + ' nor the `hparams` namespace/dict, nor the datamodule.') return attr @@ -222,6 +226,8 @@ def lightning_setattr(model, attribute, value): model.hparams[attribute] = value else: setattr(model.hparams, attribute, value) + if hasattr(model.datamodule, attribute): + setattr(model.datamodule, attribute, value) else: - raise ValueError(f'{attribute} is not stored in the model namespace' - ' or the `hparams` namespace/dict.') + raise ValueError(f'{attribute} is neither stored in the model namespace' + ' nor the `hparams` namespace/dict, nor the datamodule.') diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py index 66f5b720016cc..ca3ec88210b55 100755 --- a/tests/trainer/test_trainer_tricks.py +++ b/tests/trainer/test_trainer_tricks.py @@ -8,6 +8,7 @@ from pytorch_lightning.utilities import AMPType, NATIVE_AMP_AVALAIBLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate +from tests.base.datamodules import MNISTDataModule def test_num_training_batches(tmpdir): @@ -230,11 +231,14 @@ def dataloader(self, *args, **kwargs): model_class = HparamsEvalModelTemplate if use_hparams else EvalModelTemplate model = model_class(**hparams) + model.datamodule = MNISTDataModule(data_dir=tmpdir, batch_size=before_batch_size) + model.datamodule.setup() # TODO: why do I have to call this myself? trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True) trainer.fit(model) after_batch_size = model.hparams.batch_size if use_hparams else model.batch_size assert before_batch_size != after_batch_size + assert model.datamodule.batch_size == after_batch_size def test_auto_scale_batch_size_duplicate_attribute_warning(tmpdir): From 5efe8bf31cf4ebfcd16300c1624027c86d2cd35d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 30 Aug 2020 09:34:12 +0200 Subject: [PATCH 02/15] fix patch check --- pytorch_lightning/trainer/training_tricks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 32d0c59434c7a..3c1b5f753435a 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -169,7 +169,7 @@ def scale_batch_size(self, f' If this is not the intended behavior, please remove either one.' ) - if hasattr(model.train_dataloader, 'patch_loader_code'): + if hasattr(model.train_dataloader(), 'patch_loader_code'): raise MisconfigurationException('The batch scaling feature cannot be used with dataloaders' ' passed directly to `.fit()`. Please disable the feature or' ' incorporate the dataloader into the model.') From 1874ca1d661174acb5403bbd453bbbbaf62b0e7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 30 Aug 2020 14:40:32 +0200 Subject: [PATCH 03/15] fix setattr --- pytorch_lightning/utilities/parsing.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index 6da272761c161..699cf3a889234 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -216,18 +216,26 @@ def lightning_getattr(model, attribute): def lightning_setattr(model, attribute, value): """ Special setattr for lightning. Checks for attribute in model namespace - and the old hparams namespace/dict """ + and the old hparams namespace/dict. + Will also set the attribute on datamodule, if it exists. + """ + found = False # Check if attribute in model if hasattr(model, attribute): setattr(model, attribute, value) + found = True # Check if attribute in model.hparams, either namespace or dict elif hasattr(model, 'hparams'): if isinstance(model.hparams, dict): model.hparams[attribute] = value else: setattr(model.hparams, attribute, value) + found = True + # Check if attribute in datamodule if hasattr(model.datamodule, attribute): setattr(model.datamodule, attribute, value) - else: + found = True + + if not found: raise ValueError(f'{attribute} is neither stored in the model namespace' ' nor the `hparams` namespace/dict, nor the datamodule.') From 361e63b7f1284c24739d915935be60e0b74441f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 30 Aug 2020 14:40:41 +0200 Subject: [PATCH 04/15] update docs --- pytorch_lightning/utilities/parsing.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index 699cf3a889234..8b9cc0e0479b8 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -175,8 +175,8 @@ def __repr__(self): def lightning_hasattr(model, attribute): - """ Special hasattr for lightning. Checks for attribute in model namespace - and the old hparams namespace/dict """ + """ Special hasattr for lightning. Checks for attribute in model namespace, + the old hparams namespace/dict, and the datamodule. """ # Check if attribute in model if hasattr(model, attribute): attr = True @@ -195,8 +195,8 @@ def lightning_hasattr(model, attribute): def lightning_getattr(model, attribute): - """ Special getattr for lightning. Checks for attribute in model namespace - and the old hparams namespace/dict """ + """ Special getattr for lightning. Checks for attribute in model namespace, + the old hparams namespace/dict, and the datamodule. """ # Check if attribute in model if hasattr(model, attribute): attr = getattr(model, attribute) From 4fdaac157574d9fb7d087c344c0c55c00769f38f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 30 Aug 2020 14:40:51 +0200 Subject: [PATCH 05/15] revert patch fix --- pytorch_lightning/trainer/training_tricks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 3c1b5f753435a..32d0c59434c7a 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -169,7 +169,7 @@ def scale_batch_size(self, f' If this is not the intended behavior, please remove either one.' ) - if hasattr(model.train_dataloader(), 'patch_loader_code'): + if hasattr(model.train_dataloader, 'patch_loader_code'): raise MisconfigurationException('The batch scaling feature cannot be used with dataloaders' ' passed directly to `.fit()`. Please disable the feature or' ' incorporate the dataloader into the model.') From 5aa826fe9ea7d2432f84fdd59e073cbd78e2e58f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 30 Aug 2020 16:30:40 +0200 Subject: [PATCH 06/15] changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d18fac2668279..d3a862c624709 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `GpuUsageLogger` to work on different platforms ([#3008](https://github.com/PyTorchLightning/pytorch-lightning/pull/3008)) +- Fixed setting batch size in `LightningModule.datamodule` when using `auto_scale_batch_size` ([#3266](https://github.com/PyTorchLightning/pytorch-lightning/pull/3266)) + ## [0.9.0] - YYYY-MM-DD ### Added From 9e07eb27d907d7ebf8ae985f2c6e7d680bfe3a7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 30 Aug 2020 16:50:20 +0200 Subject: [PATCH 07/15] fix datamodule passed in as fit arg --- pytorch_lightning/trainer/training_tricks.py | 16 ++++++---------- tests/trainer/test_trainer_tricks.py | 7 +++---- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 32d0c59434c7a..0b8083ca4e699 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -277,16 +277,12 @@ def _adjust_batch_size(trainer, """ model = trainer.get_model() batch_size = lightning_getattr(model, batch_arg_name) - if value: - lightning_setattr(model, batch_arg_name, value) - new_size = value - if desc: - log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}') - else: - new_size = int(batch_size * factor) - if desc: - log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}') - lightning_setattr(model, batch_arg_name, new_size) + new_size = value if value is not None else int(batch_size * factor) + if desc: + log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}') + lightning_setattr(model, batch_arg_name, new_size) + if trainer.datamodule is not None and hasattr(trainer.datamodule, batch_arg_name): + setattr(trainer.datamodule, batch_arg_name, new_size) return new_size diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py index ca3ec88210b55..ff75563423ea2 100755 --- a/tests/trainer/test_trainer_tricks.py +++ b/tests/trainer/test_trainer_tricks.py @@ -231,14 +231,13 @@ def dataloader(self, *args, **kwargs): model_class = HparamsEvalModelTemplate if use_hparams else EvalModelTemplate model = model_class(**hparams) - model.datamodule = MNISTDataModule(data_dir=tmpdir, batch_size=before_batch_size) - model.datamodule.setup() # TODO: why do I have to call this myself? + datamodule = MNISTDataModule(data_dir=tmpdir, batch_size=before_batch_size) trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True) - trainer.fit(model) + trainer.fit(model, datamodule) after_batch_size = model.hparams.batch_size if use_hparams else model.batch_size assert before_batch_size != after_batch_size - assert model.datamodule.batch_size == after_batch_size + assert datamodule.batch_size == after_batch_size def test_auto_scale_batch_size_duplicate_attribute_warning(tmpdir): From 130b599e5d433b209e8d33524378ebf334329246 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 30 Aug 2020 17:01:37 +0200 Subject: [PATCH 08/15] docs --- pytorch_lightning/trainer/training_tricks.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 0b8083ca4e699..7c982241626dd 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -259,7 +259,9 @@ def _adjust_batch_size(trainer, desc: str = None): """ Function for adjusting the batch size. It is expected that the user has provided a model that has a hparam field called `batch_size` i.e. - `model.hparams.batch_size` should exist. + `model.hparams.batch_size` should exist. Additionally there can be a + datamodule attached to either Trainer or model, in which case the attribute + also gets updated when present. Args: trainer: instance of pytorch_lightning.Trainer From b9ea152fe0eb24e98f5b94303681125c7627bd29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 2 Sep 2020 18:33:24 +0200 Subject: [PATCH 09/15] set datamodule batch size in lightning_setattr --- pytorch_lightning/trainer/training_tricks.py | 2 -- pytorch_lightning/utilities/parsing.py | 6 ++++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 7c982241626dd..01d2de24c7331 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -283,8 +283,6 @@ def _adjust_batch_size(trainer, if desc: log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}') lightning_setattr(model, batch_arg_name, new_size) - if trainer.datamodule is not None and hasattr(trainer.datamodule, batch_arg_name): - setattr(trainer.datamodule, batch_arg_name, new_size) return new_size diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index 8b9cc0e0479b8..c5bb347d4b843 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -220,6 +220,8 @@ def lightning_setattr(model, attribute, value): Will also set the attribute on datamodule, if it exists. """ found = False + trainer = model.trainer + # Check if attribute in model if hasattr(model, attribute): setattr(model, attribute, value) @@ -236,6 +238,10 @@ def lightning_setattr(model, attribute, value): setattr(model.datamodule, attribute, value) found = True + if trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute): + setattr(trainer.datamodule, attribute, value) + found = True + if not found: raise ValueError(f'{attribute} is neither stored in the model namespace' ' nor the `hparams` namespace/dict, nor the datamodule.') From 4183026d651fb25c35c9a4b25e1b3d5fcb936645 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 2 Sep 2020 18:42:30 +0200 Subject: [PATCH 10/15] fix merge --- tests/trainer/test_trainer_tricks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py index 2cd07c914e1d7..008ed7598d2d0 100755 --- a/tests/trainer/test_trainer_tricks.py +++ b/tests/trainer/test_trainer_tricks.py @@ -234,7 +234,7 @@ def dataloader(self, *args, **kwargs): datamodule = MNISTDataModule(data_dir=tmpdir, batch_size=before_batch_size) trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True) - trainer.tune(model) + trainer.tune(model, datamodule) after_batch_size = model.hparams.batch_size if use_hparams else model.batch_size assert before_batch_size != after_batch_size assert datamodule.batch_size == after_batch_size From 224c3c62cc93803fa15e7fec65cf56bdd840e38b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 2 Sep 2020 19:04:08 +0200 Subject: [PATCH 11/15] check with has_attr --- pytorch_lightning/utilities/parsing.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index c5bb347d4b843..0a92428e44d76 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -177,6 +177,8 @@ def __repr__(self): def lightning_hasattr(model, attribute): """ Special hasattr for lightning. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. """ + trainer = model.trainer + # Check if attribute in model if hasattr(model, attribute): attr = True @@ -188,6 +190,8 @@ def lightning_hasattr(model, attribute): attr = hasattr(model.hparams, attribute) elif hasattr(model.datamodule, attribute): attr = True + elif trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute): + attr = getattr(trainer.datamodule, attribute) else: attr = False @@ -197,6 +201,8 @@ def lightning_hasattr(model, attribute): def lightning_getattr(model, attribute): """ Special getattr for lightning. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. """ + trainer = model.trainer + # Check if attribute in model if hasattr(model, attribute): attr = getattr(model, attribute) @@ -208,6 +214,8 @@ def lightning_getattr(model, attribute): attr = getattr(model.hparams, attribute) elif hasattr(model.datamodule, attribute): attr = getattr(model.datamodule, attribute) + elif trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute): + attr = getattr(trainer.datamodule, attribute) else: raise ValueError(f'{attribute} is neither stored in the model namespace' ' nor the `hparams` namespace/dict, nor the datamodule.') @@ -219,29 +227,26 @@ def lightning_setattr(model, attribute, value): and the old hparams namespace/dict. Will also set the attribute on datamodule, if it exists. """ - found = False + if not lightning_hasattr(model, attribute): + raise ValueError(f'{attribute} is neither stored in the model namespace' + ' nor the `hparams` namespace/dict, nor the datamodule.') + trainer = model.trainer # Check if attribute in model if hasattr(model, attribute): setattr(model, attribute, value) - found = True + # Check if attribute in model.hparams, either namespace or dict elif hasattr(model, 'hparams'): if isinstance(model.hparams, dict): model.hparams[attribute] = value else: setattr(model.hparams, attribute, value) - found = True + # Check if attribute in datamodule if hasattr(model.datamodule, attribute): setattr(model.datamodule, attribute, value) - found = True if trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute): setattr(trainer.datamodule, attribute, value) - found = True - - if not found: - raise ValueError(f'{attribute} is neither stored in the model namespace' - ' nor the `hparams` namespace/dict, nor the datamodule.') From 4066b786c28c85c2d3d9d7477c130a3f1f9a5b4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 3 Sep 2020 16:25:10 +0200 Subject: [PATCH 12/15] access datamodule via trainer --- pytorch_lightning/utilities/parsing.py | 12 ++++-------- tests/trainer/test_trainer_tricks.py | 12 +++++++++--- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index 0a92428e44d76..d8387064e08e7 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -188,8 +188,7 @@ def lightning_hasattr(model, attribute): attr = attribute in model.hparams else: attr = hasattr(model.hparams, attribute) - elif hasattr(model.datamodule, attribute): - attr = True + # Check if attribute in datamodule (datamodule gets registerd in Trainer) elif trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute): attr = getattr(trainer.datamodule, attribute) else: @@ -212,8 +211,8 @@ def lightning_getattr(model, attribute): attr = model.hparams[attribute] else: attr = getattr(model.hparams, attribute) - elif hasattr(model.datamodule, attribute): - attr = getattr(model.datamodule, attribute) + + # Check if attribute in datamodule (datamodule gets registerd in Trainer) elif trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute): attr = getattr(trainer.datamodule, attribute) else: @@ -244,9 +243,6 @@ def lightning_setattr(model, attribute, value): else: setattr(model.hparams, attribute, value) - # Check if attribute in datamodule - if hasattr(model.datamodule, attribute): - setattr(model.datamodule, attribute, value) - + # Check if attribute in datamodule (datamodule gets registerd in Trainer) if trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute): setattr(trainer.datamodule, attribute, value) diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py index 008ed7598d2d0..85121f5946c20 100755 --- a/tests/trainer/test_trainer_tricks.py +++ b/tests/trainer/test_trainer_tricks.py @@ -229,15 +229,21 @@ def dataloader(self, *args, **kwargs): del self.batch_size return dataloader + datamodule_model = MNISTDataModule(data_dir=tmpdir, batch_size=111) # this datamodule should get ignored! + datamodule_fit = MNISTDataModule(data_dir=tmpdir, batch_size=before_batch_size) + model_class = HparamsEvalModelTemplate if use_hparams else EvalModelTemplate model = model_class(**hparams) - datamodule = MNISTDataModule(data_dir=tmpdir, batch_size=before_batch_size) + model.datamodule = datamodule_model # unused when another module gets passed to .tune() / .fit() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True) - trainer.tune(model, datamodule) + trainer.tune(model, datamodule_fit) + assert trainer.datamodule == datamodule_fit after_batch_size = model.hparams.batch_size if use_hparams else model.batch_size assert before_batch_size != after_batch_size - assert datamodule.batch_size == after_batch_size + assert datamodule_fit.batch_size == after_batch_size + # should be left unchanged, since it was not passed to .tune() + assert datamodule_model.batch_size == 111 def test_auto_scale_batch_size_duplicate_attribute_warning(tmpdir): From 8a5d27b76eeb0b13d421e410fcd7cf20d82f9f59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 3 Sep 2020 16:37:12 +0200 Subject: [PATCH 13/15] pass fit args down to tuner --- pytorch_lightning/trainer/trainer.py | 8 +++++++- pytorch_lightning/trainer/training_tricks.py | 15 ++++++++------- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 50cd43967421e..20a9cec516dd3 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -953,7 +953,13 @@ def tune( if self.auto_scale_batch_size: if isinstance(self.auto_scale_batch_size, bool): self.auto_scale_batch_size = 'power' - self.scale_batch_size(model, mode=self.auto_scale_batch_size) + self.scale_batch_size( + model, + mode=self.auto_scale_batch_size, + train_dataloader=train_dataloader, + val_dataloaders=val_dataloaders, + datamodule=datamodule, + ) model.logger = self.logger # reset logger binding # Run learning rate finder: diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 01d2de24c7331..0fcceff819b4e 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -134,7 +134,8 @@ def scale_batch_size(self, steps_per_trial: int = 3, init_val: int = 2, max_trials: int = 25, - batch_arg_name: str = 'batch_size'): + batch_arg_name: str = 'batch_size', + **fit_kwargs): r""" Will iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM) error. @@ -190,9 +191,9 @@ def scale_batch_size(self, # Initially we just double in size until an OOM is encountered new_size = _adjust_batch_size(self, value=init_val) # initially set to init_val if mode == 'power': - new_size = _run_power_scaling(self, model, new_size, batch_arg_name, max_trials) + new_size = _run_power_scaling(self, model, new_size, batch_arg_name, max_trials, **fit_kwargs) elif mode == 'binsearch': - new_size = _run_binsearch_scaling(self, model, new_size, batch_arg_name, max_trials) + new_size = _run_binsearch_scaling(self, model, new_size, batch_arg_name, max_trials, **fit_kwargs) else: raise ValueError('mode in method `scale_batch_size` can only be `power` or `binsearch') @@ -286,7 +287,7 @@ def _adjust_batch_size(trainer, return new_size -def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials): +def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs): """ Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered. """ for _ in range(max_trials): @@ -294,7 +295,7 @@ def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials): trainer.global_step = 0 # reset after each try try: # Try fit - trainer.fit(model) + trainer.fit(model, **fit_kwargs) # Double in size new_size = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded') except RuntimeError as exception: @@ -309,7 +310,7 @@ def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials): return new_size -def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials): +def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs): """ Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is encountered. Hereafter, the batch size is further refined using a binary search """ @@ -320,7 +321,7 @@ def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials) trainer.global_step = 0 # reset after each try try: # Try fit - trainer.fit(model) + trainer.fit(model, **fit_kwargs) count += 1 if count > max_trials: break From ed6f751fcb4203971fb4184dffcaebdd2d5fec64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 3 Sep 2020 16:39:41 +0200 Subject: [PATCH 14/15] docs --- pytorch_lightning/trainer/training_tricks.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 0fcceff819b4e..bee9cc7668985 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -159,6 +159,10 @@ def scale_batch_size(self, max_trials: max number of increase in batch size done before algorithm is terminated + batch_arg_name: name of the attribute that stores the batch size. + + **fit_kwargs: remaining arguments to be passed to .fit() when, e.g., dataloader + or datamodule. """ if not lightning_hasattr(model, batch_arg_name): raise MisconfigurationException( From e4df06b7c096255079d8802c5f3c52ee626a9246 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 3 Sep 2020 21:28:40 +0200 Subject: [PATCH 15/15] fix typos in docs Co-authored-by: Rohit Gupta --- pytorch_lightning/trainer/training_tricks.py | 4 ++-- pytorch_lightning/utilities/parsing.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index bee9cc7668985..705abc6343d49 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -161,7 +161,7 @@ def scale_batch_size(self, batch_arg_name: name of the attribute that stores the batch size. - **fit_kwargs: remaining arguments to be passed to .fit() when, e.g., dataloader + **fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader or datamodule. """ if not lightning_hasattr(model, batch_arg_name): @@ -265,7 +265,7 @@ def _adjust_batch_size(trainer, """ Function for adjusting the batch size. It is expected that the user has provided a model that has a hparam field called `batch_size` i.e. `model.hparams.batch_size` should exist. Additionally there can be a - datamodule attached to either Trainer or model, in which case the attribute + datamodule attached to either Trainer or model, in that case the attribute also gets updated when present. Args: diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index d8387064e08e7..dab1127579b87 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -188,7 +188,7 @@ def lightning_hasattr(model, attribute): attr = attribute in model.hparams else: attr = hasattr(model.hparams, attribute) - # Check if attribute in datamodule (datamodule gets registerd in Trainer) + # Check if the attribute in datamodule (datamodule gets registered in Trainer) elif trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute): attr = getattr(trainer.datamodule, attribute) else: @@ -212,7 +212,7 @@ def lightning_getattr(model, attribute): else: attr = getattr(model.hparams, attribute) - # Check if attribute in datamodule (datamodule gets registerd in Trainer) + # Check if the attribute in datamodule (datamodule gets registered in Trainer) elif trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute): attr = getattr(trainer.datamodule, attribute) else: @@ -243,6 +243,6 @@ def lightning_setattr(model, attribute, value): else: setattr(model.hparams, attribute, value) - # Check if attribute in datamodule (datamodule gets registerd in Trainer) + # Check if the attribute in datamodule (datamodule gets registered in Trainer) if trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute): setattr(trainer.datamodule, attribute, value)