Skip to content

Commit

Permalink
add proper deprecation
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Feb 17, 2021
1 parent 58d9b80 commit 74c55c4
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 4 deletions.
10 changes: 10 additions & 0 deletions pytorch_lightning/trainer/deprecated_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn
Expand All @@ -23,6 +24,7 @@ class DeprecatedDistDeviceAttributes:
_running_stage: RunningStage
num_gpus: int
accelerator_connector: AcceleratorConnector
lightning_module = LightningModule

@property
def on_cpu(self) -> bool:
Expand Down Expand Up @@ -130,3 +132,11 @@ def use_single_gpu(self, val: bool) -> None:
)
if val:
self.accelerator_connector._device_type = DeviceType.GPU

def get_model(self) -> LightningModule:
rank_zero_warn(
"The use of `Trainer.get_model()` is deprecated in favor of `Trainer.lightning_module`"
" and will be removed in v1.4.",
DeprecationWarning,
)
return self.lightning_module
4 changes: 0 additions & 4 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,10 +352,6 @@ def model(self, model: torch.nn.Module) -> None:
"""
self.accelerator.model = model

def get_model(self) -> LightningModule:
# backward compatible
return self.lightning_module

@property
def lightning_optimizers(self) -> List[LightningOptimizer]:
if self._lightning_optimizers is None:
Expand Down
7 changes: 7 additions & 0 deletions tests/deprecated_api/test_remove_1-4.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@
from tests.helpers import BoringModel


def test_v1_4_0_deprecated_trainer_methods():
with pytest.deprecated_call(match='will be removed in v1.4'):
trainer = Trainer()
_ = trainer.get_model()
assert trainer.get_model() == trainer.lightning_module


def test_v1_4_0_deprecated_imports():
_soft_unimport_module('pytorch_lightning.utilities.argparse_utils')
with pytest.deprecated_call(match='will be removed in v1.4'):
Expand Down

0 comments on commit 74c55c4

Please sign in to comment.