Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prune deprecated profiler as bool #6164

Merged
merged 8 commits into from
Feb 24, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

- Removed Trainer argument `profiler` as bool ([#6164](https://github.com/PyTorchLightning/pytorch-lightning/pull/6164))
Borda marked this conversation as resolved.
Show resolved Hide resolved


### Fixed

Expand Down
2 changes: 1 addition & 1 deletion docs/source/starter/new-project.rst
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ Lightning has many tools for debugging. Here is an example of just a few of them
.. testcode::

# Profile your code to find speed/memory bottlenecks
Trainer(profiler=True)
Trainer(profiler="simple")

---------------

Expand Down
24 changes: 9 additions & 15 deletions pytorch_lightning/trainer/connectors/profiler_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,35 +21,29 @@
PyTorchProfiler,
SimpleProfiler,
)
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

PROFILERS = {"simple": SimpleProfiler, "advanced": AdvancedProfiler, "pytorch": PyTorchProfiler}
PROFILERS = {
"simple": SimpleProfiler,
"advanced": AdvancedProfiler,
"pytorch": PyTorchProfiler,
}


class ProfilerConnector:

def __init__(self, trainer):
self.trainer = trainer

def on_trainer_init(self, profiler: Union[BaseProfiler, bool, str]):
def on_trainer_init(self, profiler: Union[BaseProfiler, str]):

if profiler and not isinstance(profiler, (bool, str, BaseProfiler)):
# TODO: Update exception on removal of bool
if profiler and not isinstance(profiler, (str, BaseProfiler)):
raise MisconfigurationException(
"Only None, bool, str and subclasses of `BaseProfiler`"
"Only None, str and subclasses of `BaseProfiler`"
" are valid values for `Trainer`'s `profiler` parameter."
f" Received {profiler} which is of type {type(profiler)}."
)

if isinstance(profiler, bool):
rank_zero_warn(
"Passing a bool value as a `profiler` argument to `Trainer` is deprecated"
" and will be removed in v1.3. Use str ('simple' or 'advanced') instead.", DeprecationWarning
)
if profiler:
profiler = SimpleProfiler()
elif isinstance(profiler, str):
if isinstance(profiler, str):
if profiler.lower() in PROFILERS:
profiler_class = PROFILERS[profiler.lower()]
profiler = profiler_class()
Expand Down
11 changes: 5 additions & 6 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(
num_sanity_val_steps: int = 2,
truncated_bptt_steps: Optional[int] = None,
resume_from_checkpoint: Optional[Union[Path, str]] = None,
profiler: Optional[Union[BaseProfiler, bool, str]] = None,
profiler: Optional[Union[BaseProfiler, str]] = None,
benchmark: bool = False,
deterministic: bool = False,
reload_dataloaders_every_epoch: bool = False,
Expand Down Expand Up @@ -177,7 +177,7 @@ def __init__(

checkpoint_callback: If ``True``, enable checkpointing.
It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`. Default: ``True``.
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`.

.. warning:: Passing a ModelCheckpoint instance to this argument is deprecated since
v1.1 and will be unsupported from v1.3. Use `callbacks` argument instead.
Expand Down Expand Up @@ -226,10 +226,9 @@ def __init__(
Ignored when a custom progress bar is passed to :paramref:`~Trainer.callbacks`. Default: None, means
a suitable value will be chosen based on the environment (terminal, Google COLAB, etc.).

profiler: To profile individual steps during training and assist in identifying bottlenecks. Passing bool
value is deprecated in v1.1 and will be removed in v1.3.
profiler: To profile individual steps during training and assist in identifying bottlenecks.

overfit_batches: Overfit a percent of training data (float) or a set number of batches (int). Default: 0.0
overfit_batches: Overfit a percent of training data (float) or a set number of batches (int).

plugins: Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.

Expand All @@ -250,7 +249,7 @@ def __init__(
num_processes: number of processes for distributed training with distributed_backend="ddp_cpu"

num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine.
Set it to `-1` to run all batches in all validation dataloaders. Default: 2
Borda marked this conversation as resolved.
Show resolved Hide resolved
Set it to `-1` to run all batches in all validation dataloaders.

reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch.

Expand Down
3 changes: 1 addition & 2 deletions tests/accelerators/test_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ def test_unsupported_precision_plugins():
trainer = Mock()
model = Mock()
accelerator = CPUAccelerator(
training_type_plugin=SingleDevicePlugin(torch.device("cpu")),
precision_plugin=MixedPrecisionPlugin()
training_type_plugin=SingleDevicePlugin(torch.device("cpu")), precision_plugin=MixedPrecisionPlugin()
)
with pytest.raises(MisconfigurationException, match=r"amp \+ cpu is not supported."):
accelerator.setup(trainer=trainer, model=model)
10 changes: 2 additions & 8 deletions tests/base/model_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,8 @@ def configure_optimizers__multiple_optimizers_frequency(self):
optimizer1 = optim.Adam(self.parameters(), lr=self.learning_rate)
optimizer2 = optim.Adam(self.parameters(), lr=self.learning_rate)
return [
{
'optimizer': optimizer1,
'frequency': 1
},
{
'optimizer': optimizer2,
'frequency': 5
},
dict(optimizer=optimizer1, frequency=1),
dict(optimizer=optimizer2, frequency=5),
]

def configure_optimizers__single_scheduler(self):
Expand Down
8 changes: 2 additions & 6 deletions tests/base/model_test_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ def test_step(self, batch, batch_idx, *args, **kwargs):
output = OrderedDict({
'test_loss': loss_test,
'test_acc': test_acc,
'test_dic': {
'test_loss_a': loss_test
},
'test_dic': dict(test_loss_a=loss_test),
})
return output

Expand Down Expand Up @@ -90,9 +88,7 @@ def test_step__multiple_dataloaders(self, batch, batch_idx, dataloader_idx, **kw
output = OrderedDict({
'test_loss': loss_test,
'test_acc': test_acc,
'test_dic': {
'test_loss_a': loss_test
},
'test_dic': dict(test_loss_a=loss_test),
})
return output
if batch_idx % 5 == 0:
Expand Down
8 changes: 2 additions & 6 deletions tests/base/model_train_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,8 @@ def training_step(self, batch, batch_idx, optimizer_idx=None):

output = OrderedDict({
'loss': loss_train,
'progress_bar': {
'some_val': log_train * log_train
},
'log': {
'train_some_val': log_train * log_train
},
'progress_bar': dict(some_val=log_train * log_train),
'log': dict(train_some_val=log_train * log_train),
})
return output

Expand Down
4 changes: 1 addition & 3 deletions tests/base/model_valid_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ def validation_step(self, batch, batch_idx, *args, **kwargs):
output = OrderedDict({
'val_loss': loss_val,
'val_acc': val_acc,
'test_dic': {
'val_loss_a': loss_val
},
'test_dic': dict(val_loss_a=loss_val),
})
return output

Expand Down
35 changes: 0 additions & 35 deletions tests/deprecated_api/test_remove_1-3.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test deprecated functionality which will be removed in vX.Y.Z"""
from argparse import ArgumentParser
from unittest import mock

import pytest
import torch

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.profiler.profilers import PassThroughProfiler, SimpleProfiler


def test_v1_3_0_deprecated_arguments(tmpdir):
Expand Down Expand Up @@ -111,38 +108,6 @@ def test_v1_3_0_deprecated_metrics():
)


# TODO: remove bool from Trainer.profiler param in v1.3.0, update profiler_connector.py
@pytest.mark.parametrize(['profiler', 'expected'], [
(True, SimpleProfiler),
(False, PassThroughProfiler),
])
def test_trainer_profiler_remove_in_v1_3_0(profiler, expected):
# remove bool from Trainer.profiler param in v1.3.0, update profiler_connector.py
with pytest.deprecated_call(match='will be removed in v1.3'):
trainer = Trainer(profiler=profiler)
assert isinstance(trainer.profiler, expected)


@pytest.mark.parametrize(
['cli_args', 'expected_parsed_arg', 'expected_profiler'],
[
('--profiler', True, SimpleProfiler),
('--profiler True', True, SimpleProfiler),
('--profiler False', False, PassThroughProfiler),
],
)
def test_v1_3_0_trainer_cli_profiler(cli_args, expected_parsed_arg, expected_profiler):
cli_args = cli_args.split(' ')
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
parser = ArgumentParser(add_help=False)
parser = Trainer.add_argparse_args(parent_parser=parser)
args = Trainer.parse_argparser(parser)

assert getattr(args, "profiler") == expected_parsed_arg
trainer = Trainer.from_argparse_args(args)
assert isinstance(trainer.profiler, expected_profiler)


def test_trainer_enable_pl_optimizer(tmpdir):
with pytest.deprecated_call(match='will be removed in v1.3'):
Trainer(enable_pl_optimizer=True)
4 changes: 1 addition & 3 deletions tests/trainer/logging_/test_train_loop_logging_1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,7 @@ def backward(self, loss, optimizer, optimizer_idx):
assert logged_metrics == expected_logged_metrics

pbar_metrics = set(trainer.progress_bar_metrics.keys())
expected_pbar_metrics = {
'b',
}
expected_pbar_metrics = {'b'}
assert pbar_metrics == expected_pbar_metrics

callback_metrics = set(trainer.callback_metrics.keys())
Expand Down
6 changes: 2 additions & 4 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,16 +1475,14 @@ def test_trainer_profiler_incorrect_str_arg():
@pytest.mark.parametrize('profiler', (
42,
[42],
{
"a": 42
},
dict(a=42),
torch.tensor(42),
Trainer(),
))
def test_trainer_profiler_incorrect_arg_type(profiler):
with pytest.raises(
MisconfigurationException,
match=r"Only None, bool, str and subclasses of `BaseProfiler`"
match="Only None, str and subclasses of `BaseProfiler`"
r" are valid values for `Trainer`'s `profiler` parameter. *"
):
Trainer(profiler=profiler)
Expand Down