diff --git a/CHANGELOG.md b/CHANGELOG.md index ddaf4288a0202..5575f7aeaf3f3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -199,6 +199,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `is_overridden(model=...)` in favor of `is_overridden(instance=...)` ([#7918](https://github.com/PyTorchLightning/pytorch-lightning/pull/7918)) +- Deprecated automatically detaching returned extras with grads ([#7994](https://github.com/PyTorchLightning/pytorch-lightning/pull/7994)) + + - Deprecated default value of `monitor` argument in EarlyStopping callback to enforce `monitor` as a required argument ([#7907](https://github.com/PyTorchLightning/pytorch-lightning/pull/7907)) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index cbc3dcfdefd98..ed20aafd982ea 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -26,11 +26,14 @@ from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.metrics import metrics_to_scalars +from pytorch_lightning.utilities.warnings import WarningCache # re-define the ones from pytorch_lightning.utilities.types without the `Number` type _METRIC = Union[Metric, torch.Tensor] _METRIC_COLLECTION = Union[_METRIC, Mapping[str, _METRIC]] +warning_cache = WarningCache() + class MetricSource(LightningEnum): CALLBACK = "callback" @@ -279,9 +282,15 @@ def extra(self, extra: Mapping[str, Any]) -> None: def check_fn(v): if v.grad_fn is not None: - raise MisconfigurationException(f'You returned a tensor with `grad_fn`. The extra values are {extra}') + warning_cache.warn( + f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically" + " but this behaviour will change in v1.6. Please detach it manually:" + " `return {'loss': ..., 'something': something.detach()}`", DeprecationWarning + ) + return v.detach() + return v - apply_to_collection(extra, torch.Tensor, check_fn) + extra = apply_to_collection(extra, torch.Tensor, check_fn) self['_extra'] = extra def log( diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index cb150cb013ec2..2fa54f3b253fb 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -212,3 +212,18 @@ def test_v1_6_0_early_stopping_monitor(tmpdir): " For backward compatibility, setting this to `early_stop_on`." ): EarlyStopping() + + +def test_v1_6_0_extras_with_gradients(tmpdir): + + class TestModel(BoringModel): + + def training_step(self, *args): + loss = super().training_step(*args)['loss'] + return {"loss": loss, 'foo': loss} + + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) + model = TestModel() + match = r"\{'foo'\} has a `grad_fn`.*behaviour will change in v1\.6" + with pytest.deprecated_call(match=match): + trainer.fit(model) diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index bff558e81b29e..8ca51b2dee3ef 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -690,17 +690,6 @@ def training_step(self, batch, batch_idx): with pytest.raises(MisconfigurationException, match='`self.log` with the key `foo/dataloader_idx_0`'): trainer.fit(model) - class TestModel(BoringModel): - - def training_step(self, *args): - loss = super().training_step(*args)['loss'] - return {"loss": loss, 'foo': loss} - - trainer = Trainer(default_root_dir=tmpdir) - model = TestModel() - with pytest.raises(MisconfigurationException, match='You returned a tensor with `grad_fn`'): - trainer.fit(model) - class TestModel(BoringModel): def training_step(self, *args):