From d1db604c61edb3142960346fbcbbfe74250baeb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 10 Mar 2021 20:16:09 +0100 Subject: [PATCH 1/8] Remove redundant test (#6466) --- tests/models/test_model_hooks.py | 49 -------------------------------- 1 file changed, 49 deletions(-) delete mode 100644 tests/models/test_model_hooks.py diff --git a/tests/models/test_model_hooks.py b/tests/models/test_model_hooks.py deleted file mode 100644 index 2e004584119f4..0000000000000 --- a/tests/models/test_model_hooks.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from unittest import mock - -from pytorch_lightning import Trainer -from tests.helpers.boring_model import BoringModel - - -@mock.patch('pytorch_lightning.core.hooks.ModelHooks.on_validation_model_eval') -@mock.patch('pytorch_lightning.core.hooks.ModelHooks.on_validation_model_train') -@mock.patch('pytorch_lightning.core.hooks.ModelHooks.on_test_model_eval') -@mock.patch('pytorch_lightning.core.hooks.ModelHooks.on_test_model_train') -def test_eval_train_calls(test_train_mock, test_eval_mock, val_train_mock, val_eval_mock, tmpdir): - """ - Tests that only training_step can be used - """ - model = BoringModel() - model.validation_epoch_end = None - - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=2, - max_epochs=2, - log_every_n_steps=1, - weights_summary=None, - ) - - trainer.fit(model) - trainer.test() - - # sanity + 2 epochs - assert val_eval_mock.call_count == 3 - assert val_train_mock.call_count == 3 - - # test is called only once - assert test_eval_mock.call_count == 1 - assert test_train_mock.call_count == 1 From f4cc7451a94010a572480c43ad5f0af7ad52cd21 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Thu, 11 Mar 2021 03:46:37 +0100 Subject: [PATCH 2/8] =?UTF-8?q?Add=20Trainer.validate(=E2=80=A6)=20method?= =?UTF-8?q?=20to=20run=20one=20validation=20epoch=20(#4948)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholi Co-authored-by: chaton Co-authored-by: Adrian Wälchli --- CHANGELOG.md | 3 + docs/source/common/trainer.rst | 15 +- pytorch_lightning/callbacks/progress.py | 11 +- .../trainer/configuration_validator.py | 21 ++- .../trainer/connectors/data_connector.py | 10 +- pytorch_lightning/trainer/states.py | 2 +- pytorch_lightning/trainer/trainer.py | 140 ++++++++++------ tests/accelerators/test_common.py | 44 ++++++ tests/accelerators/test_dp.py | 32 ---- tests/callbacks/test_callbacks.py | 82 +++++++--- tests/callbacks/test_progress_bar.py | 22 ++- tests/checkpointing/test_model_checkpoint.py | 10 ++ tests/core/test_datamodules.py | 117 +------------- tests/models/test_hooks.py | 16 ++ tests/plugins/test_sharded_plugin.py | 26 +-- tests/trainer/optimization/test_optimizers.py | 21 ++- tests/trainer/test_config_validator.py | 62 +++++--- tests/trainer/test_dataloaders.py | 149 ++++++------------ tests/trainer/test_trainer.py | 70 ++++---- tests/trainer/test_trainer_test_loop.py | 76 --------- 20 files changed, 446 insertions(+), 483 deletions(-) create mode 100644 tests/accelerators/test_common.py delete mode 100644 tests/trainer/test_trainer_test_loop.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 6ff1ded491b15..7cdc8fd5acfb8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `TrainerState.{FITTING,VALIDATING,TESTING,PREDICTING,TUNING}` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) +- Added `Trainer.validate()` method to perform one evaluation epoch over the validation set ([#4948](https://github.com/PyTorchLightning/pytorch-lightning/pull/4948)) + + - Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915)) diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index 17cfc7eccbc20..6edf896ada01c 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -151,6 +151,19 @@ So you can run it like so: ------------ +Validation +---------- +You can perform an evaluation epoch over the validation set, outside of the training loop, +using :meth:`pytorch_lightning.trainer.trainer.Trainer.validate`. This might be +useful if you want to collect new metrics from a model right at its initialization +or after it has already been trained. + +.. code-block:: python + + trainer.validate(val_dataloaders=val_dataloaders) + +------------ + Testing ------- Once you're done training, feel free to run the test set! @@ -158,7 +171,7 @@ Once you're done training, feel free to run the test set! .. code-block:: python - trainer.test(test_dataloaders=test_dataloader) + trainer.test(test_dataloaders=test_dataloaders) ------------ diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index c382e67b21a64..74e57e2b5642e 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -355,9 +355,11 @@ def init_predict_tqdm(self) -> tqdm: def init_validation_tqdm(self) -> tqdm: """ Override this to customize the tqdm bar for validation. """ + # The main progress bar doesn't exist in `trainer.validate()` + has_main_bar = self.main_progress_bar is not None bar = tqdm( desc='Validating', - position=(2 * self.process_position + 1), + position=(2 * self.process_position + has_main_bar), disable=self.is_disabled, leave=False, dynamic_ncols=True, @@ -426,7 +428,8 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, def on_validation_end(self, trainer, pl_module): super().on_validation_end(trainer, pl_module) - self.main_progress_bar.set_postfix(trainer.progress_bar_dict) + if self.main_progress_bar is not None: + self.main_progress_bar.set_postfix(trainer.progress_bar_dict) self.val_progress_bar.close() def on_train_end(self, trainer, pl_module): @@ -479,8 +482,10 @@ def print( def _should_update(self, current, total): return self.is_enabled and (current % self.refresh_rate == 0 or current == total) - def _update_bar(self, bar): + def _update_bar(self, bar: Optional[tqdm]) -> None: """ Updates the bar by the refresh rate without overshooting. """ + if bar is None: + return if bar.total is not None: delta = min(self.refresh_rate, bar.total - bar.n) else: diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 1bf38048ee159..8c539b5ff478d 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden @@ -22,7 +23,7 @@ class ConfigValidator(object): def __init__(self, trainer): self.trainer = trainer - def verify_loop_configurations(self, model: LightningModule): + def verify_loop_configurations(self, model: LightningModule) -> None: r""" Checks that the model is configured correctly before the run is started. @@ -30,10 +31,16 @@ def verify_loop_configurations(self, model: LightningModule): model: The model to check the configuration. """ - if self.trainer.training: + if self.trainer.state == TrainerState.FITTING: self.__verify_train_loop_configuration(model) - elif self.trainer.evaluating: - self.__verify_eval_loop_configuration(model) + self.__verify_eval_loop_configuration(model, 'val') + elif self.trainer.state == TrainerState.TUNING: + self.__verify_train_loop_configuration(model) + elif self.trainer.state == TrainerState.VALIDATING: + self.__verify_eval_loop_configuration(model, 'val') + elif self.trainer.state == TrainerState.TESTING: + self.__verify_eval_loop_configuration(model, 'test') + # TODO: add predict def __verify_train_loop_configuration(self, model): # ----------------------------------- @@ -81,11 +88,9 @@ def __verify_train_loop_configuration(self, model): ' It ensures optimizer_step or optimizer_zero_grad are called on every batch.' ) - def __verify_eval_loop_configuration(self, model): - stage = "val" if self.trainer.validating else "test" - + def __verify_eval_loop_configuration(self, model: LightningModule, stage: str) -> None: loader_name = f'{stage}_dataloader' - step_name = f'{stage}_step' + step_name = 'validation_step' if stage == 'val' else 'test_step' has_loader = is_overridden(loader_name, model) has_step = is_overridden(step_name, model) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 7a0e0f39cadfc..fbe1cecdd837e 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -100,10 +100,10 @@ def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloa def attach_dataloaders( self, model, - train_dataloader=None, - val_dataloaders=None, - test_dataloaders=None, - predict_dataloaders=None, + train_dataloader: Optional[Union[DataLoader, List[DataLoader]]] = None, + val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + predict_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, ): # when dataloader is passed via fit, patch the train_dataloader # functions to overwrite with these implementations @@ -119,7 +119,7 @@ def attach_dataloaders( if predict_dataloaders is not None: model.predict_dataloader = _PatchDataLoader(predict_dataloaders) - def attach_datamodule(self, model, datamodule: Optional[LightningDataModule]) -> None: + def attach_datamodule(self, model, datamodule: Optional[LightningDataModule] = None) -> None: # We use datamodule if it's been provided, otherwise we check model for it datamodule = datamodule or getattr(model, 'datamodule', None) diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index 33a2326c518d5..b1f188ab047fe 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -58,7 +58,7 @@ class RunningStage(LightningEnum): """ TRAINING = 'train' SANITY_CHECKING = 'sanity_check' - VALIDATING = 'validation' + VALIDATING = 'validate' TESTING = 'test' PREDICTING = 'predict' TUNING = 'tune' diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ff8be336ee57a..c3039d24aadc0 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -820,6 +820,69 @@ def run_sanity_check(self, ref_model): self._running_stage = stage + def validate( + self, + model: Optional[LightningModule] = None, + val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + ckpt_path: Optional[str] = 'best', + verbose: bool = True, + datamodule: Optional[LightningDataModule] = None, + ): + r""" + Perform one evaluation epoch over the validation set. + + Args: + model: The model to validate. + + val_dataloaders: Either a single PyTorch DataLoader or a list of them, + specifying validation samples. + + ckpt_path: Either ``best`` or path to the checkpoint you wish to validate. + If ``None``, use the current weights of the model. + When the model is given as argument, this parameter will not apply. + + verbose: If True, prints the validation results. + + datamodule: A instance of :class:`LightningDataModule`. + + Returns: + The dictionary with final validation results returned by validation_epoch_end. + If validation_epoch_end is not defined, the output is a list of the dictionaries + returned by validation_step. + """ + # -------------------- + # SETUP HOOK + # -------------------- + self.verbose_evaluate = verbose + + self.state = TrainerState.VALIDATING + self.validating = True + + # If you supply a datamodule you can't supply val_dataloaders + if val_dataloaders and datamodule: + raise MisconfigurationException( + 'You cannot pass both `trainer.validate(val_dataloaders=..., datamodule=...)`' + ) + + model_provided = model is not None + model = model or self.lightning_module + + # Attach datamodule to get setup/prepare_data added to model before the call to it below + self.data_connector.attach_datamodule(model, datamodule) + # Attach dataloaders (if given) + self.data_connector.attach_dataloaders(model, val_dataloaders=val_dataloaders) + + if not model_provided: + self.validated_ckpt_path = self.__load_ckpt_weights(model, ckpt_path=ckpt_path) + + # run validate + results = self.fit(model) + + assert self.state.stopped + self.validating = False + + return results + def test( self, model: Optional[LightningModule] = None, @@ -833,17 +896,19 @@ def test( fit to make sure you never run on your test set until you want to. Args: - ckpt_path: Either ``best`` or path to the checkpoint you wish to test. - If ``None``, use the current weights of the model. Default to ``best``. - datamodule: A instance of :class:`LightningDataModule`. - model: The model to test. test_dataloaders: Either a single PyTorch DataLoader or a list of them, specifying test samples. + ckpt_path: Either ``best`` or path to the checkpoint you wish to test. + If ``None``, use the current weights of the model. + When the model is given as argument, this parameter will not apply. + verbose: If True, prints the test results. + datamodule: A instance of :class:`LightningDataModule`. + Returns: Returns a list of dictionaries, one for each test dataloader containing their respective metrics. """ @@ -858,7 +923,7 @@ def test( # If you supply a datamodule you can't supply test_dataloaders if test_dataloaders and datamodule: raise MisconfigurationException( - 'You cannot pass test_dataloaders to trainer.test if you supply a datamodule' + 'You cannot pass both `trainer.test(test_dataloaders=..., datamodule=...)`' ) model_provided = model is not None @@ -866,22 +931,25 @@ def test( # Attach datamodule to get setup/prepare_data added to model before the call to it below self.data_connector.attach_datamodule(model, datamodule) - results = ( - self.__evaluate_given_model(model, dataloaders=test_dataloaders) if model_provided else - self.__evaluate_using_weights(model, ckpt_path=ckpt_path, dataloaders=test_dataloaders) - ) + # Attach dataloaders (if given) + self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders) + + if not model_provided: + self.tested_ckpt_path = self.__load_ckpt_weights(model, ckpt_path=ckpt_path) + + # run test + results = self.fit(model) assert self.state.stopped self.testing = False return results - def __evaluate_using_weights( + def __load_ckpt_weights( self, model, ckpt_path: Optional[str] = None, - dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None - ): + ) -> Optional[str]: # if user requests the best checkpoint but we don't have it, error if ckpt_path == 'best' and not self.checkpoint_callback.best_model_path: raise MisconfigurationException( @@ -894,42 +962,18 @@ def __evaluate_using_weights( if ckpt_path == 'best': ckpt_path = self.checkpoint_callback.best_model_path - if len(ckpt_path) == 0: - rank_zero_warn( - f'`.test()` found no path for the best weights, {ckpt_path}. Please' - ' specify a path for a checkpoint `.test(ckpt_path=PATH)`' + if not ckpt_path: + fn = self.state.value + raise MisconfigurationException( + f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please' + ' specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`' ) - return {} self.training_type_plugin.barrier() ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt['state_dict']) - - # attach dataloaders - if dataloaders is not None: - self.data_connector.attach_dataloaders(model, test_dataloaders=dataloaders) - - if self.validating: - self.validated_ckpt_path = ckpt_path - else: - self.tested_ckpt_path = ckpt_path - - # run test - results = self.fit(model) - - return results - - def __evaluate_given_model(self, model, dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None): - # attach data - if dataloaders is not None: - self.data_connector.attach_dataloaders(model, test_dataloaders=dataloaders) - - # run test - # sets up testing so we short circuit to eval - results = self.fit(model) - - return results + return ckpt_path def predict( self, @@ -970,15 +1014,11 @@ def predict( 'You cannot pass dataloaders to trainer.predict if you supply a datamodule.' ) - if datamodule is not None: - # Attach datamodule to get setup/prepare_data added to model before the call to it below - self.data_connector.attach_datamodule(model, datamodule) - - # attach data - if dataloaders is not None: - self.data_connector.attach_dataloaders(model, predict_dataloaders=dataloaders) + # Attach datamodule to get setup/prepare_data added to model before the call to it below + self.data_connector.attach_datamodule(model, datamodule) + # Attach dataloaders (if given) + self.data_connector.attach_dataloaders(model, predict_dataloaders=dataloaders) - self.model = model results = self.fit(model) assert self.state.stopped diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py new file mode 100644 index 0000000000000..6962af7249d1b --- /dev/null +++ b/tests/accelerators/test_common.py @@ -0,0 +1,44 @@ +import pytest +import torch + +import tests.helpers.utils as tutils +from pytorch_lightning import Trainer +from tests.accelerators.test_dp import CustomClassificationModelDP +from tests.helpers.datamodules import ClassifDataModule +from tests.helpers.runif import RunIf + + +@pytest.mark.parametrize("trainer_kwargs", ( + pytest.param({"gpus": 1}, marks=RunIf(min_gpus=1)), + pytest.param({"accelerator": "dp", "gpus": 2}, marks=RunIf(min_gpus=2)), + pytest.param({"accelerator": "ddp_spawn", "gpus": 2}, marks=RunIf(min_gpus=2)), +)) +def test_evaluate(tmpdir, trainer_kwargs): + tutils.set_random_master_port() + + dm = ClassifDataModule() + model = CustomClassificationModelDP() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + limit_train_batches=10, + limit_val_batches=10, + deterministic=True, + **trainer_kwargs + ) + + result = trainer.fit(model, datamodule=dm) + assert result + assert 'ckpt' in trainer.checkpoint_callback.best_model_path + + old_weights = model.layer_0.weight.clone().detach().cpu() + + result = trainer.validate(datamodule=dm) + assert result[0]['val_acc'] > 0.55 + + result = trainer.test(datamodule=dm) + assert result[0]['test_acc'] > 0.55 + + # make sure weights didn't change + new_weights = model.layer_0.weight.clone().detach().cpu() + torch.testing.assert_allclose(old_weights, new_weights) diff --git a/tests/accelerators/test_dp.py b/tests/accelerators/test_dp.py index 52f585409e865..6b84e1a70ae58 100644 --- a/tests/accelerators/test_dp.py +++ b/tests/accelerators/test_dp.py @@ -25,8 +25,6 @@ from tests.helpers.runif import RunIf from tests.helpers.simple_models import ClassificationModel -PRETEND_N_OF_GPUS = 16 - class CustomClassificationModelDP(ClassificationModel): @@ -96,36 +94,6 @@ def test_multi_gpu_model_dp(tmpdir): memory.get_memory_profile('min_max') -@RunIf(min_gpus=2) -def test_dp_test(tmpdir): - tutils.set_random_master_port() - - dm = ClassifDataModule() - model = CustomClassificationModelDP() - trainer = pl.Trainer( - default_root_dir=tmpdir, - max_epochs=2, - limit_train_batches=10, - limit_val_batches=10, - gpus=[0, 1], - accelerator='dp', - ) - trainer.fit(model, datamodule=dm) - assert 'ckpt' in trainer.checkpoint_callback.best_model_path - results = trainer.test(datamodule=dm) - assert 'test_acc' in results[0] - - old_weights = model.layer_0.weight.clone().detach().cpu() - - results = trainer.test(model, datamodule=dm) - assert 'test_acc' in results[0] - - # make sure weights didn't change - new_weights = model.layer_0.weight.clone().detach().cpu() - - assert torch.all(torch.eq(old_weights, new_weights)) - - class ReductionTestModel(BoringModel): def train_dataloader(self): diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 2426348f770bf..626eb59dffb9c 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -19,8 +19,8 @@ @mock.patch("torch.save") # need to mock torch.save or we get pickle error -def test_trainer_callback_system_fit(_, tmpdir): - """Test the callback system for fit.""" +def test_trainer_callback_hook_system_fit(_, tmpdir): + """Test the callback hook system for fit.""" model = BoringModel() callback_mock = MagicMock() @@ -97,8 +97,8 @@ def test_trainer_callback_system_fit(_, tmpdir): ] -def test_trainer_callback_system_test(tmpdir): - """Test the callback system for test.""" +def test_trainer_callback_hook_system_test(tmpdir): + """Test the callback hook system for test.""" model = BoringModel() callback_mock = MagicMock() @@ -130,6 +130,42 @@ def test_trainer_callback_system_test(tmpdir): ] +def test_trainer_callback_hook_system_validate(tmpdir): + """Test the callback hook system for validate.""" + + model = BoringModel() + callback_mock = MagicMock() + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[callback_mock], + max_epochs=1, + limit_val_batches=2, + progress_bar_refresh_rate=0, + ) + + trainer.validate(model) + + assert callback_mock.method_calls == [ + call.on_init_start(trainer), + call.on_init_end(trainer), + call.setup(trainer, model, 'validate'), + call.on_before_accelerator_backend_setup(trainer, model), + call.on_validation_start(trainer, model), + call.on_validation_epoch_start(trainer, model), + call.on_validation_batch_start(trainer, model, ANY, 0, 0), + call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), + call.on_validation_batch_start(trainer, model, ANY, 1, 0), + call.on_validation_batch_end(trainer, model, ANY, ANY, 1, 0), + call.on_validation_epoch_end(trainer, model), + call.on_epoch_end(trainer, model), + call.on_validation_end(trainer, model), + call.teardown(trainer, model, 'validate'), + ] + + +# TODO: add callback tests for predict and tune + + def test_callbacks_configured_in_model(tmpdir): """ Test the callback system with callbacks added through the model hook. """ @@ -166,22 +202,29 @@ def assert_expected_calls(_trainer, model_callback, trainer_callback): # .fit() trainer_options.update(callbacks=[trainer_callback_mock]) trainer = Trainer(**trainer_options) + assert trainer_callback_mock in trainer.callbacks assert model_callback_mock not in trainer.callbacks trainer.fit(model) + assert model_callback_mock in trainer.callbacks assert trainer.callbacks[-1] == model_callback_mock assert_expected_calls(trainer, model_callback_mock, trainer_callback_mock) # .test() - model_callback_mock.reset_mock() - trainer_callback_mock.reset_mock() - trainer_options.update(callbacks=[trainer_callback_mock]) - trainer = Trainer(**trainer_options) - trainer.test(model) - assert model_callback_mock in trainer.callbacks - assert trainer.callbacks[-1] == model_callback_mock - assert_expected_calls(trainer, model_callback_mock, trainer_callback_mock) + for fn in ("test", "validate"): + model_callback_mock.reset_mock() + trainer_callback_mock.reset_mock() + + trainer_options.update(callbacks=[trainer_callback_mock]) + trainer = Trainer(**trainer_options) + + trainer_fn = getattr(trainer, fn) + trainer_fn(model) + + assert model_callback_mock in trainer.callbacks + assert trainer.callbacks[-1] == model_callback_mock + assert_expected_calls(trainer, model_callback_mock, trainer_callback_mock) def test_configure_callbacks_hook_multiple_calls(tmpdir): @@ -208,10 +251,13 @@ def configure_callbacks(self): callbacks_after_fit = trainer.callbacks.copy() assert callbacks_after_fit == callbacks_before_fit + [model_callback_mock] - trainer.test(model) - callbacks_after_test = trainer.callbacks.copy() - assert callbacks_after_test == callbacks_after_fit + for fn in ("test", "validate"): + trainer_fn = getattr(trainer, fn) + trainer_fn(model) + + callbacks_after = trainer.callbacks.copy() + assert callbacks_after == callbacks_after_fit - trainer.test(ckpt_path=None) - callbacks_after_test = trainer.callbacks.copy() - assert callbacks_after_test == callbacks_after_fit + trainer_fn(ckpt_path=None) + callbacks_after = trainer.callbacks.copy() + assert callbacks_after == callbacks_after_fit diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 67ea5a00cfda3..76f1e4cb0570f 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -90,7 +90,6 @@ def test_progress_bar_totals(tmpdir): trainer = Trainer( default_root_dir=tmpdir, progress_bar_refresh_rate=1, - limit_val_batches=1.0, max_epochs=1, ) bar = trainer.progress_bar_callback @@ -122,6 +121,12 @@ def test_progress_bar_totals(tmpdir): assert 0 == bar.total_test_batches assert bar.test_progress_bar is None + trainer.validate(model) + + assert bar.val_progress_bar.total == m + assert bar.val_progress_bar.n == m + assert bar.val_batch_idx == m + trainer.test(model) # check test progress bar total @@ -157,6 +162,13 @@ def test_progress_bar_fast_dev_run(tmpdir): assert 2 == progress_bar.main_progress_bar.total assert 2 == progress_bar.main_progress_bar.n + trainer.validate(model) + + # the validation progress bar should display 1 batch + assert 1 == progress_bar.val_batch_idx + assert 1 == progress_bar.val_progress_bar.total + assert 1 == progress_bar.val_progress_bar.n + trainer.test(model) # the test progress bar should display 1 batch @@ -214,8 +226,16 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal trainer.fit(model) assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches assert progress_bar.val_batches_seen == 3 * progress_bar.total_val_batches + trainer.num_sanity_val_steps + assert progress_bar.test_batches_seen == 0 + + trainer.validate(model) + assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches + assert progress_bar.val_batches_seen == 4 * progress_bar.total_val_batches + trainer.num_sanity_val_steps + assert progress_bar.test_batches_seen == 0 trainer.test(model) + assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches + assert progress_bar.val_batches_seen == 4 * progress_bar.total_val_batches + trainer.num_sanity_val_steps assert progress_bar.test_batches_seen == progress_bar.total_test_batches diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 845b05aed9b38..d96fe3dcab33d 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -865,6 +865,9 @@ def assert_checkpoint_log_dir(idx): assert_checkpoint_log_dir(0) assert_checkpoint_content(ckpt_dir) + trainer.validate(model) + assert trainer.current_epoch == epochs - 1 + trainer.test(model) assert trainer.current_epoch == epochs - 1 @@ -878,17 +881,24 @@ def assert_checkpoint_log_dir(idx): assert_trainer_init(trainer) model = ExtendedBoringModel() + trainer.test(model) assert not trainer.checkpoint_connector.has_trained # resume_from_checkpoint is resumed when calling `.fit` assert trainer.global_step == 0 assert trainer.current_epoch == 0 + trainer.fit(model) assert not trainer.checkpoint_connector.has_trained assert trainer.global_step == epochs * limit_train_batches assert trainer.current_epoch == epochs assert_checkpoint_log_dir(idx) + trainer.validate(model) + assert not trainer.checkpoint_connector.has_trained + assert trainer.global_step == epochs * limit_train_batches + assert trainer.current_epoch == epochs + def test_configure_model_checkpoint(tmpdir): """ Test all valid and invalid ways a checkpoint callback can be passed to the Trainer. """ diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index ab51a87329e2f..2118fec6c207b 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -19,7 +19,6 @@ import pytest import torch -import torch.nn.functional as F from pytorch_lightning import LightningDataModule, Trainer from pytorch_lightning.callbacks import ModelCheckpoint @@ -29,7 +28,7 @@ from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf from tests.helpers.simple_models import ClassificationModel -from tests.helpers.utils import reset_seed, set_random_master_port +from tests.helpers.utils import reset_seed @mock.patch("pytorch_lightning.trainer.trainer.Trainer.node_rank", new_callable=PropertyMock) @@ -297,20 +296,6 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: assert checkpoint[dm.__class__.__name__] == dm.__class__.__name__ -def test_test_loop_only(tmpdir): - reset_seed() - - dm = BoringDataModule() - model = BoringModel() - - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - weights_summary=None, - ) - trainer.test(model, datamodule=dm) - - def test_full_loop(tmpdir): reset_seed() @@ -327,109 +312,17 @@ def test_full_loop(tmpdir): # fit model result = trainer.fit(model, dm) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + assert dm.trainer is not None assert result - # test - result = trainer.test(datamodule=dm) - assert result[0]['test_acc'] > 0.6 - - -def test_trainer_attached_to_dm(tmpdir): - reset_seed() - - dm = BoringDataModule() - model = BoringModel() - - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_train_batches=2, - limit_val_batches=2, - limit_test_batches=2, - weights_summary=None, - deterministic=True, - ) - - # fit model - trainer.fit(model, dm) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + # validate + result = trainer.validate(datamodule=dm) assert dm.trainer is not None + assert result[0]['val_acc'] > 0.7 # test result = trainer.test(datamodule=dm) - result = result[0] assert dm.trainer is not None - - -@RunIf(min_gpus=1) -def test_full_loop_single_gpu(tmpdir): - reset_seed() - - dm = ClassifDataModule() - model = ClassificationModel() - - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - weights_summary=None, - gpus=1, - deterministic=True, - ) - - # fit model - result = trainer.fit(model, dm) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" - assert result - - # test - result = trainer.test(datamodule=dm) - assert result[0]['test_acc'] > 0.6 - - -@RunIf(min_gpus=2) -def test_full_loop_dp(tmpdir): - set_random_master_port() - - class CustomClassificationModelDP(ClassificationModel): - - def _step(self, batch, batch_idx): - x, y = batch - logits = self(x) - return {'logits': logits, 'y': y} - - def training_step(self, batch, batch_idx): - out = self._step(batch, batch_idx) - loss = F.cross_entropy(out['logits'], out['y']) - return loss - - def validation_step(self, batch, batch_idx): - return self._step(batch, batch_idx) - - def test_step(self, batch, batch_idx): - return self._step(batch, batch_idx) - - def test_step_end(self, outputs): - self.log('test_acc', self.test_acc(outputs['logits'], outputs['y'])) - - dm = ClassifDataModule() - model = CustomClassificationModelDP() - - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - weights_summary=None, - accelerator='dp', - gpus=2, - deterministic=True, - ) - - # fit model - result = trainer.fit(model, datamodule=dm) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" - assert result - - # test - result = trainer.test(datamodule=dm) assert result[0]['test_acc'] > 0.6 diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 7c53925bd7cc4..0d1c7cf40a2bf 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -466,7 +466,23 @@ def teardown(self, stage=None): 'on_fit_end', 'teardown', ] + assert model.called == expected + + model = HookedModel() + trainer.validate(model, verbose=False) + expected = [ + 'on_validation_model_eval', + 'on_validation_start', + 'on_validation_epoch_start', + 'on_validation_batch_start', + 'on_validation_batch_end', + 'on_validation_epoch_end', + 'on_epoch_end', + 'on_validation_end', + 'on_validation_model_train', + 'teardown', + ] assert model.called == expected model = HookedModel() diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index b59563f70e4aa..a48f048160ee5 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -259,30 +259,20 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir): @RunIf(skip_windows=True, special=True, fairscale=True) -def test_ddp_sharded_plugin_test(tmpdir): +@pytest.mark.parametrize("trainer_kwargs", ( + {'num_processes': 2}, + pytest.param({'gpus': 2}, marks=RunIf(min_gpus=2)) +)) +def test_ddp_sharded_plugin_test_multigpu(tmpdir, trainer_kwargs): """ - Test to ensure we can use test without fit + Test to ensure we can use validate and test without fit """ model = BoringModel() trainer = Trainer( accelerator='ddp_sharded_spawn', - num_processes=2, - fast_dev_run=True, - ) - - trainer.test(model) - - -@RunIf(min_gpus=2, skip_windows=True, fairscale=True) -def test_ddp_sharded_plugin_test_multigpu(tmpdir): - """ - Test to ensure we can use test without fit - """ - model = BoringModel() - trainer = Trainer( - accelerator='ddp_sharded_spawn', - gpus=2, fast_dev_run=True, + **trainer_kwargs, ) + trainer.validate(model) trainer.test(model) diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index 01c23ed18fe65..34845c46b45eb 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -287,15 +287,22 @@ def test_configure_optimizers_with_frequency(tmpdir): assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" -def test_init_optimizers_during_testing(tmpdir): +@pytest.mark.parametrize("fn", ("validate", "test")) +def test_init_optimizers_during_evaluation(tmpdir, fn): """ - Test that optimizers is an empty list during testing. + Test that optimizers is an empty list during evaluation """ - model = EvalModelTemplate() - model.configure_optimizers = model.configure_optimizers__multiple_schedulers - - trainer = Trainer(default_root_dir=tmpdir, limit_test_batches=10) - trainer.test(model, ckpt_path=None) + class TestModel(BoringModel): + def configure_optimizers(self): + optimizer1 = torch.optim.Adam(self.parameters(), lr=0.1) + optimizer2 = torch.optim.Adam(self.parameters(), lr=0.1) + lr_scheduler1 = torch.optim.lr_scheduler.StepLR(optimizer1, step_size=1) + lr_scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=1) + return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2] + + trainer = Trainer(default_root_dir=tmpdir, limit_val_batches=10, limit_test_batches=10) + validate_or_test = getattr(trainer, fn) + validate_or_test(TestModel(), ckpt_path=None) assert len(trainer.lr_schedulers) == 0 assert len(trainer.optimizers) == 0 diff --git a/tests/trainer/test_config_validator.py b/tests/trainer/test_config_validator.py index 00ad020aa1b57..59e10480a485e 100644 --- a/tests/trainer/test_config_validator.py +++ b/tests/trainer/test_config_validator.py @@ -13,12 +13,9 @@ # limitations under the License. import pytest -import tests.helpers.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.base import EvalModelTemplate - -# TODO: add matching messages +from tests.helpers import BoringModel def test_wrong_train_setting(tmpdir): @@ -26,49 +23,44 @@ def test_wrong_train_setting(tmpdir): * Test that an error is thrown when no `train_dataloader()` is defined * Test that an error is thrown when no `training_step()` is defined """ - tutils.reset_seed() - hparams = EvalModelTemplate.get_default_hparams() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) - with pytest.raises(MisconfigurationException): - model = EvalModelTemplate(**hparams) + with pytest.raises(MisconfigurationException, match=r'No `train_dataloader\(\)` method defined.'): + model = BoringModel() model.train_dataloader = None trainer.fit(model) - with pytest.raises(MisconfigurationException): - model = EvalModelTemplate(**hparams) + with pytest.raises(MisconfigurationException, match=r'No `training_step\(\)` method defined.'): + model = BoringModel() model.training_step = None trainer.fit(model) def test_wrong_configure_optimizers(tmpdir): """ Test that an error is thrown when no `configure_optimizers()` is defined """ - tutils.reset_seed() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) - with pytest.raises(MisconfigurationException): - model = EvalModelTemplate() + with pytest.raises(MisconfigurationException, match=r'No `configure_optimizers\(\)` method defined.'): + model = BoringModel() model.configure_optimizers = None trainer.fit(model) -def test_val_loop_config(tmpdir): +def test_fit_val_loop_config(tmpdir): """" When either val loop or val data are missing raise warning """ - tutils.reset_seed() - hparams = EvalModelTemplate.get_default_hparams() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) # no val data has val loop - with pytest.warns(UserWarning): - model = EvalModelTemplate(**hparams) + with pytest.warns(UserWarning, match=r'you passed in a val_dataloader but have no validation_step'): + model = BoringModel() model.validation_step = None trainer.fit(model) # has val loop but no val data - with pytest.warns(UserWarning): - model = EvalModelTemplate(**hparams) + with pytest.warns(UserWarning, match=r'you defined a validation_step but have no val_dataloader'): + model = BoringModel() model.val_dataloader = None trainer.fit(model) @@ -77,17 +69,35 @@ def test_test_loop_config(tmpdir): """" When either test loop or test data are missing """ - hparams = EvalModelTemplate.get_default_hparams() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) # has test loop but no test data - with pytest.warns(UserWarning): - model = EvalModelTemplate(**hparams) + with pytest.warns(UserWarning, match=r'you defined a test_step but have no test_dataloader'): + model = BoringModel() model.test_dataloader = None trainer.test(model) # has test data but no test loop - with pytest.warns(UserWarning): - model = EvalModelTemplate(**hparams) + with pytest.warns(UserWarning, match=r'you passed in a test_dataloader but have no test_step'): + model = BoringModel() model.test_step = None - trainer.test(model, test_dataloaders=model.dataloader(train=False)) + trainer.test(model) + + +def test_val_loop_config(tmpdir): + """" + When either validation loop or validation data are missing + """ + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + + # has val loop but no val data + with pytest.warns(UserWarning, match=r'you defined a validation_step but have no val_dataloader'): + model = BoringModel() + model.val_dataloader = None + trainer.validate(model) + + # has val data but no val loop + with pytest.warns(UserWarning, match=r'you passed in a val_dataloader but have no validation_step'): + model = BoringModel() + model.validation_step = None + trainer.validate(model) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 5530779b4f77d..e4aea38fb7f37 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -125,8 +125,7 @@ def test_multiple_val_dataloader(tmpdir): assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" # verify there are 2 val loaders - assert len(trainer.val_dataloaders) == 2, \ - 'Multiple val_dataloaders not initiated properly' + assert len(trainer.val_dataloaders) == 2, 'Multiple val_dataloaders not initiated properly' # make sure predictions are good for each val set for dataloader in trainer.val_dataloaders: @@ -134,18 +133,22 @@ def test_multiple_val_dataloader(tmpdir): @pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) -def test_multiple_test_dataloader(tmpdir, ckpt_path): - """Verify multiple test_dataloader.""" - - model_template = EvalModelTemplate() +def test_multiple_eval_dataloader(tmpdir, ckpt_path): + """Verify multiple evaluation dataloaders.""" class MultipleTestDataloaderModel(EvalModelTemplate): - def test_dataloader(self): return [self.dataloader(train=False), self.dataloader(train=False)] - def test_step(self, batch, batch_idx, *args, **kwargs): - return model_template.test_step__multiple_dataloaders(batch, batch_idx, *args, **kwargs) + def test_step(self, *args, **kwargs): + return super().test_step__multiple_dataloaders(*args, **kwargs) + + def val_dataloader(self): + return self.test_dataloader() + + def validation_step(self, *args, **kwargs): + output = self.test_step(*args, **kwargs) + return {k.replace("test_", "val_"): v for k, v in output.items()} model = MultipleTestDataloaderModel() @@ -159,18 +162,19 @@ def test_step(self, batch, batch_idx, *args, **kwargs): trainer.fit(model) if ckpt_path == 'specific': ckpt_path = trainer.checkpoint_callback.best_model_path - trainer.test(ckpt_path=ckpt_path) - # verify there are 2 test loaders - assert len(trainer.test_dataloaders) == 2, 'Multiple test_dataloaders not initiated properly' + trainer.validate(ckpt_path=ckpt_path, verbose=False) + # verify there are 2 loaders + assert len(trainer.val_dataloaders) == 2 + # make sure predictions are good for each dl + for dataloader in trainer.val_dataloaders: + tpipes.run_prediction_eval_model_template(trainer.model, dataloader) - # make sure predictions are good for each test set + trainer.test(ckpt_path=ckpt_path, verbose=False) + assert len(trainer.test_dataloaders) == 2 for dataloader in trainer.test_dataloaders: tpipes.run_prediction_eval_model_template(trainer.model, dataloader) - # run the test method - trainer.test(ckpt_path=ckpt_path) - def test_train_dataloader_passed_to_fit(tmpdir): """Verify that train dataloader can be passed to fit """ @@ -189,90 +193,45 @@ def test_train_dataloader_passed_to_fit(tmpdir): assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" -def test_train_val_dataloaders_passed_to_fit(tmpdir): - """ Verify that train & val dataloader can be passed to fit """ - - # train, val passed to fit - model = EvalModelTemplate() - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_val_batches=0.1, - limit_train_batches=0.2, - ) - fit_options = dict(train_dataloader=model.dataloader(train=True), val_dataloaders=model.dataloader(train=False)) - - trainer.fit(model, **fit_options) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" - assert len(trainer.val_dataloaders) == 1, \ - f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' - - @pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) -def test_all_dataloaders_passed_to_fit(tmpdir, ckpt_path): - """Verify train, val & test dataloader(s) can be passed to fit and test method""" +@pytest.mark.parametrize("n", (1, 2)) +def test_dataloaders_passed_to_fn(tmpdir, ckpt_path, n): + """Verify that dataloaders can be passed.""" model = EvalModelTemplate() + if n == 1: + dataloaders = model.dataloader(train=False) + else: + dataloaders = [model.dataloader(train=False)] * 2 + model.validation_step = model.validation_step__multiple_dataloaders + model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders + model.test_step = model.test_step__multiple_dataloaders - # train, val and test passed to fit + # train, multiple val and multiple test passed to fit trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2, ) - fit_options = dict(train_dataloader=model.dataloader(train=True), val_dataloaders=model.dataloader(train=False)) - trainer.fit(model, **fit_options) - - if ckpt_path == 'specific': - ckpt_path = trainer.checkpoint_callback.best_model_path - test_options = dict(test_dataloaders=model.dataloader(train=False), ckpt_path=ckpt_path) - trainer.test(**test_options) + trainer.fit(model, train_dataloader=model.dataloader(train=True), val_dataloaders=dataloaders) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" - assert len(trainer.val_dataloaders) == 1, \ - f'val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' - assert len(trainer.test_dataloaders) == 1, \ - f'test_dataloaders` not initiated properly, got {trainer.test_dataloaders}' - + assert len(trainer.val_dataloaders) == n -@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) -def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path): - """Verify that multiple val & test dataloaders can be passed to fit.""" - - model = EvalModelTemplate() - model.validation_step = model.validation_step__multiple_dataloaders - model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders - model.test_step = model.test_step__multiple_dataloaders - - # train, multiple val and multiple test passed to fit - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_val_batches=0.1, - limit_train_batches=0.2, - ) - fit_options = dict( - train_dataloader=model.dataloader(train=True), - val_dataloaders=[model.dataloader(train=False), model.dataloader(train=False)] - ) - trainer.fit(model, **fit_options) if ckpt_path == 'specific': ckpt_path = trainer.checkpoint_callback.best_model_path - test_options = dict( - test_dataloaders=[model.dataloader(train=False), model.dataloader(train=False)], ckpt_path=ckpt_path - ) - trainer.test(**test_options) - assert len(trainer.val_dataloaders) == 2, \ - f'Multiple `val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' - assert len(trainer.test_dataloaders) == 2, \ - f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}' + trainer.test(test_dataloaders=dataloaders, ckpt_path=ckpt_path) + trainer.validate(val_dataloaders=dataloaders, ckpt_path=ckpt_path) + + assert len(trainer.val_dataloaders) == n + assert len(trainer.test_dataloaders) == n @pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ - pytest.param(0.0, 0.0, 0.0), - pytest.param(1.0, 1.0, 1.0), + (0.0, 0.0, 0.0), + (1.0, 1.0, 1.0), ]) def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit in percent""" @@ -299,8 +258,8 @@ def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, @pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ - pytest.param(0, 0, 0), - pytest.param(10, 10, 10), + (0, 0, 0), + (10, 10, 10), ]) def test_inf_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number""" @@ -327,10 +286,10 @@ def test_inf_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, lim @pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ - pytest.param(0.0, 0.0, 0.0), - pytest.param(0, 0, 0.5), - pytest.param(1.0, 1.0, 1.0), - pytest.param(0.2, 0.4, 0.4), + (0.0, 0.0, 0.0), + (0, 0, 0.5), + (1.0, 1.0, 1.0), + (0.2, 0.4, 0.4), ]) def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): """Verify num_batches for train, val & test dataloaders passed with batch limit in percent""" @@ -362,9 +321,9 @@ def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, lim @pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ - pytest.param(0, 0, 0), - pytest.param(1, 2, 3), - pytest.param(1, 2, 1e50), + (0, 0, 0), + (1, 2, 3), + (1, 2, 1e50), ]) @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): @@ -445,10 +404,10 @@ def test_dataloaders_with_fast_dev_run(tmpdir, fast_dev_run): if fast_dev_run == 'temp': with pytest.raises(MisconfigurationException, match='either a bool or an int'): - trainer = Trainer(**trainer_options) + Trainer(**trainer_options) elif fast_dev_run == -1: with pytest.raises(MisconfigurationException, match='should be >= 0'): - trainer = Trainer(**trainer_options) + Trainer(**trainer_options) else: trainer = Trainer(**trainer_options) @@ -1191,12 +1150,6 @@ def test_replace_sampler_with_multiprocessing_context(tmpdir): train = RandomDataset(32, 64) context = 'spawn' train = DataLoader(train, batch_size=32, num_workers=2, multiprocessing_context=context, shuffle=True) - - class ExtendedBoringModel(BoringModel): - - def train_dataloader(self): - return train - trainer = Trainer( max_epochs=1, progress_bar_refresh_rate=20, diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index f1a3687b43508..5b06879b1f6d1 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -599,44 +599,57 @@ def test_benchmark_option(tmpdir): assert torch.backends.cudnn.benchmark -@pytest.mark.parametrize("ckpt_path", [None, "best", "specific"]) -@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2]) -def test_tested_checkpoint_path(tmpdir, ckpt_path, save_top_k): - hparams = EvalModelTemplate.get_default_hparams() +@pytest.mark.parametrize("ckpt_path", (None, "best", "specific")) +@pytest.mark.parametrize("save_top_k", (-1, 0, 1, 2)) +@pytest.mark.parametrize("fn", ("validate", "test")) +def test_tested_checkpoint_path(tmpdir, ckpt_path, save_top_k, fn): + class TestModel(BoringModel): + def validation_step(self, batch, batch_idx): + self.log("foo", -batch_idx) + return super().validation_step(batch, batch_idx) - model = EvalModelTemplate(**hparams) + model = TestModel() trainer = Trainer( max_epochs=2, progress_bar_refresh_rate=0, default_root_dir=tmpdir, - callbacks=[ModelCheckpoint(monitor="early_stop_on", save_top_k=save_top_k)], + callbacks=[ModelCheckpoint(monitor="foo", save_top_k=save_top_k)], ) trainer.fit(model) + + test_or_validate = getattr(trainer, fn) if ckpt_path == "best": # ckpt_path is 'best', meaning we load the best weights if save_top_k == 0: with pytest.raises(MisconfigurationException, match=".*is not configured to save the best.*"): - trainer.test(ckpt_path=ckpt_path) + test_or_validate(ckpt_path=ckpt_path) else: - trainer.test(ckpt_path=ckpt_path) - assert trainer.tested_ckpt_path == trainer.checkpoint_callback.best_model_path + test_or_validate(ckpt_path=ckpt_path) + if fn == "test": + assert trainer.tested_ckpt_path == trainer.checkpoint_callback.best_model_path + else: + assert trainer.validated_ckpt_path == trainer.checkpoint_callback.best_model_path elif ckpt_path is None: # ckpt_path is None, meaning we don't load any checkpoints and # use the weights from the end of training - trainer.test(ckpt_path=ckpt_path) + test_or_validate(ckpt_path=ckpt_path) assert trainer.tested_ckpt_path is None + assert trainer.validated_ckpt_path is None else: # specific checkpoint, pick one from saved ones if save_top_k == 0: with pytest.raises(FileNotFoundError): - trainer.test(ckpt_path="random.ckpt") + test_or_validate(ckpt_path="random.ckpt") else: ckpt_path = str( list((Path(tmpdir) / f"lightning_logs/version_{trainer.logger.version}/checkpoints").iterdir() )[0].absolute() ) - trainer.test(ckpt_path=ckpt_path) - assert trainer.tested_ckpt_path == ckpt_path + test_or_validate(ckpt_path=ckpt_path) + if fn == "test": + assert trainer.tested_ckpt_path == ckpt_path + else: + assert trainer.validated_ckpt_path == ckpt_path def test_disabled_training(tmpdir): @@ -1292,10 +1305,11 @@ def test_trainer_pickle(tmpdir): cloudpickle.dumps(trainer) -def test_trainer_setup_call(tmpdir): - """Test setup call with fit and test call.""" +@pytest.mark.parametrize("stage", ("fit", "validate", "test")) +def test_trainer_setup_call(tmpdir, stage): + """Test setup call gets the correct stage""" - class CurrentModel(EvalModelTemplate): + class CurrentModel(BoringModel): def setup(self, stage): self.stage = stage @@ -1311,21 +1325,23 @@ def setup(self, model, stage): # fit model trainer = TrainerSubclass(default_root_dir=tmpdir, max_epochs=1, checkpoint_callback=False) - trainer.fit(model) - assert trainer.stage == "fit" - assert trainer.lightning_module.stage == "fit" + if stage == "fit": + trainer.fit(model) + elif stage == "validate": + trainer.validate(model, ckpt_path=None) + else: + trainer.test(model, ckpt_path=None) - trainer.test(ckpt_path=None) - assert trainer.stage == "test" - assert trainer.lightning_module.stage == "test" + assert trainer.stage == stage + assert trainer.lightning_module.stage == stage @pytest.mark.parametrize( "train_batches, max_steps, log_interval", [ - pytest.param(10, 10, 1), - pytest.param(3, 10, 1), - pytest.param(3, 10, 5), + (10, 10, 1), + (3, 10, 1), + (3, 10, 5), ], ) @patch("pytorch_lightning.loggers.tensorboard.TensorBoardLogger.log_metrics") @@ -1398,7 +1414,7 @@ def predict(tmpdir, accelerator, gpus, num_processes, model=None, plugins=None, dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))] model = model or BoringModel() - datamodule = TestLightningDataModule(dataloaders) + dm = TestLightningDataModule(dataloaders) trainer = Trainer( default_root_dir=tmpdir, @@ -1411,7 +1427,7 @@ def predict(tmpdir, accelerator, gpus, num_processes, model=None, plugins=None, plugins=plugins, ) if datamodule: - results = trainer.predict(model, datamodule=datamodule) + results = trainer.predict(model, datamodule=dm) else: results = trainer.predict(model, dataloaders=dataloaders) diff --git a/tests/trainer/test_trainer_test_loop.py b/tests/trainer/test_trainer_test_loop.py deleted file mode 100644 index 7e2a9299fc8a0..0000000000000 --- a/tests/trainer/test_trainer_test_loop.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch - -import pytorch_lightning as pl -import tests.helpers.utils as tutils -from tests.base import EvalModelTemplate -from tests.helpers.runif import RunIf - - -@RunIf(min_gpus=2) -def test_single_gpu_test(tmpdir): - tutils.set_random_master_port() - - model = EvalModelTemplate() - trainer = pl.Trainer( - default_root_dir=tmpdir, - max_epochs=2, - limit_train_batches=10, - limit_val_batches=10, - gpus=[0], - ) - trainer.fit(model) - assert 'ckpt' in trainer.checkpoint_callback.best_model_path - results = trainer.test() - assert 'test_acc' in results[0] - - old_weights = model.c_d1.weight.clone().detach().cpu() - - results = trainer.test(model) - assert 'test_acc' in results[0] - - # make sure weights didn't change - new_weights = model.c_d1.weight.clone().detach().cpu() - - assert torch.all(torch.eq(old_weights, new_weights)) - - -@RunIf(min_gpus=2) -def test_ddp_spawn_test(tmpdir): - tutils.set_random_master_port() - - model = EvalModelTemplate() - trainer = pl.Trainer( - default_root_dir=tmpdir, - max_epochs=2, - limit_train_batches=10, - limit_val_batches=10, - gpus=[0, 1], - accelerator='ddp_spawn', - ) - trainer.fit(model) - assert 'ckpt' in trainer.checkpoint_callback.best_model_path - results = trainer.test() - assert 'test_acc' in results[0] - - old_weights = model.c_d1.weight.clone().detach().cpu() - - results = trainer.test(model) - assert 'test_acc' in results[0] - - # make sure weights didn't change - new_weights = model.c_d1.weight.clone().detach().cpu() - - assert torch.all(torch.eq(old_weights, new_weights)) From 2ecda5df52b2bd14e3125160d1d3a7837fd3c444 Mon Sep 17 00:00:00 2001 From: Max Frei <36265931+maxfrei750@users.noreply.github.com> Date: Thu, 11 Mar 2021 09:40:23 +0100 Subject: [PATCH 3/8] Allow user to disable the automatic formatting of checkpoint file names. (#6277) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * cleaning SWA (#6259) * rename * if * test * chlog * Remove opt from manual_backward in docs (#6267) * switch agents pool (#6270) * Allow user to disable the automatic formatting of checkpoint file names. * Added changelog entry. * Made flake8 happy. * Applied review suggestion: quotes for special characters in docstring Co-authored-by: Carlos Mocholí * Fixed example in docstring. * Fixed syntax error in docstring. Co-authored-by: Jirka Borovec Co-authored-by: Akihiro Nitta Co-authored-by: thomas chaton Co-authored-by: Carlos Mocholí --- CHANGELOG.md | 3 ++ .../callbacks/model_checkpoint.py | 31 +++++++++++++++++-- tests/checkpointing/test_model_checkpoint.py | 9 ++++++ 3 files changed, 41 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7cdc8fd5acfb8..08405cb89b392 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915)) +- Added `auto_insert_metric_name` parameter to `ModelCheckpoint` ([#6277](https://github.com/PyTorchLightning/pytorch-lightning/pull/6277)) + + - Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274)) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index f457e9de7d0fa..f05a10a41996b 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -131,6 +131,16 @@ class ModelCheckpoint(Callback): ... filename='sample-mnist-{epoch:02d}-{val_loss:.2f}' ... ) + # save epoch and val_loss in name, but specify the formatting yourself (e.g. to avoid problems with Tensorboard + # or Neptune, due to the presence of characters like '=' or '/') + # saves a file like: my/path/sample-mnist-epoch02-val_loss0.32.ckpt + >>> checkpoint_callback = ModelCheckpoint( + ... monitor='val/loss', + ... dirpath='my/path/', + ... filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}', + ... auto_insert_metric_name=False + ... ) + # retrieve the best checkpoint after training checkpoint_callback = ModelCheckpoint(dirpath='my/path/') trainer = Trainer(callbacks=[checkpoint_callback]) @@ -156,6 +166,7 @@ def __init__( save_weights_only: bool = False, mode: str = "min", period: int = 1, + auto_insert_metric_name: bool = True ): super().__init__() self.monitor = monitor @@ -164,6 +175,7 @@ def __init__( self.save_top_k = save_top_k self.save_weights_only = save_weights_only self.period = period + self.auto_insert_metric_name = auto_insert_metric_name self._last_global_step_saved = -1 self.current_score = None self.best_k_models = {} @@ -356,6 +368,7 @@ def _format_checkpoint_name( step: int, metrics: Dict[str, Any], prefix: str = "", + auto_insert_metric_name: bool = True ) -> str: if not filename: # filename is not set, use default name @@ -367,7 +380,10 @@ def _format_checkpoint_name( metrics.update({"epoch": epoch, 'step': step}) for group in groups: name = group[1:] - filename = filename.replace(group, name + "={" + name) + + if auto_insert_metric_name: + filename = filename.replace(group, name + "={" + name) + if name not in metrics: metrics[name] = 0 filename = filename.format(**metrics) @@ -392,6 +408,11 @@ def format_checkpoint_name(self, epoch: int, step: int, metrics: Dict[str, Any], >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}') >>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics=dict(val_loss=0.123456))) 'epoch=2-val_loss=0.12.ckpt' + >>> ckpt = ModelCheckpoint(dirpath=tmpdir, + ... filename='epoch={epoch}-validation_loss={val_loss:.2f}', + ... auto_insert_metric_name=False) + >>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics=dict(val_loss=0.123456))) + 'epoch=2-validation_loss=0.12.ckpt' >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}') >>> os.path.basename(ckpt.format_checkpoint_name(0, 4, metrics={})) 'missing=0.ckpt' @@ -400,7 +421,13 @@ def format_checkpoint_name(self, epoch: int, step: int, metrics: Dict[str, Any], 'step=0.ckpt' """ - filename = self._format_checkpoint_name(self.filename, epoch, step, metrics) + filename = self._format_checkpoint_name( + self.filename, + epoch, + step, + metrics, + auto_insert_metric_name=self.auto_insert_metric_name) + if ver is not None: filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}")) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index d96fe3dcab33d..4a8088070f041 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -432,6 +432,15 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir): ckpt_name = ckpt.format_checkpoint_name(4, 3, {'val/loss': 0.03}) assert ckpt_name == 'epoch=4_val/loss=0.03000.ckpt' + # auto_insert_metric_name=False + ckpt_name = ModelCheckpoint._format_checkpoint_name( + 'epoch={epoch:03d}-val_acc={val/acc}', + 3, + 2, + {'val/acc': 0.03}, + auto_insert_metric_name=False) + assert ckpt_name == 'epoch=003-val_acc=0.03' + class ModelCheckpointExtensionTest(ModelCheckpoint): FILE_EXTENSION = '.tpkc' From 079fe9bc0908fffdd55a08f7321a199bceaef6f8 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Thu, 11 Mar 2021 16:49:48 +0530 Subject: [PATCH 4/8] Hotfix for torchvision (#6476) --- pl_examples/basic_examples/autoencoder.py | 5 +++-- pl_examples/basic_examples/backbone_image_classifier.py | 5 +++-- pl_examples/basic_examples/dali_image_classifier.py | 5 +++-- pl_examples/basic_examples/mnist_datamodule.py | 3 ++- pl_examples/domain_templates/generative_adversarial_net.py | 5 +++-- tests/helpers/datasets.py | 1 + 6 files changed, 15 insertions(+), 9 deletions(-) diff --git a/pl_examples/basic_examples/autoencoder.py b/pl_examples/basic_examples/autoencoder.py index b3188a21b7f04..a2010a89f4461 100644 --- a/pl_examples/basic_examples/autoencoder.py +++ b/pl_examples/basic_examples/autoencoder.py @@ -22,9 +22,10 @@ import pytorch_lightning as pl from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo -if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: +if _TORCHVISION_AVAILABLE: from torchvision import transforms - from torchvision.datasets.mnist import MNIST +if _TORCHVISION_MNIST_AVAILABLE: + from torchvision.datasets import MNIST else: from tests.helpers.datasets import MNIST diff --git a/pl_examples/basic_examples/backbone_image_classifier.py b/pl_examples/basic_examples/backbone_image_classifier.py index 01a5dca0de3c7..3546bee9ad129 100644 --- a/pl_examples/basic_examples/backbone_image_classifier.py +++ b/pl_examples/basic_examples/backbone_image_classifier.py @@ -21,9 +21,10 @@ import pytorch_lightning as pl from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo -if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: +if _TORCHVISION_AVAILABLE: from torchvision import transforms - from torchvision.datasets.mnist import MNIST +if _TORCHVISION_MNIST_AVAILABLE: + from torchvision.datasets import MNIST else: from tests.helpers.datasets import MNIST diff --git a/pl_examples/basic_examples/dali_image_classifier.py b/pl_examples/basic_examples/dali_image_classifier.py index b4bf1407a9b26..da5b1e4fd9e9c 100644 --- a/pl_examples/basic_examples/dali_image_classifier.py +++ b/pl_examples/basic_examples/dali_image_classifier.py @@ -31,9 +31,10 @@ cli_lightning_logo, ) -if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: +if _TORCHVISION_AVAILABLE: from torchvision import transforms - from torchvision.datasets.mnist import MNIST +if _TORCHVISION_MNIST_AVAILABLE: + from torchvision.datasets import MNIST else: from tests.helpers.datasets import MNIST diff --git a/pl_examples/basic_examples/mnist_datamodule.py b/pl_examples/basic_examples/mnist_datamodule.py index a50f67cdab301..a6d59c64d9aa0 100644 --- a/pl_examples/basic_examples/mnist_datamodule.py +++ b/pl_examples/basic_examples/mnist_datamodule.py @@ -20,8 +20,9 @@ from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE from pytorch_lightning import LightningDataModule -if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: +if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib +if _TORCHVISION_MNIST_AVAILABLE: from torchvision.datasets import MNIST else: from tests.helpers.datasets import MNIST diff --git a/pl_examples/domain_templates/generative_adversarial_net.py b/pl_examples/domain_templates/generative_adversarial_net.py index 285fba8b93f1b..e65ede17dac7a 100644 --- a/pl_examples/domain_templates/generative_adversarial_net.py +++ b/pl_examples/domain_templates/generative_adversarial_net.py @@ -32,9 +32,10 @@ from pytorch_lightning.core import LightningDataModule, LightningModule from pytorch_lightning.trainer import Trainer -if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: +if _TORCHVISION_AVAILABLE: import torchvision - import torchvision.transforms as transforms + from torchvision import transforms +if _TORCHVISION_MNIST_AVAILABLE: from torchvision.datasets import MNIST else: from tests.helpers.datasets import MNIST diff --git a/tests/helpers/datasets.py b/tests/helpers/datasets.py index 5af3fbfbc4a11..e7bdad0f1538c 100644 --- a/tests/helpers/datasets.py +++ b/tests/helpers/datasets.py @@ -69,6 +69,7 @@ def __init__( train: bool = True, normalize: tuple = (0.1307, 0.3081), download: bool = True, + **kwargs, ): super().__init__() self.root = root From afe0ededa3a2f766d56de4cfc0fc0180aeff9ce8 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 11 Mar 2021 16:45:26 +0100 Subject: [PATCH 5/8] cover subproc coverage (#6477) --- .github/workflows/ci_test-conda.yml | 22 ++++++++++++++++++++-- .github/workflows/ci_test-full.yml | 6 +++--- requirements/test.txt | 7 ++++--- 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/.github/workflows/ci_test-conda.yml b/.github/workflows/ci_test-conda.yml index 419580b71cd10..812d06f310812 100644 --- a/.github/workflows/ci_test-conda.yml +++ b/.github/workflows/ci_test-conda.yml @@ -44,12 +44,30 @@ jobs: - name: Tests run: | # NOTE: run coverage on tests does not propagare faler status for Win, https://github.com/nedbat/coveragepy/issues/1003 - python -m pytest pytorch_lightning tests -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml + python -m pytest pytorch_lightning tests --cov=pytorch_lightning -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-torch${{ matrix.pytorch-version }}.xml shell: bash -l {0} - - name: Upload pytest test results + - name: Upload pytest results uses: actions/upload-artifact@v2 with: name: pytest-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }} path: junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml if: failure() + + - name: Statistics + if: success() + run: | + coverage report + coverage xml + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v1 + if: always() + # see: https://github.com/actions/toolkit/issues/399 + continue-on-error: true + with: + token: ${{ secrets.CODECOV_TOKEN }} + file: coverage.xml + flags: cpu,pytest,torch${{ matrix.pytorch-version }} + name: CPU-coverage + fail_ci_if_error: false diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index dd29777d9940c..3d3f7d11570a4 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -138,13 +138,13 @@ jobs: - name: Tests run: | # NOTE: do not include coverage report here, see: https://github.com/nedbat/coveragepy/issues/1003 - coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml + python -m pytest pytorch_lightning tests --cov=pytorch_lightning -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-py${{ matrix.python-version }}-${{ matrix.requires }}.xml - name: Examples run: | python -m pytest pl_examples -v --durations=10 - - name: Upload pytest test results + - name: Upload pytest results uses: actions/upload-artifact@v2 with: name: pytest-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }} @@ -165,6 +165,6 @@ jobs: with: token: ${{ secrets.CODECOV_TOKEN }} file: coverage.xml - flags: cpu,pytest + flags: cpu,pytest,python${{ matrix.python-version }} name: CPU-coverage fail_ci_if_error: false diff --git a/requirements/test.txt b/requirements/test.txt index 2d47143ca58d4..60c861cea9c50 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,7 +1,8 @@ -coverage>=5.0 +coverage>=5.2 codecov>=2.1 -pytest>=5.0 -# pytest-cov +pytest>=6.0 +pytest-cov>2.10 +pytest-xdist flake8>=3.6 check-manifest twine==3.2 From e886d55ac1be6c34bd8ea56dba29029153d8eb5b Mon Sep 17 00:00:00 2001 From: Eric Cousineau Date: Thu, 11 Mar 2021 10:50:49 -0500 Subject: [PATCH 6/8] argparse: Add use_argument_group=True (#6088) * argparse: Add inplace option Replicate in GAN model * datamodule: Deduplicate logic w/ argparser utilities * Update pl_examples/domain_templates/generative_adversarial_net.py Co-authored-by: Jirka Borovec * Apply suggestions from code review Co-authored-by: Akihiro Nitta * Keep docstrings * Correct name * Whitespace * Consistency * fix weird type stuff * try alt - use_argument_group * fix syntax + lint * fix ci errs * fix ci * change examples... still failing w/ "unrecognized arguments: --batch_size" * address review * mnist_datamodule: add some docstrings * argparse: check cls or cls.__init__ for param didn't capture issue, but meh * fix lint * fix no-doc edge case * address review Co-authored-by: Jirka Borovec Co-authored-by: Akihiro Nitta Co-authored-by: Carlos Mocholi --- docs/source/common/hyperparameters.rst | 12 +- docs/source/common/trainer.rst | 5 +- .../04-transformers-text-classification.ipynb | 3 +- .../backbone_image_classifier.py | 4 +- .../basic_examples/dali_image_classifier.py | 4 +- .../basic_examples/mnist_datamodule.py | 2 + .../basic_examples/simple_image_classifier.py | 4 +- .../computer_vision_fine_tuning.py | 8 +- .../generative_adversarial_net.py | 12 +- pl_examples/domain_templates/imagenet.py | 4 +- .../domain_templates/reinforce_learn_Qnet.py | 4 +- .../domain_templates/reinforce_learn_ppo.py | 4 +- .../domain_templates/semantic_segmentation.py | 4 +- pytorch_lightning/core/datamodule.py | 83 +----------- pytorch_lightning/trainer/properties.py | 4 +- pytorch_lightning/utilities/argparse.py | 84 +++++++++--- tests/trainer/test_trainer_cli.py | 11 +- tests/utilities/test_argparse_utils.py | 122 +++++++++++++++++- 18 files changed, 243 insertions(+), 131 deletions(-) diff --git a/docs/source/common/hyperparameters.rst b/docs/source/common/hyperparameters.rst index 5240a4690e388..83398c1d63388 100644 --- a/docs/source/common/hyperparameters.rst +++ b/docs/source/common/hyperparameters.rst @@ -53,10 +53,10 @@ a module (i.e.: if your project has a model that trains on Imagenet and another @staticmethod def add_model_specific_args(parent_parser): - parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser = parent_parser.add_argument_group("LitModel") parser.add_argument('--encoder_layers', type=int, default=12) parser.add_argument('--data_path', type=str, default='/some/path') - return parser + return parent_parser Now in your main trainer file, add the ``Trainer`` args, the program args, and add the model args @@ -226,9 +226,9 @@ polluting the ``main.py`` file, the ``LightningModule`` lets you define argument @staticmethod def add_model_specific_args(parent_parser): - parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser = parent_parser.add_argument_group("LitMNIST") parser.add_argument('--layer_1_dim', type=int, default=128) - return parser + return parent_parser .. testcode:: @@ -240,9 +240,9 @@ polluting the ``main.py`` file, the ``LightningModule`` lets you define argument @staticmethod def add_model_specific_args(parent_parser): - parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser = parent_parser.add_argument_group("GoodGAN") parser.add_argument('--encoder_layers', type=int, default=12) - return parser + return parent_parser Now we can allow each model to inject the arguments it needs in the ``main.py`` diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index 6edf896ada01c..10c7c5ad59bfd 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -128,10 +128,7 @@ So you can run it like so: if __name__ == '__main__': parser = ArgumentParser() - parser = Trainer.add_argparse_args( - # group the Trainer arguments together - parser.add_argument_group(title="pl.Trainer args") - ) + parser = Trainer.add_argparse_args() args = parser.parse_args() main(args) diff --git a/notebooks/04-transformers-text-classification.ipynb b/notebooks/04-transformers-text-classification.ipynb index d0a150b7a7896..957255969f608 100644 --- a/notebooks/04-transformers-text-classification.ipynb +++ b/notebooks/04-transformers-text-classification.ipynb @@ -368,12 +368,13 @@ "\n", " @staticmethod\n", " def add_model_specific_args(parent_parser):\n", + " parser = parent_parser.add_argument_group(\"GLUETransformer\")", " parser = ArgumentParser(parents=[parent_parser], add_help=False)\n", " parser.add_argument(\"--learning_rate\", default=2e-5, type=float)\n", " parser.add_argument(\"--adam_epsilon\", default=1e-8, type=float)\n", " parser.add_argument(\"--warmup_steps\", default=0, type=int)\n", " parser.add_argument(\"--weight_decay\", default=0.0, type=float)\n", - " return parser" + " return parent_parser" ] }, { diff --git a/pl_examples/basic_examples/backbone_image_classifier.py b/pl_examples/basic_examples/backbone_image_classifier.py index 3546bee9ad129..1c78d264a8681 100644 --- a/pl_examples/basic_examples/backbone_image_classifier.py +++ b/pl_examples/basic_examples/backbone_image_classifier.py @@ -93,9 +93,9 @@ def configure_optimizers(self): @staticmethod def add_model_specific_args(parent_parser): - parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser = parent_parser.add_argument_group("LitClassifier") parser.add_argument('--learning_rate', type=float, default=0.0001) - return parser + return parent_parser def cli_main(): diff --git a/pl_examples/basic_examples/dali_image_classifier.py b/pl_examples/basic_examples/dali_image_classifier.py index da5b1e4fd9e9c..08bf64da252bf 100644 --- a/pl_examples/basic_examples/dali_image_classifier.py +++ b/pl_examples/basic_examples/dali_image_classifier.py @@ -175,10 +175,10 @@ def configure_optimizers(self): @staticmethod def add_model_specific_args(parent_parser): - parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser = parent_parser.add_argument_group("LitClassifier") parser.add_argument('--hidden_dim', type=int, default=128) parser.add_argument('--learning_rate', type=float, default=0.0001) - return parser + return parent_parser def cli_main(): diff --git a/pl_examples/basic_examples/mnist_datamodule.py b/pl_examples/basic_examples/mnist_datamodule.py index a6d59c64d9aa0..ea64f96c05d7d 100644 --- a/pl_examples/basic_examples/mnist_datamodule.py +++ b/pl_examples/basic_examples/mnist_datamodule.py @@ -55,6 +55,8 @@ def __init__( val_split: how many of the training images to use for the validation split num_workers: how many workers to use for loading data normalize: If true applies image normalize + seed: starting seed for RNG. + batch_size: desired batch size. """ super().__init__(*args, **kwargs) if num_workers and platform.system() == "Windows": diff --git a/pl_examples/basic_examples/simple_image_classifier.py b/pl_examples/basic_examples/simple_image_classifier.py index b0b03446f3628..dfa5869779ff3 100644 --- a/pl_examples/basic_examples/simple_image_classifier.py +++ b/pl_examples/basic_examples/simple_image_classifier.py @@ -68,10 +68,10 @@ def configure_optimizers(self): @staticmethod def add_model_specific_args(parent_parser): - parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser = parent_parser.add_argument_group("LitClassifier") parser.add_argument('--hidden_dim', type=int, default=128) parser.add_argument('--learning_rate', type=float, default=0.0001) - return parser + return parent_parser def cli_main(): diff --git a/pl_examples/domain_templates/computer_vision_fine_tuning.py b/pl_examples/domain_templates/computer_vision_fine_tuning.py index 823efaa53a5e5..88f4e66605741 100644 --- a/pl_examples/domain_templates/computer_vision_fine_tuning.py +++ b/pl_examples/domain_templates/computer_vision_fine_tuning.py @@ -147,14 +147,14 @@ def val_dataloader(self): @staticmethod def add_model_specific_args(parent_parser): - parser = argparse.ArgumentParser(parents=[parent_parser]) + parser = parent_parser.add_argument_group("CatDogImageDataModule") parser.add_argument( "--num-workers", default=0, type=int, metavar="W", help="number of CPU workers", dest="num_workers" ) parser.add_argument( "--batch-size", default=8, type=int, metavar="W", help="number of sample in a batch", dest="batch_size" ) - return parser + return parent_parser # --- Pytorch-lightning module --- @@ -268,7 +268,7 @@ def configure_optimizers(self): @staticmethod def add_model_specific_args(parent_parser): - parser = argparse.ArgumentParser(parents=[parent_parser]) + parser = parent_parser.add_argument_group("TransferLearningModel") parser.add_argument( "--backbone", default="resnet50", @@ -303,7 +303,7 @@ def add_model_specific_args(parent_parser): parser.add_argument( "--milestones", default=[2, 4], type=list, metavar="M", help="List of two epochs milestones" ) - return parser + return parent_parser def main(args: argparse.Namespace) -> None: diff --git a/pl_examples/domain_templates/generative_adversarial_net.py b/pl_examples/domain_templates/generative_adversarial_net.py index e65ede17dac7a..29fcf97de86db 100644 --- a/pl_examples/domain_templates/generative_adversarial_net.py +++ b/pl_examples/domain_templates/generative_adversarial_net.py @@ -135,14 +135,18 @@ def __init__( self.example_input_array = torch.zeros(2, self.hparams.latent_dim) @staticmethod - def add_argparse_args(parent_parser: ArgumentParser): - parser = ArgumentParser(parents=[parent_parser], add_help=False) + def add_argparse_args(parent_parser: ArgumentParser, *, use_argument_group=True): + if use_argument_group: + parser = parent_parser.add_argument_group("pl.GAN") + parser_out = parent_parser + else: + parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser_out = parser parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of second order momentum of gradient") parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space") - - return parser + return parser_out def forward(self, z): return self.generator(z) diff --git a/pl_examples/domain_templates/imagenet.py b/pl_examples/domain_templates/imagenet.py index 35357fe291e8a..1b42edfde463b 100644 --- a/pl_examples/domain_templates/imagenet.py +++ b/pl_examples/domain_templates/imagenet.py @@ -198,7 +198,7 @@ def substitute_val_keys(out): @staticmethod def add_model_specific_args(parent_parser): # pragma: no-cover - parser = ArgumentParser(parents=[parent_parser]) + parser = parent_parser.add_argument_group("ImageNetLightningModel") parser.add_argument( '-a', '--arch', @@ -233,7 +233,7 @@ def add_model_specific_args(parent_parser): # pragma: no-cover dest='weight_decay' ) parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') - return parser + return parent_parser def main(args: Namespace) -> None: diff --git a/pl_examples/domain_templates/reinforce_learn_Qnet.py b/pl_examples/domain_templates/reinforce_learn_Qnet.py index 887b7f1549f53..4d90faeb45bcf 100644 --- a/pl_examples/domain_templates/reinforce_learn_Qnet.py +++ b/pl_examples/domain_templates/reinforce_learn_Qnet.py @@ -390,7 +390,7 @@ def get_device(self, batch) -> str: @staticmethod def add_model_specific_args(parent_parser): # pragma: no-cover - parser = argparse.ArgumentParser(parents=[parent_parser]) + parser = parent_parser.add_argument_group("DQNLightning") parser.add_argument("--batch_size", type=int, default=16, help="size of the batches") parser.add_argument("--lr", type=float, default=1e-2, help="learning rate") parser.add_argument("--env", type=str, default="CartPole-v0", help="gym environment tag") @@ -407,7 +407,7 @@ def add_model_specific_args(parent_parser): # pragma: no-cover parser.add_argument("--eps_start", type=float, default=1.0, help="starting value of epsilon") parser.add_argument("--eps_end", type=float, default=0.01, help="final value of epsilon") parser.add_argument("--episode_length", type=int, default=200, help="max length of an episode") - return parser + return parent_parser def main(args) -> None: diff --git a/pl_examples/domain_templates/reinforce_learn_ppo.py b/pl_examples/domain_templates/reinforce_learn_ppo.py index 026784f900622..68ecc3fb22db0 100644 --- a/pl_examples/domain_templates/reinforce_learn_ppo.py +++ b/pl_examples/domain_templates/reinforce_learn_ppo.py @@ -446,7 +446,7 @@ def train_dataloader(self) -> DataLoader: @staticmethod def add_model_specific_args(parent_parser): # pragma: no-cover - parser = argparse.ArgumentParser(parents=[parent_parser]) + parser = parent_parser.add_argument_group("PPOLightning") parser.add_argument("--env", type=str, default="CartPole-v0") parser.add_argument("--gamma", type=float, default=0.99, help="discount factor") parser.add_argument("--lam", type=float, default=0.95, help="advantage discount factor") @@ -467,7 +467,7 @@ def add_model_specific_args(parent_parser): # pragma: no-cover "--clip_ratio", type=float, default=0.2, help="hyperparameter for clipping in the policy objective" ) - return parser + return parent_parser def main(args) -> None: diff --git a/pl_examples/domain_templates/semantic_segmentation.py b/pl_examples/domain_templates/semantic_segmentation.py index b5391c1f9b7ce..1ae10d40a4e53 100644 --- a/pl_examples/domain_templates/semantic_segmentation.py +++ b/pl_examples/domain_templates/semantic_segmentation.py @@ -245,7 +245,7 @@ def val_dataloader(self): @staticmethod def add_model_specific_args(parent_parser): # pragma: no-cover - parser = ArgumentParser(parents=[parent_parser]) + parser = parent_parser.add_argument_group("SegModel") parser.add_argument("--data_path", type=str, help="path where dataset is stored") parser.add_argument("--batch_size", type=int, default=16, help="size of the batches") parser.add_argument("--lr", type=float, default=0.001, help="adam: learning rate") @@ -257,7 +257,7 @@ def add_model_specific_args(parent_parser): # pragma: no-cover default=False, help="whether to use bilinear interpolation or transposed" ) - return parser + return parent_parser def main(hparams: Namespace): diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 1b6852c071fe1..994c259f48964 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -14,7 +14,6 @@ """LightningDataModule for loading DataLoaders with ease.""" import functools -import inspect from abc import abstractmethod from argparse import ArgumentParser, Namespace from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union @@ -23,7 +22,7 @@ from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks from pytorch_lightning.utilities import rank_zero_only -from pytorch_lightning.utilities.parsing import str_to_bool, str_to_bool_or_str +from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types class _DataModuleWrapper(type): @@ -269,58 +268,13 @@ def setup(self, stage: Optional[str] = None): pass @classmethod - def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: - r"""Extends existing argparse by default `LightningDataModule` attributes.""" - parser = ArgumentParser(parents=[parent_parser], add_help=False) - added_args = [x.dest for x in parser._actions] - - blacklist = ["kwargs"] - depr_arg_names = blacklist + added_args - depr_arg_names = set(depr_arg_names) - - allowed_types = (str, int, float, bool) - - # TODO: get "help" from docstring :) - for arg, arg_types, arg_default in ( - at for at in cls.get_init_arguments_and_types() if at[0] not in depr_arg_names - ): - arg_types = [at for at in allowed_types if at in arg_types] - if not arg_types: - # skip argument with not supported type - continue - arg_kwargs = {} - if bool in arg_types: - arg_kwargs.update(nargs="?", const=True) - # if the only arg type is bool - if len(arg_types) == 1: - use_type = str_to_bool - # if only two args (str, bool) - elif len(arg_types) == 2 and set(arg_types) == {str, bool}: - use_type = str_to_bool_or_str - else: - # filter out the bool as we need to use more general - use_type = [at for at in arg_types if at is not bool][0] - else: - use_type = arg_types[0] - - if arg_default == inspect._empty: - arg_default = None - - parser.add_argument( - f"--{arg}", - dest=arg, - default=arg_default, - type=use_type, - help=f"autogenerated by plb.{cls.__name__}", - **arg_kwargs, - ) - - return parser + def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser: + """Extends existing argparse by default `LightningDataModule` attributes.""" + return add_argparse_args(cls, parent_parser, **kwargs) @classmethod def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): - """ - Create an instance from CLI arguments. + """Create an instance from CLI arguments. Args: args: The parser or namespace to take arguments from. Only known arguments will be @@ -329,22 +283,11 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): These must be valid DataModule arguments. Example:: - parser = ArgumentParser(add_help=False) parser = LightningDataModule.add_argparse_args(parser) module = LightningDataModule.from_argparse_args(args) - """ - if isinstance(args, ArgumentParser): - args = cls.parse_argparser(args) - params = vars(args) - - # we only want to pass in valid DataModule args, the rest may be user specific - valid_kwargs = inspect.signature(cls.__init__).parameters - datamodule_kwargs = dict((name, params[name]) for name in valid_kwargs if name in params) - datamodule_kwargs.update(**kwargs) - - return cls(**datamodule_kwargs) + return from_argparse_args(cls, args, **kwargs) @classmethod def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: @@ -354,19 +297,7 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: List with tuples of 3 values: (argument name, set with argument types, argument default value). """ - datamodule_default_params = inspect.signature(cls.__init__).parameters - name_type_default = [] - for arg in datamodule_default_params: - arg_type = datamodule_default_params[arg].annotation - arg_default = datamodule_default_params[arg].default - try: - arg_types = tuple(arg_type.__args__) - except AttributeError: - arg_types = (arg_type, ) - - name_type_default.append((arg, arg_types, arg_default)) - - return name_type_default + return get_init_arguments_and_types(cls) @classmethod def from_datasets( diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 8cbd53d93f37f..b5654b148afc6 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -229,8 +229,8 @@ def match_env_arguments(cls) -> Namespace: return parse_env_variables(cls) @classmethod - def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: - return add_argparse_args(cls, parent_parser) + def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser: + return add_argparse_args(cls, parent_parser, **kwargs) @property def gpus(self) -> Optional[Union[List[int], str, int]]: diff --git a/pytorch_lightning/utilities/argparse.py b/pytorch_lightning/utilities/argparse.py index daf7270f1432b..ee42ab3241ff6 100644 --- a/pytorch_lightning/utilities/argparse.py +++ b/pytorch_lightning/utilities/argparse.py @@ -13,7 +13,7 @@ # limitations under the License. import inspect import os -from argparse import ArgumentParser, Namespace +from argparse import _ArgumentGroup, ArgumentParser, Namespace from contextlib import suppress from typing import Any, Dict, List, Tuple, Union @@ -21,8 +21,7 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): - """ - Create an instance from CLI arguments. + """Create an instance from CLI arguments. Eventually use varibles from OS environement which are defined as "PL__" Args: @@ -135,37 +134,84 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: return name_type_default -def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: - r"""Extends existing argparse by default `Trainer` attributes. +def get_abbrev_qualified_cls_name(cls): + assert isinstance(cls, type), repr(cls) + if cls.__module__.startswith("pytorch_lightning."): + # Abbreviate. + return f"pl.{cls.__name__}" + else: + # Fully qualified. + return f"{cls.__module__}.{cls.__qualname__}" + + +def add_argparse_args( + cls, + parent_parser: ArgumentParser, + *, + use_argument_group=True, +) -> ArgumentParser: + r"""Extends existing argparse by default attributes for ``cls``. Args: cls: Lightning class parent_parser: The custom cli arguments parser, which will be extended by - the Trainer default arguments. + the class's default arguments. + use_argument_group: + By default, this is True, and uses ``add_argument_group`` to add + a new group. + If False, this will use old behavior. + + Returns: + If use_argument_group is True, returns ``parent_parser`` to keep old + workflows. If False, will return the new ArgumentParser object. Only arguments of the allowed types (str, float, int, bool) will - extend the `parent_parser`. + extend the ``parent_parser``. Examples: + + # Option 1: Default usage. >>> import argparse >>> from pytorch_lightning import Trainer >>> parser = argparse.ArgumentParser() >>> parser = Trainer.add_argparse_args(parser) >>> args = parser.parse_args([]) + + # Option 2: Disable use_argument_group (old behavior). + >>> import argparse + >>> from pytorch_lightning import Trainer + >>> parser = argparse.ArgumentParser() + >>> parser = Trainer.add_argparse_args(parser, use_argument_group=False) + >>> args = parser.parse_args([]) """ - parser = ArgumentParser( - parents=[parent_parser], - add_help=False, - ) + if isinstance(parent_parser, _ArgumentGroup): + raise RuntimeError("Please only pass an ArgumentParser instance.") + if use_argument_group: + group_name = get_abbrev_qualified_cls_name(cls) + parser = parent_parser.add_argument_group(group_name) + else: + parser = ArgumentParser( + parents=[parent_parser], + add_help=False, + ) - blacklist = ['kwargs'] - depr_arg_names = cls.get_deprecated_arg_names() + blacklist + ignore_arg_names = ['self', 'args', 'kwargs'] + if hasattr(cls, "get_deprecated_arg_names"): + ignore_arg_names += cls.get_deprecated_arg_names() allowed_types = (str, int, float, bool) - args_help = parse_args_from_docstring(cls.__init__.__doc__ or cls.__doc__) - for arg, arg_types, arg_default in (at for at in get_init_arguments_and_types(cls) if at[0] not in depr_arg_names): + # Get symbols from cls or init function. + for symbol in (cls, cls.__init__): + args_and_types = get_init_arguments_and_types(symbol) + args_and_types = [x for x in args_and_types if x[0] not in ignore_arg_names] + if len(args_and_types) > 0: + break + + args_help = parse_args_from_docstring(cls.__init__.__doc__ or cls.__doc__ or "") + + for arg, arg_types, arg_default in args_and_types: arg_types = [at for at in allowed_types if at in arg_types] if not arg_types: # skip argument with not supported type @@ -196,6 +242,9 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: if arg == 'track_grad_norm': use_type = float + if arg_default is inspect._empty: + arg_default = None + parser.add_argument( f'--{arg}', dest=arg, @@ -205,7 +254,10 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: **arg_kwargs, ) - return parser + if use_argument_group: + return parent_parser + else: + return parser def parse_args_from_docstring(docstring: str) -> Dict[str, str]: diff --git a/tests/trainer/test_trainer_cli.py b/tests/trainer/test_trainer_cli.py index ac3906eee3ec0..32da5d2b2fa99 100644 --- a/tests/trainer/test_trainer_cli.py +++ b/tests/trainer/test_trainer_cli.py @@ -67,10 +67,15 @@ def test_add_argparse_args_redefined(cli_args: list): @pytest.mark.parametrize('cli_args', [['--accumulate_grad_batches=22'], ['--weights_save_path=./'], []]) -def test_add_argparse_via_argument_group(cli_args: list): - """Simple test ensuring that passing an argument group still works""" +def test_add_argparse_args(cli_args: list): + """Simple test ensuring Trainer.add_argparse_args works.""" parser = ArgumentParser(add_help=False) - parser = Trainer.add_argparse_args(parser.add_argument_group(title="pl.Trainer args")) + parser = Trainer.add_argparse_args(parser) + args = parser.parse_args(cli_args) + assert Trainer.from_argparse_args(args) + + parser = ArgumentParser(add_help=False) + parser = Trainer.add_argparse_args(parser, use_argument_group=False) args = parser.parse_args(cli_args) assert Trainer.from_argparse_args(args) diff --git a/tests/utilities/test_argparse_utils.py b/tests/utilities/test_argparse_utils.py index 63227abf831ec..b2eac514941e6 100644 --- a/tests/utilities/test_argparse_utils.py +++ b/tests/utilities/test_argparse_utils.py @@ -1,4 +1,15 @@ -from pytorch_lightning.utilities.argparse import parse_args_from_docstring +import io +from argparse import ArgumentParser +from typing import List + +import pytest + +from pytorch_lightning import Trainer +from pytorch_lightning.utilities.argparse import ( + add_argparse_args, + get_abbrev_qualified_cls_name, + parse_args_from_docstring, +) def test_parse_args_from_docstring_normal(): @@ -48,3 +59,112 @@ def test_parse_args_from_docstring_empty(): """ ) assert len(args_help.keys()) == 0 + + +def test_get_abbrev_qualified_cls_name(): + assert get_abbrev_qualified_cls_name(Trainer) == "pl.Trainer" + + class NestedClass: + pass + + assert not __name__.startswith("pytorch_lightning.") + expected_name = f"{__name__}.test_get_abbrev_qualified_cls_name..NestedClass" + assert get_abbrev_qualified_cls_name(NestedClass) == expected_name + + +class AddArgparseArgsExampleClass: + """ + Args: + my_parameter: A thing. + """ + + def __init__(self, my_parameter: int = 0): + pass + + @staticmethod + def get_deprecated_arg_names() -> List[str]: + return [] + + +class AddArgparseArgsExampleClassViaInit: + + def __init__(self, my_parameter: int = 0): + """ + Args: + my_parameter: A thing. + """ + pass + + +class AddArgparseArgsExampleClassNoDoc: + + def __init__(self, my_parameter: int = 0): + pass + + +def extract_help_text(parser): + help_str_buffer = io.StringIO() + parser.print_help(file=help_str_buffer) + help_str_buffer.seek(0) + return help_str_buffer.read() + + +@pytest.mark.parametrize(["cls", "name"], [ + [AddArgparseArgsExampleClass, "AddArgparseArgsExampleClass"], + [AddArgparseArgsExampleClassViaInit, "AddArgparseArgsExampleClassViaInit"], + [AddArgparseArgsExampleClassNoDoc, "AddArgparseArgsExampleClassNoDoc"], +]) +def test_add_argparse_args(cls, name): + """ + Tests that ``add_argparse_args`` handles argument groups correctly, and + can be parsed. + """ + parser = ArgumentParser() + parser_main = parser.add_argument_group("main") + parser_main.add_argument("--main_arg", type=str, default="") + parser_old = parser # For testing. + parser = add_argparse_args(cls, parser) + assert parser is parser_old + + # Check nominal argument groups. + help_text = extract_help_text(parser) + assert "main:" in help_text + assert "--main_arg" in help_text + assert f"{name}:" in help_text + assert "--my_parameter" in help_text + if cls is not AddArgparseArgsExampleClassNoDoc: + assert "A thing" in help_text + + fake_argv = ["--main_arg=abc", "--my_parameter=2"] + args = parser.parse_args(fake_argv) + assert args.main_arg == "abc" + assert args.my_parameter == 2 + + +def test_negative_add_argparse_args(): + with pytest.raises(RuntimeError, match="Please only pass an ArgumentParser instance."): + parser = ArgumentParser() + add_argparse_args(AddArgparseArgsExampleClass, parser.add_argument_group("bad workflow")) + + +def test_add_argparse_args_no_argument_group(): + """ + Tests that ``add_argparse_args(..., use_argument_group=False)`` (old + workflow) handles argument groups correctly, and can be parsed. + """ + parser = ArgumentParser() + parser.add_argument("--main_arg", type=str, default="") + parser_old = parser # For testing. + parser = add_argparse_args(AddArgparseArgsExampleClass, parser, use_argument_group=False) + assert parser is not parser_old + + # Check arguments. + help_text = extract_help_text(parser) + assert "--main_arg" in help_text + assert "--my_parameter" in help_text + assert "AddArgparseArgsExampleClass:" not in help_text + + fake_argv = ["--main_arg=abc", "--my_parameter=2"] + args = parser.parse_args(fake_argv) + assert args.main_arg == "abc" + assert args.my_parameter == 2 From c53edce1a1bada7be97c40c8c6d6ae3caca56f4a Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Thu, 11 Mar 2021 21:21:10 +0530 Subject: [PATCH 7/8] Disable batch transfer in DP mode (#6098) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add exceptions and test * hook * fix * clean up * clean up * regex * regex * docs * rev * comment and docs * chlog * Apply suggestions from code review Co-authored-by: Carlos Mocholí * Apply suggestions from code review Co-authored-by: chaton * Monkey-patch device count * docs * pep * api_change Co-authored-by: Carlos Mocholí Co-authored-by: chaton --- CHANGELOG.md | 3 ++ pytorch_lightning/accelerators/gpu.py | 11 +++- pytorch_lightning/core/hooks.py | 21 ++++++-- .../trainer/connectors/data_connector.py | 31 ++++++----- tests/accelerators/test_dp.py | 53 +++++++++++++++++++ 5 files changed, 99 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 08405cb89b392..4f721b263668f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -184,6 +184,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed error message for AMP + CPU incompatibility ([#6107](https://github.com/PyTorchLightning/pytorch-lightning/pull/6107)) +- Disabled batch transfer in DP mode ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093)) + + ## [1.2.0] - 2021-02-18 ### Added diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index dd45e592bdd7e..af9ce25f902b3 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -1,10 +1,11 @@ import logging import os -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import torch from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.plugins import DataParallelPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException if TYPE_CHECKING: @@ -48,3 +49,11 @@ def set_nvidia_flags() -> None: all_gpu_ids = ",".join([str(x) for x in range(torch.cuda.device_count())]) devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids) _log.info(f"LOCAL_RANK: {os.getenv('LOCAL_RANK', 0)} - CUDA_VISIBLE_DEVICES: [{devices}]") + + def to_device(self, batch: Any) -> Any: + # no need to transfer batch to device in DP mode + # TODO: Add support to allow batch transfer to device in Lightning for DP mode. + if not isinstance(self.training_type_plugin, DataParallelPlugin): + batch = super().to_device(batch) + + return batch diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 9826f9d44ac2c..1399d1b3c66ba 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -615,10 +615,7 @@ def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = Note: This hook only runs on single GPU training and DDP (no data-parallel). - If you need multi-GPU support for your custom batch objects, you need to define your custom - :class:`~torch.nn.parallel.DistributedDataParallel` or - :class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and - override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`. + Data-Parallel support will come in near future. Args: batch: A batch of data that needs to be transferred to a new device. @@ -638,6 +635,10 @@ def transfer_batch_to_device(self, batch, device): batch = super().transfer_batch_to_device(data, device) return batch + Raises: + MisconfigurationException: + If using data-parallel, ``Trainer(accelerator='dp')``. + See Also: - :meth:`move_data_to_device` - :meth:`apply_to_collection` @@ -649,10 +650,11 @@ def on_before_batch_transfer(self, batch, dataloader_idx): """ Override to alter or apply batch augmentations to your batch before it is transferred to the device. - .. warning:: dataloader_idx always returns 0, and will be updated to support the true idx in the future. + .. warning:: ``dataloader_idx`` always returns 0, and will be updated to support the true index in the future. Note: This hook only runs on single GPU training and DDP (no data-parallel). + Data-Parallel support will come in near future. Args: batch: A batch of data that needs to be altered or augmented. @@ -667,6 +669,10 @@ def on_before_batch_transfer(self, batch, dataloader_idx): batch['x'] = transforms(batch['x']) return batch + Raises: + MisconfigurationException: + If using data-parallel, ``Trainer(accelerator='dp')``. + See Also: - :meth:`on_after_batch_transfer` - :meth:`transfer_batch_to_device` @@ -681,6 +687,7 @@ def on_after_batch_transfer(self, batch, dataloader_idx): Note: This hook only runs on single GPU training and DDP (no data-parallel). + Data-Parallel support will come in near future. Args: batch: A batch of data that needs to be altered or augmented. @@ -695,6 +702,10 @@ def on_after_batch_transfer(self, batch, dataloader_idx): batch['x'] = gpu_transforms(batch['x']) return batch + Raises: + MisconfigurationException: + If using data-parallel, ``Trainer(accelerator='dp')``. + See Also: - :meth:`on_before_batch_transfer` - :meth:`transfer_batch_to_device` diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index fbe1cecdd837e..b3fc0b4eb7b29 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -89,6 +89,7 @@ def attach_data(self, model, train_dataloader, val_dataloaders, datamodule): # set up the passed in dataloaders (if needed) self.attach_dataloaders(model, train_dataloader, val_dataloaders) self.attach_datamodule(model, datamodule) + self._validate_data_hooks(model) def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloaders, datamodule): # If you supply a datamodule you can't supply train_dataloader or val_dataloaders @@ -97,6 +98,14 @@ def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloa 'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule' ) + def _validate_data_hooks(self, model): + # Raise Misconfiguration exception since these hooks are not supported in DP mode + # TODO: Remove this blocker once batch transfer to device is integrated in Lightning for DP mode. + batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer') + for hook in batch_transfer_hooks: + if self.trainer.accelerator_connector.use_dp and is_overridden(hook, model): + raise MisconfigurationException(f'Overriding `{hook}` is not supported in DP mode.') + def attach_dataloaders( self, model, @@ -127,22 +136,16 @@ def attach_datamodule(self, model, datamodule: Optional[LightningDataModule] = N if datamodule: # Override loader hooks - if is_overridden('train_dataloader', datamodule): - model.train_dataloader = datamodule.train_dataloader - if is_overridden('val_dataloader', datamodule): - model.val_dataloader = datamodule.val_dataloader - if is_overridden('test_dataloader', datamodule): - model.test_dataloader = datamodule.test_dataloader - if is_overridden('predict_dataloader', datamodule): - model.predict_dataloader = datamodule.predict_dataloader + dl_methods = ('train_dataloader', 'val_dataloader', 'test_dataloader', 'predict_dataloader') + for method in dl_methods: + if is_overridden(method, datamodule): + setattr(model, method, getattr(datamodule, method)) # Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule - if is_overridden('on_before_batch_transfer', datamodule): - model.on_before_batch_transfer = datamodule.on_before_batch_transfer - if is_overridden('transfer_batch_to_device', datamodule): - model.transfer_batch_to_device = datamodule.transfer_batch_to_device - if is_overridden('on_after_batch_transfer', datamodule): - model.on_after_batch_transfer = datamodule.on_after_batch_transfer + batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer') + for hook in batch_transfer_hooks: + if is_overridden(hook, datamodule): + setattr(model, hook, getattr(datamodule, hook)) self.trainer.datamodule = datamodule datamodule.trainer = self.trainer diff --git a/tests/accelerators/test_dp.py b/tests/accelerators/test_dp.py index 6b84e1a70ae58..ab46aba3119fb 100644 --- a/tests/accelerators/test_dp.py +++ b/tests/accelerators/test_dp.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest import torch import torch.nn.functional as F from torch.utils.data import DataLoader @@ -18,8 +19,10 @@ import pytorch_lightning as pl import tests.helpers.pipelines as tpipes import tests.helpers.utils as tutils +from pytorch_lightning import Trainer from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.core import memory +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf @@ -132,6 +135,56 @@ def training_epoch_end(self, outputs): assert outputs[0]["reduce_float"].item() == 0.5 # mean([0., 1.]) = 0.5 +def test_dp_raise_exception_with_batch_transfer_hooks(tmpdir, monkeypatch): + """ + Test that an exception is raised when overriding batch_transfer_hooks in DP model. + """ + monkeypatch.setattr("torch.cuda.device_count", lambda: 2) + + class CustomModel(BoringModel): + + def transfer_batch_to_device(self, batch, device): + batch = batch.to(device) + return batch + + trainer_options = dict( + default_root_dir=tmpdir, + max_steps=7, + gpus=[0, 1], + accelerator='dp', + ) + + trainer = Trainer(**trainer_options) + model = CustomModel() + + with pytest.raises(MisconfigurationException, match=r'Overriding `transfer_batch_to_device` is not .* in DP'): + trainer.fit(model) + + class CustomModel(BoringModel): + + def on_before_batch_transfer(self, batch, dataloader_idx): + batch += 1 + return batch + + trainer = Trainer(**trainer_options) + model = CustomModel() + + with pytest.raises(MisconfigurationException, match=r'Overriding `on_before_batch_transfer` is not .* in DP'): + trainer.fit(model) + + class CustomModel(BoringModel): + + def on_after_batch_transfer(self, batch, dataloader_idx): + batch += 1 + return batch + + trainer = Trainer(**trainer_options) + model = CustomModel() + + with pytest.raises(MisconfigurationException, match=r'Overriding `on_after_batch_transfer` is not .* in DP'): + trainer.fit(model) + + @RunIf(min_gpus=2) def test_dp_training_step_dict(tmpdir): """ This test verifies that dp properly reduces dictionaries """ From 62d4304ca4d42f5a681321cda1e4063e05edd096 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 11 Mar 2021 18:49:30 +0100 Subject: [PATCH 8/8] remove obsolete todo in pl_examples (#6475) --- pl_examples/basic_examples/simple_image_classifier.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pl_examples/basic_examples/simple_image_classifier.py b/pl_examples/basic_examples/simple_image_classifier.py index dfa5869779ff3..3f7079d665ea8 100644 --- a/pl_examples/basic_examples/simple_image_classifier.py +++ b/pl_examples/basic_examples/simple_image_classifier.py @@ -105,8 +105,6 @@ def cli_main(): # ------------ # testing # ------------ - # todo: without passing model it fails for missing best weights - # MisconfigurationException, 'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.' result = trainer.test(model, datamodule=dm) pprint(result)