Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix tuner.scale_batch_size not finding batch size attribute when using datamodule #5968

Merged
merged 17 commits into from
Mar 14, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed DP reduction with collection ([#6324](https://github.com/PyTorchLightning/pytorch-lightning/pull/6324))


- Fixed an issue with `Tuner.scale_batch_size` not finding the batch size attribute in the datamodule ([#5968](https://github.com/PyTorchLightning/pytorch-lightning/pull/5968))


## [1.2.1] - 2021-02-23

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def on_train_start(self):
# provide rank to profiler
self.trainer.profile_connector.on_train_start(self.trainer)

def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule):
def setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None):
# clean hparams
if hasattr(model, "hparams"):
parsing.clean_namespace(model.hparams)
Expand Down
13 changes: 11 additions & 2 deletions pytorch_lightning/tuner/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,20 @@ def on_trainer_init(self, auto_lr_find, auto_scale_batch_size):
self.trainer.auto_lr_find = auto_lr_find
self.trainer.auto_scale_batch_size = auto_scale_batch_size

def tune(self, model, train_dataloader, val_dataloaders, datamodule):
def setup_trainer(
self,
model: LightningModule,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: LightningDataModule = None,
):
self.trainer.model_connector.copy_trainer_model_properties(model)
# setup data, etc...
self.trainer.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule)

# hook
self.trainer.data_connector.prepare_data(model)

def tune(self, model, train_dataloader, val_dataloaders, datamodule):
# Run auto batch size scaling
if self.trainer.auto_scale_batch_size:
if isinstance(self.trainer.auto_scale_batch_size, bool):
Expand Down Expand Up @@ -101,6 +108,7 @@ def scale_batch_size(
or datamodule.

"""
self.setup_trainer(model, **fit_kwargs)
return scale_batch_size(
self.trainer,
model,
Expand All @@ -125,6 +133,7 @@ def lr_find(
datamodule: Optional[LightningDataModule] = None,
update_attr: bool = False,
):
self.setup_trainer(model, train_dataloader, val_dataloaders, datamodule)
return lr_find(
self.trainer,
model,
Expand Down
52 changes: 52 additions & 0 deletions tests/tuner/test_scale_batch_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pytest
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
from torch.utils.data import DataLoader

from pytorch_lightning import Trainer
from pytorch_lightning.tuner.tuning import Tuner
from tests.helpers import BoringDataModule, BoringModel


class BatchSizeDataModule(BoringDataModule):

def __init__(self, batch_size=None):
super().__init__()
if batch_size is not None:
self.batch_size = batch_size

def train_dataloader(self):
return DataLoader(self.random_train, batch_size=getattr(self, "batch_size", 1))


class BatchSizeModel(BoringModel):

def __init__(self, batch_size=None):
super().__init__()
if batch_size is not None:
self.batch_size = batch_size


@pytest.mark.parametrize(
"model,datamodule", [
(BatchSizeModel(2), None),
(BatchSizeModel(2), BatchSizeDataModule(2)),
(BatchSizeModel(2), BatchSizeDataModule(None)),
(BatchSizeModel(None), BatchSizeDataModule(2)),
]
)
def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamodule):
""" Test the tuner method `Tuner.scale_batch_size` with a datamodule. """
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=1,
limit_val_batches=0,
max_epochs=1,
)
tuner = Tuner(trainer)
new_batch_size = tuner.scale_batch_size(
model=model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule
)
assert new_batch_size == 16
if hasattr(model, "batch_size"):
assert model.batch_size == 16
if datamodule is not None and hasattr(datamodule, "batch_size"):
assert datamodule.batch_size == 16