Skip to content

Commit

Permalink
fix running stage access
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Feb 20, 2021
1 parent bab7691 commit cbdf2a8
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
2 changes: 1 addition & 1 deletion tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def on_train_start(self):
# haven't trained with the new loaded model
dp_model = new_trainer.model
dp_model.eval()
dp_model.module.module.running_stage = RunningStage.EVALUATING
new_trainer._running_stage = RunningStage.EVALUATING

dataloader = self.train_dataloader()
tpipes.run_prediction(self.trainer.lightning_module, dataloader)
Expand Down
8 changes: 5 additions & 3 deletions tests/overrides/test_data_parallel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import MagicMock
from unittest.mock import MagicMock, Mock

import pytest
import torch
Expand Down Expand Up @@ -103,7 +103,8 @@ def training_step(self, batch, batch_idx):
return {"loss": loss}

model = TestModel()
model.running_stage = RunningStage.TRAINING
model.trainer = Mock()
model.trainer._running_stage = RunningStage.TRAINING
batch = torch.rand(2, 32).cuda()
batch_idx = 0

Expand Down Expand Up @@ -146,7 +147,8 @@ def training_step(self, batch, batch_idx):

model = TestModel()
model.to(device)
model.running_stage = RunningStage.TRAINING
model.trainer = Mock()
model.trainer._running_stage = RunningStage.TRAINING
batch = torch.rand(2, 32).to(device)
batch_idx = 0

Expand Down

0 comments on commit cbdf2a8

Please sign in to comment.