From b8085f0a14e105fd5ccef317b965dfa91421a05e Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 30 Nov 2021 19:27:49 +0000 Subject: [PATCH] Refactor image inputs and update to new input object (#997) Co-authored-by: thomas chaton Co-authored-by: Ananya Harsh Jha --- CHANGELOG.md | 2 + docs/source/api/data.rst | 1 - docs/source/api/image.rst | 3 +- docs/source/template/data.rst | 2 +- flash/audio/classification/data.py | 14 +- flash/core/classification.py | 19 +- flash/core/data/data_module.py | 6 +- flash/core/data/io/classification_input.py | 79 +++ flash/core/data/io/input.py | 29 +- flash/core/data/io/input_base.py | 21 +- flash/core/data/utilities/classification.py | 306 ++++++++++ flash/core/data/utilities/data_frame.py | 90 +++ flash/core/data/utilities/paths.py | 99 ++-- flash/core/data/utilities/samples.py | 33 ++ flash/core/integrations/icevision/data.py | 5 +- flash/image/classification/adapters.py | 3 +- flash/image/classification/cli.py | 9 +- flash/image/classification/data.py | 540 +++++++++++------- .../classification/integrations/baal/data.py | 4 +- flash/image/data.py | 85 ++- flash/image/detection/output.py | 13 +- flash/image/face_detection/data.py | 91 +-- flash/image/segmentation/data.py | 2 +- flash/image/style_transfer/cli.py | 5 +- flash/image/style_transfer/data.py | 125 ++-- flash/tabular/data.py | 6 +- flash/template/classification/data.py | 5 +- flash/text/classification/data.py | 15 +- flash/video/classification/data.py | 7 +- .../image_classification_multi_label.py | 11 +- tests/core/data/io/test_output.py | 4 +- tests/core/data/utilities/__init__.py | 0 .../data/utilities/test_classification.py | 127 ++++ tests/core/test_model.py | 6 +- tests/image/classification/test_data.py | 10 +- 35 files changed, 1325 insertions(+), 452 deletions(-) create mode 100644 flash/core/data/io/classification_input.py create mode 100644 flash/core/data/utilities/classification.py create mode 100644 flash/core/data/utilities/data_frame.py create mode 100644 flash/core/data/utilities/samples.py create mode 100644 tests/core/data/utilities/__init__.py create mode 100644 tests/core/data/utilities/test_classification.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 981654fc32..d3bebc802d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `GraphEmbedder` task ([#592](https://github.com/PyTorchLightning/lightning-flash/pull/592)) +- Added support for comma delimited multi-label targets to the `ImageClassifier` ([#997](https://github.com/PyTorchLightning/lightning-flash/pull/997)) + ### Changed - Changed `DataSource` to `Input` ([#929](https://github.com/PyTorchLightning/lightning-flash/pull/929)) diff --git a/docs/source/api/data.rst b/docs/source/api/data.rst index ad4f6ffac2..f8d730ab9b 100644 --- a/docs/source/api/data.rst +++ b/docs/source/api/data.rst @@ -95,7 +95,6 @@ ___________________________ ~flash.core.data.io.input.InputFormat ~flash.core.data.io.input.FiftyOneInput ~flash.core.data.io.input.ImageLabelsMap - ~flash.core.data.io.input.LabelsState ~flash.core.data.io.input.MockDataset ~flash.core.data.io.input.NumpyInput ~flash.core.data.io.input.PathsInput diff --git a/docs/source/api/image.rst b/docs/source/api/image.rst index b8db9c6ad9..a8b2872194 100644 --- a/docs/source/api/image.rst +++ b/docs/source/api/image.rst @@ -18,6 +18,7 @@ ______________ :template: classtemplate.rst ~classification.model.ImageClassifier + ~classification.data.ImageClassificationFiftyOneInput ~classification.data.ImageClassificationData ~classification.data.ImageClassificationInputTransform @@ -140,7 +141,5 @@ ________________ :template: classtemplate.rst ~data.ImageDeserializer - ~data.ImageFiftyOneInput ~data.ImageNumpyInput - ~data.ImagePathsInput ~data.ImageTensorInput diff --git a/docs/source/template/data.rst b/docs/source/template/data.rst index bad1170438..3d12658781 100644 --- a/docs/source/template/data.rst +++ b/docs/source/template/data.rst @@ -52,7 +52,7 @@ We override our ``TemplateNumpyInput`` so that we can call ``super`` with the da We perform two additional steps here to improve the user experience: 1. We set the ``num_classes`` attribute on the ``dataset``. If ``num_classes`` is set, it is automatically made available as a property of the :class:`~flash.core.data.data_module.DataModule`. -2. We create and set a :class:`~flash.core.data.io.input.LabelsState`. The labels provided here will be shared with the :class:`~flash.core.classification.Labels` output, so the user doesn't need to provide them. +2. We create and set a :class:`~flash.core.data.io.input.ClassificationState`. The labels provided here will be shared with the :class:`~flash.core.classification.Labels` output, so the user doesn't need to provide them. Here's the code for the ``TemplateSKLearnInput.load_data`` method: diff --git a/flash/audio/classification/data.py b/flash/audio/classification/data.py index 0b3c688f94..5becd0bd0e 100644 --- a/flash/audio/classification/data.py +++ b/flash/audio/classification/data.py @@ -16,6 +16,8 @@ import numpy as np from flash.audio.classification.transforms import default_transforms, train_default_transforms +from flash.core.data.callback import BaseDataFetcher +from flash.core.data.data_module import DataModule from flash.core.data.io.input import ( DataKeys, has_file_allowed_extension, @@ -27,7 +29,7 @@ from flash.core.data.io.input_transform import InputTransform from flash.core.data.process import Deserializer from flash.core.data.utils import image_default_loader -from flash.image.classification.data import ImageClassificationData +from flash.image.classification.data import MatplotlibVisualization from flash.image.data import ImageDeserializer, IMG_EXTENSIONS, NP_EXTENSIONS @@ -114,7 +116,15 @@ def train_default_transforms(self) -> Optional[Dict[str, Callable]]: return train_default_transforms(self.spectrogram_size, self.time_mask_param, self.freq_mask_param) -class AudioClassificationData(ImageClassificationData): +class AudioClassificationData(DataModule): """Data module for audio classification.""" input_transform_cls = AudioClassificationInputTransform + + def set_block_viz_window(self, value: bool) -> None: + """Setter method to switch on/off matplotlib to pop up windows.""" + self.data_fetcher.block_viz_window = value + + @staticmethod + def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: + return MatplotlibVisualization(*args, **kwargs) diff --git a/flash/core/classification.py b/flash/core/classification.py index 39f0bd0c80..095d8174f6 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -19,7 +19,8 @@ from pytorch_lightning.utilities import rank_zero_warn from flash.core.adapter import AdapterTask -from flash.core.data.io.input import DataKeys, LabelsState +from flash.core.data.io.classification_input import ClassificationState +from flash.core.data.io.input import DataKeys from flash.core.data.io.output import Output from flash.core.model import Task from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires @@ -186,7 +187,7 @@ class Labels(Classes): Args: labels: A list of labels, assumed to map the class index to the label for that class. If ``labels`` is not - provided, will attempt to get them from the :class:`.LabelsState`. + provided, will attempt to get them from the :class:`.ClassificationState`. multi_label: If true, treats outputs as multi label logits. threshold: The threshold to use for multi_label classification. """ @@ -196,7 +197,7 @@ def __init__(self, labels: Optional[List[str]] = None, multi_label: bool = False self._labels = labels if labels is not None: - self.set_state(LabelsState(labels)) + self.set_state(ClassificationState(labels)) def transform(self, sample: Any) -> Union[int, List[int], str, List[str]]: labels = None @@ -204,7 +205,7 @@ def transform(self, sample: Any) -> Union[int, List[int], str, List[str]]: if self._labels is not None: labels = self._labels else: - state = self.get_state(LabelsState) + state = self.get_state(ClassificationState) if state is not None: labels = state.labels @@ -214,7 +215,7 @@ def transform(self, sample: Any) -> Union[int, List[int], str, List[str]]: if self.multi_label: return [labels[cls] for cls in classes] return labels[classes] - rank_zero_warn("No LabelsState was found, this output will act as a Classes output.", UserWarning) + rank_zero_warn("No ClassificationState was found, this output will act as a Classes output.", UserWarning) return classes @@ -223,7 +224,7 @@ class FiftyOneLabels(ClassificationOutput): Args: labels: A list of labels, assumed to map the class index to the label for that class. If ``labels`` is not - provided, will attempt to get them from the :class:`.LabelsState`. + provided, will attempt to get them from the :class:`.ClassificationState`. multi_label: If true, treats outputs as multi label logits. threshold: A threshold to use to filter candidate labels. In the single label case, predictions below this threshold will be replaced with None @@ -252,7 +253,7 @@ def __init__( self.return_filepath = return_filepath if labels is not None: - self.set_state(LabelsState(labels)) + self.set_state(ClassificationState(labels)) def transform( self, @@ -266,7 +267,7 @@ def transform( if self._labels is not None: labels = self._labels else: - state = self.get_state(LabelsState) + state = self.get_state(ClassificationState) if state is not None: labels = state.labels @@ -309,7 +310,7 @@ def transform( logits=logits, ) else: - rank_zero_warn("No LabelsState was found, int targets will be used as label strings", UserWarning) + rank_zero_warn("No ClassificationState was found, int targets will be used as label strings", UserWarning) if self.multi_label: classifications = [] diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index 6b2b04ddeb..59c14f477c 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -149,6 +149,10 @@ def __init__( self.set_running_stages() + # Share state between input objects (this will be available in ``load_sample`` but not in ``load_data``) + data_pipeline = self.data_pipeline + data_pipeline.initialize() + @property def train_dataset(self) -> Optional[Dataset]: """This property returns the train dataset.""" @@ -420,7 +424,7 @@ def num_classes(self) -> Optional[int]: @property def multi_label(self) -> Optional[bool]: - """Property that returns the number of labels of the datamodule if a multilabel task.""" + """Property that returns ``True`` if this ``DataModule`` contains multi-label data.""" multi_label_train = getattr(self.train_dataset, "multi_label", None) multi_label_val = getattr(self.val_dataset, "multi_label", None) multi_label_test = getattr(self.test_dataset, "multi_label", None) diff --git a/flash/core/data/io/classification_input.py b/flash/core/data/io/classification_input.py new file mode 100644 index 0000000000..bb41e0f1ae --- /dev/null +++ b/flash/core/data/io/classification_input.py @@ -0,0 +1,79 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from functools import lru_cache +from typing import Any, List, Optional, Sequence + +from flash.core.data.io.input_base import Input +from flash.core.data.properties import ProcessState +from flash.core.data.utilities.classification import ( + get_target_details, + get_target_formatter, + get_target_mode, + TargetFormatter, +) + + +@dataclass(unsafe_hash=True, frozen=True) +class ClassificationState(ProcessState): + """A :class:`~flash.core.data.properties.ProcessState` containing ``labels`` (a mapping from class index to + label) and ``num_classes``.""" + + labels: Optional[Sequence[str]] + num_classes: Optional[int] = None + + +class ClassificationInput(Input): + """The ``ClassificationInput`` class provides utility methods for handling classification targets. + :class:`~flash.core.data.io.input_base.Input` objects that extend ``ClassificationInput`` should do the following: + + * In the ``load_data`` method, include a call to ``load_target_metadata``. This will determine the format of the + targets and store metadata like ``labels`` and ``num_classes``. + * In the ``load_sample`` method, use ``format_target`` to convert the target to a standard format for use with our + tasks. + """ + + @property + @lru_cache(maxsize=None) + def target_formatter(self) -> TargetFormatter: + """Get the :class:`~flash.core.data.utiltiies.classification.TargetFormatter` to use when formatting + targets. + + This property uses ``functools.lru_cache`` so that we only instantiate the formatter once. + """ + classification_state = self.get_state(ClassificationState) + return get_target_formatter(self.target_mode, classification_state.labels, classification_state.num_classes) + + def load_target_metadata(self, targets: List[Any]) -> None: + """Determine the target format and store the ``labels`` and ``num_classes``. + + Args: + targets: The list of targets. + """ + self.target_mode = get_target_mode(targets) + self.multi_label = self.target_mode.multi_label + if self.training: + self.labels, self.num_classes = get_target_details(targets, self.target_mode) + self.set_state(ClassificationState(self.labels, self.num_classes)) + + def format_target(self, target: Any) -> Any: + """Format a single target according to the previously computed target format and metadata. + + Args: + target: The target to format. + + Returns: + The formatted target. + """ + return self.target_formatter(target) diff --git a/flash/core/data/io/input.py b/flash/core/data/io/input.py index 5969608480..55af982c2e 100644 --- a/flash/core/data/io/input.py +++ b/flash/core/data/io/input.py @@ -45,8 +45,9 @@ from tqdm import tqdm from flash.core.data.auto_dataset import AutoDataset, BaseAutoDataset, IterableAutoDataset +from flash.core.data.io.classification_input import ClassificationState from flash.core.data.properties import ProcessState, Properties -from flash.core.data.utilities.paths import read_csv +from flash.core.data.utilities.data_frame import read_csv from flash.core.data.utils import CurrentRunningStageFuncContext from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires from flash.core.utilities.stages import RunningStage @@ -72,7 +73,7 @@ def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bo Returns: bool: True if the filename ends with one of given extensions """ - return filename.lower().endswith(extensions) + return str(filename).lower().endswith(extensions) # Credit to the PyTorchVision Team: @@ -135,14 +136,6 @@ def has_len(data: Union[Sequence[Any], Iterable[Any]]) -> bool: return False -@dataclass(unsafe_hash=True, frozen=True) -class LabelsState(ProcessState): - """A :class:`~flash.core.data.properties.ProcessState` containing ``labels``, a mapping from class index to - label.""" - - labels: Optional[Sequence[str]] - - @dataclass(unsafe_hash=True, frozen=True) class ImageLabelsMap(ProcessState): @@ -361,7 +354,7 @@ class DatasetInput(Input[Dataset]): Args: labels: Optionally pass the labels as a mapping from class index to label string. These will then be set as the - :class:`~flash.core.data.io.input.LabelsState`. + :class:`~flash.core.data.io.input.ClassificationState`. """ def load_sample(self, sample: Any, dataset: Optional[Any] = None) -> Mapping[str, Any]: @@ -380,7 +373,7 @@ class SequenceInput( Args: labels: Optionally pass the labels as a mapping from class index to label string. These will then be set as the - :class:`~flash.core.data.io.input.LabelsState`. + :class:`~flash.core.data.io.input.ClassificationState`. """ def __init__(self, labels: Optional[Sequence[str]] = None): @@ -389,7 +382,7 @@ def __init__(self, labels: Optional[Sequence[str]] = None): self.labels = labels if self.labels is not None: - self.set_state(LabelsState(self.labels)) + self.set_state(ClassificationState(self.labels)) def load_data( self, @@ -415,7 +408,7 @@ class PathsInput(SequenceInput): Args: extensions: The file extensions supported by this data source (e.g. ``(".jpg", ".png")``). labels: Optionally pass the labels as a mapping from class index to label string. These will then be set as the - :class:`~flash.core.data.io.input.LabelsState`. + :class:`~flash.core.data.io.input.ClassificationState`. """ def __init__( @@ -459,7 +452,7 @@ def load_data( classes, class_to_idx = self.find_classes(data) if not classes: return self.predict_load_data(data) - self.set_state(LabelsState(classes)) + self.set_state(ClassificationState(classes)) if dataset is not None: dataset.num_classes = len(classes) @@ -577,7 +570,7 @@ def load_data( if isinstance(target_keys, List): dataset.multi_label = True dataset.num_classes = len(target_keys) - self.set_state(LabelsState(target_keys)) + self.set_state(ClassificationState(target_keys)) data_frame = data_frame.apply(partial(self._resolve_multi_target, target_keys), axis=1) target_keys = target_keys[0] else: @@ -585,9 +578,9 @@ def load_data( if self.training: labels = list(sorted(data_frame[target_keys].unique())) dataset.num_classes = len(labels) - self.set_state(LabelsState(labels)) + self.set_state(ClassificationState(labels)) - labels = self.get_state(LabelsState) + labels = self.get_state(ClassificationState) if labels is not None: labels = labels.labels diff --git a/flash/core/data/io/input_base.py b/flash/core/data/io/input_base.py index d8c58746a1..94ef607f49 100644 --- a/flash/core/data/io/input_base.py +++ b/flash/core/data/io/input_base.py @@ -14,6 +14,7 @@ import functools import os import sys +from copy import copy, deepcopy from typing import Any, cast, Dict, Iterable, MutableMapping, Optional, Sequence, Tuple, Union from torch.utils.data import Dataset @@ -147,7 +148,7 @@ def _call_load_sample(self, sample: Any) -> Any: InputBase, ), ) - return load_sample(sample) + return load_sample(copy(sample)) @staticmethod def load_data(*args: Any, **kwargs: Any) -> Union[Sequence, Iterable]: @@ -190,6 +191,24 @@ def __setstate__(self, newstate): newstate["data"] = None self.__dict__.update(newstate) + def __copy__(self): + """The default copy implementation seems to use ``__getstate__`` and ``__setstate__`` so we override it + here with a custom implementation to ensure that it includes the data list.""" + cls = self.__class__ + result = cls.__new__(cls) + result.__dict__.update(self.__dict__) + return result + + def __deepcopy__(self, memo): + """The default deepcopy implementation seems to use ``__getstate__`` and ``__setstate__`` so we override it + here with a custom implementation to ensure that it includes the data list.""" + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + setattr(result, k, deepcopy(v, memo)) + return result + def __bool__(self): """If ``self.data`` is ``None`` then the ``InputBase`` is considered falsey. diff --git a/flash/core/data/utilities/classification.py b/flash/core/data/utilities/classification.py new file mode 100644 index 0000000000..1fc53e0421 --- /dev/null +++ b/flash/core/data/utilities/classification.py @@ -0,0 +1,306 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re +from enum import auto, Enum +from functools import reduce +from typing import Any, cast, Iterable, List, Optional, Tuple, Union + +import numpy as np +import torch + + +def _is_list_like(x: Any) -> bool: + try: + _ = x[0] + _ = len(x) + return True + except (TypeError, IndexError): # Single element tensors raise an `IndexError` + return False + + +def _as_list(x: Union[List, torch.Tensor, np.ndarray]) -> List: + if torch.is_tensor(x) or isinstance(x, np.ndarray): + return cast(List, x.tolist()) + return x + + +def _convert(text: str) -> Union[int, str]: + return int(text) if text.isdigit() else text + + +def _alphanumeric_key(key: str) -> List[Union[int, str]]: + return [_convert(c) for c in re.split("([0-9]+)", key)] + + +def _sorted_nicely(iterable: Iterable[str]) -> Iterable[str]: + """Sort the given iterable in the way that humans expect. For example, given ``{"class_1", "class_11", + "class_2"}`` this returns ``["class_1", "class_2", "class_11"]``. + + Copied from: + https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/ + """ + return sorted(iterable, key=_alphanumeric_key) + + +class TargetMode(Enum): + """The ``TargetMode`` Enum describes the different supported formats for targets in Flash.""" + + MULTI_TOKEN = auto() + MULTI_NUMERIC = auto() + MUTLI_COMMA_DELIMITED = auto() + MULTI_BINARY = auto() + + SINGLE_TOKEN = auto() + SINGLE_NUMERIC = auto() + SINGLE_BINARY = auto() + + @classmethod + def from_target(cls, target: Any) -> "TargetMode": + """Determine the ``TargetMode`` for a given target. + + Multi-label targets can be: + * Comma delimited string - ``TargetMode.MUTLI_COMMA_DELIMITED`` (e.g. ["blue,green", "red"]) + * List of strings - ``TargetMode.MULTI_TOKEN`` (e.g. [["blue", "green"], ["red"]]) + * List of numbers - ``TargetMode.MULTI_NUMERIC`` (e.g. [[0, 1], [2]]) + * Binary list - ``TargetMode.MULTI_BINARY`` (e.g. [[1, 1, 0], [0, 0, 1]]) + + Single-label targets can be: + * Single string - ``TargetMode.SINGLE_TOKEN`` (e.g. ["blue", "green", "red"]) + * Single number - ``TargetMode.SINGLE_NUMERIC`` (e.g. [0, 1, 2]) + * One-hot binary list - ``TargetMode.SINGLE_BINARY`` (e.g. [[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + + Args: + target: A target that is one of: a single target, a list of targets, a comma delimited string. + """ + if isinstance(target, str): + # TODO: This could be a dangerous assumption if people happen to have a label that contains a comma + if "," in target: + return TargetMode.MUTLI_COMMA_DELIMITED + else: + return TargetMode.SINGLE_TOKEN + elif _is_list_like(target): + if isinstance(target[0], str): + return TargetMode.MULTI_TOKEN + elif all(t == 0 or t == 1 for t in target): + if sum(target) == 1: + return TargetMode.SINGLE_BINARY + return TargetMode.MULTI_BINARY + return TargetMode.MULTI_NUMERIC + return TargetMode.SINGLE_NUMERIC + + @property + def multi_label(self) -> bool: + return any( + [ + self is TargetMode.MUTLI_COMMA_DELIMITED, + self is TargetMode.MULTI_NUMERIC, + self is TargetMode.MULTI_TOKEN, + self is TargetMode.MULTI_BINARY, + ] + ) + + @property + def numeric(self) -> bool: + return any( + [ + self is TargetMode.MULTI_NUMERIC, + self is TargetMode.SINGLE_NUMERIC, + ] + ) + + @property + def binary(self) -> bool: + return any( + [ + self is TargetMode.MULTI_BINARY, + self is TargetMode.SINGLE_BINARY, + ] + ) + + +_RESOLUTION_MAPPING = { + TargetMode.MULTI_BINARY: [TargetMode.MULTI_NUMERIC], + TargetMode.SINGLE_BINARY: [TargetMode.MULTI_BINARY, TargetMode.MULTI_NUMERIC], + TargetMode.SINGLE_TOKEN: [TargetMode.MUTLI_COMMA_DELIMITED], +} + + +def _resolve_target_mode(a: TargetMode, b: TargetMode) -> TargetMode: + """The purpose of the addition here is to reduce the ``TargetMode`` over multiple targets. If one target mode + is a comma delimited string and the other a single string then their sum will be comma delimited. If one target + is multi binary and the other is single binary, their sum will be multi binary. Otherwise, we expect that both + target modes are the same. + + Raises: + ValueError: If the two target modes could not be resolved to a single mode. + """ + if a is b: + return a + elif a in _RESOLUTION_MAPPING and b in _RESOLUTION_MAPPING[a]: + return b + elif b in _RESOLUTION_MAPPING and a in _RESOLUTION_MAPPING[b]: + return a + raise ValueError( + "Found inconsistent target modes. All targets should be either: single values, lists of values, or " + "comma-delimited strings." + ) + + +def get_target_mode(targets: List[Any]) -> TargetMode: + """Aggregate the ``TargetMode`` for a list of targets. + + Args: + targets: The list of targets to get the label mode for. + + Returns: + The total ``TargetMode`` of the list of targets. + """ + targets = _as_list(targets) + return reduce(_resolve_target_mode, [TargetMode.from_target(target) for target in targets]) + + +class TargetFormatter: + """A ``TargetFormatter`` is used to convert targets of a given type to a standard format required by the + task.""" + + def __call__(self, target: Any) -> Any: + return self.format(target) + + def format(self, target: Any) -> Any: + return _as_list(target) + + +class SingleLabelTargetFormatter(TargetFormatter): + def __init__(self, labels: List[Any]): + self.label_to_idx = {label: idx for idx, label in enumerate(labels)} + + def format(self, target: Any) -> Any: + return self.label_to_idx[(target[0] if not isinstance(target, str) else target).strip()] + + +class MultiLabelTargetFormatter(SingleLabelTargetFormatter): + def __init__(self, labels: List[Any]): + super().__init__(labels) + + self.num_classes = len(labels) + + def format(self, target: Any) -> Any: + result = [0] * self.num_classes + for t in target: + idx = super().format(t) + result[idx] = 1 + return result + + +class CommaDelimitedTargetFormatter(MultiLabelTargetFormatter): + def format(self, target: Any) -> Any: + return super().format(target.split(",")) + + +class MultiNumericTargetFormatter(TargetFormatter): + def __init__(self, num_classes: int): + self.num_classes = num_classes + + def format(self, target: Any) -> Any: + result = [0] * self.num_classes + for idx in target: + result[idx] = 1 + return result + + +class OneHotTargetFormatter(TargetFormatter): + def format(self, target: Any) -> Any: + for idx, t in enumerate(target): + if t == 1: + return idx + return 0 + + +def get_target_formatter( + target_mode: TargetMode, labels: Optional[List[Any]], num_classes: Optional[int] +) -> TargetFormatter: + """Get the ``TargetFormatter`` object to use for the given ``TargetMode``, ``labels``, and ``num_classes``. + + Args: + target_mode: The target mode to format. + labels: Labels used by the target (if available). + num_classes: The number of classes in the targets. + + Returns: + The target formatter to use when formatting targets. + """ + if target_mode is TargetMode.SINGLE_NUMERIC or target_mode is TargetMode.MULTI_BINARY: + return TargetFormatter() + elif target_mode is TargetMode.SINGLE_BINARY: + return OneHotTargetFormatter() + elif target_mode is TargetMode.MULTI_NUMERIC: + return MultiNumericTargetFormatter(num_classes) + elif target_mode is TargetMode.SINGLE_TOKEN: + return SingleLabelTargetFormatter(labels) + elif target_mode is TargetMode.MUTLI_COMMA_DELIMITED: + return CommaDelimitedTargetFormatter(labels) + return MultiLabelTargetFormatter(labels) + + +def get_target_details(targets: List[Any], target_mode: TargetMode) -> Tuple[Optional[List[Any]], int]: + """Given a list of targets and their ``TargetMode``, this function determines the ``labels`` and + ``num_classes``. Targets can be: + + * Token-based: ``labels`` is the unique tokens, ``num_classes`` is the number of unique tokens. + * Numeric: ``labels`` is ``None`` and ``num_classes`` is the maximum value plus one. + * Binary: ``labels`` is ``None`` and ``num_classes`` is the length of the binary target. + + Args: + targets: A list of targets. + target_mode: The ``TargetMode`` of the targets from ``get_target_mode``. + + Returns: + (labels, num_classes): Tuple containing the inferred ``labels`` (or ``None`` if no labels could be inferred) + and ``num_classes``. + """ + targets = _as_list(targets) + if target_mode.numeric: + # Take a max over all values + if target_mode is TargetMode.MULTI_NUMERIC: + values = [] + for target in targets: + values.extend(target) + else: + values = targets + num_classes = max(values) + if _is_list_like(num_classes): + num_classes = num_classes[0] + num_classes = num_classes + 1 + labels = None + elif target_mode.binary: + # Take a length + # TODO: Add a check here and error if target lengths are not all equal + num_classes = len(targets[0]) + labels = None + else: + # Compute tokens + tokens = [] + if target_mode is TargetMode.MUTLI_COMMA_DELIMITED: + for target in targets: + tokens.extend(target.split(",")) + elif target_mode is TargetMode.MULTI_TOKEN: + for target in targets: + tokens.extend(target) + else: + tokens = targets + + tokens = [token.strip() for token in tokens] + labels = list(_sorted_nicely(set(tokens))) + num_classes = len(labels) + return labels, num_classes diff --git a/flash/core/data/utilities/data_frame.py b/flash/core/data/utilities/data_frame.py new file mode 100644 index 0000000000..07452f422b --- /dev/null +++ b/flash/core/data/utilities/data_frame.py @@ -0,0 +1,90 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from functools import partial +from typing import Any, Callable, List, Optional, Union + +import pandas as pd +from pytorch_lightning.utilities import rank_zero_warn + +from flash.core.data.utilities.paths import PATH_TYPE +from flash.core.utilities.imports import _PANDAS_GREATER_EQUAL_1_3_0 + + +def read_csv(file: PATH_TYPE) -> pd.DataFrame: + """A wrapper for ``pd.read_csv`` which tries to handle errors gracefully. + + Args: + file: The CSV file to read. + + Returns: + A ``DataFrame`` containing the contents of the file. + """ + try: + return pd.read_csv(file, encoding="utf-8") + except UnicodeDecodeError: + rank_zero_warn("A UnicodeDecodeError was raised when reading the CSV. This error will be ignored.") + if _PANDAS_GREATER_EQUAL_1_3_0: + return pd.read_csv(file, encoding="utf-8", encoding_errors="ignore") + else: + return pd.read_csv(file, encoding=None, engine="python") + + +def _resolve_multi_target(target_keys: List[str], row: pd.Series) -> List[Any]: + return [row[target_key] for target_key in target_keys] + + +def resolve_targets(data_frame: pd.DataFrame, target_keys: Union[str, List[str]]) -> List[Any]: + """Given a data frame and a target key or list of target keys, this function returns a list of targets. + + Args: + data_frame: The ``pd.DataFrame`` containing the target column / columns. + target_keys: The column in the data frame (or a list of columns) from which to resolve the target. + """ + if not isinstance(target_keys, List): + return data_frame[target_keys].tolist() + return data_frame.apply(partial(_resolve_multi_target, target_keys), axis=1).tolist() + + +def _resolve_file( + resolver: Callable[[PATH_TYPE, Any], PATH_TYPE], root: PATH_TYPE, input_key: str, row: pd.Series +) -> PATH_TYPE: + return resolver(root, row[input_key]) + + +def default_resolver(root: Optional[PATH_TYPE], file_id: Any) -> PATH_TYPE: + file = os.path.join(root, file_id) if root is not None else file_id + if os.path.isfile(file): + return file + raise ValueError( + f"File ID `{file_id}` did not resolve to an existing file. For use cases which involve first converting the ID " + f"to a file you should pass a custom resolver when loading the data." + ) + + +def resolve_files( + data_frame: pd.DataFrame, key: str, root: PATH_TYPE, resolver: Callable[[Optional[PATH_TYPE], Any], PATH_TYPE] +) -> List[PATH_TYPE]: + """Resolves a list of files from a given column in a data frame. + + Args: + data_frame: The ``pd.DataFrame`` containing file IDs. + key: The column in the data frame containing the file IDs. + root: The root path to use when resolving files. + resolver: The resolver function to use. This function should receive the root and a file ID as input and return + the path to an existing file. + """ + if resolver is None: + resolver = default_resolver + return data_frame.apply(partial(_resolve_file, resolver, root, key), axis=1).tolist() diff --git a/flash/core/data/utilities/paths.py b/flash/core/data/utilities/paths.py index 77f737241a..c2f43915e3 100644 --- a/flash/core/data/utilities/paths.py +++ b/flash/core/data/utilities/paths.py @@ -12,15 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -import warnings -from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union - -import pandas as pd - -from flash.core.utilities.imports import _PANDAS_GREATER_EQUAL_1_3_0 +from typing import Any, Callable, cast, List, Optional, Tuple, TypeVar, Union PATH_TYPE = Union[str, bytes, os.PathLike] +T = TypeVar("T") + # adapted from torchvision: # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py#L10 @@ -37,19 +34,17 @@ def has_file_allowed_extension(filename: PATH_TYPE, extensions: Tuple[str, ...]) return str(filename).lower().endswith(extensions) -# Copied from torchvision: +# Adapted from torchvision: # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py#L48 def make_dataset( directory: PATH_TYPE, - class_to_idx: Dict[str, int], extensions: Optional[Tuple[str, ...]] = None, is_valid_file: Optional[Callable[[str], bool]] = None, -) -> List[Tuple[str, int]]: +) -> Tuple[List[PATH_TYPE], Optional[List[PATH_TYPE]]]: """Generates a list of samples of a form (path_to_sample, class). Args: directory (str): root dataset directory - class_to_idx (Dict[str, int]): dictionary mapping class name to class index extensions (optional): A list of allowed extensions. Either extensions or is_valid_file should be passed. Defaults to None. is_valid_file (optional): A function that takes path of a file @@ -61,9 +56,9 @@ def make_dataset( ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None. Returns: - List[Tuple[str, int]]: samples of a form (path_to_sample, class) + (files, targets) Tuple containing the list of files and corresponding list of targets. """ - instances = [] + files, targets = [], [] directory = os.path.expanduser(str(directory)) both_none = extensions is None and is_valid_file is None both_something = extensions is not None and is_valid_file is not None @@ -75,18 +70,20 @@ def is_valid_file(x: str) -> bool: return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions)) is_valid_file = cast(Callable[[str], bool], is_valid_file) - for target_class in sorted(class_to_idx.keys()): - class_index = class_to_idx[target_class] - target_dir = os.path.join(directory, target_class) - if not os.path.isdir(target_dir): - continue - for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): - for fname in sorted(fnames): - path = os.path.join(root, fname) - if is_valid_file(path): - item = path, class_index - instances.append(item) - return instances + subdirs = list_subdirs(directory) + if len(subdirs) > 0: + for target_class in subdirs: + target_dir = os.path.join(directory, target_class) + if not os.path.isdir(target_dir): + continue + for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): + for fname in sorted(fnames): + path = os.path.join(root, fname) + if is_valid_file(path): + files.append(path) + targets.append(target_class) + return files, targets + return list_valid_files(directory), None def isdir(path: Any) -> bool: @@ -97,19 +94,16 @@ def isdir(path: Any) -> bool: return False -def find_classes(dir: PATH_TYPE) -> Tuple[List[str], Dict[str, int]]: - """Finds the class folders in a dataset. Ensures that no class is a subdirectory of another. +def list_subdirs(dir: PATH_TYPE) -> List[str]: + """List the subdirectories of a given directory. Args: - dir: Root directory path. + dir: The directory to scan. Returns: - (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. + The list of subdirectories. """ - classes = [d.name for d in os.scandir(str(dir)) if d.is_dir()] - classes.sort() - class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} - return classes, class_to_idx + return [d.name for d in os.scandir(str(dir)) if d.is_dir()] def list_valid_files( @@ -132,28 +126,33 @@ def list_valid_files( if valid_extensions is None: return paths - return list( - filter( - lambda file: has_file_allowed_extension(file, valid_extensions), - paths, - ) - ) + return [path for path in paths if has_file_allowed_extension(path, valid_extensions)] -def read_csv(file: PATH_TYPE) -> pd.DataFrame: - """A wrapper for ``pd.read_csv`` which tries to handle errors gracefully. +def filter_valid_files( + files: Union[PATH_TYPE, List[PATH_TYPE]], + *additional_lists: List[Any], + valid_extensions: Optional[Tuple[str, ...]] = None +) -> Union[List[Any], Tuple[List[Any], ...]]: + """Filter the given list of files and any additional lists to include only the entries that contain a file with + a valid extension. Args: - file: The CSV file to read. + files: The list of files to filter by. + additional_lists: Any additional lists to be filtered together with files. + valid_extensions: The tuple of valid file extensions. Returns: - A ``DataFrame`` containing the contents of the file. + The filtered lists. """ - try: - return pd.read_csv(file, encoding="utf-8") - except UnicodeDecodeError: - warnings.warn("A UnicodeDecodeError was raised when reading the CSV. This error will be ignored.") - if _PANDAS_GREATER_EQUAL_1_3_0: - return pd.read_csv(file, encoding="utf-8", encoding_errors="ignore") - else: - return pd.read_csv(file, encoding=None, engine="python") + if not isinstance(files, List): + files = [files] + + if valid_extensions is None: + return (files,) + additional_lists + filtered = list( + filter(lambda sample: has_file_allowed_extension(sample[0], valid_extensions), zip(files, *additional_lists)) + ) + if len(additional_lists) > 0: + return tuple(zip(*filtered)) + return [f[0] for f in filtered] diff --git a/flash/core/data/utilities/samples.py b/flash/core/data/utilities/samples.py new file mode 100644 index 0000000000..4d3cfbe79e --- /dev/null +++ b/flash/core/data/utilities/samples.py @@ -0,0 +1,33 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, List, Optional, TypeVar + +from flash.core.data.io.input import DataKeys + +T = TypeVar("T") + + +def to_samples(inputs: List[Any], targets: Optional[List[Any]] = None) -> List[Dict[str, Any]]: + """Package a list of inputs and, optionally, a list of targets in a list of dictionaries (samples). + + Args: + inputs: The list of inputs to package as dictionaries. + targets: Optionally provide a list of targets to also be included in the samples. + + Returns: + A list of sample dictionaries. + """ + if targets is None: + return [{DataKeys.INPUT: input} for input in inputs] + return [{DataKeys.INPUT: input, DataKeys.TARGET: target} for input, target in zip(inputs, targets)] diff --git a/flash/core/integrations/icevision/data.py b/flash/core/integrations/icevision/data.py index 1e249a1333..a6b8cb9e92 100644 --- a/flash/core/integrations/icevision/data.py +++ b/flash/core/integrations/icevision/data.py @@ -16,7 +16,8 @@ import numpy as np -from flash.core.data.io.input import DataKeys, LabelsState +from flash.core.data.io.classification_input import ClassificationState +from flash.core.data.io.input import DataKeys from flash.core.data.io.input_base import Input from flash.core.data.utilities.paths import list_valid_files from flash.core.integrations.icevision.transforms import from_icevision_record @@ -44,7 +45,7 @@ def load_data( else: raise ValueError("The parser must be a callable or an IceVision Parser type.") self.num_classes = parser.class_map.num_classes - self.set_state(LabelsState([parser.class_map.get_by_id(i) for i in range(self.num_classes)])) + self.set_state(ClassificationState([parser.class_map.get_by_id(i) for i in range(self.num_classes)])) records = parser.parse(data_splitter=SingleSplitSplitter()) return [{DataKeys.INPUT: record} for record in records[0]] diff --git a/flash/image/classification/adapters.py b/flash/image/classification/adapters.py index 8ed89d6fcc..60cf9292f9 100644 --- a/flash/image/classification/adapters.py +++ b/flash/image/classification/adapters.py @@ -30,6 +30,7 @@ from flash.core.adapter import Adapter, AdapterTask from flash.core.data.auto_dataset import BaseAutoDataset from flash.core.data.io.input import DataKeys +from flash.core.data.io.input_base import InputBase from flash.core.model import Task from flash.core.registry import FlashRegistry from flash.core.utilities.compatibility import accelerator_connector @@ -204,7 +205,7 @@ def _convert_dataset( num_task: int, epoch_length: int, ): - if isinstance(dataset, BaseAutoDataset): + if isinstance(dataset, (InputBase, BaseAutoDataset)): metadata = getattr(dataset, "data", None) if metadata is None or (metadata is not None and not isinstance(dataset.data, list)): diff --git a/flash/image/classification/cli.py b/flash/image/classification/cli.py index 350e0fe68d..1ffc4052d6 100644 --- a/flash/image/classification/cli.py +++ b/flash/image/classification/cli.py @@ -22,7 +22,7 @@ def from_hymenoptera( batch_size: int = 4, num_workers: int = 0, - **input_transform_kwargs, + **data_module_kwargs, ) -> ImageClassificationData: """Downloads and loads the Hymenoptera (Ants, Bees) data set.""" download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data") @@ -31,14 +31,14 @@ def from_hymenoptera( val_folder="data/hymenoptera_data/val/", batch_size=batch_size, num_workers=num_workers, - **input_transform_kwargs, + **data_module_kwargs, ) def from_movie_posters( batch_size: int = 4, num_workers: int = 0, - **input_transform_kwargs, + **data_module_kwargs, ) -> ImageClassificationData: """Downloads and loads the movie posters genre classification data set.""" download_data("https://pl-flash-data.s3.amazonaws.com/movie_posters.zip", "./data") @@ -49,7 +49,7 @@ def from_movie_posters( val_file="data/movie_posters/val/metadata.csv", batch_size=batch_size, num_workers=num_workers, - **input_transform_kwargs, + **data_module_kwargs, ) @@ -64,7 +64,6 @@ def image_classification(): "trainer.max_epochs": 3, }, datamodule_attributes={"num_classes", "multi_label"}, - legacy=True, ) cli.trainer.save_checkpoint("image_classification_model.pt") diff --git a/flash/image/classification/data.py b/flash/image/classification/data.py index 35f0ddb759..e5b6fe1459 100644 --- a/flash/image/classification/data.py +++ b/flash/image/classification/data.py @@ -11,30 +11,38 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union +import os +from typing import Any, Callable, Collection, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import pandas as pd import torch -from torch.utils.data.sampler import Sampler -from flash.core.data.base_viz import BaseVisualization # for viz +from flash.core.data.base_viz import BaseVisualization from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule -from flash.core.data.io.input import DataKeys, InputFormat, LoaderDataFrameInput +from flash.core.data.io.classification_input import ClassificationInput, ClassificationState +from flash.core.data.io.input import DataKeys, InputFormat from flash.core.data.io.input_transform import InputTransform from flash.core.data.process import Deserializer +from flash.core.data.utilities.classification import TargetMode +from flash.core.data.utilities.data_frame import read_csv, resolve_files, resolve_targets +from flash.core.data.utilities.paths import filter_valid_files, make_dataset, PATH_TYPE +from flash.core.data.utilities.samples import to_samples +from flash.core.integrations.fiftyone.utils import FiftyOneLabelUtilities from flash.core.integrations.labelstudio.input import LabelStudioImageClassificationInput from flash.core.utilities.imports import _MATPLOTLIB_AVAILABLE, Image, requires from flash.core.utilities.stages import RunningStage from flash.image.classification.transforms import default_transforms, train_default_transforms from flash.image.data import ( - image_loader, + fol, ImageDeserializer, - ImageFiftyOneInput, + ImageFilesInput, ImageNumpyInput, - ImagePathsInput, ImageTensorInput, + IMG_EXTENSIONS, + NP_EXTENSIONS, + SampleCollection, ) if _MATPLOTLIB_AVAILABLE: @@ -43,30 +51,119 @@ plt = None -class ImageClassificationDataFrameInput(LoaderDataFrameInput): - def __init__(self): - super().__init__(image_loader) +class ImageClassificationFilesInput(ClassificationInput, ImageFilesInput): + def load_data( + self, + files: List[PATH_TYPE], + targets: Optional[List[Any]] = None, + ) -> List[Dict[str, Any]]: + if targets is None: + return super().load_data(files) + files, targets = filter_valid_files(files, targets, valid_extensions=IMG_EXTENSIONS + NP_EXTENSIONS) + self.load_target_metadata(targets) + return to_samples(files, targets) + + def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: + sample = super().load_sample(sample) + if DataKeys.TARGET in sample: + sample[DataKeys.TARGET] = self.format_target(sample[DataKeys.TARGET]) + return sample - @requires("image") - def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: - sample = super().load_sample(sample, dataset) - w, h = sample[DataKeys.INPUT].size # WxH - sample[DataKeys.METADATA]["size"] = (h, w) + +class ImageClassificationFolderInput(ImageClassificationFilesInput): + def load_data(self, folder: PATH_TYPE) -> List[Dict[str, Any]]: + files, targets = make_dataset(folder, extensions=IMG_EXTENSIONS + NP_EXTENSIONS) + return super().load_data(files, targets) + + +class ImageClassificationFiftyOneInput(ImageClassificationFilesInput): + @requires("fiftyone") + def load_data(self, sample_collection: SampleCollection, label_field: str = "ground_truth") -> List[Dict[str, Any]]: + label_utilities = FiftyOneLabelUtilities(label_field, fol.Label) + label_utilities.validate(sample_collection) + + label_path = sample_collection._get_label_field_path(label_field, "label")[1] + + filepaths = sample_collection.values("filepath") + targets = sample_collection.values(label_path) + + return super().load_data(filepaths, targets) + + @staticmethod + @requires("fiftyone") + def predict_load_data(data: SampleCollection) -> List[Dict[str, Any]]: + return super().load_data(data.values("filepath")) + + +class ImageClassificationTensorInput(ClassificationInput, ImageTensorInput): + def load_data(self, tensor: Any, targets: Optional[List[Any]] = None) -> List[Dict[str, Any]]: + if targets is not None: + self.load_target_metadata(targets) + + return to_samples(tensor, targets) + + def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: + sample = super().load_sample(sample) + if DataKeys.TARGET in sample: + sample[DataKeys.TARGET] = self.format_target(sample[DataKeys.TARGET]) return sample +class ImageClassificationNumpyInput(ClassificationInput, ImageNumpyInput): + def load_data(self, array: Any, targets: Optional[List[Any]] = None) -> List[Dict[str, Any]]: + if targets is not None: + self.load_target_metadata(targets) + + return to_samples(array, targets) + + def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: + sample = super().load_sample(sample) + if DataKeys.TARGET in sample: + sample[DataKeys.TARGET] = self.format_target(sample[DataKeys.TARGET]) + return sample + + +class ImageClassificationDataFrameInput(ImageClassificationFilesInput): + def load_data( + self, + data_frame: pd.DataFrame, + input_key: str, + target_keys: Optional[Union[str, List[str]]] = None, + root: Optional[PATH_TYPE] = None, + resolver: Optional[Callable[[Optional[PATH_TYPE], Any], PATH_TYPE]] = None, + ) -> List[Dict[str, Any]]: + files = resolve_files(data_frame, input_key, root, resolver) + if target_keys is not None: + targets = resolve_targets(data_frame, target_keys) + else: + targets = None + result = super().load_data(files, targets) + + # If we had binary multi-class targets then we also know the labels (column names) + if self.training and self.target_mode is TargetMode.MULTI_BINARY and isinstance(target_keys, List): + classification_state = self.get_state(ClassificationState) + self.set_state(ClassificationState(target_keys, classification_state.num_classes)) + + return result + + +class ImageClassificationCSVInput(ImageClassificationDataFrameInput): + def load_data( + self, + csv_file: PATH_TYPE, + input_key: str, + target_keys: Optional[Union[str, List[str]]] = None, + root: Optional[PATH_TYPE] = None, + resolver: Optional[Callable[[Optional[PATH_TYPE], Any], PATH_TYPE]] = None, + ) -> List[Dict[str, Any]]: + data_frame = read_csv(csv_file) + if root is None: + root = os.path.dirname(csv_file) + return super().load_data(data_frame, input_key, target_keys, root, resolver) + + class ImageClassificationInputTransform(InputTransform): - """Preprocssing of data of image classification. - - Args:: - train_transfor:m - val_transform: - test_transform: - predict_transform: - image_size: tuple with the (heigh, width) of the images - deserializer: - input_kwargs: Additional kwargs for the data source initializer - """ + """Preprocssing of data of image classification.""" def __init__( self, @@ -76,7 +173,6 @@ def __init__( predict_transform: Optional[Dict[str, Callable]] = None, image_size: Tuple[int, int] = (196, 196), deserializer: Optional[Deserializer] = None, - **input_kwargs: Any, ): self.image_size = image_size @@ -86,13 +182,13 @@ def __init__( test_transform=test_transform, predict_transform=predict_transform, inputs={ - InputFormat.FIFTYONE: ImageFiftyOneInput(**input_kwargs), - InputFormat.FILES: ImagePathsInput(), - InputFormat.FOLDERS: ImagePathsInput(), - InputFormat.NUMPY: ImageNumpyInput(), - InputFormat.TENSORS: ImageTensorInput(), - "data_frame": ImageClassificationDataFrameInput(), - InputFormat.CSV: ImageClassificationDataFrameInput(), + InputFormat.FIFTYONE: ImageClassificationFiftyOneInput, + InputFormat.FILES: ImageClassificationFilesInput, + InputFormat.FOLDERS: ImageClassificationFolderInput, + InputFormat.NUMPY: ImageClassificationNumpyInput, + InputFormat.TENSORS: ImageClassificationTensorInput, + InputFormat.DATAFRAME: ImageClassificationDataFrameInput, + InputFormat.CSV: ImageClassificationCSVInput, InputFormat.LABELSTUDIO: LabelStudioImageClassificationInput(), }, deserializer=deserializer or ImageDeserializer(), @@ -118,6 +214,131 @@ class ImageClassificationData(DataModule): input_transform_cls = ImageClassificationInputTransform + @classmethod + def from_files( + cls, + train_files: Optional[Sequence[str]] = None, + train_targets: Optional[Sequence[Any]] = None, + val_files: Optional[Sequence[str]] = None, + val_targets: Optional[Sequence[Any]] = None, + test_files: Optional[Sequence[str]] = None, + test_targets: Optional[Sequence[Any]] = None, + predict_files: Optional[Sequence[str]] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + image_size: Tuple[int, int] = (196, 196), + **data_module_kwargs: Any, + ) -> "ImageClassificationData": + return cls( + ImageClassificationFilesInput(RunningStage.TRAINING, train_files, train_targets), + ImageClassificationFilesInput(RunningStage.VALIDATING, val_files, val_targets), + ImageClassificationFilesInput(RunningStage.TESTING, test_files, test_targets), + ImageClassificationFilesInput(RunningStage.PREDICTING, predict_files), + input_transform=cls.input_transform_cls( + train_transform, + val_transform, + test_transform, + predict_transform, + image_size=image_size, + ), + **data_module_kwargs, + ) + + @classmethod + def from_folders( + cls, + train_folder: Optional[str] = None, + val_folder: Optional[str] = None, + test_folder: Optional[str] = None, + predict_folder: Optional[str] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + image_size: Tuple[int, int] = (196, 196), + **data_module_kwargs: Any, + ) -> "ImageClassificationData": + return cls( + ImageClassificationFolderInput(RunningStage.TRAINING, train_folder), + ImageClassificationFolderInput(RunningStage.VALIDATING, val_folder), + ImageClassificationFolderInput(RunningStage.TESTING, test_folder), + ImageClassificationFolderInput(RunningStage.PREDICTING, predict_folder), + input_transform=cls.input_transform_cls( + train_transform, + val_transform, + test_transform, + predict_transform, + image_size=image_size, + ), + **data_module_kwargs, + ) + + @classmethod + def from_numpy( + cls, + train_data: Optional[Collection[np.ndarray]] = None, + train_targets: Optional[Collection[Any]] = None, + val_data: Optional[Collection[np.ndarray]] = None, + val_targets: Optional[Sequence[Any]] = None, + test_data: Optional[Collection[np.ndarray]] = None, + test_targets: Optional[Sequence[Any]] = None, + predict_data: Optional[Collection[np.ndarray]] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + image_size: Tuple[int, int] = (196, 196), + **data_module_kwargs: Any, + ) -> "ImageClassificationData": + return cls( + ImageClassificationNumpyInput(RunningStage.TRAINING, train_data, train_targets), + ImageClassificationNumpyInput(RunningStage.VALIDATING, val_data, val_targets), + ImageClassificationNumpyInput(RunningStage.TESTING, test_data, test_targets), + ImageClassificationNumpyInput(RunningStage.PREDICTING, predict_data), + input_transform=cls.input_transform_cls( + train_transform, + val_transform, + test_transform, + predict_transform, + image_size=image_size, + ), + **data_module_kwargs, + ) + + @classmethod + def from_tensors( + cls, + train_data: Optional[Collection[torch.Tensor]] = None, + train_targets: Optional[Collection[Any]] = None, + val_data: Optional[Collection[torch.Tensor]] = None, + val_targets: Optional[Sequence[Any]] = None, + test_data: Optional[Collection[torch.Tensor]] = None, + test_targets: Optional[Sequence[Any]] = None, + predict_data: Optional[Collection[torch.Tensor]] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + image_size: Tuple[int, int] = (196, 196), + **data_module_kwargs: Any, + ) -> "ImageClassificationData": + return cls( + ImageClassificationTensorInput(RunningStage.TRAINING, train_data, train_targets), + ImageClassificationTensorInput(RunningStage.VALIDATING, val_data, val_targets), + ImageClassificationTensorInput(RunningStage.TESTING, test_data, test_targets), + ImageClassificationTensorInput(RunningStage.PREDICTING, predict_data), + input_transform=cls.input_transform_cls( + train_transform, + val_transform, + test_transform, + predict_transform, + image_size=image_size, + ), + **data_module_kwargs, + ) + @classmethod def from_data_frame( cls, @@ -139,179 +360,112 @@ def from_data_frame( val_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, predict_transform: Optional[Dict[str, Callable]] = None, - data_fetcher: Optional[BaseDataFetcher] = None, - input_transform: Optional[InputTransform] = None, - val_split: Optional[float] = None, - batch_size: int = 4, - num_workers: int = 0, - sampler: Optional[Type[Sampler]] = None, - **input_transform_kwargs: Any, - ) -> "DataModule": - """Creates a :class:`~flash.image.classification.data.ImageClassificationData` object from the given pandas - ``DataFrame`` objects. - - Args: - input_field: The field (column) in the pandas ``DataFrame`` to use for the input. - target_fields: The field or fields (columns) in the pandas ``DataFrame`` to use for the target. - train_data_frame: The pandas ``DataFrame`` containing the training data. - train_images_root: The directory containing the train images. If ``None``, values in the ``input_field`` - will be assumed to be the full file paths. - train_resolver: The function to use to resolve filenames given the ``train_images_root`` and IDs from the - ``input_field`` column. - val_data_frame: The pandas ``DataFrame`` containing the validation data. - val_images_root: The directory containing the validation images. If ``None``, the directory containing the - ``val_file`` will be used. - val_resolver: The function to use to resolve filenames given the ``val_images_root`` and IDs from the - ``input_field`` column. - test_data_frame: The pandas ``DataFrame`` containing the testing data. - test_images_root: The directory containing the test images. If ``None``, the directory containing the - ``test_file`` will be used. - test_resolver: The function to use to resolve filenames given the ``test_images_root`` and IDs from the - ``input_field`` column. - predict_data_frame: The pandas ``DataFrame`` containing the data to use when predicting. - predict_images_root: The directory containing the predict images. If ``None``, the directory containing the - ``predict_file`` will be used. - predict_resolver: The function to use to resolve filenames given the ``predict_images_root`` and IDs from - the ``input_field`` column. - train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the - :class:`~flash.core.data.data_module.DataModule`. - input_transform: The :class:`~flash.core.data.data.InputTransform` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` - will be constructed and used. - val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - sampler: The ``sampler`` to use for the ``train_dataloader``. - input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. - Will only be used if ``input_transform = None``. - - Returns: - The constructed data module. - """ - return cls.from_input( - "data_frame", - (train_data_frame, input_field, target_fields, train_images_root, train_resolver), - (val_data_frame, input_field, target_fields, val_images_root, val_resolver), - (test_data_frame, input_field, target_fields, test_images_root, test_resolver), - (predict_data_frame, input_field, target_fields, predict_images_root, predict_resolver), - train_transform=train_transform, - val_transform=val_transform, - test_transform=test_transform, - predict_transform=predict_transform, - data_fetcher=data_fetcher, - input_transform=input_transform, - val_split=val_split, - batch_size=batch_size, - num_workers=num_workers, - sampler=sampler, - **input_transform_kwargs, + image_size: Tuple[int, int] = (196, 196), + **data_module_kwargs: Any, + ) -> "ImageClassificationData": + return cls( + ImageClassificationDataFrameInput( + RunningStage.TRAINING, train_data_frame, input_field, target_fields, train_images_root, train_resolver + ), + ImageClassificationCSVInput( + RunningStage.VALIDATING, val_data_frame, input_field, target_fields, val_images_root, val_resolver + ), + ImageClassificationCSVInput( + RunningStage.TESTING, test_data_frame, input_field, target_fields, test_images_root, test_resolver + ), + ImageClassificationCSVInput( + RunningStage.PREDICTING, + predict_data_frame, + input_field, + root=predict_images_root, + resolver=predict_resolver, + ), + input_transform=cls.input_transform_cls( + train_transform, + val_transform, + test_transform, + predict_transform, + image_size=image_size, + ), + **data_module_kwargs, ) @classmethod def from_csv( cls, input_field: str, - target_fields: Optional[Union[str, Sequence[str]]] = None, - train_file: Optional[str] = None, - train_images_root: Optional[str] = None, - train_resolver: Optional[Callable[[str, str], str]] = None, - val_file: Optional[str] = None, - val_images_root: Optional[str] = None, - val_resolver: Optional[Callable[[str, str], str]] = None, + target_fields: Optional[Union[str, List[str]]] = None, + train_file: Optional[PATH_TYPE] = None, + train_images_root: Optional[PATH_TYPE] = None, + train_resolver: Optional[Callable[[PATH_TYPE, Any], PATH_TYPE]] = None, + val_file: Optional[PATH_TYPE] = None, + val_images_root: Optional[PATH_TYPE] = None, + val_resolver: Optional[Callable[[PATH_TYPE, Any], PATH_TYPE]] = None, test_file: Optional[str] = None, test_images_root: Optional[str] = None, - test_resolver: Optional[Callable[[str, str], str]] = None, + test_resolver: Optional[Callable[[PATH_TYPE, Any], PATH_TYPE]] = None, predict_file: Optional[str] = None, predict_images_root: Optional[str] = None, - predict_resolver: Optional[Callable[[str, str], str]] = None, + predict_resolver: Optional[Callable[[PATH_TYPE, Any], PATH_TYPE]] = None, train_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, val_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, predict_transform: Optional[Dict[str, Callable]] = None, - data_fetcher: Optional[BaseDataFetcher] = None, - input_transform: Optional[InputTransform] = None, - val_split: Optional[float] = None, - batch_size: int = 4, - num_workers: int = 0, - sampler: Optional[Type[Sampler]] = None, - **input_transform_kwargs: Any, - ) -> "DataModule": - """Creates a :class:`~flash.image.classification.data.ImageClassificationData` object from the given CSV - files using the :class:`~flash.core.data.io.input.Input` of name - :attr:`~flash.core.data.io.input.InputFormat.CSV` from the passed or constructed - :class:`~flash.core.data.io.input_transform.InputTransform`. - - Args: - input_field: The field (column) in the CSV file to use for the input. - target_fields: The field or fields (columns) in the CSV file to use for the target. - train_file: The CSV file containing the training data. - train_images_root: The directory containing the train images. If ``None``, the directory containing the - ``train_file`` will be used. - train_resolver: The function to use to resolve filenames given the ``train_images_root`` and IDs from the - ``input_field`` column. - val_file: The CSV file containing the validation data. - val_images_root: The directory containing the validation images. If ``None``, the directory containing the - ``val_file`` will be used. - val_resolver: The function to use to resolve filenames given the ``val_images_root`` and IDs from the - ``input_field`` column. - test_file: The CSV file containing the testing data. - test_images_root: The directory containing the test images. If ``None``, the directory containing the - ``test_file`` will be used. - test_resolver: The function to use to resolve filenames given the ``test_images_root`` and IDs from the - ``input_field`` column. - predict_file: The CSV file containing the data to use when predicting. - predict_images_root: The directory containing the predict images. If ``None``, the directory containing the - ``predict_file`` will be used. - predict_resolver: The function to use to resolve filenames given the ``predict_images_root`` and IDs from - the ``input_field`` column. - train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the - :class:`~flash.core.data.data_module.DataModule`. - input_transform: The :class:`~flash.core.data.data.InputTransform` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` - will be constructed and used. - val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - sampler: The ``sampler`` to use for the ``train_dataloader``. - input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. - Will only be used if ``input_transform = None``. - - Returns: - The constructed data module. - """ - return cls.from_input( - InputFormat.CSV, - (train_file, input_field, target_fields, train_images_root, train_resolver), - (val_file, input_field, target_fields, val_images_root, val_resolver), - (test_file, input_field, target_fields, test_images_root, test_resolver), - (predict_file, input_field, target_fields, predict_images_root, predict_resolver), - train_transform=train_transform, - val_transform=val_transform, - test_transform=test_transform, - predict_transform=predict_transform, - data_fetcher=data_fetcher, - input_transform=input_transform, - val_split=val_split, - batch_size=batch_size, - num_workers=num_workers, - sampler=sampler, - **input_transform_kwargs, + image_size: Tuple[int, int] = (196, 196), + **data_module_kwargs: Any, + ) -> "ImageClassificationData": + return cls( + ImageClassificationCSVInput( + RunningStage.TRAINING, train_file, input_field, target_fields, train_images_root, train_resolver + ), + ImageClassificationCSVInput( + RunningStage.VALIDATING, val_file, input_field, target_fields, val_images_root, val_resolver + ), + ImageClassificationCSVInput( + RunningStage.TESTING, test_file, input_field, target_fields, test_images_root, test_resolver + ), + ImageClassificationCSVInput( + RunningStage.PREDICTING, predict_file, input_field, root=predict_images_root, resolver=predict_resolver + ), + input_transform=cls.input_transform_cls( + train_transform, + val_transform, + test_transform, + predict_transform, + image_size=image_size, + ), + **data_module_kwargs, + ) + + @classmethod + @requires("fiftyone") + def from_fiftyone( + cls, + train_dataset: Optional[SampleCollection] = None, + val_dataset: Optional[SampleCollection] = None, + test_dataset: Optional[SampleCollection] = None, + predict_dataset: Optional[SampleCollection] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + label_field: str = "ground_truth", + image_size: Tuple[int, int] = (196, 196), + **data_module_kwargs, + ) -> "ImageClassificationData": + return cls( + ImageClassificationFiftyOneInput(RunningStage.TRAINING, train_dataset, label_field), + ImageClassificationFiftyOneInput(RunningStage.VALIDATING, val_dataset, label_field), + ImageClassificationFiftyOneInput(RunningStage.TESTING, test_dataset, label_field), + ImageClassificationFiftyOneInput(RunningStage.PREDICTING, predict_dataset, label_field), + input_transform=cls.input_transform_cls( + train_transform, + val_transform, + test_transform, + predict_transform, + image_size=image_size, + ), + **data_module_kwargs, ) def set_block_viz_window(self, value: bool) -> None: diff --git a/flash/image/classification/integrations/baal/data.py b/flash/image/classification/integrations/baal/data.py index c0badc5c96..1c76b3ca9c 100644 --- a/flash/image/classification/integrations/baal/data.py +++ b/flash/image/classification/integrations/baal/data.py @@ -92,7 +92,7 @@ def __init__( if not self.labelled.num_classes: raise MisconfigurationException("The labelled dataset should be labelled") - if self.labelled and (self.labelled._val_ds is not None or self.labelled._predict_ds is not None): + if self.labelled and (self.labelled._val_ds or self.labelled._predict_ds): raise MisconfigurationException("The labelled `datamodule` should have only train data.") self._dataset = ActiveLearningDataset( @@ -116,7 +116,7 @@ def __init__( @property def has_test(self) -> bool: - return self.labelled._test_ds is not None + return bool(self.labelled._test_ds) @property def has_labelled_data(self) -> bool: diff --git a/flash/image/data.py b/flash/image/data.py index 3d098e4c17..139cb5a9dd 100644 --- a/flash/image/data.py +++ b/flash/image/data.py @@ -12,26 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. import base64 -from collections import defaultdict from io import BytesIO from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Dict, List, TYPE_CHECKING import numpy as np import torch import flash -from flash.core.data.io.input import ( - DataKeys, - FiftyOneInput, - has_file_allowed_extension, - NumpyInput, - PathsInput, - TensorInput, -) +from flash.core.data.io.input import DataKeys +from flash.core.data.io.input_base import Input from flash.core.data.process import Deserializer +from flash.core.data.utilities.paths import filter_valid_files, has_file_allowed_extension, PATH_TYPE +from flash.core.data.utilities.samples import to_samples from flash.core.data.utils import image_default_loader -from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, Image, requires +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _TORCHVISION_AVAILABLE, Image, lazy_import, requires if _TORCHVISION_AVAILABLE: from torchvision.datasets.folder import IMG_EXTENSIONS @@ -39,6 +34,13 @@ else: IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") +SampleCollection = None +if _FIFTYONE_AVAILABLE: + fol = lazy_import("fiftyone.core.labels") + if TYPE_CHECKING: + pass +else: + fol = None NP_EXTENSIONS = (".npy",) @@ -73,55 +75,38 @@ def example_input(self) -> str: return base64.b64encode(f.read()).decode("UTF-8") -def _labels_to_indices(data): - out = defaultdict(list) - for idx, sample in enumerate(data): - label = sample[DataKeys.TARGET] - if torch.is_tensor(label): - label = label.item() - out[label].append(idx) - return out - - -class ImagePathsInput(PathsInput): - def __init__(self): - super().__init__(loader=image_loader, extensions=IMG_EXTENSIONS + NP_EXTENSIONS) - +class ImageInput(Input): @requires("image") - def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: - sample = super().load_sample(sample, dataset) + def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: w, h = sample[DataKeys.INPUT].size # WxH + if DataKeys.METADATA not in sample: + sample[DataKeys.METADATA] = {} sample[DataKeys.METADATA]["size"] = (h, w) return sample -class ImageTensorInput(TensorInput): - def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: - img = to_pil_image(sample[DataKeys.INPUT]) - sample[DataKeys.INPUT] = img - w, h = img.size # WxH - sample[DataKeys.METADATA] = {"size": (h, w)} +class ImageFilesInput(ImageInput): + def load_data(self, files: List[PATH_TYPE]) -> List[Dict[str, Any]]: + files = filter_valid_files(files, valid_extensions=IMG_EXTENSIONS + NP_EXTENSIONS) + return to_samples(files) + + def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: + filepath = sample[DataKeys.INPUT] + sample[DataKeys.INPUT] = image_loader(filepath) + sample = super().load_sample(sample) + sample[DataKeys.METADATA]["filepath"] = filepath return sample -class ImageNumpyInput(NumpyInput): - def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: - img = to_pil_image(torch.from_numpy(sample[DataKeys.INPUT])) +class ImageTensorInput(ImageInput): + def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: + img = to_pil_image(sample[DataKeys.INPUT]) sample[DataKeys.INPUT] = img - w, h = img.size # WxH - sample[DataKeys.METADATA] = {"size": (h, w)} - return sample + return super().load_sample(sample) -class ImageFiftyOneInput(FiftyOneInput): - @staticmethod - def load_sample(sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: - img_path = sample[DataKeys.INPUT] - img = image_default_loader(img_path) +class ImageNumpyInput(ImageInput): + def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: + img = to_pil_image(torch.from_numpy(sample[DataKeys.INPUT])) sample[DataKeys.INPUT] = img - w, h = img.size # WxH - sample[DataKeys.METADATA] = { - "filepath": img_path, - "size": (h, w), - } - return sample + return super().load_sample(sample) diff --git a/flash/image/detection/output.py b/flash/image/detection/output.py index b52c24cbe4..42efb058c2 100644 --- a/flash/image/detection/output.py +++ b/flash/image/detection/output.py @@ -15,7 +15,8 @@ from pytorch_lightning.utilities import rank_zero_warn -from flash.core.data.io.input import DataKeys, LabelsState +from flash.core.data.io.classification_input import ClassificationState +from flash.core.data.io.input import DataKeys from flash.core.data.io.output import Output from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires @@ -33,7 +34,7 @@ class FiftyOneDetectionLabels(Output): Args: labels: A list of labels, assumed to map the class index to the label for that class. If ``labels`` is not - provided, will attempt to get them from the :class:`.LabelsState`. + provided, will attempt to get them from the :class:`.ClassificationState`. threshold: a score threshold to apply to candidate detections. return_filepath: Boolean determining whether to return a dict containing filepath and FiftyOne labels (True) or only a @@ -53,7 +54,7 @@ def __init__( self.return_filepath = return_filepath if labels is not None: - self.set_state(LabelsState(labels)) + self.set_state(ClassificationState(labels)) def transform(self, sample: Dict[str, Any]) -> Union[Detections, Dict[str, Any]]: if DataKeys.METADATA not in sample: @@ -63,11 +64,13 @@ def transform(self, sample: Dict[str, Any]) -> Union[Detections, Dict[str, Any]] if self._labels is not None: labels = self._labels else: - state = self.get_state(LabelsState) + state = self.get_state(ClassificationState) if state is not None: labels = state.labels else: - rank_zero_warn("No LabelsState was found, int targets will be used as label strings", UserWarning) + rank_zero_warn( + "No ClassificationState was found, int targets will be used as label strings", UserWarning + ) height, width = sample[DataKeys.METADATA]["size"] diff --git a/flash/image/face_detection/data.py b/flash/image/face_detection/data.py index a241c38ea9..d685fc3c22 100644 --- a/flash/image/face_detection/data.py +++ b/flash/image/face_detection/data.py @@ -11,20 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple +from typing import Any, Callable, Dict, List, Optional, Sequence import torch import torch.nn as nn from torch.utils.data import Dataset -from flash.core.data.io.input import DataKeys, DatasetInput, InputFormat +from flash.core.data.data_module import DataModule +from flash.core.data.io.input import DataKeys, InputFormat from flash.core.data.io.input_transform import InputTransform from flash.core.data.io.output_transform import OutputTransform from flash.core.data.transforms import ApplyToKeys -from flash.core.data.utils import image_default_loader from flash.core.utilities.imports import _FASTFACE_AVAILABLE, _TORCHVISION_AVAILABLE -from flash.image.data import ImagePathsInput -from flash.image.detection import ObjectDetectionData +from flash.core.utilities.stages import RunningStage +from flash.image.classification.data import ImageClassificationFilesInput, ImageClassificationFolderInput +from flash.image.data import ImageInput if _TORCHVISION_AVAILABLE: import torchvision @@ -61,39 +62,18 @@ def fastface_collate_fn(samples: Sequence[Dict[str, Any]]) -> Dict[str, Sequence return samples -class FastFaceInput(DatasetInput): +class FastFaceInput(ImageInput): """Logic for loading from FDDBDataset.""" - def load_data(self, data: Dataset, dataset: Any = None) -> Dataset: - new_data = [] - for img_file_path, targets in zip(data.ids, data.targets): - new_data.append( - super().load_sample( - ( - img_file_path, - dict( - boxes=targets["target_boxes"], - # label `1` indicates positive sample - labels=[1 for _ in range(targets["target_boxes"].shape[0])], - ), - ) - ) - ) - - return new_data - - def load_sample(self, sample: Any, dataset: Optional[Any] = None) -> Mapping[str, Any]: - filepath = sample[DataKeys.INPUT] - img = image_default_loader(filepath) - sample[DataKeys.INPUT] = img - - w, h = img.size # WxH - sample[DataKeys.METADATA] = { - "filepath": filepath, - "size": (h, w), - } - - return sample + def load_data(self, dataset: Dataset) -> List[Dict[str, Any]]: + return [ + { + DataKeys.INPUT: filepath, + "boxes": targets["target_boxes"], + "labels": [1] * targets["target_boxes"].shape[0], + } + for filepath, targets in zip(dataset.ids, dataset.targets) + ] class FaceDetectionInputTransform(InputTransform): @@ -105,19 +85,16 @@ def __init__( val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, - image_size: Tuple[int, int] = (128, 128), ): - self.image_size = image_size - super().__init__( train_transform=train_transform, val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, inputs={ - InputFormat.FILES: ImagePathsInput(), - InputFormat.FOLDERS: ImagePathsInput(), - InputFormat.DATASETS: FastFaceInput(), + InputFormat.FILES: ImageClassificationFilesInput, + InputFormat.FOLDERS: ImageClassificationFolderInput, + InputFormat.DATASETS: FastFaceInput, }, default_input=InputFormat.FILES, ) @@ -166,6 +143,34 @@ def per_batch_transform(batch: Any) -> Any: return batch -class FaceDetectionData(ObjectDetectionData): +class FaceDetectionData(DataModule): input_transform_cls = FaceDetectionInputTransform output_transform_cls = FaceDetectionOutputTransform + + @classmethod + def from_datasets( + cls, + train_dataset: Optional[Dataset] = None, + val_dataset: Optional[Dataset] = None, + test_dataset: Optional[Dataset] = None, + predict_dataset: Optional[Dataset] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + **data_module_kwargs, + ) -> "FaceDetectionData": + return cls( + FastFaceInput(RunningStage.TRAINING, train_dataset), + FastFaceInput(RunningStage.VALIDATING, val_dataset), + FastFaceInput(RunningStage.TESTING, test_dataset), + FastFaceInput(RunningStage.PREDICTING, predict_dataset), + input_transform=cls.input_transform_cls( + train_transform, + val_transform, + test_transform, + predict_transform, + ), + output_transform=cls.output_transform_cls(), + **data_module_kwargs, + ) diff --git a/flash/image/segmentation/data.py b/flash/image/segmentation/data.py index 69eb67b783..c0a427ab36 100644 --- a/flash/image/segmentation/data.py +++ b/flash/image/segmentation/data.py @@ -21,7 +21,7 @@ import flash from flash.core.data.auto_dataset import BaseAutoDataset -from flash.core.data.base_viz import BaseVisualization # for viz +from flash.core.data.base_viz import BaseVisualization from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule from flash.core.data.io.input import ( diff --git a/flash/image/style_transfer/cli.py b/flash/image/style_transfer/cli.py index c3fa094d0e..564b1a8c0f 100644 --- a/flash/image/style_transfer/cli.py +++ b/flash/image/style_transfer/cli.py @@ -24,7 +24,7 @@ def from_coco_128( batch_size: int = 4, num_workers: int = 0, - **input_transform_kwargs, + **data_module_kwargs, ) -> StyleTransferData: """Downloads and loads the COCO 128 data set.""" download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/") @@ -32,7 +32,7 @@ def from_coco_128( train_folder="data/coco128/images/train2017/", batch_size=batch_size, num_workers=num_workers, - **input_transform_kwargs, + **data_module_kwargs, ) @@ -47,7 +47,6 @@ def style_transfer(): "model.style_image": os.path.join(flash.ASSETS_ROOT, "starry_night.jpg"), }, finetune=False, - legacy=True, ) cli.trainer.save_checkpoint("style_transfer_model.pt") diff --git a/flash/image/style_transfer/data.py b/flash/image/style_transfer/data.py index 4201844930..cdb7524732 100644 --- a/flash/image/style_transfer/data.py +++ b/flash/image/style_transfer/data.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -import pathlib -from typing import Any, Callable, Dict, Optional, Sequence, Union +from typing import Any, Callable, Collection, Dict, Optional, Sequence, Union +import numpy as np +import torch from torch import nn from flash.core.data.data_module import DataModule @@ -22,9 +23,9 @@ from flash.core.data.io.input_transform import InputTransform from flash.core.data.transforms import ApplyToKeys from flash.core.utilities.imports import _TORCHVISION_AVAILABLE -from flash.image.classification import ImageClassificationData -from flash.image.data import ImageNumpyInput, ImagePathsInput, ImageTensorInput -from flash.image.style_transfer.utils import raise_not_supported +from flash.core.utilities.stages import RunningStage +from flash.image.classification.data import ImageClassificationFilesInput, ImageClassificationFolderInput +from flash.image.data import ImageFilesInput, ImageNumpyInput, ImageTensorInput if _TORCHVISION_AVAILABLE: from torchvision import transforms as T @@ -55,11 +56,6 @@ def __init__( predict_transform: Optional[Dict[str, Callable]] = None, image_size: int = 256, ): - if val_transform: - raise_not_supported("validation") - if test_transform: - raise_not_supported("test") - if isinstance(image_size, int): image_size = (image_size, image_size) @@ -71,11 +67,10 @@ def __init__( test_transform=test_transform, predict_transform=predict_transform, inputs={ - InputFormat.FILES: ImagePathsInput(), - InputFormat.FOLDERS: ImagePathsInput(), - InputFormat.NUMPY: ImageNumpyInput(), - InputFormat.TENSORS: ImageTensorInput(), - InputFormat.TENSORS: ImageTensorInput(), + InputFormat.FILES: ImageFilesInput, + InputFormat.FOLDERS: ImageClassificationFolderInput, + InputFormat.NUMPY: ImageNumpyInput, + InputFormat.TENSORS: ImageTensorInput, }, default_input=InputFormat.FILES, ) @@ -106,35 +101,89 @@ def default_transforms(self) -> Optional[Dict[str, Callable]]: return None -class StyleTransferData(ImageClassificationData): +class StyleTransferData(DataModule): input_transform_cls = StyleTransferInputTransform @classmethod - def from_folders( + def from_files( cls, - train_folder: Optional[Union[str, pathlib.Path]] = None, - predict_folder: Optional[Union[str, pathlib.Path]] = None, - train_transform: Optional[Union[str, Dict]] = None, - predict_transform: Optional[Union[str, Dict]] = None, - input_transform: Optional[InputTransform] = None, - **kwargs: Any, - ) -> "DataModule": - - if any(param in kwargs and kwargs[param] is not None for param in ("val_folder", "val_transform")): - raise_not_supported("validation") + train_files: Optional[Sequence[str]] = None, + predict_files: Optional[Sequence[str]] = None, + train_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + image_size: int = 256, + **data_module_kwargs: Any, + ) -> "StyleTransferData": + return cls( + ImageFilesInput(RunningStage.TRAINING, train_files), + predict_dataset=ImageClassificationFilesInput(RunningStage.PREDICTING, predict_files), + input_transform=cls.input_transform_cls( + train_transform, + predict_transform=predict_transform, + image_size=image_size, + ), + **data_module_kwargs, + ) - if any(param in kwargs and kwargs[param] is not None for param in ("test_folder", "test_transform")): - raise_not_supported("test") + @classmethod + def from_folders( + cls, + train_folder: Optional[str] = None, + predict_folder: Optional[str] = None, + train_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + image_size: int = 256, + **data_module_kwargs: Any, + ) -> "StyleTransferData": + return cls( + ImageClassificationFolderInput(RunningStage.TRAINING, train_folder), + predict_dataset=ImageClassificationFolderInput(RunningStage.PREDICTING, predict_folder), + input_transform=cls.input_transform_cls( + train_transform, + predict_transform=predict_transform, + image_size=image_size, + ), + **data_module_kwargs, + ) - input_transform = input_transform or cls.input_transform_cls( - train_transform=train_transform, - predict_transform=predict_transform, + @classmethod + def from_numpy( + cls, + train_data: Optional[Collection[np.ndarray]] = None, + predict_data: Optional[Collection[np.ndarray]] = None, + train_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + image_size: int = 256, + **data_module_kwargs: Any, + ) -> "StyleTransferData": + return cls( + ImageNumpyInput(RunningStage.TRAINING, train_data), + predict_dataset=ImageNumpyInput(RunningStage.PREDICTING, predict_data), + input_transform=cls.input_transform_cls( + train_transform, + predict_transform=predict_transform, + image_size=image_size, + ), + **data_module_kwargs, ) - return cls.from_input( - InputFormat.FOLDERS, - train_data=train_folder, - predict_data=predict_folder, - input_transform=input_transform, - **kwargs, + @classmethod + def from_tensors( + cls, + train_data: Optional[Collection[torch.Tensor]] = None, + predict_data: Optional[Collection[torch.Tensor]] = None, + train_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + image_size: int = 256, + **data_module_kwargs: Any, + ) -> "StyleTransferData": + return cls( + ImageTensorInput(RunningStage.TRAINING, train_data), + predict_dataset=ImageTensorInput(RunningStage.PREDICTING, predict_data), + input_transform=cls.input_transform_cls( + train_transform, + predict_transform=predict_transform, + image_size=image_size, + ), + **data_module_kwargs, ) diff --git a/flash/tabular/data.py b/flash/tabular/data.py index b0a43ce878..91ccef1043 100644 --- a/flash/tabular/data.py +++ b/flash/tabular/data.py @@ -18,15 +18,15 @@ import numpy as np from pytorch_lightning.utilities.exceptions import MisconfigurationException -from flash.core.classification import LabelsState from flash.core.data.data_module import DataModule +from flash.core.data.io.classification_input import ClassificationState from flash.core.data.io.input import DataKeys, InputFormat from flash.core.data.io.input_base import Input from flash.core.data.io.input_transform import InputTransform from flash.core.data.io.output_transform import OutputTransform from flash.core.data.process import Deserializer from flash.core.data.properties import ProcessState -from flash.core.data.utilities.paths import read_csv +from flash.core.data.utilities.data_frame import read_csv from flash.core.utilities.imports import _PANDAS_AVAILABLE from flash.core.utilities.stages import RunningStage from flash.tabular.classification.utils import ( @@ -115,7 +115,7 @@ def load_data( parameters = self.compute_parameters(df, target_field, numerical_fields, categorical_fields, is_regression) self.set_state(TabularParametersState(parameters)) - self.set_state(LabelsState(parameters["classes"])) + self.set_state(ClassificationState(parameters["classes"])) else: parameters_state = self.get_state(TabularParametersState) parameters = parameters or (parameters_state.parameters if parameters_state is not None else None) diff --git a/flash/template/classification/data.py b/flash/template/classification/data.py index aa24f8013b..e3786b5c64 100644 --- a/flash/template/classification/data.py +++ b/flash/template/classification/data.py @@ -20,7 +20,8 @@ from flash.core.data.base_viz import BaseVisualization from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule -from flash.core.data.io.input import DataKeys, InputFormat, LabelsState, NumpyInput +from flash.core.data.io.classification_input import ClassificationState +from flash.core.data.io.input import DataKeys, InputFormat, NumpyInput from flash.core.data.io.input_transform import InputTransform from flash.core.data.transforms import ApplyToKeys from flash.core.utilities.imports import _SKLEARN_AVAILABLE @@ -67,7 +68,7 @@ def load_data(self, data: Bunch, dataset: Any) -> Sequence[Mapping[str, Any]]: A sequence of samples / sample metadata. """ dataset.num_classes = len(data.target_names) - self.set_state(LabelsState(data.target_names)) + self.set_state(ClassificationState(data.target_names)) return super().load_data((data.data, data.target), dataset=dataset) def predict_load_data(self, data: Bunch) -> Sequence[Mapping[str, Any]]: diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index 6309d951c9..daa31b8d69 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -23,7 +23,8 @@ from flash.core.data.auto_dataset import AutoDataset from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule -from flash.core.data.io.input import DataKeys, Input, InputFormat, LabelsState +from flash.core.data.io.classification_input import ClassificationState +from flash.core.data.io.input import DataKeys, Input, InputFormat from flash.core.data.io.input_transform import InputTransform from flash.core.data.io.output_transform import OutputTransform from flash.core.data.process import Deserializer @@ -115,15 +116,15 @@ def load_data( dataset.multi_label = True hf_dataset = hf_dataset.map(partial(self._multilabel_target, target)) # NOTE: renames target column dataset.num_classes = len(target) - self.set_state(LabelsState(target)) + self.set_state(ClassificationState(target)) else: dataset.multi_label = False if self.training: labels = list(sorted(list(set(hf_dataset[target])))) dataset.num_classes = len(labels) - self.set_state(LabelsState(labels)) + self.set_state(ClassificationState(labels)) - labels = self.get_state(LabelsState) + labels = self.get_state(ClassificationState) # convert labels to ids (note: the target column get overwritten) if labels is not None: @@ -224,15 +225,15 @@ def load_data( # multi-target_list dataset.multi_label = True dataset.num_classes = len(target_list[0]) - self.set_state(LabelsState(target_list)) + self.set_state(ClassificationState(target_list)) else: dataset.multi_label = False if self.training: labels = list(sorted(list(set(hf_dataset[DataKeys.TARGET])))) dataset.num_classes = len(labels) - self.set_state(LabelsState(labels)) + self.set_state(ClassificationState(labels)) - labels = self.get_state(LabelsState) + labels = self.get_state(ClassificationState) # convert labels to ids if labels is not None: diff --git a/flash/video/classification/data.py b/flash/video/classification/data.py index 3ff714c2cd..204ad8efbc 100644 --- a/flash/video/classification/data.py +++ b/flash/video/classification/data.py @@ -20,7 +20,8 @@ from torch.utils.data import Sampler from flash.core.data.data_module import DataModule -from flash.core.data.io.input import DataKeys, InputFormat, LabelsState +from flash.core.data.io.classification_input import ClassificationState +from flash.core.data.io.input import DataKeys, InputFormat from flash.core.data.io.input_base import Input, IterableInput from flash.core.data.io.input_transform import InputTransform from flash.core.data.utilities.paths import list_valid_files @@ -75,7 +76,7 @@ class VideoClassificationInput(IterableInput): def load_data(self, dataset: "LabeledVideoDataset") -> "LabeledVideoDataset": if self.training: label_to_class_mapping = {p[1]: p[0].split("/")[-2] for p in dataset._labeled_videos._paths_and_labels} - self.set_state(LabelsState(label_to_class_mapping)) + self.set_state(ClassificationState(label_to_class_mapping)) self.num_classes = len(np.unique([s[1]["label"] for s in dataset._labeled_videos])) return dataset @@ -203,7 +204,7 @@ def load_data( decoder=decoder, ) if self.training: - self.set_state(LabelsState(self.id_to_label)) + self.set_state(ClassificationState(self.id_to_label)) self.num_classes = len(self.labels_set) return dataset diff --git a/flash_examples/image_classification_multi_label.py b/flash_examples/image_classification_multi_label.py index fc05161a5e..cb51698038 100644 --- a/flash_examples/image_classification_multi_label.py +++ b/flash_examples/image_classification_multi_label.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os + import torch import flash @@ -22,16 +24,23 @@ # More info here: https://www.cs.ccu.edu.tw/~wtchu/projects/MoviePoster/ download_data("https://pl-flash-data.s3.amazonaws.com/movie_posters.zip") + +def resolver(root, file_id): + return os.path.join(root, f"{file_id}.jpg") + + datamodule = ImageClassificationData.from_csv( "Id", ["Action", "Romance", "Crime", "Thriller", "Adventure"], train_file="data/movie_posters/train/metadata.csv", + train_resolver=resolver, val_file="data/movie_posters/val/metadata.csv", + val_resolver=resolver, image_size=(128, 128), ) # 2. Build the task -model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes, multi_label=True) +model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes, multi_label=datamodule.multi_label) # 3. Create the trainer and finetune the model trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) diff --git a/tests/core/data/io/test_output.py b/tests/core/data/io/test_output.py index c7626a1b6e..eaf94cebdb 100644 --- a/tests/core/data/io/test_output.py +++ b/tests/core/data/io/test_output.py @@ -19,7 +19,7 @@ from flash.core.classification import Labels from flash.core.data.data_pipeline import DataPipeline, DataPipelineState -from flash.core.data.io.input import LabelsState +from flash.core.data.io.classification_input import ClassificationState from flash.core.data.io.input_transform import DefaultInputTransform from flash.core.data.io.output import Output from flash.core.model import Task @@ -56,4 +56,4 @@ def __init__(self): trainer.save_checkpoint(checkpoint_file) model = CustomModel.load_from_checkpoint(checkpoint_file) assert isinstance(model._data_pipeline_state, DataPipelineState) - assert model._data_pipeline_state._state[LabelsState] == LabelsState(["a", "b"]) + assert model._data_pipeline_state._state[ClassificationState] == ClassificationState(["a", "b"]) diff --git a/tests/core/data/utilities/__init__.py b/tests/core/data/utilities/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/core/data/utilities/test_classification.py b/tests/core/data/utilities/test_classification.py new file mode 100644 index 0000000000..3b250cc19d --- /dev/null +++ b/tests/core/data/utilities/test_classification.py @@ -0,0 +1,127 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time +from collections import namedtuple + +import numpy as np +import pytest +import torch + +from flash.core.data.utilities.classification import ( + get_target_details, + get_target_formatter, + get_target_mode, + TargetMode, +) + +Case = namedtuple("Case", ["target", "formatted_target", "target_mode", "labels", "num_classes"]) + +cases = [ + # Single + Case([0, 1, 2], [0, 1, 2], TargetMode.SINGLE_NUMERIC, None, 3), + Case([[1, 0, 0], [0, 1, 0], [0, 0, 1]], [0, 1, 2], TargetMode.SINGLE_BINARY, None, 3), + Case(["blue", "green", "red"], [0, 1, 2], TargetMode.SINGLE_TOKEN, ["blue", "green", "red"], 3), + # Multi + Case([[0, 1], [1, 2], [2, 0]], [[1, 1, 0], [0, 1, 1], [1, 0, 1]], TargetMode.MULTI_NUMERIC, None, 3), + Case([[1, 1, 0], [0, 1, 1], [1, 0, 1]], [[1, 1, 0], [0, 1, 1], [1, 0, 1]], TargetMode.MULTI_BINARY, None, 3), + Case( + [["blue", "green"], ["green", "red"], ["red", "blue"]], + [[1, 1, 0], [0, 1, 1], [1, 0, 1]], + TargetMode.MULTI_TOKEN, + ["blue", "green", "red"], + 3, + ), + Case( + ["blue,green", "green,red", "red,blue"], + [[1, 1, 0], [0, 1, 1], [1, 0, 1]], + TargetMode.MUTLI_COMMA_DELIMITED, + ["blue", "green", "red"], + 3, + ), + # Ambiguous + Case([[0], [1, 2], [2, 0]], [[1, 0, 0], [0, 1, 1], [1, 0, 1]], TargetMode.MULTI_NUMERIC, None, 3), + Case([[1, 0, 0], [0, 1, 1], [1, 0, 1]], [[1, 0, 0], [0, 1, 1], [1, 0, 1]], TargetMode.MULTI_BINARY, None, 3), + Case( + [["blue"], ["green", "red"], ["red", "blue"]], + [[1, 0, 0], [0, 1, 1], [1, 0, 1]], + TargetMode.MULTI_TOKEN, + ["blue", "green", "red"], + 3, + ), + Case( + ["blue", "green,red", "red,blue"], + [[1, 0, 0], [0, 1, 1], [1, 0, 1]], + TargetMode.MUTLI_COMMA_DELIMITED, + ["blue", "green", "red"], + 3, + ), + # Special cases + Case(["blue ", " green", "red"], [0, 1, 2], TargetMode.SINGLE_TOKEN, ["blue", "green", "red"], 3), + Case( + ["blue", "green, red", "red, blue"], + [[1, 0, 0], [0, 1, 1], [1, 0, 1]], + TargetMode.MUTLI_COMMA_DELIMITED, + ["blue", "green", "red"], + 3, + ), + Case( + [f"class_{i}" for i in range(10000)], + list(range(10000)), + TargetMode.SINGLE_TOKEN, + [f"class_{i}" for i in range(10000)], + 10000, + ), + # Array types + Case(torch.tensor([0, 1, 2]), [0, 1, 2], TargetMode.SINGLE_NUMERIC, None, 3), + Case(np.array([0, 1, 2]), [0, 1, 2], TargetMode.SINGLE_NUMERIC, None, 3), +] + + +@pytest.mark.parametrize("case", cases) +def test_case(case): + target_mode = get_target_mode(case.target) + assert target_mode is case.target_mode + + labels, num_classes = get_target_details(case.target, target_mode) + assert labels == case.labels + assert num_classes == case.num_classes + + formatter = get_target_formatter(target_mode, labels, num_classes) + assert [formatter(t) for t in case.target] == case.formatted_target + + +@pytest.mark.parametrize("case", cases) +def test_speed(case): + repeats = int(1e5 / len(case.target)) # Approx. a hundred thousand targets + + if torch.is_tensor(case.target): + targets = case.target.repeat(repeats) + elif isinstance(case.target, np.ndarray): + targets = np.repeat(case.target, repeats) + else: + targets = case.target * repeats + + start = time.time() + target_mode = get_target_mode(targets) + labels, num_classes = get_target_details(targets, target_mode) + formatter = get_target_formatter(target_mode, labels, num_classes) + end = time.time() + + assert (end - start) / len(targets) < 1e-5 # 0.01ms per target + + start = time.time() + _ = [formatter(t) for t in targets] + end = time.time() + + assert (end - start) / len(targets) < 1e-5 # 0.01ms per target diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 3f91113201..a21615545a 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -206,7 +206,11 @@ def _rand_image(): datamodule = ImageClassificationData.from_folders(predict_folder=train_dir) task = ImageClassifier(num_classes=10) - predictions = task.predict(str(train_dir), data_pipeline=datamodule.data_pipeline) + predictions = task.predict( + str(train_dir), + input="folders", + data_pipeline=datamodule.data_pipeline, + ) assert len(predictions) == 2 diff --git a/tests/image/classification/test_data.py b/tests/image/classification/test_data.py index e1d3d501cc..dbcc009cf4 100644 --- a/tests/image/classification/test_data.py +++ b/tests/image/classification/test_data.py @@ -498,8 +498,8 @@ def single_target_csv(image_tmpdir): fieldnames = ["image", "target"] writer = csv.DictWriter(csvfile, fieldnames) writer.writeheader() - writer.writerow({"image": "image_1", "target": "Ants"}) - writer.writerow({"image": "image_2", "target": "Bees"}) + writer.writerow({"image": "image_1.png", "target": "Ants"}) + writer.writerow({"image": "image_2.png", "target": "Bees"}) return str(image_tmpdir / "metadata.csv") @@ -526,8 +526,8 @@ def multi_target_csv(image_tmpdir): fieldnames = ["image", "target_1", "target_2"] writer = csv.DictWriter(csvfile, fieldnames) writer.writeheader() - writer.writerow({"image": "image_1", "target_1": 1, "target_2": 0}) - writer.writerow({"image": "image_2", "target_1": 1, "target_2": 1}) + writer.writerow({"image": "image_1.png", "target_1": 1, "target_2": 0}) + writer.writerow({"image": "image_2.png", "target_1": 1, "target_2": 1}) return str(image_tmpdir / "metadata.csv") @@ -560,7 +560,7 @@ def bad_csv_no_image(image_tmpdir): @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_from_bad_csv_no_image(bad_csv_no_image): - with pytest.raises(ValueError, match="Found no matches"): + with pytest.raises(ValueError, match="File ID `image_3` did not resolve to an existing file."): img_data = ImageClassificationData.from_csv( "image", ["target"],