Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Rename ClassificationInput to ClassificationInputMixin #1116

Merged
merged 7 commits into from
Jan 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Changed `Wav2Vec2Processor` to `AutoProcessor` and seperate it from backbone [optional] ([#1075](https://github.com/PyTorchLightning/lightning-flash/pull/1075))

- Renamed `ClassificationInput` to `ClassificationInputMixin` ([#1116](https://github.com/PyTorchLightning/lightning-flash/pull/1116))

### Deprecated

### Fixed
Expand Down
11 changes: 11 additions & 0 deletions docs/source/api/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,17 @@ ___________________________
~flash.core.data.io.input.InputFormat
~flash.core.data.io.input.ImageLabelsMap

flash.core.data.io.classification
_________________________________

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

~flash.core.data.io.classification_input.ClassificationState
~flash.core.data.io.classification_input.ClassificationInputMixin

flash.core.data.process
_______________________

Expand Down
4 changes: 2 additions & 2 deletions docs/source/template/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ Each :class:`~flash.core.data.io.input.Input` has 2 methods:
By default these methods just return their input, so you don't need both a :meth:`~flash.core.data.io.input.Input.load_data` and a :meth:`~flash.core.data.io.input.Input.load_sample` to create a :class:`~flash.core.data.io.input.Input`.
Where possible, you should override one of our existing :class:`~flash.core.data.io.input.Input` classes.

Let's start by implementing a ``TemplateNumpyClassificationInput``, which overrides :class:`~flash.core.data.io.classification_input.ClassificationInput`.
Let's start by implementing a ``TemplateNumpyClassificationInput``, which overrides :class:`~flash.core.data.io.classification_input.ClassificationInputMixin`.
The main :class:`~flash.core.data.io.input.Input` method that we have to implement is :meth:`~flash.core.data.io.input.Input.load_data`.
:class:`~flash.core.data.io.classification_input.ClassificationInput` provides utilities for handling targets within flash which need to be called from the :meth:`~flash.core.data.io.input.Input.load_data` and :meth:`~flash.core.data.io.input.Input.load_sample`.
:class:`~flash.core.data.io.classification_input.ClassificationInputMixin` provides utilities for handling targets within flash which need to be called from the :meth:`~flash.core.data.io.input.Input.load_data` and :meth:`~flash.core.data.io.input.Input.load_sample`.
In this ``Input``, we'll also set the ``num_features`` attribute so that we can access it later.

Here's the code for our ``TemplateNumpyClassificationInput.load_data`` method:
Expand Down
6 changes: 3 additions & 3 deletions flash/audio/classification/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import numpy as np
import pandas as pd

from flash.core.data.io.classification_input import ClassificationInput, ClassificationState
from flash.core.data.io.input import DataKeys
from flash.core.data.io.classification_input import ClassificationInputMixin, ClassificationState
from flash.core.data.io.input import DataKeys, Input
from flash.core.data.utilities.classification import TargetMode
from flash.core.data.utilities.data_frame import read_csv, resolve_files, resolve_targets
from flash.core.data.utilities.paths import filter_valid_files, has_file_allowed_extension, make_dataset, PATH_TYPE
Expand All @@ -37,7 +37,7 @@ def spectrogram_loader(filepath: str):
return data


class AudioClassificationInput(ClassificationInput):
class AudioClassificationInput(Input, ClassificationInputMixin):
@requires("audio")
def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
h, w = sample[DataKeys.INPUT].shape[-2:] # H x W
Expand Down
11 changes: 5 additions & 6 deletions flash/core/data/io/classification_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from functools import lru_cache
from typing import Any, List, Optional, Sequence

from flash.core.data.io.input import Input
from flash.core.data.properties import ProcessState
from flash.core.data.properties import ProcessState, Properties
from flash.core.data.utilities.classification import (
get_target_details,
get_target_formatter,
Expand All @@ -34,9 +33,9 @@ class ClassificationState(ProcessState):
num_classes: Optional[int] = None


class ClassificationInput(Input):
"""The ``ClassificationInput`` class provides utility methods for handling classification targets.
:class:`~flash.core.data.io.input.Input` objects that extend ``ClassificationInput`` should do the following:
class ClassificationInputMixin(Properties):
"""The ``ClassificationInputMixin`` class provides utility methods for handling classification targets.
:class:`~flash.core.data.io.input.Input` objects that extend ``ClassificationInputMixin`` should do the following:

* In the ``load_data`` method, include a call to ``load_target_metadata``. This will determine the format of the
targets and store metadata like ``labels`` and ``num_classes``.
Expand All @@ -47,7 +46,7 @@ class ClassificationInput(Input):
@property
@lru_cache(maxsize=None)
def target_formatter(self) -> TargetFormatter:
"""Get the :class:`~flash.core.data.utiltiies.classification.TargetFormatter` to use when formatting
"""Get the :class:`~flash.core.data.utilities.classification.TargetFormatter` to use when formatting
targets.

This property uses ``functools.lru_cache`` so that we only instantiate the formatter once.
Expand Down
4 changes: 2 additions & 2 deletions flash/graph/classification/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch.utils.data import Dataset

from flash.core.data.data_module import DatasetInput
from flash.core.data.io.classification_input import ClassificationInput, ClassificationState
from flash.core.data.io.classification_input import ClassificationInputMixin, ClassificationState
from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _GRAPH_AVAILABLE, requires

Expand All @@ -26,7 +26,7 @@
from torch_geometric.data import InMemoryDataset


class GraphClassificationDatasetInput(DatasetInput, ClassificationInput):
class GraphClassificationDatasetInput(DatasetInput, ClassificationInputMixin):
@requires("graph")
def load_data(self, dataset: Dataset) -> Dataset:
if not self.predicting:
Expand Down
8 changes: 4 additions & 4 deletions flash/image/classification/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import pandas as pd

from flash.core.data.io.classification_input import ClassificationInput, ClassificationState
from flash.core.data.io.classification_input import ClassificationInputMixin, ClassificationState
from flash.core.data.io.input import DataKeys
from flash.core.data.utilities.classification import TargetMode
from flash.core.data.utilities.data_frame import read_csv, resolve_files, resolve_targets
Expand All @@ -34,7 +34,7 @@
SampleCollection = None


class ImageClassificationFilesInput(ClassificationInput, ImageFilesInput):
class ImageClassificationFilesInput(ClassificationInputMixin, ImageFilesInput):
def load_data(self, files: List[PATH_TYPE], targets: Optional[List[Any]] = None) -> List[Dict[str, Any]]:
if targets is None:
return super().load_data(files)
Expand Down Expand Up @@ -74,7 +74,7 @@ def predict_load_data(data: SampleCollection) -> List[Dict[str, Any]]:
return super().load_data(data.values("filepath"))


class ImageClassificationTensorInput(ClassificationInput, ImageTensorInput):
class ImageClassificationTensorInput(ClassificationInputMixin, ImageTensorInput):
def load_data(self, tensor: Any, targets: Optional[List[Any]] = None) -> List[Dict[str, Any]]:
if targets is not None:
self.load_target_metadata(targets)
Expand All @@ -87,7 +87,7 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
return sample


class ImageClassificationNumpyInput(ClassificationInput, ImageNumpyInput):
class ImageClassificationNumpyInput(ClassificationInputMixin, ImageNumpyInput):
def load_data(self, array: Any, targets: Optional[List[Any]] = None) -> List[Dict[str, Any]]:
if targets is not None:
self.load_target_metadata(targets)
Expand Down
6 changes: 3 additions & 3 deletions flash/tabular/classification/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.
from typing import Any, Dict, List, Optional, Union

from flash import DataKeys
from flash.core.data.io.classification_input import ClassificationInput
from flash.core.data.io.classification_input import ClassificationInputMixin
from flash.core.data.io.input import DataKeys
from flash.core.data.utilities.data_frame import read_csv, resolve_targets
from flash.core.utilities.imports import _PANDAS_AVAILABLE
from flash.tabular.input import TabularDataFrameInput
Expand All @@ -25,7 +25,7 @@
DataFrame = object


class TabularClassificationDataFrameInput(TabularDataFrameInput, ClassificationInput):
class TabularClassificationDataFrameInput(TabularDataFrameInput, ClassificationInputMixin):
def load_data(
self,
data_frame: DataFrame,
Expand Down
2 changes: 1 addition & 1 deletion flash/tabular/regression/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import numpy as np

from flash import DataKeys
from flash.core.data.io.input import DataKeys
from flash.core.data.utilities.data_frame import read_csv
from flash.core.utilities.imports import _PANDAS_AVAILABLE
from flash.tabular.input import TabularDataFrameInput
Expand Down
4 changes: 2 additions & 2 deletions flash/template/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from flash.core.data.callback import BaseDataFetcher
from flash.core.data.data_module import DataModule
from flash.core.data.data_pipeline import DataPipelineState
from flash.core.data.io.classification_input import ClassificationInput
from flash.core.data.io.classification_input import ClassificationInputMixin
from flash.core.data.io.input import DataKeys, Input
from flash.core.data.io.input_transform import InputTransform
from flash.core.data.utilities.samples import to_samples
Expand All @@ -34,7 +34,7 @@
Bunch = object


class TemplateNumpyClassificationInput(ClassificationInput):
class TemplateNumpyClassificationInput(Input, ClassificationInputMixin):
"""An example data source that records ``num_features`` on the dataset."""

def load_data(
Expand Down
6 changes: 3 additions & 3 deletions flash/text/classification/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

import pandas as pd

from flash.core.data.io.classification_input import ClassificationInput, ClassificationState
from flash.core.data.io.input import DataKeys
from flash.core.data.io.classification_input import ClassificationInputMixin, ClassificationState
from flash.core.data.io.input import DataKeys, Input
from flash.core.data.utilities.classification import TargetMode
from flash.core.data.utilities.paths import PATH_TYPE
from flash.core.integrations.transformers.states import TransformersBackboneState
Expand All @@ -29,7 +29,7 @@
Dataset = object


class TextClassificationInput(ClassificationInput):
class TextClassificationInput(Input, ClassificationInputMixin):
@staticmethod
def _resolve_target(target_keys: Union[str, List[str]], element: Dict[str, Any]) -> Dict[str, Any]:
if not isinstance(target_keys, List):
Expand Down