diff --git a/CHANGELOG.md b/CHANGELOG.md index db0dc3f7a5..877962446e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added support for (input, target) style datasets (e.g. torchvision) to the from_datasets method ([#552](https://github.com/PyTorchLightning/lightning-flash/pull/552)) ### Changed diff --git a/flash/core/data/data_source.py b/flash/core/data/data_source.py index 97cf0e6fd3..f2d07b4b0d 100644 --- a/flash/core/data/data_source.py +++ b/flash/core/data/data_source.py @@ -334,20 +334,18 @@ def generate_dataset( SEQUENCE_DATA_TYPE = TypeVar("SEQUENCE_DATA_TYPE") -class DatasetDataSource(DataSource): - - def load_data(self, dataset: Dataset, auto_dataset: AutoDataset) -> Dataset: - if self.training: - # store a sample to infer the shape - parameters = signature(self.load_sample).parameters - if len(parameters) > 1 and AutoDataset.DATASET_KEY in parameters: - auto_dataset.sample = self.load_sample(dataset[0], self) - else: - auto_dataset.sample = self.load_sample(dataset[0]) - return dataset +class DatasetDataSource(DataSource[Dataset]): + """The ``DatasetDataSource`` implements default behaviours for data sources which expect the input to + :meth:`~flash.core.data.data_source.DataSource.load_data` to be a :class:`torch.utils.data.dataset.Dataset` + + Args: + labels: Optionally pass the labels as a mapping from class index to label string. These will then be set as the + :class:`~flash.core.data.data_source.LabelsState`. + """ - def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any]) -> Any: - # wrap everything within `.INPUT`. + def load_sample(self, sample: Any, dataset: Optional[Any] = None) -> Mapping[str, Any]: + if isinstance(sample, tuple) and len(sample) == 2: + return {DefaultDataKeys.INPUT: sample[0], DefaultDataKeys.TARGET: sample[1]} return {DefaultDataKeys.INPUT: sample} diff --git a/tests/core/data/test_auto_dataset.py b/tests/core/data/test_auto_dataset.py index 0051ae4f4b..7acbffe671 100644 --- a/tests/core/data/test_auto_dataset.py +++ b/tests/core/data/test_auto_dataset.py @@ -18,8 +18,7 @@ from flash.core.data.auto_dataset import AutoDataset, BaseAutoDataset, IterableAutoDataset from flash.core.data.callback import FlashCallback -from flash.core.data.data_module import DataModule -from flash.core.data.data_source import DataSource, DefaultDataKeys +from flash.core.data.data_source import DataSource class _AutoDatasetTestDataSource(DataSource): @@ -189,9 +188,3 @@ def test_preprocessing_data_source_with_running_stage(with_dataset): else: assert data_source.train_load_sample_count == len(dataset) assert data_source.train_load_data_count == 1 - - -def test_dataset_data_source(): - - dm = DataModule.from_datasets(range(10), range(10)) - assert dm.train_dataset.sample == {DefaultDataKeys.INPUT: 0} diff --git a/tests/core/data/test_data_source.py b/tests/core/data/test_data_source.py new file mode 100644 index 0000000000..77dbb173be --- /dev/null +++ b/tests/core/data/test_data_source.py @@ -0,0 +1,23 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from flash.core.data.data_source import DatasetDataSource, DefaultDataKeys + + +def test_dataset_data_source(): + data_source = DatasetDataSource() + + input, target = 'test', 3 + + assert data_source.load_sample((input, target)) == {DefaultDataKeys.INPUT: input, DefaultDataKeys.TARGET: target} + assert data_source.load_sample(input) == {DefaultDataKeys.INPUT: input} diff --git a/tests/image/classification/test_data.py b/tests/image/classification/test_data.py index b7d48a68fe..183f3427a4 100644 --- a/tests/image/classification/test_data.py +++ b/tests/image/classification/test_data.py @@ -27,6 +27,7 @@ if _TORCHVISION_AVAILABLE: import torchvision + from torchvision.datasets import FakeData if _PIL_AVAILABLE: from PIL import Image @@ -443,3 +444,32 @@ def test_from_fiftyone(tmpdir): assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, ) assert sorted(list(labels.numpy())) == [0, 1] + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_from_datasets(): + img_data = ImageClassificationData.from_datasets( + train_dataset=FakeData(size=3, num_classes=2), + val_dataset=FakeData(size=3, num_classes=2), + test_dataset=FakeData(size=3, num_classes=2), + batch_size=2, + num_workers=0, + ) + + # check training data + data = next(iter(img_data.train_dataloader())) + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + + # check validation data + data = next(iter(img_data.val_dataloader())) + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + + # check test data + data = next(iter(img_data.test_dataloader())) + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, )