Skip to content

Commit

Permalink
Remove no return warning from val/test step (#6139)
Browse files Browse the repository at this point in the history
* remove warning

* auto_opt

* chlog

* auto_opt

* no_warning_call

* rm old code

* add warning for predict

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
rohitgr7 and awaelchli authored Mar 6, 2021
1 parent 217470b commit facfda8
Show file tree
Hide file tree
Showing 14 changed files with 72 additions and 102 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274))


- Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139))


### Changed

- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))
Expand All @@ -49,6 +52,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed support for passing a bool value to `profiler` argument of Trainer ([#6164](https://github.com/PyTorchLightning/pytorch-lightning/pull/6164))


- Removed no return warning from val/test step ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139))


- Removed passing a `ModelCheckpoint` instance to `Trainer(checkpoint_callback)` ([#6166](https://github.com/PyTorchLightning/pytorch-lightning/pull/6166))


Expand Down
19 changes: 0 additions & 19 deletions pytorch_lightning/overrides/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any

import torch
from torch.nn import DataParallel
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.warnings import WarningCache

warning_cache = WarningCache()


class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):
Expand Down Expand Up @@ -53,20 +48,12 @@ def forward(self, *inputs, **kwargs):
# ddp_plugin ``post_training_step`` hook
if not self.module.automatic_optimization:
trainer.model.require_backward_grad_sync = False
warn_if_output_is_none(output, "training_step")

elif trainer and trainer.testing:
output = self.module.test_step(*inputs, **kwargs)
warn_if_output_is_none(output, "test_step")

elif trainer and (trainer.sanity_checking or trainer.validating):
output = self.module.validation_step(*inputs, **kwargs)
warn_if_output_is_none(output, "validation_step")

elif trainer and trainer.predicting:
output = self.module.predict(*inputs, **kwargs)
warn_if_output_is_none(output, "predict")

else:
output = self.module(*inputs, **kwargs)

Expand All @@ -76,12 +63,6 @@ def on_post_move_to_device(self):
pass


def warn_if_output_is_none(output: Any, method_name: str) -> None:
""" Warns user about which method returned None. """
if output is None:
warning_cache.warn(f'Your {method_name} returned None. Did you forget to return an output?')


def unwrap_lightning_module(wrapped_model) -> LightningModule:
model = wrapped_model
if isinstance(model, (DistributedDataParallel, DataParallel)):
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
# save the last weights
last_path = None
if (
self.lightning_module.trainer.state == TrainerState.FITTING
and best_model_path is not None
self.lightning_module.trainer.state == TrainerState.FITTING and best_model_path is not None
and len(best_model_path) > 0
):
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
# save the last weights
last_path = None
if (
self.lightning_module.trainer.state == TrainerState.FITTING
and best_model_path is not None
self.lightning_module.trainer.state == TrainerState.FITTING and best_model_path is not None
and len(best_model_path) > 0
):
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,7 @@ def get_evaluate_epoch_results(self):

# log results of evaluation
if (
self.trainer.state != TrainerState.FITTING
and self.trainer.evaluating
and self.trainer.is_global_zero
self.trainer.state != TrainerState.FITTING and self.trainer.evaluating and self.trainer.is_global_zero
and self.trainer.verbose_evaluate
):
print('-' * 80)
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def get_evaluation_dataloaders(self):
self.trainer.reset_val_dataloader(model)
if self.trainer.sanity_checking:
self.trainer.num_sanity_val_batches = [
min(self.trainer.num_sanity_val_steps, val_batches)
for val_batches in self.trainer.num_val_batches
min(self.trainer.num_sanity_val_steps, val_batches) for val_batches in self.trainer.num_val_batches
]
max_batches = self.trainer.num_sanity_val_batches
else:
Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/trainer/predict_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch

from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.warnings import WarningCache


class PredictLoop(object):
Expand All @@ -22,6 +23,7 @@ def __init__(self, trainer):
self.trainer = trainer
self.max_batches = None
self.num_dataloaders = None
self.warning_cache = WarningCache()

def on_trainer_init(self):
self.trainer.num_predict_batches = []
Expand Down Expand Up @@ -74,6 +76,10 @@ def predict(self, batch, batch_idx, dataloader_idx):

model_ref._current_fx_name = "predict"
predictions = self.trainer.accelerator.predict(args)

if predictions is None:
self.warning_cache.warn("predict returned None if it was on purpose, ignore this warning...")

self._predictions[dataloader_idx].append(predictions)
self.trainer._progress_bar_callback.on_predict_batch_end(
self.trainer, model_ref, predictions, batch, batch_idx, dataloader_idx
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,8 +878,7 @@ def test(
# Attach datamodule to get setup/prepare_data added to model before the call to it below
self.data_connector.attach_datamodule(model, datamodule)
results = (
self.__evaluate_given_model(model, dataloaders=test_dataloaders)
if model_provided else
self.__evaluate_given_model(model, dataloaders=test_dataloaders) if model_provided else
self.__evaluate_using_weights(model, ckpt_path=ckpt_path, dataloaders=test_dataloaders)
)

Expand Down
2 changes: 1 addition & 1 deletion tests/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def no_warning_call(warning_type, match: Optional[str] = None):

try:
w = record.pop(warning_type)
if not ((match and match in w.text) or w):
if not (match and match in str(w.message)):
return
except AssertionError:
# no warning raised
Expand Down
66 changes: 9 additions & 57 deletions tests/overrides/test_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from torch.nn import DataParallel

from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.base import warning_cache
from pytorch_lightning.overrides.data_parallel import (
LightningParallelModule,
python_scalar_to_tensor,
Expand All @@ -20,12 +19,14 @@
LightningParallelModule,
LightningDistributedModule,
])
@pytest.mark.parametrize("stage", [
("training", "training_step"),
("testing", "test_step"),
("validating", "validation_step"),
("predicting", "predict"),
])
@pytest.mark.parametrize(
"stage", [
("training", "training_step"),
("testing", "test_step"),
("validating", "validation_step"),
("predicting", "predict"),
]
)
def test_lightning_wrapper_module_methods(wrapper_class, stage):
""" Test that the LightningWrapper redirects .forward() to the LightningModule methods. """
pl_module = MagicMock()
Expand All @@ -36,63 +37,14 @@ def test_lightning_wrapper_module_methods(wrapper_class, stage):

prop, step = stage
pl_module.trainer.sanity_checking = False

for p in ("training", "testing", "validating", "predicting"):
setattr(pl_module.trainer, p, p == prop)

wrapped_module(batch, batch_idx)

getattr(pl_module, step).assert_called_with(batch, batch_idx)


@pytest.mark.parametrize("wrapper_class", [
LightningParallelModule,
LightningDistributedModule,
])
@pytest.mark.parametrize("stage", [
("training", "training_step"),
("testing", "test_step"),
("validating", "validation_step"),
])
def test_lightning_wrapper_module_warn_none_output(wrapper_class, stage):
""" Test that the LightningWrapper module warns about forgotten return statement. """
warning_cache.clear()
pl_module = MagicMock()

prop, step = stage
pl_module.trainer.sanity_checking = False
for p in ("training", "testing", "validating", "predicting"):
setattr(pl_module.trainer, p, p == prop)

wrapped_module = wrapper_class(pl_module)

getattr(pl_module, step).return_value = None

with pytest.warns(UserWarning, match=f"Your {step} returned None"):
wrapped_module()


@pytest.mark.parametrize("wrapper_class", [
LightningParallelModule,
LightningDistributedModule,
])
def test_lightning_wrapper_module_no_warn(wrapper_class):
warning_cache.clear()
pl_module = MagicMock()

pl_module.trainer.sanity_checking = False
pl_module.trainer.training = False
pl_module.trainer.testing = False
pl_module.trainer.validating = False
pl_module.trainer.predicting = False

wrapped_module = wrapper_class(pl_module)

with pytest.warns(None) as record:
wrapped_module()
pl_module.assert_called()
assert not record


@pytest.mark.parametrize(
"inp,expected", [
[torch.tensor(1.0), torch.tensor([1.0])],
Expand Down
24 changes: 16 additions & 8 deletions tests/trainer/data_flow/test_train_loop_flow_scalar_1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pytorch_lightning.core.lightning import LightningModule
from tests.helpers.boring_model import BoringModel
from tests.helpers.deterministic_model import DeterministicModel
from tests.helpers.utils import no_warning_call


@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
Expand Down Expand Up @@ -211,7 +212,8 @@ def backward(self, loss, optimizer, optimizer_idx):

def test_train_step_no_return(tmpdir):
"""
Tests that only training_step can be used
Tests that only training_step raises a warning when
nothing is returned in case of automatic_optimization
"""

class TestModel(BoringModel):
Expand All @@ -231,20 +233,26 @@ def validation_epoch_end(self, outputs):
assert len(outputs) == 0

model = TestModel()
trainer = Trainer(
trainer_args = dict(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=2,
log_every_n_steps=1,
weights_summary=None,
fast_dev_run=2,
)

with pytest.warns(UserWarning, match=r'.*training_step returned None.*'):
trainer = Trainer(**trainer_args)

with pytest.warns(UserWarning, match=r'training_step returned None .*'):
trainer.fit(model)

assert model.training_step_called
assert model.validation_step_called

model = TestModel()
model.automatic_optimization = False
trainer = Trainer(**trainer_args)

with no_warning_call(UserWarning, match=r'training_step returned None .*'):
trainer.fit(model)


def test_training_step_no_return_when_even(tmpdir):
"""
Expand Down
2 changes: 2 additions & 0 deletions tests/trainer/test_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def test_trainer_state_while_running(tmpdir, extra_params):
trainer = Trainer(default_root_dir=tmpdir, **extra_params, auto_lr_find=True)

class TestModel(BoringModel):

def __init__(self, expected_state):
super().__init__()
self.expected_state = expected_state
Expand Down Expand Up @@ -78,6 +79,7 @@ def test_interrupt_state_on_keyboard_interrupt(tmpdir, extra_params):
model = BoringModel()

class InterruptCallback(Callback):

def on_batch_start(self, trainer, pl_module):
raise KeyboardInterrupt

Expand Down
21 changes: 19 additions & 2 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1393,11 +1393,11 @@ def predict_dataloader(self):
return self._dataloaders


def predict(tmpdir, accelerator, gpus, num_processes, plugins=None, datamodule=True):
def predict(tmpdir, accelerator, gpus, num_processes, model=None, plugins=None, datamodule=True):

dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))]

model = BoringModel()
model = model or BoringModel()
datamodule = TestLightningDataModule(dataloaders)

trainer = Trainer(
Expand All @@ -1422,6 +1422,23 @@ def predict(tmpdir, accelerator, gpus, num_processes, plugins=None, datamodule=T
assert results[0][0].shape == torch.Size([1, 2])


def test_trainer_predict_no_return(tmpdir):
"""
Test trainer.predict warns when nothing is returned
"""

class CustomBoringModel(BoringModel):

def predict(self, batch, batch_idx, dataloader_idx=None):
if (batch_idx + 1) % 2 == 0:
return

return super().predict(batch, batch_idx, dataloader_idx)

with pytest.warns(UserWarning, match='predict returned None'):
predict(tmpdir, None, None, 1, model=CustomBoringModel())


@pytest.mark.parametrize('datamodule', [False, True])
def test_trainer_predict_cpu(tmpdir, datamodule):
predict(tmpdir, None, None, 1, datamodule=datamodule)
Expand Down
Loading

0 comments on commit facfda8

Please sign in to comment.