Skip to content

Commit

Permalink
Fix _PatchDataLoader
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Aug 22, 2021
1 parent b13ca7d commit b947f07
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
10 changes: 7 additions & 3 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,14 @@ def __init__(self, dataloader: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS], stage
def __call__(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]:
return self.dataloader

@property
def __name__(self) -> str:
return self.stage + "_dataloader"

def patch(self, model: "pl.LightningModule") -> None:
self._old_loader = getattr(model, self.stage + "_dataloader")
setattr(model, self.stage + "_dataloader", self)
self._old_loader = getattr(model, self.__name__)
setattr(model, self.__name__, self)

def unpatch(self, model: "pl.LightningModule") -> None:
setattr(model, self.stage + "_dataloader", self._old_loader)
setattr(model, self.__name__, self._old_loader)
self._old_loader = None
4 changes: 4 additions & 0 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1477,6 +1477,10 @@ def __init__(self, loader):
def __call__(self):
return self.loader

def __name__(self):
# `_PatchDataLoader` requires this when passed to `trainer.call_hook`
return "foo"

class TestModel(BoringModel):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit b947f07

Please sign in to comment.