Skip to content

Commit

Permalink
prune deprecated profiler as bool (Lightning-AI#6164)
Browse files Browse the repository at this point in the history
* prune profiler

* chlog
  • Loading branch information
Borda authored and ananthsub committed Feb 24, 2021
1 parent b42dfb5 commit 84061c2
Show file tree
Hide file tree
Showing 12 changed files with 28 additions and 89 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

- Removed support for passing a bool value to `profiler` argument of Trainer ([#6164](https://github.com/PyTorchLightning/pytorch-lightning/pull/6164))


### Fixed

Expand Down
2 changes: 1 addition & 1 deletion docs/source/starter/new-project.rst
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ Lightning has many tools for debugging. Here is an example of just a few of them
.. testcode::

# Profile your code to find speed/memory bottlenecks
Trainer(profiler=True)
Trainer(profiler="simple")

---------------

Expand Down
24 changes: 9 additions & 15 deletions pytorch_lightning/trainer/connectors/profiler_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,35 +21,29 @@
PyTorchProfiler,
SimpleProfiler,
)
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

PROFILERS = {"simple": SimpleProfiler, "advanced": AdvancedProfiler, "pytorch": PyTorchProfiler}
PROFILERS = {
"simple": SimpleProfiler,
"advanced": AdvancedProfiler,
"pytorch": PyTorchProfiler,
}


class ProfilerConnector:

def __init__(self, trainer):
self.trainer = trainer

def on_trainer_init(self, profiler: Union[BaseProfiler, bool, str]):
def on_trainer_init(self, profiler: Union[BaseProfiler, str]):

if profiler and not isinstance(profiler, (bool, str, BaseProfiler)):
# TODO: Update exception on removal of bool
if profiler and not isinstance(profiler, (str, BaseProfiler)):
raise MisconfigurationException(
"Only None, bool, str and subclasses of `BaseProfiler`"
"Only None, str and subclasses of `BaseProfiler`"
" are valid values for `Trainer`'s `profiler` parameter."
f" Received {profiler} which is of type {type(profiler)}."
)

if isinstance(profiler, bool):
rank_zero_warn(
"Passing a bool value as a `profiler` argument to `Trainer` is deprecated"
" and will be removed in v1.3. Use str ('simple' or 'advanced') instead.", DeprecationWarning
)
if profiler:
profiler = SimpleProfiler()
elif isinstance(profiler, str):
if isinstance(profiler, str):
if profiler.lower() in PROFILERS:
profiler_class = PROFILERS[profiler.lower()]
profiler = profiler_class()
Expand Down
11 changes: 5 additions & 6 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(
num_sanity_val_steps: int = 2,
truncated_bptt_steps: Optional[int] = None,
resume_from_checkpoint: Optional[Union[Path, str]] = None,
profiler: Optional[Union[BaseProfiler, bool, str]] = None,
profiler: Optional[Union[BaseProfiler, str]] = None,
benchmark: bool = False,
deterministic: bool = False,
reload_dataloaders_every_epoch: bool = False,
Expand Down Expand Up @@ -177,7 +177,7 @@ def __init__(
checkpoint_callback: If ``True``, enable checkpointing.
It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`. Default: ``True``.
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`.
.. warning:: Passing a ModelCheckpoint instance to this argument is deprecated since
v1.1 and will be unsupported from v1.3. Use `callbacks` argument instead.
Expand Down Expand Up @@ -226,10 +226,9 @@ def __init__(
Ignored when a custom progress bar is passed to :paramref:`~Trainer.callbacks`. Default: None, means
a suitable value will be chosen based on the environment (terminal, Google COLAB, etc.).
profiler: To profile individual steps during training and assist in identifying bottlenecks. Passing bool
value is deprecated in v1.1 and will be removed in v1.3.
profiler: To profile individual steps during training and assist in identifying bottlenecks.
overfit_batches: Overfit a percent of training data (float) or a set number of batches (int). Default: 0.0
overfit_batches: Overfit a percent of training data (float) or a set number of batches (int).
plugins: Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.
Expand All @@ -250,7 +249,7 @@ def __init__(
num_processes: number of processes for distributed training with distributed_backend="ddp_cpu"
num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine.
Set it to `-1` to run all batches in all validation dataloaders. Default: 2
Set it to `-1` to run all batches in all validation dataloaders.
reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch.
Expand Down
3 changes: 1 addition & 2 deletions tests/accelerators/test_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ def test_unsupported_precision_plugins():
trainer = Mock()
model = Mock()
accelerator = CPUAccelerator(
training_type_plugin=SingleDevicePlugin(torch.device("cpu")),
precision_plugin=MixedPrecisionPlugin()
training_type_plugin=SingleDevicePlugin(torch.device("cpu")), precision_plugin=MixedPrecisionPlugin()
)
with pytest.raises(MisconfigurationException, match=r"amp \+ cpu is not supported."):
accelerator.setup(trainer=trainer, model=model)
10 changes: 2 additions & 8 deletions tests/base/model_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,8 @@ def configure_optimizers__multiple_optimizers_frequency(self):
optimizer1 = optim.Adam(self.parameters(), lr=self.learning_rate)
optimizer2 = optim.Adam(self.parameters(), lr=self.learning_rate)
return [
{
'optimizer': optimizer1,
'frequency': 1
},
{
'optimizer': optimizer2,
'frequency': 5
},
dict(optimizer=optimizer1, frequency=1),
dict(optimizer=optimizer2, frequency=5),
]

def configure_optimizers__single_scheduler(self):
Expand Down
8 changes: 2 additions & 6 deletions tests/base/model_test_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ def test_step(self, batch, batch_idx, *args, **kwargs):
output = OrderedDict({
'test_loss': loss_test,
'test_acc': test_acc,
'test_dic': {
'test_loss_a': loss_test
},
'test_dic': dict(test_loss_a=loss_test),
})
return output

Expand Down Expand Up @@ -90,9 +88,7 @@ def test_step__multiple_dataloaders(self, batch, batch_idx, dataloader_idx, **kw
output = OrderedDict({
'test_loss': loss_test,
'test_acc': test_acc,
'test_dic': {
'test_loss_a': loss_test
},
'test_dic': dict(test_loss_a=loss_test),
})
return output
if batch_idx % 5 == 0:
Expand Down
8 changes: 2 additions & 6 deletions tests/base/model_train_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,8 @@ def training_step(self, batch, batch_idx, optimizer_idx=None):

output = OrderedDict({
'loss': loss_train,
'progress_bar': {
'some_val': log_train * log_train
},
'log': {
'train_some_val': log_train * log_train
},
'progress_bar': dict(some_val=log_train * log_train),
'log': dict(train_some_val=log_train * log_train),
})
return output

Expand Down
4 changes: 1 addition & 3 deletions tests/base/model_valid_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ def validation_step(self, batch, batch_idx, *args, **kwargs):
output = OrderedDict({
'val_loss': loss_val,
'val_acc': val_acc,
'test_dic': {
'val_loss_a': loss_val
},
'test_dic': dict(val_loss_a=loss_val),
})
return output

Expand Down
35 changes: 0 additions & 35 deletions tests/deprecated_api/test_remove_1-3.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test deprecated functionality which will be removed in vX.Y.Z"""
from argparse import ArgumentParser
from unittest import mock

import pytest
import torch

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.profiler.profilers import PassThroughProfiler, SimpleProfiler


def test_v1_3_0_deprecated_arguments(tmpdir):
Expand Down Expand Up @@ -111,38 +108,6 @@ def test_v1_3_0_deprecated_metrics():
)


# TODO: remove bool from Trainer.profiler param in v1.3.0, update profiler_connector.py
@pytest.mark.parametrize(['profiler', 'expected'], [
(True, SimpleProfiler),
(False, PassThroughProfiler),
])
def test_trainer_profiler_remove_in_v1_3_0(profiler, expected):
# remove bool from Trainer.profiler param in v1.3.0, update profiler_connector.py
with pytest.deprecated_call(match='will be removed in v1.3'):
trainer = Trainer(profiler=profiler)
assert isinstance(trainer.profiler, expected)


@pytest.mark.parametrize(
['cli_args', 'expected_parsed_arg', 'expected_profiler'],
[
('--profiler', True, SimpleProfiler),
('--profiler True', True, SimpleProfiler),
('--profiler False', False, PassThroughProfiler),
],
)
def test_v1_3_0_trainer_cli_profiler(cli_args, expected_parsed_arg, expected_profiler):
cli_args = cli_args.split(' ')
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
parser = ArgumentParser(add_help=False)
parser = Trainer.add_argparse_args(parent_parser=parser)
args = Trainer.parse_argparser(parser)

assert getattr(args, "profiler") == expected_parsed_arg
trainer = Trainer.from_argparse_args(args)
assert isinstance(trainer.profiler, expected_profiler)


def test_trainer_enable_pl_optimizer(tmpdir):
with pytest.deprecated_call(match='will be removed in v1.3'):
Trainer(enable_pl_optimizer=True)
4 changes: 1 addition & 3 deletions tests/trainer/logging_/test_train_loop_logging_1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,7 @@ def backward(self, loss, optimizer, optimizer_idx):
assert logged_metrics == expected_logged_metrics

pbar_metrics = set(trainer.progress_bar_metrics.keys())
expected_pbar_metrics = {
'b',
}
expected_pbar_metrics = {'b'}
assert pbar_metrics == expected_pbar_metrics

callback_metrics = set(trainer.callback_metrics.keys())
Expand Down
6 changes: 2 additions & 4 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,16 +1475,14 @@ def test_trainer_profiler_incorrect_str_arg():
@pytest.mark.parametrize('profiler', (
42,
[42],
{
"a": 42
},
dict(a=42),
torch.tensor(42),
Trainer(),
))
def test_trainer_profiler_incorrect_arg_type(profiler):
with pytest.raises(
MisconfigurationException,
match=r"Only None, bool, str and subclasses of `BaseProfiler`"
match="Only None, str and subclasses of `BaseProfiler`"
r" are valid values for `Trainer`'s `profiler` parameter. *"
):
Trainer(profiler=profiler)
Expand Down

0 comments on commit 84061c2

Please sign in to comment.