From b947f07bfcb930dd4c5c1795d908d20c3a259bb5 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 22 Aug 2021 02:12:49 +0200 Subject: [PATCH] Fix `_PatchDataLoader` --- pytorch_lightning/trainer/connectors/data_connector.py | 10 +++++++--- tests/trainer/test_dataloaders.py | 4 ++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 7526faab547e1..6695bd6e6b5db 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -190,10 +190,14 @@ def __init__(self, dataloader: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS], stage def __call__(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: return self.dataloader + @property + def __name__(self) -> str: + return self.stage + "_dataloader" + def patch(self, model: "pl.LightningModule") -> None: - self._old_loader = getattr(model, self.stage + "_dataloader") - setattr(model, self.stage + "_dataloader", self) + self._old_loader = getattr(model, self.__name__) + setattr(model, self.__name__, self) def unpatch(self, model: "pl.LightningModule") -> None: - setattr(model, self.stage + "_dataloader", self._old_loader) + setattr(model, self.__name__, self._old_loader) self._old_loader = None diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index e6686cf8117e0..1d42a13ace9ab 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -1477,6 +1477,10 @@ def __init__(self, loader): def __call__(self): return self.loader + def __name__(self): + # `_PatchDataLoader` requires this when passed to `trainer.call_hook` + return "foo" + class TestModel(BoringModel): def __init__(self): super().__init__()