From cbdf2a854ffc70e7ccaa6d6b4be5d24313ebdef1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 20 Feb 2021 20:50:50 +0100 Subject: [PATCH] fix running stage access --- tests/models/test_restore.py | 2 +- tests/overrides/test_data_parallel.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index d28ab6177f21c..a3f88e37bb09a 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -453,7 +453,7 @@ def on_train_start(self): # haven't trained with the new loaded model dp_model = new_trainer.model dp_model.eval() - dp_model.module.module.running_stage = RunningStage.EVALUATING + new_trainer._running_stage = RunningStage.EVALUATING dataloader = self.train_dataloader() tpipes.run_prediction(self.trainer.lightning_module, dataloader) diff --git a/tests/overrides/test_data_parallel.py b/tests/overrides/test_data_parallel.py index 64481bd70390d..90bb6fac88457 100644 --- a/tests/overrides/test_data_parallel.py +++ b/tests/overrides/test_data_parallel.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock +from unittest.mock import MagicMock, Mock import pytest import torch @@ -103,7 +103,8 @@ def training_step(self, batch, batch_idx): return {"loss": loss} model = TestModel() - model.running_stage = RunningStage.TRAINING + model.trainer = Mock() + model.trainer._running_stage = RunningStage.TRAINING batch = torch.rand(2, 32).cuda() batch_idx = 0 @@ -146,7 +147,8 @@ def training_step(self, batch, batch_idx): model = TestModel() model.to(device) - model.running_stage = RunningStage.TRAINING + model.trainer = Mock() + model.trainer._running_stage = RunningStage.TRAINING batch = torch.rand(2, 32).to(device) batch_idx = 0