Skip to content

Commit

Permalink
Deprecate TrainerModelHooksMixin (#7422)
Browse files Browse the repository at this point in the history
* Deprecate TrainerModelHooksMixin

* Update CHANGELOG.md

* Update model_hooks.py

* Update model_hooks.py
  • Loading branch information
ananthsub authored May 7, 2021
1 parent 8208c33 commit fecce50
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 7 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Deprecated


- Deprecated `TrainerModelHooksMixin` in favor of `pytorch_lightning.utilities.signature_utils` ([#7422](https://github.com/PyTorchLightning/pytorch-lightning/pull/7422))


### Removed


Expand Down
21 changes: 19 additions & 2 deletions pytorch_lightning/trainer/model_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,42 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from abc import ABC
from typing import Optional

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.distributed import rank_zero_deprecation
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature


class TrainerModelHooksMixin(ABC):
"""
TODO: Remove this class in v1.6.
Use the utilities from ``pytorch_lightning.utilities.signature_utils`` instead.
"""

lightning_module: LightningModule

def is_function_implemented(self, f_name: str, model: Optional[LightningModule] = None) -> bool:
rank_zero_deprecation(
"Internal: TrainerModelHooksMixin.is_function_implemented is deprecated in v1.4"
" and will be removed in v1.6."
)
# note: currently unused - kept as it is public
if model is None:
model = self.lightning_module
f_op = getattr(model, f_name, None)
return callable(f_op)

def has_arg(self, f_name: str, arg_name: str) -> bool:
rank_zero_deprecation(
"Internal: TrainerModelHooksMixin.is_function_implemented is deprecated in v1.4"
" and will be removed in v1.6."
" Use `pytorch_lightning.utilities.signature_utils.is_param_in_hook_signature` instead."
)
model = self.lightning_module
f_op = getattr(model, f_name, None)
return arg_name in inspect.signature(f_op).parameters
if not f_op:
return False
return is_param_in_hook_signature(f_op, arg_name)
12 changes: 7 additions & 5 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,18 +936,20 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens):
# enable not needing to add opt_idx to training_step
args = [batch, batch_idx]

lightning_module = self.trainer.lightning_module

if len(self.trainer.optimizers) > 1:
if self.trainer.has_arg("training_step", "optimizer_idx"):
if not self.trainer.lightning_module.automatic_optimization:
training_step_fx = getattr(lightning_module, "training_step")
has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx")
if has_opt_idx_in_train_step:
if not lightning_module.automatic_optimization:
self.warning_cache.warn(
"`training_step` hook signature has changed in v1.3."
" `optimizer_idx` argument has been removed in case of manual optimization. Support for"
" the old signature will be removed in v1.5", DeprecationWarning
)
args.append(opt_idx)
elif not self.trainer.has_arg(
"training_step", "optimizer_idx"
) and self.trainer.lightning_module.automatic_optimization:
elif not has_opt_idx_in_train_step and lightning_module.automatic_optimization:
raise ValueError(
f"Your LightningModule defines {len(self.trainer.optimizers)} optimizers but"
' `training_step` is missing the `optimizer_idx` argument.'
Expand Down
30 changes: 30 additions & 0 deletions tests/deprecated_api/test_remove_1-6.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Test deprecated functionality which will be removed in v1.6.0 """

import pytest

from pytorch_lightning import Trainer
from tests.helpers import BoringModel


def test_v1_6_0_trainer_model_hook_mixin(tmpdir):
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, checkpoint_callback=False, logger=False)
trainer.fit(model)
with pytest.deprecated_call(match="is deprecated in v1.4 and will be removed in v1.6"):
trainer.is_function_implemented("training_step", model)

with pytest.deprecated_call(match="is deprecated in v1.4 and will be removed in v1.6"):
trainer.has_arg("training_step", "batch")

0 comments on commit fecce50

Please sign in to comment.