diff --git a/CHANGELOG.md b/CHANGELOG.md index 8dfc9a1c021ac..d8cafbfe8a22e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Removed +- Removed support for passing a bool value to `profiler` argument of Trainer ([#6164](https://github.com/PyTorchLightning/pytorch-lightning/pull/6164)) + ### Fixed diff --git a/docs/source/starter/new-project.rst b/docs/source/starter/new-project.rst index 0a07e9903b348..593d995a34d5b 100644 --- a/docs/source/starter/new-project.rst +++ b/docs/source/starter/new-project.rst @@ -737,7 +737,7 @@ Lightning has many tools for debugging. Here is an example of just a few of them .. testcode:: # Profile your code to find speed/memory bottlenecks - Trainer(profiler=True) + Trainer(profiler="simple") --------------- diff --git a/pytorch_lightning/trainer/connectors/profiler_connector.py b/pytorch_lightning/trainer/connectors/profiler_connector.py index 225c6b9c981ef..98d65c1285ff7 100644 --- a/pytorch_lightning/trainer/connectors/profiler_connector.py +++ b/pytorch_lightning/trainer/connectors/profiler_connector.py @@ -21,10 +21,13 @@ PyTorchProfiler, SimpleProfiler, ) -from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException -PROFILERS = {"simple": SimpleProfiler, "advanced": AdvancedProfiler, "pytorch": PyTorchProfiler} +PROFILERS = { + "simple": SimpleProfiler, + "advanced": AdvancedProfiler, + "pytorch": PyTorchProfiler, +} class ProfilerConnector: @@ -32,24 +35,15 @@ class ProfilerConnector: def __init__(self, trainer): self.trainer = trainer - def on_trainer_init(self, profiler: Union[BaseProfiler, bool, str]): + def on_trainer_init(self, profiler: Union[BaseProfiler, str]): - if profiler and not isinstance(profiler, (bool, str, BaseProfiler)): - # TODO: Update exception on removal of bool + if profiler and not isinstance(profiler, (str, BaseProfiler)): raise MisconfigurationException( - "Only None, bool, str and subclasses of `BaseProfiler`" + "Only None, str and subclasses of `BaseProfiler`" " are valid values for `Trainer`'s `profiler` parameter." f" Received {profiler} which is of type {type(profiler)}." ) - - if isinstance(profiler, bool): - rank_zero_warn( - "Passing a bool value as a `profiler` argument to `Trainer` is deprecated" - " and will be removed in v1.3. Use str ('simple' or 'advanced') instead.", DeprecationWarning - ) - if profiler: - profiler = SimpleProfiler() - elif isinstance(profiler, str): + if isinstance(profiler, str): if profiler.lower() in PROFILERS: profiler_class = PROFILERS[profiler.lower()] profiler = profiler_class() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cf3bfd7a3e5a3..19cf5a02a6a93 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -122,7 +122,7 @@ def __init__( num_sanity_val_steps: int = 2, truncated_bptt_steps: Optional[int] = None, resume_from_checkpoint: Optional[Union[Path, str]] = None, - profiler: Optional[Union[BaseProfiler, bool, str]] = None, + profiler: Optional[Union[BaseProfiler, str]] = None, benchmark: bool = False, deterministic: bool = False, reload_dataloaders_every_epoch: bool = False, @@ -177,7 +177,7 @@ def __init__( checkpoint_callback: If ``True``, enable checkpointing. It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in - :paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`. Default: ``True``. + :paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`. .. warning:: Passing a ModelCheckpoint instance to this argument is deprecated since v1.1 and will be unsupported from v1.3. Use `callbacks` argument instead. @@ -226,10 +226,9 @@ def __init__( Ignored when a custom progress bar is passed to :paramref:`~Trainer.callbacks`. Default: None, means a suitable value will be chosen based on the environment (terminal, Google COLAB, etc.). - profiler: To profile individual steps during training and assist in identifying bottlenecks. Passing bool - value is deprecated in v1.1 and will be removed in v1.3. + profiler: To profile individual steps during training and assist in identifying bottlenecks. - overfit_batches: Overfit a percent of training data (float) or a set number of batches (int). Default: 0.0 + overfit_batches: Overfit a percent of training data (float) or a set number of batches (int). plugins: Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins. @@ -250,7 +249,7 @@ def __init__( num_processes: number of processes for distributed training with distributed_backend="ddp_cpu" num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine. - Set it to `-1` to run all batches in all validation dataloaders. Default: 2 + Set it to `-1` to run all batches in all validation dataloaders. reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch. diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index 5c97527be048b..81a5132e47356 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -14,8 +14,7 @@ def test_unsupported_precision_plugins(): trainer = Mock() model = Mock() accelerator = CPUAccelerator( - training_type_plugin=SingleDevicePlugin(torch.device("cpu")), - precision_plugin=MixedPrecisionPlugin() + training_type_plugin=SingleDevicePlugin(torch.device("cpu")), precision_plugin=MixedPrecisionPlugin() ) with pytest.raises(MisconfigurationException, match=r"amp \+ cpu is not supported."): accelerator.setup(trainer=trainer, model=model) diff --git a/tests/base/model_optimizers.py b/tests/base/model_optimizers.py index 4f607d45062a8..39e67748f0a90 100644 --- a/tests/base/model_optimizers.py +++ b/tests/base/model_optimizers.py @@ -45,14 +45,8 @@ def configure_optimizers__multiple_optimizers_frequency(self): optimizer1 = optim.Adam(self.parameters(), lr=self.learning_rate) optimizer2 = optim.Adam(self.parameters(), lr=self.learning_rate) return [ - { - 'optimizer': optimizer1, - 'frequency': 1 - }, - { - 'optimizer': optimizer2, - 'frequency': 5 - }, + dict(optimizer=optimizer1, frequency=1), + dict(optimizer=optimizer2, frequency=5), ] def configure_optimizers__single_scheduler(self): diff --git a/tests/base/model_test_steps.py b/tests/base/model_test_steps.py index e28ecd837cf9a..0b81143ee57f2 100644 --- a/tests/base/model_test_steps.py +++ b/tests/base/model_test_steps.py @@ -54,9 +54,7 @@ def test_step(self, batch, batch_idx, *args, **kwargs): output = OrderedDict({ 'test_loss': loss_test, 'test_acc': test_acc, - 'test_dic': { - 'test_loss_a': loss_test - }, + 'test_dic': dict(test_loss_a=loss_test), }) return output @@ -90,9 +88,7 @@ def test_step__multiple_dataloaders(self, batch, batch_idx, dataloader_idx, **kw output = OrderedDict({ 'test_loss': loss_test, 'test_acc': test_acc, - 'test_dic': { - 'test_loss_a': loss_test - }, + 'test_dic': dict(test_loss_a=loss_test), }) return output if batch_idx % 5 == 0: diff --git a/tests/base/model_train_steps.py b/tests/base/model_train_steps.py index 217395e7867fc..2a4161a23e053 100644 --- a/tests/base/model_train_steps.py +++ b/tests/base/model_train_steps.py @@ -44,12 +44,8 @@ def training_step(self, batch, batch_idx, optimizer_idx=None): output = OrderedDict({ 'loss': loss_train, - 'progress_bar': { - 'some_val': log_train * log_train - }, - 'log': { - 'train_some_val': log_train * log_train - }, + 'progress_bar': dict(some_val=log_train * log_train), + 'log': dict(train_some_val=log_train * log_train), }) return output diff --git a/tests/base/model_valid_steps.py b/tests/base/model_valid_steps.py index 11863e0af3d62..554d76253e4db 100644 --- a/tests/base/model_valid_steps.py +++ b/tests/base/model_valid_steps.py @@ -43,9 +43,7 @@ def validation_step(self, batch, batch_idx, *args, **kwargs): output = OrderedDict({ 'val_loss': loss_val, 'val_acc': val_acc, - 'test_dic': { - 'val_loss_a': loss_val - }, + 'test_dic': dict(val_loss_a=loss_val), }) return output diff --git a/tests/deprecated_api/test_remove_1-3.py b/tests/deprecated_api/test_remove_1-3.py index 99cb280e96797..ad2aa18aecc95 100644 --- a/tests/deprecated_api/test_remove_1-3.py +++ b/tests/deprecated_api/test_remove_1-3.py @@ -12,15 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. """Test deprecated functionality which will be removed in vX.Y.Z""" -from argparse import ArgumentParser -from unittest import mock import pytest import torch from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from pytorch_lightning.profiler.profilers import PassThroughProfiler, SimpleProfiler def test_v1_3_0_deprecated_arguments(tmpdir): @@ -111,38 +108,6 @@ def test_v1_3_0_deprecated_metrics(): ) -# TODO: remove bool from Trainer.profiler param in v1.3.0, update profiler_connector.py -@pytest.mark.parametrize(['profiler', 'expected'], [ - (True, SimpleProfiler), - (False, PassThroughProfiler), -]) -def test_trainer_profiler_remove_in_v1_3_0(profiler, expected): - # remove bool from Trainer.profiler param in v1.3.0, update profiler_connector.py - with pytest.deprecated_call(match='will be removed in v1.3'): - trainer = Trainer(profiler=profiler) - assert isinstance(trainer.profiler, expected) - - -@pytest.mark.parametrize( - ['cli_args', 'expected_parsed_arg', 'expected_profiler'], - [ - ('--profiler', True, SimpleProfiler), - ('--profiler True', True, SimpleProfiler), - ('--profiler False', False, PassThroughProfiler), - ], -) -def test_v1_3_0_trainer_cli_profiler(cli_args, expected_parsed_arg, expected_profiler): - cli_args = cli_args.split(' ') - with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): - parser = ArgumentParser(add_help=False) - parser = Trainer.add_argparse_args(parent_parser=parser) - args = Trainer.parse_argparser(parser) - - assert getattr(args, "profiler") == expected_parsed_arg - trainer = Trainer.from_argparse_args(args) - assert isinstance(trainer.profiler, expected_profiler) - - def test_trainer_enable_pl_optimizer(tmpdir): with pytest.deprecated_call(match='will be removed in v1.3'): Trainer(enable_pl_optimizer=True) diff --git a/tests/trainer/logging_/test_train_loop_logging_1_0.py b/tests/trainer/logging_/test_train_loop_logging_1_0.py index d957f56738cbe..1082b4f4c14c3 100644 --- a/tests/trainer/logging_/test_train_loop_logging_1_0.py +++ b/tests/trainer/logging_/test_train_loop_logging_1_0.py @@ -178,9 +178,7 @@ def backward(self, loss, optimizer, optimizer_idx): assert logged_metrics == expected_logged_metrics pbar_metrics = set(trainer.progress_bar_metrics.keys()) - expected_pbar_metrics = { - 'b', - } + expected_pbar_metrics = {'b'} assert pbar_metrics == expected_pbar_metrics callback_metrics = set(trainer.callback_metrics.keys()) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 167930425dab1..34305e434575a 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1475,16 +1475,14 @@ def test_trainer_profiler_incorrect_str_arg(): @pytest.mark.parametrize('profiler', ( 42, [42], - { - "a": 42 - }, + dict(a=42), torch.tensor(42), Trainer(), )) def test_trainer_profiler_incorrect_arg_type(profiler): with pytest.raises( MisconfigurationException, - match=r"Only None, bool, str and subclasses of `BaseProfiler`" + match="Only None, str and subclasses of `BaseProfiler`" r" are valid values for `Trainer`'s `profiler` parameter. *" ): Trainer(profiler=profiler)