diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 23626ed9cbeae..84210e9d7b667 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -37,7 +37,7 @@ def __init__(self): def prepare_data(self): # download, split, etc... # only called on 1 GPU/TPU in distributed - def setup(self): + def setup(self, stage): # make assignments here (val/train/test split) # called on every process in DDP def train_dataloader(self):