Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prune deprecated profiler as bool #6164

Merged
merged 8 commits into from
Feb 24, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

- Removed Trainer argument `profiler` as bool ([#6164](https://github.com/PyTorchLightning/pytorch-lightning/pull/6164))
Borda marked this conversation as resolved.
Show resolved Hide resolved


### Fixed

Expand Down
14 changes: 3 additions & 11 deletions pytorch_lightning/trainer/connectors/profiler_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,16 @@ class ProfilerConnector:
def __init__(self, trainer):
self.trainer = trainer

def on_trainer_init(self, profiler: Union[BaseProfiler, bool, str]):
def on_trainer_init(self, profiler: Union[BaseProfiler, str]):

if profiler and not isinstance(profiler, (bool, str, BaseProfiler)):
if profiler and not isinstance(profiler, (str, BaseProfiler)):
# TODO: Update exception on removal of bool
raise MisconfigurationException(
"Only None, bool, str and subclasses of `BaseProfiler`"
carmocca marked this conversation as resolved.
Show resolved Hide resolved
" are valid values for `Trainer`'s `profiler` parameter."
f" Received {profiler} which is of type {type(profiler)}."
)

if isinstance(profiler, bool):
rank_zero_warn(
"Passing a bool value as a `profiler` argument to `Trainer` is deprecated"
" and will be removed in v1.3. Use str ('simple' or 'advanced') instead.", DeprecationWarning
)
if profiler:
profiler = SimpleProfiler()
elif isinstance(profiler, str):
if isinstance(profiler, str):
if profiler.lower() in PROFILERS:
profiler_class = PROFILERS[profiler.lower()]
profiler = profiler_class()
Expand Down
11 changes: 5 additions & 6 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(
num_sanity_val_steps: int = 2,
truncated_bptt_steps: Optional[int] = None,
resume_from_checkpoint: Optional[Union[Path, str]] = None,
profiler: Optional[Union[BaseProfiler, bool, str]] = None,
profiler: Optional[Union[BaseProfiler, str]] = None,
benchmark: bool = False,
deterministic: bool = False,
reload_dataloaders_every_epoch: bool = False,
Expand Down Expand Up @@ -177,7 +177,7 @@ def __init__(

checkpoint_callback: If ``True``, enable checkpointing.
It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`. Default: ``True``.
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`.

.. warning:: Passing a ModelCheckpoint instance to this argument is deprecated since
v1.1 and will be unsupported from v1.3. Use `callbacks` argument instead.
Expand Down Expand Up @@ -226,10 +226,9 @@ def __init__(
Ignored when a custom progress bar is passed to :paramref:`~Trainer.callbacks`. Default: None, means
a suitable value will be chosen based on the environment (terminal, Google COLAB, etc.).

profiler: To profile individual steps during training and assist in identifying bottlenecks. Passing bool
value is deprecated in v1.1 and will be removed in v1.3.
profiler: To profile individual steps during training and assist in identifying bottlenecks.

overfit_batches: Overfit a percent of training data (float) or a set number of batches (int). Default: 0.0
overfit_batches: Overfit a percent of training data (float) or a set number of batches (int).

plugins: Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.

Expand All @@ -250,7 +249,7 @@ def __init__(
num_processes: number of processes for distributed training with distributed_backend="ddp_cpu"

num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine.
Set it to `-1` to run all batches in all validation dataloaders. Default: 2
Borda marked this conversation as resolved.
Show resolved Hide resolved
Set it to `-1` to run all batches in all validation dataloaders.

reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch.

Expand Down
33 changes: 0 additions & 33 deletions tests/deprecated_api/test_remove_1-3.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.profiler.profilers import PassThroughProfiler, SimpleProfiler


def test_v1_3_0_deprecated_arguments(tmpdir):
Expand Down Expand Up @@ -111,38 +110,6 @@ def test_v1_3_0_deprecated_metrics():
)


# TODO: remove bool from Trainer.profiler param in v1.3.0, update profiler_connector.py
@pytest.mark.parametrize(['profiler', 'expected'], [
(True, SimpleProfiler),
(False, PassThroughProfiler),
])
def test_trainer_profiler_remove_in_v1_3_0(profiler, expected):
# remove bool from Trainer.profiler param in v1.3.0, update profiler_connector.py
with pytest.deprecated_call(match='will be removed in v1.3'):
trainer = Trainer(profiler=profiler)
assert isinstance(trainer.profiler, expected)


@pytest.mark.parametrize(
['cli_args', 'expected_parsed_arg', 'expected_profiler'],
[
('--profiler', True, SimpleProfiler),
('--profiler True', True, SimpleProfiler),
('--profiler False', False, PassThroughProfiler),
],
)
def test_v1_3_0_trainer_cli_profiler(cli_args, expected_parsed_arg, expected_profiler):
cli_args = cli_args.split(' ')
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
parser = ArgumentParser(add_help=False)
parser = Trainer.add_argparse_args(parent_parser=parser)
args = Trainer.parse_argparser(parser)

assert getattr(args, "profiler") == expected_parsed_arg
trainer = Trainer.from_argparse_args(args)
assert isinstance(trainer.profiler, expected_profiler)


def test_trainer_enable_pl_optimizer(tmpdir):
with pytest.deprecated_call(match='will be removed in v1.3'):
Trainer(enable_pl_optimizer=True)