Skip to content

Commit

Permalink
[Fix] Ensure we set the eval/train flag correctly on accelerator model (
Browse files Browse the repository at this point in the history
#6877)

* Ensure we move the model to eval mode before running evaluation

* Ensure we set the flag appropriately across all stages

* Add test, move hooks logic

* Apply same fix to the validate loop

* Update pytorch_lightning/trainer/trainer.py

* Fix function name

* Fix order, add predict

* Shorten the name

* Fix input dm, drop duplicate on predict start hook call, as it's called in the setup function

* Use hook, remove double call
  • Loading branch information
Sean Naren authored Apr 8, 2021
1 parent 851fd7f commit 742c48e
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 10 deletions.
10 changes: 5 additions & 5 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,13 @@ def on_validation_model_eval(self) -> None:
"""
Sets the model to eval during the val loop
"""
self.eval()
self.trainer.model.eval()

def on_validation_model_train(self) -> None:
"""
Sets the model to train during the val loop
"""
self.train()
self.trainer.model.train()

def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
"""
Expand Down Expand Up @@ -172,19 +172,19 @@ def on_test_model_train(self) -> None:
"""
Sets the model to train during the test loop
"""
self.train()
self.trainer.model.train()

def on_test_model_eval(self) -> None:
"""
Sets the model to eval during the test loop
"""
self.eval()
self.trainer.model.eval()

def on_predict_model_eval(self) -> None:
"""
Sets the model to eval during the predict loop
"""
self.eval()
self.trainer.model.eval()

def on_epoch_start(self) -> None:
"""
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/trainer/predict_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def on_predict_model_eval(self, *_, **__):
model_ref.on_predict_model_eval()

def setup(self, model, max_batches, dataloaders):
self.trainer.call_hook("on_predict_start")

# copy properties for forward overrides
self.trainer.model_connector.copy_trainer_model_properties(model)
Expand Down
9 changes: 5 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,11 +582,11 @@ def run_train(self) -> None:
self.checkpoint_connector.has_trained = False

# enable train mode
model = self.lightning_module
model.train()
self.model.train()
torch.set_grad_enabled(True)

# reload data when needed
model = self.lightning_module
self.train_loop.reset_train_val_dataloaders(model)

# hook
Expand Down Expand Up @@ -772,8 +772,6 @@ def run_evaluate(self):
return eval_loop_results

def run_predict(self):
self.predict_loop.on_predict_start()

# prepare dataloaders
dataloaders, max_batches = self.predict_loop.get_predict_dataloaders()

Expand All @@ -789,6 +787,9 @@ def run_predict(self):
model.zero_grad()
torch.set_grad_enabled(False)

# call hook
self.predict_loop.on_predict_start()

# set up the eval loop
self.predict_loop.setup(model, max_batches, dataloaders)

Expand Down
32 changes: 32 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1438,7 +1438,9 @@ def setup(self, model, stage):
)
@patch("pytorch_lightning.loggers.tensorboard.TensorBoardLogger.log_metrics")
def test_log_every_n_steps(log_metrics_mock, tmpdir, train_batches, max_steps, log_interval):

class TestModel(BoringModel):

def training_step(self, *args, **kwargs):
self.log("foo", -1)
return super().training_step(*args, **kwargs)
Expand Down Expand Up @@ -1888,3 +1890,33 @@ def test_exception_when_testing_or_validating_with_fast_dev_run(tmpdir):
trainer.validate()
with pytest.raises(MisconfigurationException, match=r"\.test\(\)` with `fast_dev_run=True"):
trainer.test()


class TrainerStagesModel(BoringModel):

def on_train_start(self) -> None:
assert self.trainer.model.training
assert self.training

def on_validation_start(self) -> None:
assert not self.trainer.model.training
assert not self.training

def on_test_start(self) -> None:
assert not self.trainer.model.training
assert not self.training

def on_predict_start(self) -> None:
assert not self.trainer.model.training
assert not self.training


@pytest.mark.parametrize(['accelerator', 'num_processes'],
[(None, 1), pytest.param('ddp', 2, marks=RunIf(skip_windows=True))])
def test_model_in_correct_mode_during_stages(tmpdir, accelerator, num_processes):
model = TrainerStagesModel()
trainer = Trainer(default_root_dir=tmpdir, accelerator=accelerator, num_processes=num_processes, fast_dev_run=True)
trainer.fit(model)
trainer.validate(model)
trainer.test(model)
trainer.predict(model, model.val_dataloader())

0 comments on commit 742c48e

Please sign in to comment.