Skip to content

Commit

Permalink
Fix tuner.scale_batch_size not finding batch size attribute when usin…
Browse files Browse the repository at this point in the history
…g datamodule (#5968)

(cherry picked from commit b2bcad1)
  • Loading branch information
awaelchli authored and lexierule committed Mar 16, 2021
1 parent 73ef543 commit 9fc733f
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 8 deletions.
22 changes: 17 additions & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))

- Added support to checkpoint after training steps in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))

- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072))

Expand All @@ -21,9 +22,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `TrainerState.{FITTING,VALIDATING,TESTING,PREDICTING,TUNING}` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))


- Added `Trainer.validate()` method to perform one evaluation epoch over the validation set ([#4948](https://github.com/PyTorchLightning/pytorch-lightning/pull/4948))


- Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915))


- Added `auto_insert_metric_name` parameter to `ModelCheckpoint` ([#6277](https://github.com/PyTorchLightning/pytorch-lightning/pull/6277))


- Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274))


Expand All @@ -49,6 +56,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Deprecated

- `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))


- Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))

Expand Down Expand Up @@ -122,15 +131,18 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed logger creating directory structure too early in DDP ([#6380](https://github.com/PyTorchLightning/pytorch-lightning/pull/6380))


- Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416))
- Fixed DeepSpeed additional memory use on rank 0 when default device not set early enough ([#6460](https://github.com/PyTorchLightning/pytorch-lightning/pull/6460))


- Fixed DeepSpeed additional memory use on rank 0 when default device not set early enough ([#6460](https://github.com/PyTorchLightning/pytorch-lightning/pull/6460))
- Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416))


- Fixed `DummyLogger.log_hyperparams` raising a `TypeError` when running with `fast_dev_run=True` ([#6398](https://github.com/PyTorchLightning/pytorch-lightning/pull/6398))


- Fixed an issue with `Tuner.scale_batch_size` not finding the batch size attribute in the datamodule ([#5968](https://github.com/PyTorchLightning/pytorch-lightning/pull/5968))


## [1.2.3] - 2021-03-09

### Fixed
Expand All @@ -148,6 +160,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `Trainer` not resetting `lightning_optimizers` when calling `Trainer.fit()` multiple times ([#6372](https://github.com/PyTorchLightning/pytorch-lightning/pull/6372))


- Fixed `DummyLogger.log_hyperparams` raising a `TypeError` when running with `fast_dev_run=True` ([#6398](https://github.com/PyTorchLightning/pytorch-lightning/pull/6398))


## [1.2.2] - 2021-03-02

### Added
Expand All @@ -169,9 +184,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed error thrown when using valid distributed mode in multi node ([#6297](https://github.com/PyTorchLightning/pytorch-lightning/pull/6297)


- Fixed DP reduction with collection ([#6324](https://github.com/PyTorchLightning/pytorch-lightning/pull/6324))


## [1.2.1] - 2021-02-23

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def on_train_start(self):
# provide rank to profiler
self.trainer.profile_connector.on_train_start(self.trainer)

def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule):
def setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None):
# clean hparams
if hasattr(model, "hparams"):
parsing.clean_namespace(model.hparams)
Expand Down
13 changes: 11 additions & 2 deletions pytorch_lightning/tuner/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,20 @@ def on_trainer_init(self, auto_lr_find, auto_scale_batch_size):
self.trainer.auto_lr_find = auto_lr_find
self.trainer.auto_scale_batch_size = auto_scale_batch_size

def tune(self, model, train_dataloader, val_dataloaders, datamodule):
def setup_trainer(
self,
model: LightningModule,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: LightningDataModule = None,
):
self.trainer.model_connector.copy_trainer_model_properties(model)
# setup data, etc...
self.trainer.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule)

# hook
self.trainer.data_connector.prepare_data(model)

def tune(self, model, train_dataloader, val_dataloaders, datamodule):
# Run auto batch size scaling
if self.trainer.auto_scale_batch_size:
if isinstance(self.trainer.auto_scale_batch_size, bool):
Expand Down Expand Up @@ -101,6 +108,7 @@ def scale_batch_size(
or datamodule.
"""
self.setup_trainer(model, **fit_kwargs)
return scale_batch_size(
self.trainer,
model,
Expand All @@ -125,6 +133,7 @@ def lr_find(
datamodule: Optional[LightningDataModule] = None,
update_attr: bool = False,
):
self.setup_trainer(model, train_dataloader, val_dataloaders, datamodule)
return lr_find(
self.trainer,
model,
Expand Down
65 changes: 65 additions & 0 deletions tests/tuner/test_scale_batch_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from torch.utils.data import DataLoader

from pytorch_lightning import Trainer
from pytorch_lightning.tuner.tuning import Tuner
from tests.helpers import BoringDataModule, BoringModel


class BatchSizeDataModule(BoringDataModule):

def __init__(self, batch_size=None):
super().__init__()
if batch_size is not None:
self.batch_size = batch_size

def train_dataloader(self):
return DataLoader(self.random_train, batch_size=getattr(self, "batch_size", 1))


class BatchSizeModel(BoringModel):

def __init__(self, batch_size=None):
super().__init__()
if batch_size is not None:
self.batch_size = batch_size


@pytest.mark.parametrize(
"model,datamodule", [
(BatchSizeModel(2), None),
(BatchSizeModel(2), BatchSizeDataModule(2)),
(BatchSizeModel(2), BatchSizeDataModule(None)),
(BatchSizeModel(None), BatchSizeDataModule(2)),
]
)
def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamodule):
""" Test the tuner method `Tuner.scale_batch_size` with a datamodule. """
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=1,
limit_val_batches=0,
max_epochs=1,
)
tuner = Tuner(trainer)
new_batch_size = tuner.scale_batch_size(
model=model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule
)
assert new_batch_size == 16
if hasattr(model, "batch_size"):
assert model.batch_size == 16
if datamodule is not None and hasattr(datamodule, "batch_size"):
assert datamodule.batch_size == 16

0 comments on commit 9fc733f

Please sign in to comment.