Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Ensure we set the eval/train flag correctly on accelerator model #6877

Merged
merged 10 commits into from
Apr 8, 2021
10 changes: 5 additions & 5 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,13 @@ def on_validation_model_eval(self) -> None:
"""
Sets the model to eval during the val loop
"""
self.eval()
self.trainer.model.eval()

def on_validation_model_train(self) -> None:
"""
Sets the model to train during the val loop
"""
self.train()
self.trainer.model.train()

def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
"""
Expand Down Expand Up @@ -172,19 +172,19 @@ def on_test_model_train(self) -> None:
"""
Sets the model to train during the test loop
"""
self.train()
self.trainer.model.train()

def on_test_model_eval(self) -> None:
"""
Sets the model to eval during the test loop
"""
self.eval()
self.trainer.model.eval()

def on_predict_model_eval(self) -> None:
"""
Sets the model to eval during the predict loop
"""
self.eval()
self.trainer.model.eval()

def on_epoch_start(self) -> None:
"""
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/trainer/predict_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def on_predict_model_eval(self, *_, **__):
model_ref.on_predict_model_eval()

def setup(self, model, max_batches, dataloaders):
self.trainer.call_hook("on_predict_start")

# copy properties for forward overrides
self.trainer.model_connector.copy_trainer_model_properties(model)
Expand Down
9 changes: 5 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,11 +582,11 @@ def run_train(self) -> None:
self.checkpoint_connector.has_trained = False

# enable train mode
model = self.lightning_module
model.train()
self.model.train()
torch.set_grad_enabled(True)

# reload data when needed
model = self.lightning_module
self.train_loop.reset_train_val_dataloaders(model)

# hook
Expand Down Expand Up @@ -772,8 +772,6 @@ def run_evaluate(self):
return eval_loop_results

def run_predict(self):
self.predict_loop.on_predict_start()

# prepare dataloaders
dataloaders, max_batches = self.predict_loop.get_predict_dataloaders()

Expand All @@ -789,6 +787,9 @@ def run_predict(self):
model.zero_grad()
torch.set_grad_enabled(False)

# call hook
self.predict_loop.on_predict_start()

# set up the eval loop
self.predict_loop.setup(model, max_batches, dataloaders)

Expand Down
32 changes: 32 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1438,7 +1438,9 @@ def setup(self, model, stage):
)
@patch("pytorch_lightning.loggers.tensorboard.TensorBoardLogger.log_metrics")
def test_log_every_n_steps(log_metrics_mock, tmpdir, train_batches, max_steps, log_interval):

class TestModel(BoringModel):

def training_step(self, *args, **kwargs):
self.log("foo", -1)
return super().training_step(*args, **kwargs)
Expand Down Expand Up @@ -1888,3 +1890,33 @@ def test_exception_when_testing_or_validating_with_fast_dev_run(tmpdir):
trainer.validate()
with pytest.raises(MisconfigurationException, match=r"\.test\(\)` with `fast_dev_run=True"):
trainer.test()


class TrainerStagesModel(BoringModel):

def on_train_start(self) -> None:
assert self.trainer.model.training
assert self.training

def on_validation_start(self) -> None:
assert not self.trainer.model.training
assert not self.training

def on_test_start(self) -> None:
assert not self.trainer.model.training
assert not self.training

def on_predict_start(self) -> None:
assert not self.trainer.model.training
assert not self.training


@pytest.mark.parametrize(['accelerator', 'num_processes'],
[(None, 1), pytest.param('ddp', 2, marks=RunIf(skip_windows=True))])
def test_model_in_correct_mode_during_stages(tmpdir, accelerator, num_processes):
model = TrainerStagesModel()
trainer = Trainer(default_root_dir=tmpdir, accelerator=accelerator, num_processes=num_processes, fast_dev_run=True)
trainer.fit(model)
trainer.validate(model)
trainer.test(model)
trainer.predict(model, model.val_dataloader())