Skip to content

Commit

Permalink
Fix a bug that caused spurious AttributeError when multiple `DataLo…
Browse files Browse the repository at this point in the history
…ader` classes are imported (#14117)
  • Loading branch information
otaj authored and awaelchli committed Aug 11, 2022
1 parent fbe63d2 commit 4b8526a
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 4 deletions.
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed a bug that caused spurious `AttributeError` when multiple `DataLoader` classes are imported ([#14117](https://github.com/Lightning-AI/lightning/pull/14117))


- Fixed epoch-end logging results not being reset after the end of the epoch ([#14061](https://github.com/Lightning-AI/lightning/pull/14061))
- Fixed saving hyperparameters in a composition where the parent class is not a `LightningModule` or `LightningDataModule` ([#14151](https://github.com/Lightning-AI/lightning/pull/14151))

Expand Down
10 changes: 6 additions & 4 deletions src/pytorch_lightning/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,15 +514,17 @@ def _replace_init_method(base_cls: Type, store_explicit_arg: Optional[str] = Non
It patches the ``__init__`` method.
"""
classes = _get_all_subclasses(base_cls) | {base_cls}
wrapped = set()
for cls in classes:
if cls.__init__ not in wrapped:
# Check that __init__ belongs to the class
# https://stackoverflow.com/a/5253424
if "__init__" in cls.__dict__:
cls._old_init = cls.__init__
cls.__init__ = _wrap_init_method(cls.__init__, store_explicit_arg)
wrapped.add(cls.__init__)
yield
for cls in classes:
if hasattr(cls, "_old_init"):
# Check that _old_init belongs to the class
# https://stackoverflow.com/a/5253424
if "_old_init" in cls.__dict__:
cls.__init__ = cls._old_init
del cls._old_init

Expand Down
25 changes: 25 additions & 0 deletions tests/tests_pytorch/utilities/test_data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import random
from dataclasses import dataclass

import pytest
Expand Down Expand Up @@ -173,6 +174,30 @@ def __init__(self, randomize, *args, **kwargs):
assert isinstance(new_dataloader, GoodImpl)


def test_replace_init_method_multiple_loaders_without_init():
"""In case of a class, that inherits from a class that we are patching, but doesn't define its own `__init__`
method (the one we are wrapping), it can happen, that `hasattr(cls, "_old_init")` is True because of parent
class, but it is impossible to delete, because that method is owned by parent class. Furthermore, the error
occured only sometimes because it depends on the order in which we are iterating over a set of classes we are
patching.
This test simulates the behavior by generating sufficient number of dummy classes, which do not define `__init__`
and are children of `DataLoader`. We are testing that a) context manager `_replace_init_method` exits cleanly, and
b) the mechanism checking for presence of `_old_init` works as expected.
"""
classes = [DataLoader]
for i in range(100):
classes.append(type(f"DataLoader_{i}", (random.choice(classes),), {}))

with _replace_init_method(DataLoader, "dataset"):
for cls in classes[1:]: # First one is `DataLoader`
assert "_old_init" not in cls.__dict__
assert hasattr(cls, "_old_init")

assert "_old_init" in DataLoader.__dict__
assert hasattr(DataLoader, "_old_init")


class DataLoaderSubclass1(DataLoader):
def __init__(self, attribute1, *args, **kwargs):
self.at1 = attribute1
Expand Down

0 comments on commit 4b8526a

Please sign in to comment.