Skip to content

Commit

Permalink
[Fix] Ensure we set the eval/train flag correctly on accelerator model (
Browse files Browse the repository at this point in the history
#6877)

* Ensure we move the model to eval mode before running evaluation

* Ensure we set the flag appropriately across all stages

* Add test, move hooks logic

* Apply same fix to the validate loop

* Update pytorch_lightning/trainer/trainer.py

* Fix function name

* Fix order, add predict

* Shorten the name

* Fix input dm, drop duplicate on predict start hook call, as it's called in the setup function

* Use hook, remove double call

(cherry picked from commit 742c48e)
  • Loading branch information
Sean Naren committed Apr 13, 2021
1 parent 6f7cf59 commit b79dc3c
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 8 deletions.
10 changes: 5 additions & 5 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,13 @@ def on_validation_model_eval(self) -> None:
"""
Sets the model to eval during the val loop
"""
self.eval()
self.trainer.model.eval()

def on_validation_model_train(self) -> None:
"""
Sets the model to train during the val loop
"""
self.train()
self.trainer.model.train()

def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
"""
Expand Down Expand Up @@ -208,19 +208,19 @@ def on_test_model_train(self) -> None:
"""
Sets the model to train during the test loop
"""
self.train()
self.trainer.model.train()

def on_test_model_eval(self) -> None:
"""
Sets the model to eval during the test loop
"""
self.eval()
self.trainer.model.eval()

def on_predict_model_eval(self) -> None:
"""
Sets the model to eval during the predict loop
"""
self.eval()
self.trainer.model.eval()

def on_epoch_start(self) -> None:
"""
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/predict_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def on_predict_model_eval(self, *_, **__):
model_ref.on_predict_model_eval()

def setup(self, model, max_batches, dataloaders):

# copy properties for forward overrides
self.trainer.model_connector.copy_trainer_model_properties(model)

Expand Down
7 changes: 5 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,11 +612,11 @@ def run_train(self):
self.checkpoint_connector.has_trained = False

# enable train mode
model = self.lightning_module
model.train()
self.model.train()
torch.set_grad_enabled(True)

# reload data when needed
model = self.lightning_module
self.train_loop.reset_train_val_dataloaders(model)

# hook
Expand Down Expand Up @@ -814,6 +814,9 @@ def run_predict(self):
model.zero_grad()
torch.set_grad_enabled(False)

# call hook
self.predict_loop.on_predict_start()

# set up the eval loop
self.predict_loop.setup(model, max_batches, dataloaders)

Expand Down
38 changes: 37 additions & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,7 +1381,14 @@ def setup(self, model, stage):
)
@patch("pytorch_lightning.loggers.tensorboard.TensorBoardLogger.log_metrics")
def test_log_every_n_steps(log_metrics_mock, tmpdir, train_batches, max_steps, log_interval):
model = EvalModelTemplate()

class TestModel(BoringModel):

def training_step(self, *args, **kwargs):
self.log("foo", -1)
return super().training_step(*args, **kwargs)

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
log_every_n_steps=log_interval,
Expand Down Expand Up @@ -1932,3 +1939,32 @@ def forward(self, x):

with pytest.raises(MisconfigurationException, match="Dataloader not found for `Trainer.predict`"):
trainer.predict(model)


class TrainerStagesModel(BoringModel):

def on_train_start(self) -> None:
assert self.trainer.model.training
assert self.training

def on_validation_start(self) -> None:
assert not self.trainer.model.training
assert not self.training

def on_test_start(self) -> None:
assert not self.trainer.model.training
assert not self.training

def on_predict_start(self) -> None:
assert not self.trainer.model.training
assert not self.training


@pytest.mark.parametrize(['accelerator', 'num_processes'],
[(None, 1), pytest.param('ddp', 2, marks=RunIf(skip_windows=True))])
def test_model_in_correct_mode_during_stages(tmpdir, accelerator, num_processes):
model = TrainerStagesModel()
trainer = Trainer(default_root_dir=tmpdir, accelerator=accelerator, num_processes=num_processes, fast_dev_run=True)
trainer.fit(model)
trainer.test(model)
trainer.predict(model, model.val_dataloader())

0 comments on commit b79dc3c

Please sign in to comment.