Skip to content

Commit

Permalink
Merge 7185815 into 34b733b
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Mar 6, 2021
2 parents 34b733b + 7185815 commit 840f74b
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 3 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed error thrown when using valid distributed mode in multi node ([#6297](https://github.com/PyTorchLightning/pytorch-lightning/pull/6297)


- Fixed an issue with `Tuner.scale_batch_size` not finding the batch size attribute in the datamodule ([#5968](https://github.com/PyTorchLightning/pytorch-lightning/pull/5968))


## [1.2.1] - 2021-02-23

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def on_train_start(self):
# provide rank to profiler
self.trainer.profile_connector.on_train_start(self.trainer)

def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule):
def setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None):
# clean hparams
if hasattr(model, "hparams"):
parsing.clean_namespace(model.hparams)
Expand Down
13 changes: 11 additions & 2 deletions pytorch_lightning/tuner/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,20 @@ def on_trainer_init(self, auto_lr_find, auto_scale_batch_size):
self.trainer.auto_lr_find = auto_lr_find
self.trainer.auto_scale_batch_size = auto_scale_batch_size

def tune(self, model, train_dataloader, val_dataloaders, datamodule):
def setup_trainer(
self,
model: LightningModule,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: LightningDataModule = None,
):
self.trainer.model_connector.copy_trainer_model_properties(model)
# setup data, etc...
self.trainer.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule)

# hook
self.trainer.data_connector.prepare_data(model)

def tune(self, model, train_dataloader, val_dataloaders, datamodule):
# Run auto batch size scaling
if self.trainer.auto_scale_batch_size:
if isinstance(self.trainer.auto_scale_batch_size, bool):
Expand Down Expand Up @@ -104,6 +111,7 @@ def scale_batch_size(
or datamodule.
"""
self.setup_trainer(model, **fit_kwargs)
return scale_batch_size(
self.trainer,
model,
Expand All @@ -128,6 +136,7 @@ def lr_find(
datamodule: Optional[LightningDataModule] = None,
update_attr: bool = False,
):
self.setup_trainer(model, train_dataloader, val_dataloaders, datamodule)
return lr_find(
self.trainer,
model,
Expand Down
65 changes: 65 additions & 0 deletions tests/tuner/test_scale_batch_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from torch.utils.data import DataLoader

from pytorch_lightning import Trainer
from pytorch_lightning.tuner.tuning import Tuner
from tests.helpers import BoringDataModule, BoringModel


class BatchSizeDataModule(BoringDataModule):

def __init__(self, batch_size=None):
super().__init__()
if batch_size is not None:
self.batch_size = batch_size

def train_dataloader(self):
return DataLoader(self.random_train, batch_size=getattr(self, "batch_size", 1))


class BatchSizeModel(BoringModel):

def __init__(self, batch_size=None):
super().__init__()
if batch_size is not None:
self.batch_size = batch_size


@pytest.mark.parametrize(
"model,datamodule", [
(BatchSizeModel(2), None),
(BatchSizeModel(2), BatchSizeDataModule(2)),
(BatchSizeModel(2), BatchSizeDataModule(None)),
(BatchSizeModel(None), BatchSizeDataModule(2)),
]
)
def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamodule):
""" Test the tuner method `Tuner.scale_batch_size` with a datamodule. """
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=1,
limit_val_batches=0,
max_epochs=1,
)
tuner = Tuner(trainer)
new_batch_size = tuner.scale_batch_size(
model=model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule
)
assert new_batch_size == 16
if hasattr(model, "batch_size"):
assert model.batch_size == 16
if datamodule is not None and hasattr(datamodule, "batch_size"):
assert datamodule.batch_size == 16

0 comments on commit 840f74b

Please sign in to comment.