Skip to content

Commit

Permalink
update batch size in DataModule when auto scaling batch size (#3266)
Browse files Browse the repository at this point in the history
* fix datamodule hasattr

* fix patch check

* fix setattr

* update docs

* revert patch fix

* changelog

* fix datamodule passed in as fit arg

* docs

* set datamodule batch size in lightning_setattr

* fix merge

* check with has_attr

* access datamodule via trainer

* pass fit args down to tuner

* docs

* fix typos in docs

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
  • Loading branch information
awaelchli and rohitgr7 authored Sep 3, 2020
1 parent 4ad5a78 commit 48c22c8
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 30 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed `GpuUsageLogger` to work on different platforms ([#3008](https://github.com/PyTorchLightning/pytorch-lightning/pull/3008))

- Fixed setting batch size in `LightningModule.datamodule` when using `auto_scale_batch_size` ([#3266](https://github.com/PyTorchLightning/pytorch-lightning/pull/3266))

## [0.9.0] - YYYY-MM-DD

### Added
Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,13 @@ def tune(
if self.auto_scale_batch_size:
if isinstance(self.auto_scale_batch_size, bool):
self.auto_scale_batch_size = 'power'
self.scale_batch_size(model, mode=self.auto_scale_batch_size)
self.scale_batch_size(
model,
mode=self.auto_scale_batch_size,
train_dataloader=train_dataloader,
val_dataloaders=val_dataloaders,
datamodule=datamodule,
)
model.logger = self.logger # reset logger binding

# Run learning rate finder:
Expand Down
37 changes: 19 additions & 18 deletions pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def scale_batch_size(self,
steps_per_trial: int = 3,
init_val: int = 2,
max_trials: int = 25,
batch_arg_name: str = 'batch_size'):
batch_arg_name: str = 'batch_size',
**fit_kwargs):
r"""
Will iteratively try to find the largest batch size for a given model
that does not give an out of memory (OOM) error.
Expand All @@ -158,6 +159,10 @@ def scale_batch_size(self,
max_trials: max number of increase in batch size done before
algorithm is terminated
batch_arg_name: name of the attribute that stores the batch size.
**fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader
or datamodule.
"""
if not lightning_hasattr(model, batch_arg_name):
raise MisconfigurationException(
Expand Down Expand Up @@ -190,9 +195,9 @@ def scale_batch_size(self,
# Initially we just double in size until an OOM is encountered
new_size = _adjust_batch_size(self, value=init_val) # initially set to init_val
if mode == 'power':
new_size = _run_power_scaling(self, model, new_size, batch_arg_name, max_trials)
new_size = _run_power_scaling(self, model, new_size, batch_arg_name, max_trials, **fit_kwargs)
elif mode == 'binsearch':
new_size = _run_binsearch_scaling(self, model, new_size, batch_arg_name, max_trials)
new_size = _run_binsearch_scaling(self, model, new_size, batch_arg_name, max_trials, **fit_kwargs)
else:
raise ValueError('mode in method `scale_batch_size` can only be `power` or `binsearch')

Expand Down Expand Up @@ -259,7 +264,9 @@ def _adjust_batch_size(trainer,
desc: str = None):
""" Function for adjusting the batch size. It is expected that the user
has provided a model that has a hparam field called `batch_size` i.e.
`model.hparams.batch_size` should exist.
`model.hparams.batch_size` should exist. Additionally there can be a
datamodule attached to either Trainer or model, in that case the attribute
also gets updated when present.
Args:
trainer: instance of pytorch_lightning.Trainer
Expand All @@ -277,28 +284,22 @@ def _adjust_batch_size(trainer,
"""
model = trainer.get_model()
batch_size = lightning_getattr(model, batch_arg_name)
if value:
lightning_setattr(model, batch_arg_name, value)
new_size = value
if desc:
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')
else:
new_size = int(batch_size * factor)
if desc:
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')
lightning_setattr(model, batch_arg_name, new_size)
new_size = value if value is not None else int(batch_size * factor)
if desc:
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')
lightning_setattr(model, batch_arg_name, new_size)
return new_size


def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials):
def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs):
""" Batch scaling mode where the size is doubled at each iteration until an
OOM error is encountered. """
for _ in range(max_trials):
garbage_collection_cuda()
trainer.global_step = 0 # reset after each try
try:
# Try fit
trainer.fit(model)
trainer.fit(model, **fit_kwargs)
# Double in size
new_size = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded')
except RuntimeError as exception:
Expand All @@ -313,7 +314,7 @@ def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials):
return new_size


def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials):
def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs):
""" Batch scaling mode where the size is initially is doubled at each iteration
until an OOM error is encountered. Hereafter, the batch size is further
refined using a binary search """
Expand All @@ -324,7 +325,7 @@ def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials)
trainer.global_step = 0 # reset after each try
try:
# Try fit
trainer.fit(model)
trainer.fit(model, **fit_kwargs)
count += 1
if count > max_trials:
break
Expand Down
41 changes: 31 additions & 10 deletions pytorch_lightning/utilities/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,10 @@ def __repr__(self):


def lightning_hasattr(model, attribute):
""" Special hasattr for lightning. Checks for attribute in model namespace
and the old hparams namespace/dict """
""" Special hasattr for lightning. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule. """
trainer = model.trainer

# Check if attribute in model
if hasattr(model, attribute):
attr = True
Expand All @@ -186,15 +188,20 @@ def lightning_hasattr(model, attribute):
attr = attribute in model.hparams
else:
attr = hasattr(model.hparams, attribute)
# Check if the attribute in datamodule (datamodule gets registered in Trainer)
elif trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
attr = getattr(trainer.datamodule, attribute)
else:
attr = False

return attr


def lightning_getattr(model, attribute):
""" Special getattr for lightning. Checks for attribute in model namespace
and the old hparams namespace/dict """
""" Special getattr for lightning. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule. """
trainer = model.trainer

# Check if attribute in model
if hasattr(model, attribute):
attr = getattr(model, attribute)
Expand All @@ -204,24 +211,38 @@ def lightning_getattr(model, attribute):
attr = model.hparams[attribute]
else:
attr = getattr(model.hparams, attribute)

# Check if the attribute in datamodule (datamodule gets registered in Trainer)
elif trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
attr = getattr(trainer.datamodule, attribute)
else:
raise ValueError(f'{attribute} is not stored in the model namespace'
' or the `hparams` namespace/dict.')
raise ValueError(f'{attribute} is neither stored in the model namespace'
' nor the `hparams` namespace/dict, nor the datamodule.')
return attr


def lightning_setattr(model, attribute, value):
""" Special setattr for lightning. Checks for attribute in model namespace
and the old hparams namespace/dict """
and the old hparams namespace/dict.
Will also set the attribute on datamodule, if it exists.
"""
if not lightning_hasattr(model, attribute):
raise ValueError(f'{attribute} is neither stored in the model namespace'
' nor the `hparams` namespace/dict, nor the datamodule.')

trainer = model.trainer

# Check if attribute in model
if hasattr(model, attribute):
setattr(model, attribute, value)

# Check if attribute in model.hparams, either namespace or dict
elif hasattr(model, 'hparams'):
if isinstance(model.hparams, dict):
model.hparams[attribute] = value
else:
setattr(model.hparams, attribute, value)
else:
raise ValueError(f'{attribute} is not stored in the model namespace'
' or the `hparams` namespace/dict.')

# Check if the attribute in datamodule (datamodule gets registered in Trainer)
if trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
setattr(trainer.datamodule, attribute, value)
11 changes: 10 additions & 1 deletion tests/trainer/test_trainer_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pytorch_lightning.utilities import AMPType, NATIVE_AMP_AVALAIBLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.base.datamodules import MNISTDataModule


def test_num_training_batches(tmpdir):
Expand Down Expand Up @@ -228,13 +229,21 @@ def dataloader(self, *args, **kwargs):
del self.batch_size
return dataloader

datamodule_model = MNISTDataModule(data_dir=tmpdir, batch_size=111) # this datamodule should get ignored!
datamodule_fit = MNISTDataModule(data_dir=tmpdir, batch_size=before_batch_size)

model_class = HparamsEvalModelTemplate if use_hparams else EvalModelTemplate
model = model_class(**hparams)
model.datamodule = datamodule_model # unused when another module gets passed to .tune() / .fit()

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True)
trainer.tune(model)
trainer.tune(model, datamodule_fit)
assert trainer.datamodule == datamodule_fit
after_batch_size = model.hparams.batch_size if use_hparams else model.batch_size
assert before_batch_size != after_batch_size
assert datamodule_fit.batch_size == after_batch_size
# should be left unchanged, since it was not passed to .tune()
assert datamodule_model.batch_size == 111


def test_auto_scale_batch_size_duplicate_attribute_warning(tmpdir):
Expand Down

0 comments on commit 48c22c8

Please sign in to comment.