diff --git a/CHANGELOG.md b/CHANGELOG.md index 697ee6d835b5e..6b7d04fcfd5d4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,6 +41,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed Horovod distributed backend compatibility with native AMP ([#3404](https://github.com/PyTorchLightning/pytorch-lightning/pull/3404)) +- Fixed batch size auto scaling exceeding the size of the dataset ([#3271](https://github.com/PyTorchLightning/pytorch-lightning/pull/3271)) + ## [0.9.0] - YYYY-MM-DD ### Added diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 3c1171f91140a..fc7b4d9540d28 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math -import os from abc import ABC, abstractmethod -from typing import Optional import torch from torch import Tensor @@ -23,11 +20,6 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks import GradientAccumulationScheduler from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.loggers.base import DummyLogger -from pytorch_lightning.utilities import AMPType, rank_zero_warn -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.memory import is_oom_error, garbage_collection_cuda -from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_getattr, lightning_setattr try: from apex import amp diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 28b9798859abf..06ccf77bb4043 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -13,13 +13,14 @@ # limitations under the License import os from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities.data import has_len from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_getattr, lightning_setattr from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import is_oom_error, garbage_collection_cuda from pytorch_lightning.loggers.base import DummyLogger from pytorch_lightning import _logger as log -from typing import Optional +from typing import Optional, Tuple def scale_batch_size(trainer, @@ -55,6 +56,13 @@ def scale_batch_size(trainer, algorithm is terminated batch_arg_name: name of the attribute that stores the batch size. + It is expected that the user has provided a model or datamodule that has a hyperparameter + with that name. We will look for this attribute name in the following places + + - `model` + - `model.hparams` + - `model.datamodule` + - `trainer.datamodule` (the datamodule passed to the tune method) **fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader or datamodule. @@ -165,16 +173,19 @@ def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **f # Try fit trainer.fit(model, **fit_kwargs) # Double in size - new_size = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded') + new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded') except RuntimeError as exception: # Only these errors should trigger an adjustment if is_oom_error(exception): # If we fail in power mode, half the size and return garbage_collection_cuda() - new_size = _adjust_batch_size(trainer, batch_arg_name, factor=0.5, desc='failed') + new_size, _ = _adjust_batch_size(trainer, batch_arg_name, factor=0.5, desc='failed') break else: raise # some other error not memory related + + if not changed: + break return new_size @@ -199,9 +210,13 @@ def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials, if high - low <= 1: break midval = (high + low) // 2 - new_size = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc='succeeded') + new_size, changed = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc='succeeded') else: - new_size = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded') + new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded') + + if not changed: + break + except RuntimeError as exception: # Only these errors should trigger an adjustment if is_oom_error(exception): @@ -209,11 +224,12 @@ def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials, garbage_collection_cuda() high = new_size midval = (high + low) // 2 - new_size = _adjust_batch_size(trainer, value=midval, desc='failed') + new_size, _ = _adjust_batch_size(trainer, value=midval, desc='failed') if high - low <= 1: break else: raise # some other error not memory related + return new_size @@ -221,17 +237,13 @@ def _adjust_batch_size(trainer, batch_arg_name: str = 'batch_size', factor: float = 1.0, value: Optional[int] = None, - desc: str = None): - """ Function for adjusting the batch size. It is expected that the user - has provided a model that has a hparam field called `batch_size` i.e. - `model.hparams.batch_size` should exist. Additionally there can be a - datamodule attached to either Trainer or model, in that case the attribute - also gets updated when present. + desc: str = None) -> Tuple[int, bool]: + """ Helper function for adjusting the batch size. Args: trainer: instance of pytorch_lightning.Trainer - batch_arg_name: field where batch_size is stored in `model.hparams` + batch_arg_name: name of the field where batch_size is stored. factor: value which the old batch size is multiplied by to get the new batch size @@ -241,11 +253,23 @@ def _adjust_batch_size(trainer, desc: either `succeeded` or `failed`. Used purely for logging + Returns: + The new batch size for the next trial and a bool that signals whether the + new value is different than the previous batch size. """ model = trainer.get_model() batch_size = lightning_getattr(model, batch_arg_name) new_size = value if value is not None else int(batch_size * factor) if desc: log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}') + + if not _is_valid_batch_size(new_size, trainer.train_dataloader): + new_size = min(new_size, len(trainer.train_dataloader.dataset)) + + changed = new_size != batch_size lightning_setattr(model, batch_arg_name, new_size) - return new_size + return new_size, changed + + +def _is_valid_batch_size(current_size, dataloader): + return not has_len(dataloader) or current_size <= len(dataloader) diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py index 50591acf34736..95b7fd2067d32 100755 --- a/tests/trainer/test_trainer_tricks.py +++ b/tests/trainer/test_trainer_tricks.py @@ -238,9 +238,10 @@ def dataloader(self, *args, **kwargs): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True) trainer.tune(model, datamodule_fit) - assert trainer.datamodule == datamodule_fit after_batch_size = model.hparams.batch_size if use_hparams else model.batch_size + assert trainer.datamodule == datamodule_fit assert before_batch_size != after_batch_size + assert after_batch_size <= len(trainer.train_dataloader.dataset) assert datamodule_fit.batch_size == after_batch_size # should be left unchanged, since it was not passed to .tune() assert datamodule_model.batch_size == 111