diff --git a/CHANGELOG.md b/CHANGELOG.md index 007ac3eadceae..77eca7f2daacc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -158,6 +158,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated +- Deprecated `DataModule` properties: `has_prepared_data`, `has_setup_fit`, `has_setup_validate`, `has_setup_test`, `has_setup_predict`, `has_teardown_fit`, `has_teardown_validate`, `has_teardown_test`, `has_teardown_predict` ([#7657](https://github.com/PyTorchLightning/pytorch-lightning/pull/7657/)) + + - Deprecated `TrainerModelHooksMixin` in favor of `pytorch_lightning.utilities.signature_utils` ([#7422](https://github.com/PyTorchLightning/pytorch-lightning/pull/7422)) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 84210e9d7b667..afa1238786490 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -20,8 +20,8 @@ from torch.utils.data import DataLoader, Dataset, IterableDataset from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks -from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types +from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_only class LightningDataModule(CheckpointHooks, DataHooks): @@ -160,7 +160,13 @@ def has_prepared_data(self) -> bool: Returns: bool: True if ``datamodule.prepare_data()`` has been called. False by default. + + .. deprecated:: v1.4 + Will be removed in v1.6.0. """ + rank_zero_deprecation( + 'DataModule property `has_prepared_data` was deprecated in v1.4 and will be removed in v1.6.' + ) return self._has_prepared_data @property @@ -169,7 +175,11 @@ def has_setup_fit(self) -> bool: Returns: bool: True ``if datamodule.setup(stage='fit')`` has been called. False by default. + + .. deprecated:: v1.4 + Will be removed in v1.6.0. """ + rank_zero_deprecation('DataModule property `has_setup_fit` was deprecated in v1.4 and will be removed in v1.6.') return self._has_setup_fit @property @@ -178,7 +188,13 @@ def has_setup_validate(self) -> bool: Returns: bool: True if ``datamodule.setup(stage='validate')`` has been called. False by default. + + .. deprecated:: v1.4 + Will be removed in v1.6.0. """ + rank_zero_deprecation( + 'DataModule property `has_setup_validate` was deprecated in v1.4 and will be removed in v1.6.' + ) return self._has_setup_validate @property @@ -187,7 +203,13 @@ def has_setup_test(self) -> bool: Returns: bool: True if ``datamodule.setup(stage='test')`` has been called. False by default. + + .. deprecated:: v1.4 + Will be removed in v1.6.0. """ + rank_zero_deprecation( + 'DataModule property `has_setup_test` was deprecated in v1.4 and will be removed in v1.6.' + ) return self._has_setup_test @property @@ -196,7 +218,13 @@ def has_setup_predict(self) -> bool: Returns: bool: True if ``datamodule.setup(stage='predict')`` has been called. False by default. + + .. deprecated:: v1.4 + Will be removed in v1.6.0. """ + rank_zero_deprecation( + 'DataModule property `has_setup_predict` was deprecated in v1.4 and will be removed in v1.6.' + ) return self._has_setup_predict @property @@ -205,7 +233,13 @@ def has_teardown_fit(self) -> bool: Returns: bool: True ``if datamodule.teardown(stage='fit')`` has been called. False by default. + + .. deprecated:: v1.4 + Will be removed in v1.6.0. """ + rank_zero_deprecation( + 'DataModule property `has_teardown_fit` was deprecated in v1.4 and will be removed in v1.6.' + ) return self._has_teardown_fit @property @@ -214,7 +248,13 @@ def has_teardown_validate(self) -> bool: Returns: bool: True if ``datamodule.teardown(stage='validate')`` has been called. False by default. + + .. deprecated:: v1.4 + Will be removed in v1.6.0. """ + rank_zero_deprecation( + 'DataModule property `has_teardown_validate` was deprecated in v1.4 and will be removed in v1.6.' + ) return self._has_teardown_validate @property @@ -223,7 +263,13 @@ def has_teardown_test(self) -> bool: Returns: bool: True if ``datamodule.teardown(stage='test')`` has been called. False by default. + + .. deprecated:: v1.4 + Will be removed in v1.6.0. """ + rank_zero_deprecation( + 'DataModule property `has_teardown_test` was deprecated in v1.4 and will be removed in v1.6.' + ) return self._has_teardown_test @property @@ -232,7 +278,13 @@ def has_teardown_predict(self) -> bool: Returns: bool: True if ``datamodule.teardown(stage='predict')`` has been called. False by default. + + .. deprecated:: v1.4 + Will be removed in v1.6.0. """ + rank_zero_deprecation( + 'DataModule property `has_teardown_predict` was deprecated in v1.4 and will be removed in v1.6.' + ) return self._has_teardown_predict @classmethod @@ -381,8 +433,13 @@ def wrapped_fn(*args: str, **kwargs: Optional[str]) -> Any: has_run = obj._has_prepared_data obj._has_prepared_data = True - if not has_run: - return fn(*args, **kwargs) + if has_run: + rank_zero_deprecation( + f"DataModule.{name} has already been called, so it will not be called again. " + f"In v1.6 this behavior will change to always call DataModule.{name}." + ) + else: + fn(*args, **kwargs) return wrapped_fn diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index e6500a15eeed1..d4e1a3ff0e3ae 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -524,46 +524,3 @@ def test_dm_init_from_datasets_dataloaders(iterable): call(test_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True), call(test_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True) ]) - - -def test_datamodule_hooks_calls(tmpdir): - """Test that repeated calls to DataHooks' hooks have no effect""" - - class TestDataModule(BoringDataModule): - setup_calls = [] - teardown_calls = [] - prepare_data_calls = 0 - - def setup(self, stage=None): - super().setup(stage=stage) - self.setup_calls.append(stage) - - def teardown(self, stage=None): - super().teardown(stage=stage) - self.teardown_calls.append(stage) - - def prepare_data(self): - super().prepare_data() - self.prepare_data_calls += 1 - - dm = TestDataModule() - dm.prepare_data() - dm.prepare_data() - dm.setup('fit') - dm.setup('fit') - dm.setup() - dm.setup() - dm.teardown('validate') - dm.teardown('validate') - - assert dm.prepare_data_calls == 1 - assert dm.setup_calls == ['fit', None] - assert dm.teardown_calls == ['validate'] - - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) - trainer.test(BoringModel(), datamodule=dm) - - # same number of calls - assert dm.prepare_data_calls == 1 - assert dm.setup_calls == ['fit', None] - assert dm.teardown_calls == ['validate', 'test'] diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index 7a92501caee4a..1b4f6cacfef70 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -16,7 +16,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin -from tests.helpers import BoringModel +from tests.helpers import BoringDataModule, BoringModel def test_v1_6_0_trainer_model_hook_mixin(tmpdir): @@ -99,3 +99,76 @@ def training_step(self, *args): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) with pytest.deprecated_call(match=r"`self.log\(sync_dist_op='sum'\)` is deprecated"): trainer.fit(TestModel()) + + +def test_v1_6_0_datamodule_lifecycle_properties(tmpdir): + dm = BoringDataModule() + with pytest.deprecated_call(match=r"DataModule property `has_prepared_data` was deprecated in v1.4"): + dm.has_prepared_data + with pytest.deprecated_call(match=r"DataModule property `has_setup_fit` was deprecated in v1.4"): + dm.has_setup_fit + with pytest.deprecated_call(match=r"DataModule property `has_setup_validate` was deprecated in v1.4"): + dm.has_setup_validate + with pytest.deprecated_call(match=r"DataModule property `has_setup_test` was deprecated in v1.4"): + dm.has_setup_test + with pytest.deprecated_call(match=r"DataModule property `has_setup_predict` was deprecated in v1.4"): + dm.has_setup_predict + with pytest.deprecated_call(match=r"DataModule property `has_teardown_fit` was deprecated in v1.4"): + dm.has_teardown_fit + with pytest.deprecated_call(match=r"DataModule property `has_teardown_validate` was deprecated in v1.4"): + dm.has_teardown_validate + with pytest.deprecated_call(match=r"DataModule property `has_teardown_test` was deprecated in v1.4"): + dm.has_teardown_test + with pytest.deprecated_call(match=r"DataModule property `has_teardown_predict` was deprecated in v1.4"): + dm.has_teardown_predict + + +def test_v1_6_0_datamodule_hooks_calls(tmpdir): + """Test that repeated calls to DataHooks' hooks show a warning about the coming API change.""" + + class TestDataModule(BoringDataModule): + setup_calls = [] + teardown_calls = [] + prepare_data_calls = 0 + + def setup(self, stage=None): + super().setup(stage=stage) + self.setup_calls.append(stage) + + def teardown(self, stage=None): + super().teardown(stage=stage) + self.teardown_calls.append(stage) + + def prepare_data(self): + super().prepare_data() + self.prepare_data_calls += 1 + + dm = TestDataModule() + dm.prepare_data() + dm.prepare_data() + dm.setup('fit') + with pytest.deprecated_call( + match=r"DataModule.setup has already been called, so it will not be called again. " + "In v1.6 this behavior will change to always call DataModule.setup" + ): + dm.setup('fit') + dm.setup() + dm.setup() + dm.teardown('validate') + with pytest.deprecated_call( + match=r"DataModule.teardown has already been called, so it will not be called again. " + "In v1.6 this behavior will change to always call DataModule.teardown" + ): + dm.teardown('validate') + + assert dm.prepare_data_calls == 1 + assert dm.setup_calls == ['fit', None] + assert dm.teardown_calls == ['validate'] + + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) + trainer.test(BoringModel(), datamodule=dm) + + # same number of calls + assert dm.prepare_data_calls == 1 + assert dm.setup_calls == ['fit', None] + assert dm.teardown_calls == ['validate', 'test']