diff --git a/CHANGELOG.md b/CHANGELOG.md index 7bc202168f..cd6b1a02c9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - Added support for `torch.jit` to tasks where possible and documented task JIT compatibility ([#389](https://github.com/PyTorchLightning/lightning-flash/pull/389)) +- Added option to provide a `Sampler` to the `DataModule` to use when creating a `DataLoader` ([#390](https://github.com/PyTorchLightning/lightning-flash/pull/390)) ### Changed diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index 58a3337e1d..4ed185e93b 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -22,6 +22,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data import DataLoader, Dataset from torch.utils.data.dataset import IterableDataset, Subset +from torch.utils.data.sampler import Sampler from flash.core.data.auto_dataset import BaseAutoDataset, IterableAutoDataset from flash.core.data.base_viz import BaseVisualization @@ -58,6 +59,8 @@ class DataModule(pl.LightningDataModule): num_workers: The number of workers to use for parallelized loading. Defaults to None which equals the number of available CPU threads, or 0 for Windows or Darwin platform. + sampler: A sampler following the :class:`~torch.utils.data.sampler.Sampler` type. + Will be passed to the DataLoader for the training dataset. Defaults to None. """ preprocess_cls = DefaultPreprocess @@ -76,6 +79,7 @@ def __init__( val_split: Optional[float] = None, batch_size: int = 1, num_workers: Optional[int] = None, + sampler: Optional[Sampler] = None, ) -> None: super().__init__() @@ -118,6 +122,7 @@ def __init__( else: num_workers = os.cpu_count() self.num_workers = num_workers + self.sampler = sampler self.set_running_stages() @@ -259,11 +264,14 @@ def _resolve_collate_fn(self, dataset: Dataset, running_stage: RunningStage) -> def _train_dataloader(self) -> DataLoader: train_ds: Dataset = self._train_ds() if isinstance(self._train_ds, Callable) else self._train_ds - shuffle = not isinstance(train_ds, (IterableDataset, IterableAutoDataset)) + shuffle: bool = False + if self.sampler is None: + shuffle = not isinstance(train_ds, (IterableDataset, IterableAutoDataset)) return DataLoader( train_ds, batch_size=self.batch_size, shuffle=shuffle, + sampler=self.sampler, num_workers=self.num_workers, pin_memory=True, drop_last=True, @@ -372,6 +380,7 @@ def from_data_source( val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, + sampler: Optional[Sampler] = None, **preprocess_kwargs: Any, ) -> 'DataModule': """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given inputs to @@ -407,6 +416,7 @@ def from_data_source( val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used if ``preprocess = None``. @@ -451,6 +461,7 @@ def from_data_source( val_split=val_split, batch_size=batch_size, num_workers=num_workers, + sampler=sampler, ) @classmethod @@ -469,6 +480,7 @@ def from_folders( val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, + sampler: Optional[Sampler] = None, **preprocess_kwargs: Any, ) -> 'DataModule': """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given folders using the @@ -497,6 +509,7 @@ def from_folders( val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used if ``preprocess = None``. @@ -527,6 +540,7 @@ def from_folders( val_split=val_split, batch_size=batch_size, num_workers=num_workers, + sampler=sampler, **preprocess_kwargs, ) @@ -549,6 +563,7 @@ def from_files( val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, + sampler: Optional[Sampler] = None, **preprocess_kwargs: Any, ) -> 'DataModule': """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given sequences of files using @@ -580,6 +595,7 @@ def from_files( val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used if ``preprocess = None``. @@ -611,6 +627,7 @@ def from_files( val_split=val_split, batch_size=batch_size, num_workers=num_workers, + sampler=sampler, **preprocess_kwargs, ) @@ -633,6 +650,7 @@ def from_tensors( val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, + sampler: Optional[Sampler] = None, **preprocess_kwargs: Any, ) -> 'DataModule': """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given tensors using the @@ -664,6 +682,7 @@ def from_tensors( val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used if ``preprocess = None``. @@ -695,6 +714,7 @@ def from_tensors( val_split=val_split, batch_size=batch_size, num_workers=num_workers, + sampler=sampler, **preprocess_kwargs, ) @@ -717,6 +737,7 @@ def from_numpy( val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, + sampler: Optional[Sampler] = None, **preprocess_kwargs: Any, ) -> 'DataModule': """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given numpy array using the @@ -748,6 +769,7 @@ def from_numpy( val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used if ``preprocess = None``. @@ -779,6 +801,7 @@ def from_numpy( val_split=val_split, batch_size=batch_size, num_workers=num_workers, + sampler=sampler, **preprocess_kwargs, ) @@ -800,6 +823,7 @@ def from_json( val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, + sampler: Optional[Sampler] = None, **preprocess_kwargs: Any, ) -> 'DataModule': """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given JSON files using the @@ -830,6 +854,7 @@ def from_json( val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used if ``preprocess = None``. @@ -862,6 +887,7 @@ def from_json( val_split=val_split, batch_size=batch_size, num_workers=num_workers, + sampler=sampler, **preprocess_kwargs, ) @@ -883,6 +909,7 @@ def from_csv( val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, + sampler: Optional[Sampler] = None, **preprocess_kwargs: Any, ) -> 'DataModule': """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given CSV files using the @@ -913,6 +940,7 @@ def from_csv( val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used if ``preprocess = None``. @@ -945,6 +973,7 @@ def from_csv( val_split=val_split, batch_size=batch_size, num_workers=num_workers, + sampler=sampler, **preprocess_kwargs, ) @@ -964,6 +993,7 @@ def from_datasets( val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, + sampler: Optional[Sampler] = None, **preprocess_kwargs: Any, ) -> 'DataModule': """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given datasets using the @@ -992,6 +1022,7 @@ def from_datasets( val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used if ``preprocess = None``. @@ -1022,5 +1053,6 @@ def from_datasets( val_split=val_split, batch_size=batch_size, num_workers=num_workers, + sampler=sampler, **preprocess_kwargs, ) diff --git a/tests/core/data/test_sampler.py b/tests/core/data/test_sampler.py new file mode 100644 index 0000000000..9ee9ace3a1 --- /dev/null +++ b/tests/core/data/test_sampler.py @@ -0,0 +1,32 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from flash import DataModule + + +@mock.patch("flash.core.data.data_module.DataLoader") +def test_dataloaders_with_sampler(mock_dataloader): + train_ds = val_ds = test_ds = 'dataset' + mock_sampler = 'sampler' + dm = DataModule(train_ds, val_ds, test_ds, num_workers=0, sampler=mock_sampler) + assert dm.sampler is mock_sampler + dl = dm.train_dataloader() + kwargs = mock_dataloader.call_args[1] + assert 'sampler' in kwargs + assert kwargs['sampler'] is mock_sampler + for dl in [dm.val_dataloader(), dm.test_dataloader()]: + kwargs = mock_dataloader.call_args[1] + assert 'sampler' not in kwargs