From 4b8526a90af2c60b484d2383d8700f3a6949c9c9 Mon Sep 17 00:00:00 2001 From: otaj <6065855+otaj@users.noreply.github.com> Date: Wed, 10 Aug 2022 18:09:50 +0200 Subject: [PATCH] Fix a bug that caused spurious `AttributeError` when multiple `DataLoader` classes are imported (#14117) --- src/pytorch_lightning/CHANGELOG.md | 3 +++ src/pytorch_lightning/utilities/data.py | 10 +++++---- tests/tests_pytorch/utilities/test_data.py | 25 ++++++++++++++++++++++ 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 8c95200e02146..4139dc469dbd8 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -8,6 +8,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed a bug that caused spurious `AttributeError` when multiple `DataLoader` classes are imported ([#14117](https://github.com/Lightning-AI/lightning/pull/14117)) + + - Fixed epoch-end logging results not being reset after the end of the epoch ([#14061](https://github.com/Lightning-AI/lightning/pull/14061)) - Fixed saving hyperparameters in a composition where the parent class is not a `LightningModule` or `LightningDataModule` ([#14151](https://github.com/Lightning-AI/lightning/pull/14151)) diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index 862c7f2de905b..f2d3040125141 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -514,15 +514,17 @@ def _replace_init_method(base_cls: Type, store_explicit_arg: Optional[str] = Non It patches the ``__init__`` method. """ classes = _get_all_subclasses(base_cls) | {base_cls} - wrapped = set() for cls in classes: - if cls.__init__ not in wrapped: + # Check that __init__ belongs to the class + # https://stackoverflow.com/a/5253424 + if "__init__" in cls.__dict__: cls._old_init = cls.__init__ cls.__init__ = _wrap_init_method(cls.__init__, store_explicit_arg) - wrapped.add(cls.__init__) yield for cls in classes: - if hasattr(cls, "_old_init"): + # Check that _old_init belongs to the class + # https://stackoverflow.com/a/5253424 + if "_old_init" in cls.__dict__: cls.__init__ = cls._old_init del cls._old_init diff --git a/tests/tests_pytorch/utilities/test_data.py b/tests/tests_pytorch/utilities/test_data.py index ffb898efaa815..5b0087a245924 100644 --- a/tests/tests_pytorch/utilities/test_data.py +++ b/tests/tests_pytorch/utilities/test_data.py @@ -1,3 +1,4 @@ +import random from dataclasses import dataclass import pytest @@ -173,6 +174,30 @@ def __init__(self, randomize, *args, **kwargs): assert isinstance(new_dataloader, GoodImpl) +def test_replace_init_method_multiple_loaders_without_init(): + """In case of a class, that inherits from a class that we are patching, but doesn't define its own `__init__` + method (the one we are wrapping), it can happen, that `hasattr(cls, "_old_init")` is True because of parent + class, but it is impossible to delete, because that method is owned by parent class. Furthermore, the error + occured only sometimes because it depends on the order in which we are iterating over a set of classes we are + patching. + + This test simulates the behavior by generating sufficient number of dummy classes, which do not define `__init__` + and are children of `DataLoader`. We are testing that a) context manager `_replace_init_method` exits cleanly, and + b) the mechanism checking for presence of `_old_init` works as expected. + """ + classes = [DataLoader] + for i in range(100): + classes.append(type(f"DataLoader_{i}", (random.choice(classes),), {})) + + with _replace_init_method(DataLoader, "dataset"): + for cls in classes[1:]: # First one is `DataLoader` + assert "_old_init" not in cls.__dict__ + assert hasattr(cls, "_old_init") + + assert "_old_init" in DataLoader.__dict__ + assert hasattr(DataLoader, "_old_init") + + class DataLoaderSubclass1(DataLoader): def __init__(self, attribute1, *args, **kwargs): self.at1 = attribute1