diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index d66ed38a5604b..c0c66948ad524 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -227,17 +227,29 @@ def _reset_eval_dataloader(self, model: LightningModule, mode: str) -> Tuple[Uni Returns: Tuple (num_batches, dataloaders) """ - dataloaders = self.request_dataloader(getattr(model, f'{mode}_dataloader')) + # use the training loader as val and test when overfitting + if self.overfit_pct > 0: + dataloaders = self.request_dataloader(getattr(model, f'train_dataloader')) + else: + dataloaders = self.request_dataloader(getattr(model, f'{mode}_dataloader')) if not isinstance(dataloaders, list): dataloaders = [dataloaders] - # shuffling in val and test set is bad practice for loader in dataloaders: + # shuffling in val and test set is bad practice if mode in ('val', 'test') and hasattr(loader, 'sampler') and isinstance(loader.sampler, RandomSampler): - rank_zero_warn( - f'Your {mode}_dataloader has shuffle=True, it is best practice to turn' - ' this off for validation and test dataloaders.') + + # when overfitting, remove the randomsampler for the train loaders + if self.overfit_pct > 0: + rank_zero_warn( + f'You requested to overfit but enabled training Dataloader shuffling. Disabling it for you' + ) + loader.sampler = SequentialSampler(loader.dataset) + else: + rank_zero_warn( + f'Your {mode}_dataloader has shuffle=True, it is best practice to turn' + ' this off for validation and test dataloaders.') if any([dl is None for dl in dataloaders]): rank_zero_warn("One of given dataloaders is None and it will be skipped.")