Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Refactor image inputs and update to new input object (#997)
Browse files Browse the repository at this point in the history
Co-authored-by: thomas chaton <thomas@grid.ai>
Co-authored-by: Ananya Harsh Jha <ananya@pytorchlightning.ai>
  • Loading branch information
3 people authored Nov 30, 2021
1 parent ce18c08 commit b8085f0
Show file tree
Hide file tree
Showing 35 changed files with 1,325 additions and 452 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `GraphEmbedder` task ([#592](https://github.com/PyTorchLightning/lightning-flash/pull/592))

- Added support for comma delimited multi-label targets to the `ImageClassifier` ([#997](https://github.com/PyTorchLightning/lightning-flash/pull/997))

### Changed

- Changed `DataSource` to `Input` ([#929](https://github.com/PyTorchLightning/lightning-flash/pull/929))
Expand Down
1 change: 0 additions & 1 deletion docs/source/api/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ ___________________________
~flash.core.data.io.input.InputFormat
~flash.core.data.io.input.FiftyOneInput
~flash.core.data.io.input.ImageLabelsMap
~flash.core.data.io.input.LabelsState
~flash.core.data.io.input.MockDataset
~flash.core.data.io.input.NumpyInput
~flash.core.data.io.input.PathsInput
Expand Down
3 changes: 1 addition & 2 deletions docs/source/api/image.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ ______________
:template: classtemplate.rst

~classification.model.ImageClassifier
~classification.data.ImageClassificationFiftyOneInput
~classification.data.ImageClassificationData
~classification.data.ImageClassificationInputTransform

Expand Down Expand Up @@ -140,7 +141,5 @@ ________________
:template: classtemplate.rst

~data.ImageDeserializer
~data.ImageFiftyOneInput
~data.ImageNumpyInput
~data.ImagePathsInput
~data.ImageTensorInput
2 changes: 1 addition & 1 deletion docs/source/template/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ We override our ``TemplateNumpyInput`` so that we can call ``super`` with the da
We perform two additional steps here to improve the user experience:

1. We set the ``num_classes`` attribute on the ``dataset``. If ``num_classes`` is set, it is automatically made available as a property of the :class:`~flash.core.data.data_module.DataModule`.
2. We create and set a :class:`~flash.core.data.io.input.LabelsState`. The labels provided here will be shared with the :class:`~flash.core.classification.Labels` output, so the user doesn't need to provide them.
2. We create and set a :class:`~flash.core.data.io.input.ClassificationState`. The labels provided here will be shared with the :class:`~flash.core.classification.Labels` output, so the user doesn't need to provide them.

Here's the code for the ``TemplateSKLearnInput.load_data`` method:

Expand Down
14 changes: 12 additions & 2 deletions flash/audio/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import numpy as np

from flash.audio.classification.transforms import default_transforms, train_default_transforms
from flash.core.data.callback import BaseDataFetcher
from flash.core.data.data_module import DataModule
from flash.core.data.io.input import (
DataKeys,
has_file_allowed_extension,
Expand All @@ -27,7 +29,7 @@
from flash.core.data.io.input_transform import InputTransform
from flash.core.data.process import Deserializer
from flash.core.data.utils import image_default_loader
from flash.image.classification.data import ImageClassificationData
from flash.image.classification.data import MatplotlibVisualization
from flash.image.data import ImageDeserializer, IMG_EXTENSIONS, NP_EXTENSIONS


Expand Down Expand Up @@ -114,7 +116,15 @@ def train_default_transforms(self) -> Optional[Dict[str, Callable]]:
return train_default_transforms(self.spectrogram_size, self.time_mask_param, self.freq_mask_param)


class AudioClassificationData(ImageClassificationData):
class AudioClassificationData(DataModule):
"""Data module for audio classification."""

input_transform_cls = AudioClassificationInputTransform

def set_block_viz_window(self, value: bool) -> None:
"""Setter method to switch on/off matplotlib to pop up windows."""
self.data_fetcher.block_viz_window = value

@staticmethod
def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher:
return MatplotlibVisualization(*args, **kwargs)
19 changes: 10 additions & 9 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from pytorch_lightning.utilities import rank_zero_warn

from flash.core.adapter import AdapterTask
from flash.core.data.io.input import DataKeys, LabelsState
from flash.core.data.io.classification_input import ClassificationState
from flash.core.data.io.input import DataKeys
from flash.core.data.io.output import Output
from flash.core.model import Task
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires
Expand Down Expand Up @@ -186,7 +187,7 @@ class Labels(Classes):
Args:
labels: A list of labels, assumed to map the class index to the label for that class. If ``labels`` is not
provided, will attempt to get them from the :class:`.LabelsState`.
provided, will attempt to get them from the :class:`.ClassificationState`.
multi_label: If true, treats outputs as multi label logits.
threshold: The threshold to use for multi_label classification.
"""
Expand All @@ -196,15 +197,15 @@ def __init__(self, labels: Optional[List[str]] = None, multi_label: bool = False
self._labels = labels

if labels is not None:
self.set_state(LabelsState(labels))
self.set_state(ClassificationState(labels))

def transform(self, sample: Any) -> Union[int, List[int], str, List[str]]:
labels = None

if self._labels is not None:
labels = self._labels
else:
state = self.get_state(LabelsState)
state = self.get_state(ClassificationState)
if state is not None:
labels = state.labels

Expand All @@ -214,7 +215,7 @@ def transform(self, sample: Any) -> Union[int, List[int], str, List[str]]:
if self.multi_label:
return [labels[cls] for cls in classes]
return labels[classes]
rank_zero_warn("No LabelsState was found, this output will act as a Classes output.", UserWarning)
rank_zero_warn("No ClassificationState was found, this output will act as a Classes output.", UserWarning)
return classes


Expand All @@ -223,7 +224,7 @@ class FiftyOneLabels(ClassificationOutput):
Args:
labels: A list of labels, assumed to map the class index to the label for that class. If ``labels`` is not
provided, will attempt to get them from the :class:`.LabelsState`.
provided, will attempt to get them from the :class:`.ClassificationState`.
multi_label: If true, treats outputs as multi label logits.
threshold: A threshold to use to filter candidate labels. In the single label case, predictions below this
threshold will be replaced with None
Expand Down Expand Up @@ -252,7 +253,7 @@ def __init__(
self.return_filepath = return_filepath

if labels is not None:
self.set_state(LabelsState(labels))
self.set_state(ClassificationState(labels))

def transform(
self,
Expand All @@ -266,7 +267,7 @@ def transform(
if self._labels is not None:
labels = self._labels
else:
state = self.get_state(LabelsState)
state = self.get_state(ClassificationState)
if state is not None:
labels = state.labels

Expand Down Expand Up @@ -309,7 +310,7 @@ def transform(
logits=logits,
)
else:
rank_zero_warn("No LabelsState was found, int targets will be used as label strings", UserWarning)
rank_zero_warn("No ClassificationState was found, int targets will be used as label strings", UserWarning)

if self.multi_label:
classifications = []
Expand Down
6 changes: 5 additions & 1 deletion flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ def __init__(

self.set_running_stages()

# Share state between input objects (this will be available in ``load_sample`` but not in ``load_data``)
data_pipeline = self.data_pipeline
data_pipeline.initialize()

@property
def train_dataset(self) -> Optional[Dataset]:
"""This property returns the train dataset."""
Expand Down Expand Up @@ -420,7 +424,7 @@ def num_classes(self) -> Optional[int]:

@property
def multi_label(self) -> Optional[bool]:
"""Property that returns the number of labels of the datamodule if a multilabel task."""
"""Property that returns ``True`` if this ``DataModule`` contains multi-label data."""
multi_label_train = getattr(self.train_dataset, "multi_label", None)
multi_label_val = getattr(self.val_dataset, "multi_label", None)
multi_label_test = getattr(self.test_dataset, "multi_label", None)
Expand Down
79 changes: 79 additions & 0 deletions flash/core/data/io/classification_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from functools import lru_cache
from typing import Any, List, Optional, Sequence

from flash.core.data.io.input_base import Input
from flash.core.data.properties import ProcessState
from flash.core.data.utilities.classification import (
get_target_details,
get_target_formatter,
get_target_mode,
TargetFormatter,
)


@dataclass(unsafe_hash=True, frozen=True)
class ClassificationState(ProcessState):
"""A :class:`~flash.core.data.properties.ProcessState` containing ``labels`` (a mapping from class index to
label) and ``num_classes``."""

labels: Optional[Sequence[str]]
num_classes: Optional[int] = None


class ClassificationInput(Input):
"""The ``ClassificationInput`` class provides utility methods for handling classification targets.
:class:`~flash.core.data.io.input_base.Input` objects that extend ``ClassificationInput`` should do the following:
* In the ``load_data`` method, include a call to ``load_target_metadata``. This will determine the format of the
targets and store metadata like ``labels`` and ``num_classes``.
* In the ``load_sample`` method, use ``format_target`` to convert the target to a standard format for use with our
tasks.
"""

@property
@lru_cache(maxsize=None)
def target_formatter(self) -> TargetFormatter:
"""Get the :class:`~flash.core.data.utiltiies.classification.TargetFormatter` to use when formatting
targets.
This property uses ``functools.lru_cache`` so that we only instantiate the formatter once.
"""
classification_state = self.get_state(ClassificationState)
return get_target_formatter(self.target_mode, classification_state.labels, classification_state.num_classes)

def load_target_metadata(self, targets: List[Any]) -> None:
"""Determine the target format and store the ``labels`` and ``num_classes``.
Args:
targets: The list of targets.
"""
self.target_mode = get_target_mode(targets)
self.multi_label = self.target_mode.multi_label
if self.training:
self.labels, self.num_classes = get_target_details(targets, self.target_mode)
self.set_state(ClassificationState(self.labels, self.num_classes))

def format_target(self, target: Any) -> Any:
"""Format a single target according to the previously computed target format and metadata.
Args:
target: The target to format.
Returns:
The formatted target.
"""
return self.target_formatter(target)
29 changes: 11 additions & 18 deletions flash/core/data/io/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@
from tqdm import tqdm

from flash.core.data.auto_dataset import AutoDataset, BaseAutoDataset, IterableAutoDataset
from flash.core.data.io.classification_input import ClassificationState
from flash.core.data.properties import ProcessState, Properties
from flash.core.data.utilities.paths import read_csv
from flash.core.data.utilities.data_frame import read_csv
from flash.core.data.utils import CurrentRunningStageFuncContext
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires
from flash.core.utilities.stages import RunningStage
Expand All @@ -72,7 +73,7 @@ def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bo
Returns:
bool: True if the filename ends with one of given extensions
"""
return filename.lower().endswith(extensions)
return str(filename).lower().endswith(extensions)


# Credit to the PyTorchVision Team:
Expand Down Expand Up @@ -135,14 +136,6 @@ def has_len(data: Union[Sequence[Any], Iterable[Any]]) -> bool:
return False


@dataclass(unsafe_hash=True, frozen=True)
class LabelsState(ProcessState):
"""A :class:`~flash.core.data.properties.ProcessState` containing ``labels``, a mapping from class index to
label."""

labels: Optional[Sequence[str]]


@dataclass(unsafe_hash=True, frozen=True)
class ImageLabelsMap(ProcessState):

Expand Down Expand Up @@ -361,7 +354,7 @@ class DatasetInput(Input[Dataset]):
Args:
labels: Optionally pass the labels as a mapping from class index to label string. These will then be set as the
:class:`~flash.core.data.io.input.LabelsState`.
:class:`~flash.core.data.io.input.ClassificationState`.
"""

def load_sample(self, sample: Any, dataset: Optional[Any] = None) -> Mapping[str, Any]:
Expand All @@ -380,7 +373,7 @@ class SequenceInput(
Args:
labels: Optionally pass the labels as a mapping from class index to label string. These will then be set as the
:class:`~flash.core.data.io.input.LabelsState`.
:class:`~flash.core.data.io.input.ClassificationState`.
"""

def __init__(self, labels: Optional[Sequence[str]] = None):
Expand All @@ -389,7 +382,7 @@ def __init__(self, labels: Optional[Sequence[str]] = None):
self.labels = labels

if self.labels is not None:
self.set_state(LabelsState(self.labels))
self.set_state(ClassificationState(self.labels))

def load_data(
self,
Expand All @@ -415,7 +408,7 @@ class PathsInput(SequenceInput):
Args:
extensions: The file extensions supported by this data source (e.g. ``(".jpg", ".png")``).
labels: Optionally pass the labels as a mapping from class index to label string. These will then be set as the
:class:`~flash.core.data.io.input.LabelsState`.
:class:`~flash.core.data.io.input.ClassificationState`.
"""

def __init__(
Expand Down Expand Up @@ -459,7 +452,7 @@ def load_data(
classes, class_to_idx = self.find_classes(data)
if not classes:
return self.predict_load_data(data)
self.set_state(LabelsState(classes))
self.set_state(ClassificationState(classes))

if dataset is not None:
dataset.num_classes = len(classes)
Expand Down Expand Up @@ -577,17 +570,17 @@ def load_data(
if isinstance(target_keys, List):
dataset.multi_label = True
dataset.num_classes = len(target_keys)
self.set_state(LabelsState(target_keys))
self.set_state(ClassificationState(target_keys))
data_frame = data_frame.apply(partial(self._resolve_multi_target, target_keys), axis=1)
target_keys = target_keys[0]
else:
dataset.multi_label = False
if self.training:
labels = list(sorted(data_frame[target_keys].unique()))
dataset.num_classes = len(labels)
self.set_state(LabelsState(labels))
self.set_state(ClassificationState(labels))

labels = self.get_state(LabelsState)
labels = self.get_state(ClassificationState)

if labels is not None:
labels = labels.labels
Expand Down
Loading

0 comments on commit b8085f0

Please sign in to comment.