Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate returning extras with grads #7994

Merged
merged 5 commits into from
Jun 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `is_overridden(model=...)` in favor of `is_overridden(instance=...)` ([#7918](https://github.com/PyTorchLightning/pytorch-lightning/pull/7918))


- Deprecated automatically detaching returned extras with grads ([#7994](https://github.com/PyTorchLightning/pytorch-lightning/pull/7994))


- Deprecated default value of `monitor` argument in EarlyStopping callback to enforce `monitor` as a required argument ([#7907](https://github.com/PyTorchLightning/pytorch-lightning/pull/7907))


Expand Down
13 changes: 11 additions & 2 deletions pytorch_lightning/trainer/connectors/logger_connector/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@
from pytorch_lightning.utilities.enums import LightningEnum
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.metrics import metrics_to_scalars
from pytorch_lightning.utilities.warnings import WarningCache

# re-define the ones from pytorch_lightning.utilities.types without the `Number` type
_METRIC = Union[Metric, torch.Tensor]
_METRIC_COLLECTION = Union[_METRIC, Mapping[str, _METRIC]]

warning_cache = WarningCache()


class MetricSource(LightningEnum):
CALLBACK = "callback"
Expand Down Expand Up @@ -279,9 +282,15 @@ def extra(self, extra: Mapping[str, Any]) -> None:

def check_fn(v):
if v.grad_fn is not None:
raise MisconfigurationException(f'You returned a tensor with `grad_fn`. The extra values are {extra}')
warning_cache.warn(
f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically"
" but this behaviour will change in v1.6. Please detach it manually:"
" `return {'loss': ..., 'something': something.detach()}`", DeprecationWarning
)
return v.detach()
return v

apply_to_collection(extra, torch.Tensor, check_fn)
extra = apply_to_collection(extra, torch.Tensor, check_fn)
self['_extra'] = extra

def log(
Expand Down
15 changes: 15 additions & 0 deletions tests/deprecated_api/test_remove_1-6.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,18 @@ def test_v1_6_0_early_stopping_monitor(tmpdir):
" For backward compatibility, setting this to `early_stop_on`."
):
EarlyStopping()


def test_v1_6_0_extras_with_gradients(tmpdir):

class TestModel(BoringModel):

def training_step(self, *args):
loss = super().training_step(*args)['loss']
return {"loss": loss, 'foo': loss}

trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
model = TestModel()
match = r"\{'foo'\} has a `grad_fn`.*behaviour will change in v1\.6"
with pytest.deprecated_call(match=match):
trainer.fit(model)
11 changes: 0 additions & 11 deletions tests/trainer/logging_/test_train_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,17 +690,6 @@ def training_step(self, batch, batch_idx):
with pytest.raises(MisconfigurationException, match='`self.log` with the key `foo/dataloader_idx_0`'):
trainer.fit(model)

class TestModel(BoringModel):

def training_step(self, *args):
loss = super().training_step(*args)['loss']
return {"loss": loss, 'foo': loss}

trainer = Trainer(default_root_dir=tmpdir)
model = TestModel()
with pytest.raises(MisconfigurationException, match='You returned a tensor with `grad_fn`'):
trainer.fit(model)

class TestModel(BoringModel):

def training_step(self, *args):
Expand Down