diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ec7c9762a34a..233d54476889b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Add Support for multiple train loaders ([#1959](https://github.com/PyTorchLightning/pytorch-lightning/pull/1959)) + - `Accuracy` metric now generalizes to Top-k accuracy for (multi-dimensional) multi-class inputs using the `top_k` parameter ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838)) - `Accuracy` metric now enables the computation of subset accuracy for multi-label or multi-dimensional multi-class inputs with the `subset_accuracy` parameter ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838)) diff --git a/docs/source/multiple_loaders.rst b/docs/source/multiple_loaders.rst index ee7b32555c53f..fb1aa33f80462 100644 --- a/docs/source/multiple_loaders.rst +++ b/docs/source/multiple_loaders.rst @@ -9,14 +9,16 @@ Multiple Datasets Lightning supports multiple dataloaders in a few ways. 1. Create a dataloader that iterates multiple datasets under the hood. -2. In the validation and test loop you also have the option to return multiple dataloaders +2. In the training loop you can pass multiple loaders as a dict or list/tuple and lightning + will automatically combine the batches from different loaders. +3. In the validation and test loop you also have the option to return multiple dataloaders which lightning will call sequentially. ---------- Multiple training dataloaders ----------------------------- -For training, the best way to use multiple dataloaders is to create a ``DataLoader`` class +For training, the usual way to use multiple dataloaders is to create a ``DataLoader`` class which wraps your multiple dataloaders (this of course also works for testing and validation dataloaders). @@ -59,6 +61,31 @@ dataloaders). # SAME ... +However, with lightning you can also return multiple loaders and lightning will take care of batch combination. + +For more details please have a look at :attr:`~pytorch_lightning.trainer.trainer.Trainer.multiple_trainloader_mode` + +.. testcode:: + + class LitModel(LightningModule): + + def train_dataloader(self): + + loader_a = torch.utils.data.DataLoader(range(6), batch_size=4) + loader_b = torch.utils.data.DataLoader(range(15), batch_size=5) + + # pass loaders as a dict. This will create batches like this: + # {'a': batch from loader_a, 'b': batch from loader_b} + loaders = {'a': loader_a, + 'b': loader_b} + + # OR: + # pass loaders as sequence. This will create batches like this: + # [batch from loader_a, batch from loader_b] + loaders = [loader_a, loader_b] + + return loaders + ---------- Test/Val dataloaders diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index f8452067e0c67..3db83c415aded 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -29,6 +29,9 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.trainer.supporters import CombinedLoader + class TrainerDataLoadingMixin(ABC): @@ -137,6 +140,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None: model: The current `LightningModule` """ self.train_dataloader = self.request_dataloader(model.train_dataloader) + if (self.overfit_batches > 0): if hasattr(self.train_dataloader, 'sampler') and isinstance(self.train_dataloader.sampler, RandomSampler): rank_zero_warn('You requested to overfit but enabled training dataloader shuffling.' @@ -147,13 +151,17 @@ def reset_train_dataloader(self, model: LightningModule) -> None: # debugging self.dev_debugger.track_load_dataloader_call('train_dataloader', dataloaders=[self.train_dataloader]) - self.num_training_batches = 0 - # automatically add samplers - self.train_dataloader = self.auto_add_sampler(self.train_dataloader, shuffle=True) + self.train_dataloader = apply_to_collection( + self.train_dataloader, DataLoader, self.auto_add_sampler, shuffle=True) + + # check the workers recursively + apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, 'train dataloader') + + # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches + self.train_dataloader = CombinedLoader(self.train_dataloader, self._multiple_trainloader_mode) self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf') - self._worker_check(self.train_dataloader, 'train dataloader') if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0: self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches)) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 445ddbd87686c..81d4f0cfcbcf1 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -18,6 +18,11 @@ import torch from pytorch_lightning.utilities.cloud_io import get_filesystem from torch import Tensor +from torch.utils.data import Dataset +from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.data import get_len +from collections.abc import Iterable, Iterator, Mapping, Sequence +from typing import Any, Union class TensorRunningAccum(object): @@ -176,3 +181,321 @@ def to_disk(self) -> None: # Write predictions for current file to disk with fs.open(filepath, "wb") as fp: torch.save(outputs, fp) + + +class CycleIterator(object): + """ + Iterator for restarting a dataloader if it runs out of samples + """ + def __init__(self, loader: Any, length: Optional[int] = None): + """ + + Args: + loader: the loader to restart for cyclic (and optionally infinite) sampling + length: the number of batches to sample (with restarted loaders if necessary) before raising StopIteration + if None: infinite + + """ + if length is None: + length = float('inf') + + self.length = length + self.loader = loader + self._loader_iter = None + self.counter = 0 + + def __iter__(self) -> Any: + """ + + Creates the internal iterator and returns self + + Returns: + CycleIterator: self + + """ + self.counter = 0 + self._loader_iter = iter(self.loader) + return self + + def __next__(self) -> Any: + """ + Fetches the next batch from internal dataloader and restarts + it if necessary + + Returns: + Any: the resulting batch + + Raises: + StopIteration: if more then :attr:`length` batches have been returned + + """ + # Note: if self.length is `inf`, then the iterator will never stop + if self.counter >= self.__len__(): + raise StopIteration + + try: + return next(self._loader_iter) + + except StopIteration: + self._loader_iter = iter(self.loader) + return next(self._loader_iter) + + finally: + self.counter += 1 + + def __len__(self) -> Union[int, float]: + return self.length + + +class CombinedDataset(object): + """ + Combine multiple datasets and compute their statistics + """ + def __init__(self, datasets: Union[Sequence, Mapping], mode: str): + """ + + Args: + datasets: a sequence/mapping datasets. Can be a collections of torch.utils.Dataset, + Iterable or even None. + mode: whether to use the minimum number of batches in all samples or the maximum + number of batches in all samples. + + """ + self.datasets = datasets + self.mode = mode + + @staticmethod + def _calc_num_data(datasets: Union[Sequence, Mapping], mode: str) -> Union[int, float]: + """ + Compute the length of `CombinedDataset` according to the `mode`. + + Args: + datasets: a sequence/mapping datasets. Can be a collections of torch.utils.data.Dataset, + Iterable or even None. + mode: Determine `CombinedDataset`'s length is the maximum or minimum of + the datasets. + + Returns: + length: the length of `CombinedDataset` + + """ + if mode not in ['min_size', 'max_size_cycle']: + raise ValueError(f"Invalid Mode: {mode}") + + # extract the lengths + all_lengths = apply_to_collection(datasets, (Dataset, Iterable, type(None)), get_len, + wrong_dtype=(Sequence, Mapping)) + + compute_func = {'min_size': min, 'max_size_cycle': max} + + if isinstance(all_lengths, (int, float)): + length = all_lengths + + elif isinstance(all_lengths, Mapping): + length = compute_func(all_lengths.values()) + + elif isinstance(all_lengths, Sequence): + length = compute_func(all_lengths) + + return length + + def __len__(self) -> int: + """Return the minimum length of the datasets.""" + return self._calc_num_data(self.datasets, self.mode) + + +class CombinedLoader(object): + """ + Combines different dataloaders and allows sampling in parallel. + + Supported modes are 'min_size', which raises StopIteration after the shortest loader + (the one with the lowest number of batches) is done, and 'max_size_cycle` which raises + StopIteration after the longest loader (the one with most batches) is done, while cycling + through the shorter loaders. + + Examples: + >>> loaders = {'a': torch.utils.data.DataLoader(range(6), batch_size=4), + ... 'b': torch.utils.data.DataLoader(range(15), batch_size=5)} + >>> combined_loader = CombinedLoader(loaders, 'max_size_cycle') + >>> for item in combined_loader: + ... print(item) + {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])} + {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])} + {'a': tensor([0, 1, 2, 3]), 'b': tensor([10, 11, 12, 13, 14])} + >>> combined_loader = CombinedLoader(loaders, 'min_size') + >>> for item in combined_loader: + ... print(item) + {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])} + {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])} + + """ + SUPPORTED_MODES = ('min_size', 'max_size_cycle') + + def __init__(self, loaders: Any, mode: str = 'min_size'): + """ + + Args: + loaders: the loaders to sample from. Can be all kind of collection + mode: the mode. Supported are 'min_size' which stops if the shortest loader is exhausted and + 'max_size_cycle' which stops if the longest loader is exhausted and cycles through the smaller ones. + + """ + self.loaders = loaders + + datasets = apply_to_collection(self.loaders, Iterable, getattr, 'dataset', None, + wrong_dtype=(Sequence, Mapping)) + # could be multiple datasets, but use self.dataset to follow the name convention in DataLoader + self.dataset = CombinedDataset(datasets, mode) + + if mode not in self.SUPPORTED_MODES: + raise ValueError(f"Invalid Mode: {mode}") + + self.mode = mode + + if self.mode == 'max_size_cycle': + self._wrap_loaders_max_size_cycle() + + @property + def sampler(self) -> Union[Iterable, Sequence, Mapping]: + """Return a collections of samplers extracting from loaders.""" + return apply_to_collection(self.loaders, Iterable, getattr, 'sampler', None, + wrong_dtype=(Sequence, Mapping)) + + def _wrap_loaders_max_size_cycle(self) -> Any: + """ + Wraps all loaders to make sure they are cycled until the longest loader is exhausted + + Returns: + Any: the wrapped loaders + + """ + all_lengths = apply_to_collection(self.loaders, Iterable, get_len, + wrong_dtype=(Sequence, Mapping)) + + if isinstance(all_lengths, (int, float)): + length = all_lengths + + elif isinstance(all_lengths, Mapping): + length = max(all_lengths.values()) + + elif isinstance(all_lengths, Sequence): + length = max(all_lengths) + + if isinstance(self.loaders, Mapping): + self.loaders = type(self.loaders)({k: CycleIterator(v, length=length) + for k, v in self.loaders.items()}) + + elif isinstance(self.loaders, Sequence): + self.loaders = type(self.loaders)([CycleIterator(v, length=length) + for v in self.loaders]) + + # dataloaders are iterable but not sequence + elif isinstance(self.loaders, Iterable): + # only one dataloader, just keep it the same. + pass + else: + raise ValueError(f'Invalid Datatype for loaders: {type(self.loaders).__name__}') + + def __iter__(self) -> Any: + """ + Create and return an iterator, `CombinedLoaderIterator`, for the combined loader. + """ + return CombinedLoaderIterator(self.loaders) + + @staticmethod + def _calc_num_batches(loaders: Any) -> Union[int, float]: + """ + Compute the length (aka the number of batches) of `CombinedLoader`. + + Args: + loaders: a collections of loaders. + + Returns: + length: the minimum length of loaders + + """ + all_lengths = apply_to_collection(loaders, Iterable, get_len, + wrong_dtype=(Sequence, Mapping)) + + if isinstance(all_lengths, (int, float)): + return all_lengths + + elif isinstance(all_lengths, Mapping): + return min(all_lengths.values()) + + elif isinstance(all_lengths, Sequence): + return min(all_lengths) + + raise TypeError(f'Got Type {type(all_lengths).__name__}, but expected one of Sequence, int or Mapping') + + def __len__(self) -> int: + return self._calc_num_batches(self.loaders) + + +class CombinedLoaderIterator(object): + """ + Custom Iterator returning data from multple loaders, and allows sampling in parallel + """ + def __init__(self, loaders: Any): + """ + + Args: + loaders: the loaders to sample from. Can be all kind of collection + + """ + self.loaders = loaders + self._loader_iters = None + + @property + def loader_iters(self) -> Any: + """ + Get the `_loader_iters` and create one if it is None. + """ + if self._loader_iters is None: + self._loader_iters = self.create_loader_iters(self.loaders) + + return self._loader_iters + + def __iter__(self) -> Any: + return self + + def __next__(self) -> Any: + """ + Fetches the next batch from multiple data loaders + + Returns: + Any: a collections of batch data + + """ + return self.request_next_batch(self.loader_iters) + + @staticmethod + def request_next_batch(loader_iters: Union[Iterator, Sequence, Mapping]) -> Any: + """ + Return the batch of data from multiple iterators. + + Args: + loader_iters: a collections of iterators + + Returns + Any: a collections of batch data + + """ + return apply_to_collection(loader_iters, Iterator, next) + + @staticmethod + def create_loader_iters( + loaders: Union[Any, Iterator, Sequence, Mapping] + ) -> Union[Any, Iterator, Sequence, Mapping]: + """ + Create and return a collection of iterators from loaders. + + Args: + loaders: a collections of loaders + + Returns + a collections of iterators + + """ + # dataloaders are Iterable but not Sequences. Need this to specifically exclude sequences + return apply_to_collection(loaders, Iterable, iter, wrong_dtype=(Sequence, Mapping)) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 62d7deb0eb378..014e0a62679dd 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -134,6 +134,7 @@ def __init__( automatic_optimization: Optional[bool] = None, move_metrics_to_cpu: bool = False, enable_pl_optimizer: bool = True, + multiple_trainloader_mode: str = 'max_size_cycle', ): r""" Customize every aspect of training via flags @@ -282,6 +283,11 @@ def __init__( enable_pl_optimizer: If True, each optimizer will be wrapped by `pytorch_lightning.core.optimizer.LightningOptimizer`. It allows Lightning to handle AMP, TPU, accumulated_gradients, etc.. + + multiple_trainloader_mode: How to loop over the datasets when there are multiple train loaders. + In 'max_size_cycle' mode, the trainer ends one epoch when the largest dataset is traversed, + and smaller datasets reload when running out of their data. In 'min_size' mode, all the datasets + reload when reaching the minimum length of datasets. """ super().__init__() self._device_type = DeviceType.CPU @@ -305,7 +311,7 @@ def __init__( self.tuner = Tuner(self) self.accelerator_backend = None self.evaluation_loop = EvaluationLoop(self) - self.train_loop = TrainLoop(self) + self.train_loop = TrainLoop(self, multiple_trainloader_mode) self.plugin_connector = PluginConnector(self) # training state diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 5dbc53ddde332..4b2d293aadf1a 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -34,7 +34,7 @@ class TrainLoop: - def __init__(self, trainer): + def __init__(self, trainer, multiple_trainloader_mode): self.trainer = trainer self.early_stopping_accumulator = None self.checkpoint_accumulator = None @@ -45,6 +45,8 @@ def __init__(self, trainer): self.automatic_optimization = True self._curr_step_result = None self._cur_grad_norm_dict = None + self._multiple_trainloader_mode = multiple_trainloader_mode + self.trainer._multiple_trainloader_mode = multiple_trainloader_mode def on_trainer_init( self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps, automatic_optimization @@ -545,10 +547,10 @@ def run_training_epoch(self): # track epoch output epoch_output = [[] for _ in range(self.num_optimizers)] - # enable profiling for the dataloader train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) dataloader_idx = 0 should_check_val = False + for batch_idx, (batch, is_last_batch) in train_dataloader: self.trainer.batch_idx = batch_idx diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index afcc4dd36faf5..95edb16c27b00 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -15,7 +15,7 @@ from abc import ABC from collections.abc import Mapping, Sequence from copy import copy -from typing import Any, Callable, Union +from typing import Any, Callable, Union, Optional import torch @@ -27,7 +27,8 @@ Batch = type(None) -def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args, **kwargs) -> Any: +def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args, + wrong_dtype: Optional[Union[type, tuple]] = None, **kwargs) -> Any: """ Recursively applies a function to all elements of a certain dtype. @@ -36,6 +37,8 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable dtype: the given function will be applied to all elements of this dtype function: the function to apply *args: positional arguments (will be forwarded to calls of ``function``) + wrong_dtype: the given function won't be applied if this type is specified and the given collections is of + the :attr:`wrong_type` even if it is of type :attr`dtype` **kwargs: keyword arguments (will be forwarded to calls of ``function``) Returns: @@ -45,7 +48,7 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable elem_type = type(data) # Breaking condition - if isinstance(data, dtype): + if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)): return function(data, *args, **kwargs) # Recursively apply to collection items diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py index 54f81f20f9ab7..1b4907ab8c2d4 100644 --- a/pytorch_lightning/utilities/data.py +++ b/pytorch_lightning/utilities/data.py @@ -17,6 +17,7 @@ from torch.utils.data import DataLoader, IterableDataset from pytorch_lightning.utilities import rank_zero_warn +from typing import Union def has_iterable_dataset(dataloader: DataLoader): @@ -45,3 +46,12 @@ def has_len(dataloader: DataLoader) -> bool: ' this can lead to unintended side effects since the samples will be duplicated.' ) return has_len + + +def get_len(dataloader: DataLoader) -> Union[int, float]: + """ Return the length of the given DataLoader. If ``__len__`` method is not implemented, return float('inf'). """ + + if has_len(dataloader): + return len(dataloader) + + return float('inf') diff --git a/tests/base/model_train_dataloaders.py b/tests/base/model_train_dataloaders.py index ad980f14fe95c..65873cfa8d6c4 100644 --- a/tests/base/model_train_dataloaders.py +++ b/tests/base/model_train_dataloaders.py @@ -20,7 +20,7 @@ class TrainDataloaderVariations(ABC): @abstractmethod - def dataloader(self, train: bool): + def dataloader(self, train: bool, *args, **kwargs): """placeholder""" def train_dataloader(self): @@ -37,3 +37,12 @@ def train_dataloader__zero_length(self): dataloader.dataset.data = dataloader.dataset.data[:0] dataloader.dataset.targets = dataloader.dataset.targets[:0] return dataloader + + def train_dataloader__multiple_mapping(self): + """Return a mapping loaders with different lengths""" + return {'a': self.dataloader(train=True, num_samples=100), + 'b': self.dataloader(train=True, num_samples=50)} + + def train_dataloader__multiple_sequence(self): + return [self.dataloader(train=True, num_samples=100), + self.dataloader(train=True, num_samples=50)] diff --git a/tests/base/model_train_steps.py b/tests/base/model_train_steps.py index 773d8cd2ad80f..e12d004db8f98 100644 --- a/tests/base/model_train_steps.py +++ b/tests/base/model_train_steps.py @@ -79,3 +79,99 @@ def training_step__result_obj_dp(self, batch, batch_idx, optimizer_idx=None): self.log('train_some_val', log_train * log_train) return loss_train + + def training_step_end_full_loop_result_obj_dp(self, result): + """ + Full loop flow train step (result obj + dp) + """ + result.minimize = result.minimize.mean() + result.checkpoint_on = result.checkpoint_on.mean() + result.train_step_metric = result.train_step_metric.mean() + result.log('train_step_end_metric', 1) + self.training_step_end_called = True + return result + + def training_epoch_end_full_loop_result_obj_dp(self, result): + """ + Full loop flow train step (result obj + dp) + """ + result.log('train_epoch_end_metric', 1, on_epoch=True) + self.training_epoch_end_called = True + + return result + + def eval_step_end_full_loop_result_obj_dp(self, result): + """ + Full loop flow train step (result obj + dp) + """ + eval_name = 'validation' if not self.trainer.testing else 'test' + reduced = getattr(result, f'{eval_name}_step_metric_step').mean() + setattr(result, f'{eval_name}_step_metric_step', reduced) + + reduced = getattr(result, f'{eval_name}_step_metric_epoch').mean() + setattr(result, f'{eval_name}_step_metric_epoch', reduced) + + reduced = getattr(result, f'{eval_name}_step_metric').mean() + setattr(result, f'{eval_name}_step_metric', reduced) + + result.checkpoint_on = result.checkpoint_on.mean() + result.early_stop_on = result.early_stop_on.mean() + result.log(f'{eval_name}_step_end_metric', torch.tensor(1).type_as(result.checkpoint_on)) + setattr(self, f'{eval_name}_step_end_called', True) + + return result + + def eval_epoch_end_full_loop_result_obj_dp(self, result): + """ + Full loop flow train step (result obj + dp) + """ + eval_name = 'validation' if not self.trainer.testing else 'test' + result.log(f'{eval_name}_epoch_end_metric', torch.tensor(1).type_as(result.checkpoint_on), on_epoch=True) + result.checkpoint_on = result.checkpoint_on.mean() + result.early_stop_on = result.early_stop_on.mean() + setattr(self, f'{eval_name}_epoch_end_called', True) + + # reduce the parametrized values + reduced = getattr(result, f'{eval_name}_step_metric_step').mean() + setattr(result, f'{eval_name}_step_metric_step', reduced) + + reduced = getattr(result, f'{eval_name}_step_metric_epoch').mean() + setattr(result, f'{eval_name}_step_metric_epoch', reduced) + + reduced = getattr(result, f'{eval_name}_step_end_metric').mean() + setattr(result, f'{eval_name}_step_end_metric', reduced) + + reduced = getattr(result, f'{eval_name}_step_metric').mean() + setattr(result, f'{eval_name}_step_metric', reduced) + + return result + + def training_step__multiple_dataloaders(self, batch, batch_idx, optimizer_idx=None): + """Training step for multiple train loaders""" + + assert isinstance(batch, dict) + assert len(batch) == 2 + assert 'a' in batch and 'b' in batch + + # forward pass + x, y = batch['a'] + x = x.view(x.size(0), -1) + y_hat = self(x) + + # calculate loss + loss_val = self.loss(y, y_hat) + log_val = loss_val + + # alternate between tensors and scalars for "log" and "progress_bar" + if batch_idx % 2 == 0: + log_val = log_val.item() + + output = OrderedDict( + { + 'loss': loss_val, + 'progress_bar': {'some_val': log_val * log_val}, + 'log': {'train_some_val': log_val * log_val}, + } + ) + return output + diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 50c426c174349..599ae862b39f1 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -33,7 +33,6 @@ def test_fit_train_loader_only(tmpdir): - model = EvalModelTemplate() train_dataloader = model.train_dataloader() @@ -52,7 +51,6 @@ def test_fit_train_loader_only(tmpdir): def test_fit_val_loader_only(tmpdir): - model = EvalModelTemplate() train_dataloader = model.train_dataloader() val_dataloader = model.val_dataloader() @@ -658,6 +656,62 @@ def test_warning_with_few_workers(mock, tmpdir, ckpt_path): trainer.test(**test_options) +@pytest.mark.skipif(platform.system() == 'Windows', reason='Does not apply to Windows platform.') +@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) +@patch('pytorch_lightning.trainer.data_loading.multiprocessing.cpu_count', return_value=4) +def test_warning_with_few_workers_multi_loader(mock, tmpdir, ckpt_path): + """ Test that error is raised if dataloader with only a few workers is used """ + + model = EvalModelTemplate() + model.training_step = model.training_step__multiple_dataloaders + model.validation_step = model.validation_step__multiple_dataloaders + model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders + model.test_step = model.test_step__multiple_dataloaders + model.test_epoch_end = model.test_epoch_end__multiple_dataloaders + + # logger file to get meta + train_dl = model.dataloader(train=True) + train_dl.num_workers = 0 + + val_dl = model.dataloader(train=False) + val_dl.num_workers = 0 + + train_dl = model.dataloader(train=False) + train_dl.num_workers = 0 + + train_multi_dl = {'a': train_dl, 'b': train_dl} + val_multi_dl = [val_dl, val_dl] + test_multi_dl = [train_dl, train_dl] + + fit_options = dict(train_dataloader=train_multi_dl, + val_dataloaders=val_multi_dl) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=0.1, + limit_train_batches=0.2, + ) + + # fit model + with pytest.warns( + UserWarning, match='The dataloader, train dataloader, does not have many workers which may be a bottleneck.' + ): + trainer.fit(model, **fit_options) + + with pytest.warns( + UserWarning, match='The dataloader, val dataloader 0, does not have many workers which may be a bottleneck.' + ): + trainer.fit(model, **fit_options) + + if ckpt_path == 'specific': + ckpt_path = trainer.checkpoint_callback.best_model_path + test_options = dict(test_dataloaders=test_multi_dl, ckpt_path=ckpt_path) + with pytest.warns( + UserWarning, match='The dataloader, test dataloader 0, does not have many workers which may be a bottleneck.' + ): + trainer.test(**test_options) + + @pytest.mark.xfail( LooseVersion(torch.__version__) < LooseVersion("1.4.0"), reason="IterableDataset with __len__ before 1.4 raises", @@ -857,6 +911,26 @@ def train_dataloader(self): assert 1 == result +@pytest.mark.parametrize(['multiple_trainloader_mode', 'num_training_batches'], [ + pytest.param("min_size", 5), + pytest.param("max_size_cycle", 10), +]) +def test_fit_multiple_train_loaders(tmpdir, multiple_trainloader_mode, num_training_batches): + """Integration test for multple train loaders""" + model = EvalModelTemplate() + + model.train_dataloader = model.train_dataloader__multiple + model.training_step = model.training_step__multiple_dataloaders + + trainer = Trainer( + max_epochs=1, default_root_dir=tmpdir, multiple_trainloader_mode=multiple_trainloader_mode + ) + + assert 1 == trainer.fit(model) + # verify the num_training_batches according to the multiple_trainloader_mode + assert num_training_batches == trainer.num_training_batches + + @pytest.mark.parametrize('check_interval', [1.0]) def test_val_dataloader_not_implemented_error(tmpdir, check_interval): """Test not_implemented_error data loader (e.g. IterableDataset)""" diff --git a/tests/trainer/test_supporters.py b/tests/trainer/test_supporters.py new file mode 100644 index 0000000000000..88812f01b1a22 --- /dev/null +++ b/tests/trainer/test_supporters.py @@ -0,0 +1,148 @@ +from collections import Sequence + +import pytest +import torch + +from torch.utils.data import TensorDataset +from pytorch_lightning.trainer.supporters import CycleIterator, CombinedLoader, CombinedDataset, CombinedLoaderIterator + + +def test_cycle_iterator(): + """Test the cycling function of `CycleIterator`""" + iterator = CycleIterator(range(100), 1000) + assert len(iterator) == 1000 + for idx, item in enumerate(iterator): + assert item < 100 + + assert idx == len(iterator) - 1 + + +def test_none_length_cycle_iterator(): + """Test the infinite cycling function of `CycleIterator`""" + iterator = CycleIterator(range(100)) + assert iterator.__len__() == float('inf') + + # test infinite loop + for idx, item in enumerate(iterator): + if idx == 1000: + break + assert item == 0 + + +@pytest.mark.parametrize(['dataset_1', 'dataset_2'], [ + ([list(range(10)), list(range(20))]), + ([range(10), range(20)]), + ([torch.randn(10, 3, 2), torch.randn(20, 5, 6)]), + ([TensorDataset(torch.randn(10, 3, 2)), TensorDataset(torch.randn(20, 5, 6))]) +]) +def test_combined_dataset(dataset_1, dataset_2): + """Verify the length of the CombinedDataset""" + datasets = [dataset_1, dataset_2] + combined_dataset = CombinedDataset(datasets) + + assert combined_dataset.max_len == 20 + assert combined_dataset.min_len == len(combined_dataset) == 10 + + +def test_combined_dataset_length_mode_error(): + with pytest.raises(ValueError, match='Invalid Mode'): + CombinedDataset._calc_num_data([range(10)], 'test') + + +def test_combined_loader_iterator_dict_min_size(): + """Test `CombinedLoaderIterator` given mapping loaders""" + loaders = {'a': torch.utils.data.DataLoader(range(10), batch_size=4), + 'b': torch.utils.data.DataLoader(range(20), batch_size=5)} + + combined_iter = CombinedLoaderIterator(loaders) + + for idx, item in enumerate(combined_iter): + assert isinstance(item, dict) + assert len(item) == 2 + assert 'a' in item and 'b' in item + + assert idx == min(len(loaders['a']), len(loaders['b'])) - 1 + + +def test_combined_loader_init_mode_error(): + """Test the ValueError when constructing `CombinedLoader`""" + with pytest.raises(ValueError, match='Invalid Mode'): + CombinedLoader([range(10)], 'test') + + +def test_combined_loader_loader_type_error(): + """Test the ValueError when wrapping the loaders""" + with pytest.raises(ValueError, match='Invalid Datatype'): + CombinedLoader(None, 'max_size_cycle') + + +def test_combined_loader_calc_length_mode_error(): + """Test the ValueError when calculating the number of batches""" + with pytest.raises(TypeError, match='Got Type NoneType, but expected one of Sequence, int or Mapping'): + CombinedLoader._calc_num_batches(None) + + +def test_combined_loader_dict_min_size(): + """Test `CombinedLoader` of mode 'min_size' given mapping loaders""" + loaders = {'a': torch.utils.data.DataLoader(range(10), batch_size=4), + 'b': torch.utils.data.DataLoader(range(20), batch_size=5)} + + combined_loader = CombinedLoader(loaders, 'min_size') + + assert len(combined_loader) == min([len(v) for v in loaders.values()]) + + for idx, item in enumerate(combined_loader): + assert isinstance(item, dict) + assert len(item) == 2 + assert 'a' in item and 'b' in item + + assert idx == len(combined_loader) - 1 + + +def test_combined_loader_dict_max_size_cycle(): + """Test `CombinedLoader` of mode 'max_size_cycle' given mapping loaders""" + loaders = {'a': torch.utils.data.DataLoader(range(10), batch_size=4), + 'b': torch.utils.data.DataLoader(range(20), batch_size=5)} + + combined_loader = CombinedLoader(loaders, 'max_size_cycle') + + assert len(combined_loader) == max([len(v) for v in loaders.values()]) + + for idx, item in enumerate(combined_loader): + assert isinstance(item, dict) + assert len(item) == 2 + assert 'a' in item and 'b' in item + + assert idx == len(combined_loader) - 1 + + +def test_combined_loader_sequence_min_size(): + """Test `CombinedLoader` of mode 'min_size' given sequence loaders""" + loaders = [torch.utils.data.DataLoader(range(10), batch_size=4), + torch.utils.data.DataLoader(range(20), batch_size=5)] + + combined_loader = CombinedLoader(loaders, 'min_size') + + assert len(combined_loader) == min([len(v) for v in loaders]) + + for idx, item in enumerate(combined_loader): + assert isinstance(item, Sequence) + assert len(item) == 2 + + assert idx == len(combined_loader) - 1 + + +def test_combined_loader_sequence_max_size_cycle(): + """Test `CombinedLoader` of mode 'max_size_cycle' given sequence loaders""" + loaders = [torch.utils.data.DataLoader(range(10), batch_size=4), + torch.utils.data.DataLoader(range(20), batch_size=5)] + + combined_loader = CombinedLoader(loaders, 'max_size_cycle') + + assert len(combined_loader) == max([len(v) for v in loaders]) + + for idx, item in enumerate(combined_loader): + assert isinstance(item, Sequence) + assert len(item) == 2 + + assert idx == len(combined_loader) - 1