From c0fd86e57328eaffdf49b201e64ba614d161115b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 14 Feb 2021 16:36:19 +0100 Subject: [PATCH 01/15] add missing setup logic --- pytorch_lightning/trainer/training_loop.py | 2 +- pytorch_lightning/tuner/tuning.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index f727a15310a84..121bcab0e1a38 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -108,7 +108,7 @@ def on_train_start(self): # provide rank to profiler self.trainer.profile_connector.on_train_start(self.trainer) - def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule): + def setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None): # clean hparams if hasattr(model, "hparams"): parsing.clean_namespace(model.hparams) diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 314821bd81e02..8a98abfd951ec 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -101,6 +101,9 @@ def scale_batch_size( or datamodule. """ + self.trainer.model_connector.copy_trainer_model_properties(model) + self.trainer.train_loop.setup_fit(model, **fit_kwargs) + self.trainer.data_connector.prepare_data(model) return scale_batch_size( self.trainer, model, @@ -124,6 +127,9 @@ def lr_find( early_stop_threshold: float = 4.0, datamodule: Optional[LightningDataModule] = None ): + self.trainer.model_connector.copy_trainer_model_properties(model) + self.trainer.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule) + self.trainer.data_connector.prepare_data(model) return lr_find( self.trainer, model, From 8d8b9ae268aeba0d190efb99d0d879eb4000dd97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 5 Mar 2021 15:49:21 +0100 Subject: [PATCH 02/15] repro script --- pl_examples/repro.py | 90 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 pl_examples/repro.py diff --git a/pl_examples/repro.py b/pl_examples/repro.py new file mode 100644 index 0000000000000..9a4e4b81804e4 --- /dev/null +++ b/pl_examples/repro.py @@ -0,0 +1,90 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from argparse import ArgumentParser + +import torch +from pytorch_lightning import Trainer +from pytorch_lightning.tuner.tuning import Tuner +from torch.nn import functional as F + +import pytorch_lightning as pl +from pl_examples.basic_examples.mnist_datamodule import MNISTDataModule + + +class LitClassifier(pl.LightningModule): + + def __init__(self, hidden_dim=128, learning_rate=1e-3): + super().__init__() + self.save_hyperparameters() + + self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim) + self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10) + + def forward(self, x): + x = x.view(x.size(0), -1) + x = torch.relu(self.l1(x)) + x = torch.relu(self.l2(x)) + return x + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + self.log('valid_loss', loss) + + def test_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + self.log('test_loss', loss) + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + + @staticmethod + def add_model_specific_args(parent_parser): + parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser.add_argument('--hidden_dim', type=int, default=128) + parser.add_argument('--learning_rate', type=float, default=0.0001) + return parser + + +if __name__ == '__main__': + pl.seed_everything(1234) + parser = ArgumentParser() + parser = Trainer.add_argparse_args(parser) + parser = LitClassifier.add_model_specific_args(parser) + parser = MNISTDataModule.add_argparse_args(parser) + args = parser.parse_args() + + dm = MNISTDataModule.from_argparse_args(args) + model = LitClassifier(args.hidden_dim, args.learning_rate) + trainer = Trainer( + gpus=1, + accelerator='dp', + auto_scale_batch_size='binsearch' + ) + + tuner = Tuner(trainer) + new_batch_size = tuner.scale_batch_size(model, mode="binsearch", init_val=128, max_trials=3, datamodule=dm) + model.hparams.batch_size = new_batch_size + + trainer.fit(model, datamodule=dm) From 4c6f4b9b2ad5ddf0f8bb2cd65a06d7804393a823 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 5 Mar 2021 16:35:05 +0100 Subject: [PATCH 03/15] add simple test --- tests/tuner/test_scale_batch_size.py | 50 ++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 tests/tuner/test_scale_batch_size.py diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py new file mode 100644 index 0000000000000..9f305d7b48420 --- /dev/null +++ b/tests/tuner/test_scale_batch_size.py @@ -0,0 +1,50 @@ +import pytest +from pytorch_lightning import Trainer +from pytorch_lightning.tuner.tuning import Tuner +from tests.helpers import BoringModel, BoringDataModule +from torch.utils.data import DataLoader + + +class BatchSizeDataModule(BoringDataModule): + + def __init__(self, batch_size=2): + super().__init__() + self.batch_size = batch_size + + def train_dataloader(self): + return DataLoader(self.random_train, batch_size=self.batch_size) + + +class BatchSizeModel(BoringModel): + + def __init__(self, batch_size=2): + super().__init__() + self.save_hyperparameters() + + +# @RunIf() +@pytest.mark.parametrize("model,datamodule", [ + (BatchSizeModel(2), None), + (BatchSizeModel(2), BatchSizeDataModule(2)) +]) +def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamodule): + """ Test the tuner method `Tuner.scale_batch_size` with a datamodule. """ + trainer = Trainer( + default_root_dir=tmpdir, + gpus=1, + limit_train_batches=1, + limit_val_batches=0, + max_epochs=1, + ) + tuner = Tuner(trainer) + new_batch_size = tuner.scale_batch_size( + model=model, + mode="binsearch", + init_val=4, + max_trials=2, + datamodule=datamodule + ) + assert new_batch_size == 16 + assert model.hparams.batch_size == 16 + if datamodule is not None: + assert datamodule.batch_size == 16 From de4b8ec941b8e5af9665de3b9ae42c3a39a29d67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 5 Mar 2021 16:36:09 +0100 Subject: [PATCH 04/15] gpu requirement --- tests/tuner/test_scale_batch_size.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py index 9f305d7b48420..56d7669c322f2 100644 --- a/tests/tuner/test_scale_batch_size.py +++ b/tests/tuner/test_scale_batch_size.py @@ -2,6 +2,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.tuner.tuning import Tuner from tests.helpers import BoringModel, BoringDataModule +from tests.helpers.runif import RunIf from torch.utils.data import DataLoader @@ -22,7 +23,7 @@ def __init__(self, batch_size=2): self.save_hyperparameters() -# @RunIf() +@RunIf(min_gpus=1) @pytest.mark.parametrize("model,datamodule", [ (BatchSizeModel(2), None), (BatchSizeModel(2), BatchSizeDataModule(2)) From eaf65e6ff46829222bab53c4baa3f37766aec0d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 5 Mar 2021 16:52:37 +0100 Subject: [PATCH 05/15] test all combinations --- tests/tuner/test_scale_batch_size.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py index 56d7669c322f2..5c7b1f6685d83 100644 --- a/tests/tuner/test_scale_batch_size.py +++ b/tests/tuner/test_scale_batch_size.py @@ -8,25 +8,29 @@ class BatchSizeDataModule(BoringDataModule): - def __init__(self, batch_size=2): + def __init__(self, batch_size=None): super().__init__() - self.batch_size = batch_size + if batch_size is not None: + self.batch_size = batch_size def train_dataloader(self): - return DataLoader(self.random_train, batch_size=self.batch_size) + return DataLoader(self.random_train, batch_size=getattr(self, "batch_size", 1)) class BatchSizeModel(BoringModel): - def __init__(self, batch_size=2): + def __init__(self, batch_size=None): super().__init__() - self.save_hyperparameters() + if batch_size is not None: + self.batch_size = batch_size @RunIf(min_gpus=1) @pytest.mark.parametrize("model,datamodule", [ (BatchSizeModel(2), None), - (BatchSizeModel(2), BatchSizeDataModule(2)) + (BatchSizeModel(2), BatchSizeDataModule(2)), + (BatchSizeModel(2), BatchSizeDataModule(None)), + (BatchSizeModel(None), BatchSizeDataModule(2)), ]) def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamodule): """ Test the tuner method `Tuner.scale_batch_size` with a datamodule. """ @@ -46,6 +50,7 @@ def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamod datamodule=datamodule ) assert new_batch_size == 16 - assert model.hparams.batch_size == 16 - if datamodule is not None: + if hasattr(model, "batch_size"): + assert model.batch_size == 16 + if datamodule is not None and hasattr(datamodule, "batch_size"): assert datamodule.batch_size == 16 From dbcb8bf5ba5a9e4726f5038e4996a2b3323f5988 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 5 Mar 2021 16:59:16 +0100 Subject: [PATCH 06/15] remove repro script --- pl_examples/repro.py | 90 -------------------------------------------- 1 file changed, 90 deletions(-) delete mode 100644 pl_examples/repro.py diff --git a/pl_examples/repro.py b/pl_examples/repro.py deleted file mode 100644 index 9a4e4b81804e4..0000000000000 --- a/pl_examples/repro.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from argparse import ArgumentParser - -import torch -from pytorch_lightning import Trainer -from pytorch_lightning.tuner.tuning import Tuner -from torch.nn import functional as F - -import pytorch_lightning as pl -from pl_examples.basic_examples.mnist_datamodule import MNISTDataModule - - -class LitClassifier(pl.LightningModule): - - def __init__(self, hidden_dim=128, learning_rate=1e-3): - super().__init__() - self.save_hyperparameters() - - self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim) - self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10) - - def forward(self, x): - x = x.view(x.size(0), -1) - x = torch.relu(self.l1(x)) - x = torch.relu(self.l2(x)) - return x - - def training_step(self, batch, batch_idx): - x, y = batch - y_hat = self(x) - loss = F.cross_entropy(y_hat, y) - return loss - - def validation_step(self, batch, batch_idx): - x, y = batch - y_hat = self(x) - loss = F.cross_entropy(y_hat, y) - self.log('valid_loss', loss) - - def test_step(self, batch, batch_idx): - x, y = batch - y_hat = self(x) - loss = F.cross_entropy(y_hat, y) - self.log('test_loss', loss) - - def configure_optimizers(self): - return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - - @staticmethod - def add_model_specific_args(parent_parser): - parser = ArgumentParser(parents=[parent_parser], add_help=False) - parser.add_argument('--hidden_dim', type=int, default=128) - parser.add_argument('--learning_rate', type=float, default=0.0001) - return parser - - -if __name__ == '__main__': - pl.seed_everything(1234) - parser = ArgumentParser() - parser = Trainer.add_argparse_args(parser) - parser = LitClassifier.add_model_specific_args(parser) - parser = MNISTDataModule.add_argparse_args(parser) - args = parser.parse_args() - - dm = MNISTDataModule.from_argparse_args(args) - model = LitClassifier(args.hidden_dim, args.learning_rate) - trainer = Trainer( - gpus=1, - accelerator='dp', - auto_scale_batch_size='binsearch' - ) - - tuner = Tuner(trainer) - new_batch_size = tuner.scale_batch_size(model, mode="binsearch", init_val=128, max_trials=3, datamodule=dm) - model.hparams.batch_size = new_batch_size - - trainer.fit(model, datamodule=dm) From 38a0df6fead2f6cd21447c0b4749ecf450d818a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 5 Mar 2021 17:01:36 +0100 Subject: [PATCH 07/15] chanelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 08081e1dd76aa..44f022394f966 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -107,6 +107,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed DP reduction with collection ([#6324](https://github.com/PyTorchLightning/pytorch-lightning/pull/6324)) +- Fixed an issue with `Tuner.scale_batch_size` not finding the batch size attribute in the datamodule ([#5968](https://github.com/PyTorchLightning/pytorch-lightning/pull/5968)) + + ## [1.2.1] - 2021-02-23 ### Fixed From 991d6d2da5e76136a8c018cbe40f5a3cb40873ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 5 Mar 2021 19:09:06 +0100 Subject: [PATCH 08/15] revert lr finder changes --- pytorch_lightning/tuner/tuning.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index ddcfc857f40c2..01332e97e7d7d 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -128,9 +128,6 @@ def lr_find( datamodule: Optional[LightningDataModule] = None, update_attr: bool = False, ): - self.trainer.model_connector.copy_trainer_model_properties(model) - self.trainer.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule) - self.trainer.data_connector.prepare_data(model) return lr_find( self.trainer, model, From 8c5249fface7a3e72a74d1e20a112bd3659f5251 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 6 Mar 2021 00:26:20 +0100 Subject: [PATCH 09/15] refactor --- pytorch_lightning/tuner/tuning.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 01332e97e7d7d..9a577590fc827 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -32,13 +32,16 @@ def on_trainer_init(self, auto_lr_find, auto_scale_batch_size): self.trainer.auto_lr_find = auto_lr_find self.trainer.auto_scale_batch_size = auto_scale_batch_size - def tune(self, model, train_dataloader, val_dataloaders, datamodule): + def setup_trainer(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None): + self.trainer.model_connector.copy_trainer_model_properties(model) # setup data, etc... self.trainer.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule) - # hook self.trainer.data_connector.prepare_data(model) + def tune(self, model, train_dataloader, val_dataloaders, datamodule): + self.setup_trainer(model, train_dataloader, val_dataloaders, datamodule) + # Run auto batch size scaling if self.trainer.auto_scale_batch_size: if isinstance(self.trainer.auto_scale_batch_size, bool): @@ -101,9 +104,7 @@ def scale_batch_size( or datamodule. """ - self.trainer.model_connector.copy_trainer_model_properties(model) - self.trainer.train_loop.setup_fit(model, **fit_kwargs) - self.trainer.data_connector.prepare_data(model) + self.setup_trainer(model, **fit_kwargs) return scale_batch_size( self.trainer, model, @@ -128,6 +129,7 @@ def lr_find( datamodule: Optional[LightningDataModule] = None, update_attr: bool = False, ): + self.setup_trainer(model, train_dataloader, val_dataloaders, datamodule) return lr_find( self.trainer, model, From 2699ea761602f811bed137b8f427e26439d06543 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 6 Mar 2021 00:29:56 +0100 Subject: [PATCH 10/15] added typing --- pytorch_lightning/tuner/tuning.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 9a577590fc827..205da351f1eb7 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -32,7 +32,13 @@ def on_trainer_init(self, auto_lr_find, auto_scale_batch_size): self.trainer.auto_lr_find = auto_lr_find self.trainer.auto_scale_batch_size = auto_scale_batch_size - def setup_trainer(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None): + def setup_trainer( + self, + model: LightningModule, + train_dataloader: Optional[DataLoader] = None, + val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + datamodule: LightningDataModule = None, + ): self.trainer.model_connector.copy_trainer_model_properties(model) # setup data, etc... self.trainer.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule) From b4812923e9592ab57c77fb9a6eeb48f83b22acd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 6 Mar 2021 00:30:22 +0100 Subject: [PATCH 11/15] formatting --- tests/tuner/test_scale_batch_size.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py index 5c7b1f6685d83..40a22e6c1981b 100644 --- a/tests/tuner/test_scale_batch_size.py +++ b/tests/tuner/test_scale_batch_size.py @@ -1,9 +1,10 @@ import pytest +from torch.utils.data import DataLoader + from pytorch_lightning import Trainer from pytorch_lightning.tuner.tuning import Tuner -from tests.helpers import BoringModel, BoringDataModule +from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf -from torch.utils.data import DataLoader class BatchSizeDataModule(BoringDataModule): @@ -26,12 +27,14 @@ def __init__(self, batch_size=None): @RunIf(min_gpus=1) -@pytest.mark.parametrize("model,datamodule", [ - (BatchSizeModel(2), None), - (BatchSizeModel(2), BatchSizeDataModule(2)), - (BatchSizeModel(2), BatchSizeDataModule(None)), - (BatchSizeModel(None), BatchSizeDataModule(2)), -]) +@pytest.mark.parametrize( + "model,datamodule", [ + (BatchSizeModel(2), None), + (BatchSizeModel(2), BatchSizeDataModule(2)), + (BatchSizeModel(2), BatchSizeDataModule(None)), + (BatchSizeModel(None), BatchSizeDataModule(2)), + ] +) def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamodule): """ Test the tuner method `Tuner.scale_batch_size` with a datamodule. """ trainer = Trainer( @@ -43,11 +46,7 @@ def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamod ) tuner = Tuner(trainer) new_batch_size = tuner.scale_batch_size( - model=model, - mode="binsearch", - init_val=4, - max_trials=2, - datamodule=datamodule + model=model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule ) assert new_batch_size == 16 if hasattr(model, "batch_size"): From d7bc9dfde67b9e3297ae8e47a69e63c5558581bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 6 Mar 2021 19:39:42 +0100 Subject: [PATCH 12/15] remove gpu requirement --- tests/tuner/test_scale_batch_size.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py index 40a22e6c1981b..4aab43d25cca5 100644 --- a/tests/tuner/test_scale_batch_size.py +++ b/tests/tuner/test_scale_batch_size.py @@ -26,7 +26,6 @@ def __init__(self, batch_size=None): self.batch_size = batch_size -@RunIf(min_gpus=1) @pytest.mark.parametrize( "model,datamodule", [ (BatchSizeModel(2), None), @@ -39,7 +38,6 @@ def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamod """ Test the tuner method `Tuner.scale_batch_size` with a datamodule. """ trainer = Trainer( default_root_dir=tmpdir, - gpus=1, limit_train_batches=1, limit_val_batches=0, max_epochs=1, From 4ff8dc2317fdece52bbab2d8fcc7bd85fa079263 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 6 Mar 2021 19:42:02 +0100 Subject: [PATCH 13/15] remove duplicate setup --- pytorch_lightning/tuner/tuning.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 205da351f1eb7..c5256c6ddc65f 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -46,8 +46,6 @@ def setup_trainer( self.trainer.data_connector.prepare_data(model) def tune(self, model, train_dataloader, val_dataloaders, datamodule): - self.setup_trainer(model, train_dataloader, val_dataloaders, datamodule) - # Run auto batch size scaling if self.trainer.auto_scale_batch_size: if isinstance(self.trainer.auto_scale_batch_size, bool): From 63d4114171b76d363fe129389d96e7943a9614e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 6 Mar 2021 20:02:30 +0100 Subject: [PATCH 14/15] rm unused import --- tests/tuner/test_scale_batch_size.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py index 4aab43d25cca5..01efd3003a64a 100644 --- a/tests/tuner/test_scale_batch_size.py +++ b/tests/tuner/test_scale_batch_size.py @@ -4,7 +4,6 @@ from pytorch_lightning import Trainer from pytorch_lightning.tuner.tuning import Tuner from tests.helpers import BoringDataModule, BoringModel -from tests.helpers.runif import RunIf class BatchSizeDataModule(BoringDataModule): From 01d5e1ed66a04906a95ab6b4b2a58b6f31037b6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 7 Mar 2021 00:01:01 +0100 Subject: [PATCH 15/15] add license --- tests/tuner/test_scale_batch_size.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py index 01efd3003a64a..ad7fc57092f32 100644 --- a/tests/tuner/test_scale_batch_size.py +++ b/tests/tuner/test_scale_batch_size.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pytest from torch.utils.data import DataLoader