diff --git a/.gitignore b/.gitignore
index f682da63dd..7e393940c4 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,4 @@
+.DS_Store
.lock
lightning_logs
@@ -149,3 +150,4 @@ xsum
coco128
wmt_en_ro
kinetics
+movie_posters
diff --git a/docs/source/custom_task.rst b/docs/source/custom_task.rst
index 3cc818a94f..93250e98a5 100644
--- a/docs/source/custom_task.rst
+++ b/docs/source/custom_task.rst
@@ -115,13 +115,13 @@ We will define a custom ``NumpyDataModule`` class subclassing :class:`~flash.dat
This ``NumpyDataModule`` class will provide a ``from_xy_dataset`` helper ``classmethod`` to instantiate
:class:`~flash.data.data_module.DataModule` from x, y numpy arrays.
-Here is how it would look like:
+Here is how it would look:
Example::
x, y = ...
- preprocess_cls = ...
- datamodule = NumpyDataModule.from_xy_dataset(x, y, preprocess_cls)
+ preprocess = ...
+ datamodule = NumpyDataModule.from_xy_dataset(x, y, preprocess)
Here is the ``NumpyDataModule`` implementation:
@@ -140,12 +140,12 @@ Example::
cls,
x: ND,
y: ND,
- preprocess_cls: Preprocess = NumpyPreprocess,
+ preprocess: Preprocess = None,
batch_size: int = 64,
num_workers: int = 0
):
- preprocess = preprocess_cls()
+ preprocess = preprocess or NumpyPreprocess()
x_train, x_test, y_train, y_test = train_test_split(
x, y, test_size=.20, random_state=0)
@@ -180,7 +180,7 @@ It allows the user much more granular control over their data processing flow.
.. note::
- Why introducing :class:`~flash.data.process.Preprocess` ?
+ Why introduce :class:`~flash.data.process.Preprocess` ?
The :class:`~flash.data.process.Preprocess` object reduces the engineering overhead to make inference on raw data or
to deploy the model in production environnement compared to traditional
diff --git a/docs/source/general/data.rst b/docs/source/general/data.rst
index 38ca3a580b..1a7e2fe415 100644
--- a/docs/source/general/data.rst
+++ b/docs/source/general/data.rst
@@ -29,7 +29,9 @@ Here are common terms you need to be familiar with:
The :class:`~flash.data.process.Preprocess` hooks covers from data-loading to model forwarding.
* - :class:`~flash.data.process.Postprocess`
- The :class:`~flash.data.process.Postprocess` provides a simple hook-based API to encapsulate your post-processing logic.
- The :class:`~flash.data.process.Postprocess` hooks covers from model outputs to predictions export.
+ The :class:`~flash.data.process.Postprocess` hooks cover from model outputs to predictions export.
+ * - :class:`~flash.data.process.Serializer`
+ - The :class:`~flash.data.process.Serializer` provides a single ``serialize`` method that is used to convert model outputs (after the :class:`~flash.data.process.Postprocess`) to the desired output format during prediction.
*******************************************
How to use out-of-the-box flashdatamodules
@@ -49,7 +51,9 @@ However, after model training, it requires a lot of engineering overhead to make
Usually, extra processing logic should be added to bridge the gap between training data and raw data.
The :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` classes can be used to
-store the data as well as the preprocessing and postprocessing transforms.
+store the data as well as the preprocessing and postprocessing transforms. The :class:`~flash.data.process.Serializer`
+class provides the logic for converting :class:`~flash.data.process.Postprocess` outputs to the desired predict format
+(e.g. classes, labels, probabilites, etc.).
By providing a series of hooks that can be overridden with custom data processing logic,
the user has much more granular control over their data processing flow.
@@ -122,7 +126,7 @@ Example::
train_folder="data/hymenoptera_data/train/",
val_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
- preprocess_cls=CustomImageClassificationPreprocess
+ preprocess=CustomImageClassificationPreprocess(),
)
@@ -157,7 +161,7 @@ Example::
val_folder="./data/val",
test_folder="./data/test",
predict_folder="./data/predict",
- preprocess=preprocess
+ preprocess=preprocess,
)
model = ImageClassifier(...)
@@ -190,6 +194,7 @@ Example::
**kwargs
):
+ # Set a custom ``Preprocess`` if none was provided
preprocess = preprocess or cls.preprocess_cls()
# {stage}_load_data_input will be given to your
@@ -291,6 +296,18 @@ ___________
:members:
+----------
+
+.. _serializer:
+
+Serializer
+___________
+
+
+.. autoclass:: flash.data.process.Serializer
+ :members:
+
+
----------
.. _datapipeline:
@@ -414,16 +431,18 @@ Example::
predictions = lightning_module(data)
-Postprocess
-___________
+Postprocess and Serializer
+__________________________
-Once the predictions have been generated by the Flash :class:`~flash.core.model.Task`.
-The Flash :class:`~flash.data.data_pipeline.DataPipeline` will behind the scenes execute the :class:`~flash.data.process.Postprocess` hooks.
+Once the predictions have been generated by the Flash :class:`~flash.core.model.Task`, the Flash
+:class:`~flash.data.data_pipeline.DataPipeline` will execute the :class:`~flash.data.process.Postprocess` hooks and the
+:class:`~flash.data.process.Serializer` behind the scenes.
-First, the ``per_batch_transform`` hooks will be applied on the batch predictions.
-Then the ``uncollate`` will split the batch into individual predictions.
-Finally, the ``per_sample_transform`` will be applied on each prediction.
+First, the :meth:`~flash.data.process.Postprocess.per_batch_transform` hooks will be applied on the batch predictions.
+Then, the :meth:`~flash.data.process.Postprocess.uncollate` will split the batch into individual predictions.
+Next, the :meth:`~flash.data.process.Postprocess.per_sample_transform` will be applied on each prediction.
+Finally, the :meth:`~flash.data.process.Serializer.serialize` method will be called to serialize the predictions.
.. note:: The transform can be applied either on device or ``CPU``.
@@ -438,7 +457,9 @@ Example::
samples = uncollate(batch)
- return [per_sample_transform(sample) for sample in samples]
+ samples = [per_sample_transform(sample) for sample in samples]
+ # only if serializers are enabled.
+ return [serialize(sample) for sample in samples]
predictions = lightning_module(data)
return uncollate_fn(predictions)
diff --git a/docs/source/index.rst b/docs/source/index.rst
index b40b69e82b..5cc7636482 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -22,6 +22,7 @@ Lightning Flash
reference/task
reference/image_classification
reference/image_embedder
+ reference/multi_label_classification
reference/summarization
reference/text_classification
reference/tabular_classification
diff --git a/docs/source/reference/multi_label_classification.rst b/docs/source/reference/multi_label_classification.rst
new file mode 100644
index 0000000000..e31867d159
--- /dev/null
+++ b/docs/source/reference/multi_label_classification.rst
@@ -0,0 +1,212 @@
+
+.. _multi_label_classification:
+
+################################
+Multi-label Image Classification
+################################
+
+********
+The task
+********
+Multi-label classification is the task of assigning a number of labels from a fixed set to each data point, which can be in any modality. In this example, we will look at the task of trying to predict the movie genres from an image of the movie poster.
+
+------
+
+********
+The data
+********
+The data we will use in this example is a subset of the awesome movie poster genre prediction data set from the paper "Movie Genre Classification based on Poster Images with Deep Neural Networks" by Wei-Ta Chu and Hung-Jui Guo, resized to 128 by 128.
+Take a look at their paper (and please consider citing their paper if you use the data) here: `www.cs.ccu.edu.tw/~wtchu/projects/MoviePoster/ `_.
+
+------
+
+*********
+Inference
+*********
+
+The :class:`~flash.vision.ImageClassifier` is already pre-trained on `ImageNet `_, a dataset of over 14 million images.
+
+We can use the :class:`~flash.vision.ImageClassifier` model (pretrained on our data) for inference on any string sequence using :func:`~flash.vision.ImageClassifier.predict`.
+We can also add a simple visualisation by extending :class:`~flash.data.base_viz.BaseVisualization`, like this:
+
+.. code-block:: python
+
+ # import our libraries
+ from typing import Any
+
+ import torchvision.transforms.functional as T
+ from torchvision.utils import make_grid
+
+ from flash import Trainer
+ from flash.data.base_viz import BaseVisualization
+ from flash.data.utils import download_data
+ from flash.vision import ImageClassificationData, ImageClassifier
+
+ # 1. Download the data
+ download_data("https://pl-flash-data.s3.amazonaws.com/movie_posters.zip", "data/")
+
+ # 2. Define our custom visualisation and datamodule
+ class CustomViz(BaseVisualization):
+
+ def show_per_batch_transform(self, batch: Any, _):
+ images = batch[0]
+ image = make_grid(images, nrow=2)
+ image = T.to_pil_image(image, 'RGB')
+ image.show()
+
+
+ # 3. Load the model from a checkpoint
+ model = ImageClassifier.load_from_checkpoint(
+ "https://flash-weights.s3.amazonaws.com/image_classification_multi_label_model.pt",
+ )
+
+ # 4a. Predict the genres of a few movie posters!
+ predictions = model.predict([
+ "data/movie_posters/val/tt0361500.jpg",
+ "data/movie_posters/val/tt0361748.jpg",
+ "data/movie_posters/val/tt0362478.jpg",
+ ])
+ print(predictions)
+
+ # 4b. Or generate predictions with a whole folder!
+ datamodule = ImageClassificationData.from_folders(
+ predict_folder="data/movie_posters/predict/",
+ data_fetcher=CustomViz(),
+ preprocess=model.preprocess,
+ )
+
+ predictions = Trainer().predict(model, datamodule=datamodule)
+ print(predictions)
+
+ # 5. Show some data!
+ datamodule.show_predict_batch()
+
+For more advanced inference options, see :ref:`predictions`.
+
+------
+
+**********
+Finetuning
+**********
+
+Now let's look at how we can finetune a model on the movie poster data.
+Once we download the data using :func:`~flash.data.download_data`, all we need is the train data and validation data folders to create the :class:`~flash.vision.ImageClassificationData`.
+
+.. note:: The dataset contains ``train`` and ``validation`` folders, and then each folder contains images and a ``metadata.csv`` which stores the labels.
+
+.. code-block::
+
+ movie_posters
+ ├── train
+ │ ├── metadata.csv
+ │ ├── tt0084058.jpg
+ │ ├── tt0084867.jpg
+ │ ...
+ └── val
+ ├── metadata.csv
+ ├── tt0200465.jpg
+ ├── tt0326965.jpg
+ ...
+
+
+The ``metadata.csv`` files in each folder contain our labels, so we need to create a function (``load_data``) to extract the list of images and associated labels:
+
+.. code-block:: python
+
+ # import our libraries
+ import os
+ from typing import List, Tuple
+
+ import pandas as pd
+ import torch
+
+ genres = [
+ "Action", "Adventure", "Animation", "Biography", "Comedy", "Crime", "Documentary", "Drama", "Family", "Fantasy", "History", "Horror", "Music", "Musical", "Mystery", "N/A", "News", "Reality-TV", "Romance", "Sci-Fi", "Short", "Sport", "Thriller", "War", "Western"
+ ]
+
+ def load_data(data: str, root: str = 'data/movie_posters') -> Tuple[List[str], List[List[int]]]:
+ metadata = pd.read_csv(os.path.join(root, data, "metadata.csv"))
+
+ images = []
+ labels = []
+ for _, row in metadata.iterrows():
+ images.append(os.path.join(root, data, row['Id'] + ".jpg"))
+ labels.append([int(row[genre]) for genre in genres])
+
+ return images, labels
+
+Our :class:`~flash.data.process.Preprocess` overrides the :meth:`~flash.data.process.Preprocess.load_data` method to create an iterable of image paths and label tensors. The :class:`~flash.vision.classification.data.ImageClassificationPreprocess` then handles loading and augmenting the images for us!
+Now all we need is three lines of code to build to train our task!
+
+.. note:: We need set `multi_label=True` in both our :class:`~flash.Task` and our :class:`~flash.data.process.Serializer` to use a binary cross entropy loss and to process outputs correctly.
+
+.. code-block:: python
+
+ import flash
+ from flash.core.classification import Labels
+ from flash.core.finetuning import FreezeUnfreeze
+ from flash.data.utils import download_data
+ from flash.vision import ImageClassificationData, ImageClassifier
+ from flash.vision.classification.data import ImageClassificationPreprocess
+
+ # 1. Download the data
+ download_data("https://pl-flash-data.s3.amazonaws.com/movie_posters.zip", "data/")
+
+ # 2. Load the data
+ ImageClassificationPreprocess.image_size = (128, 128)
+
+ train_filepaths, train_labels = load_data('train')
+ val_filepaths, val_labels = load_data('val')
+ test_filepaths, test_labels = load_data('test')
+
+ datamodule = ImageClassificationData.from_filepaths(
+ train_filepaths=train_filepaths,
+ train_labels=train_labels,
+ val_filepaths=val_filepaths,
+ val_labels=val_labels,
+ test_filepaths=test_filepaths,
+ test_labels=test_labels,
+ preprocess=ImageClassificationPreprocess(),
+ )
+
+ # 3. Build the model
+ model = ImageClassifier(
+ backbone="resnet18",
+ num_classes=len(genres),
+ multi_label=True,
+ )
+
+ # 4. Create the trainer.
+ trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1)
+
+ # 5. Train the model
+ trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1))
+
+ # 6a. Predict what's on a few images!
+
+ # Serialize predictions as labels.
+ model.serializer = Labels(genres, multi_label=True)
+
+ predictions = model.predict([
+ "data/movie_posters/val/tt0361500.jpg",
+ "data/movie_posters/val/tt0361748.jpg",
+ "data/movie_posters/val/tt0362478.jpg",
+ ])
+
+ print(predictions)
+
+ datamodule = ImageClassificationData.from_folders(
+ predict_folder="data/movie_posters/predict/",
+ preprocess=model.preprocess,
+ )
+
+ # 6b. Or generate predictions with a whole folder!
+ predictions = trainer.predict(model, datamodule=datamodule)
+ print(predictions)
+
+ # 7. Save it!
+ trainer.save_checkpoint("image_classification_multi_label_model.pt")
+
+------
+
+For more backbone options, see :ref:`image_classification`.
diff --git a/flash/core/classification.py b/flash/core/classification.py
index 9650eadc34..7e88cc3ded 100644
--- a/flash/core/classification.py
+++ b/flash/core/classification.py
@@ -11,36 +11,124 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, List, Optional
+from dataclasses import dataclass
+from typing import Any, List, Mapping, Optional, Union
import torch
import torch.nn.functional as F
+from pytorch_lightning.utilities import rank_zero_warn
from flash.core.model import Task
-from flash.data.process import Postprocess, Preprocess
+from flash.data.process import ProcessState, Serializer
-class ClassificationPostprocess(Postprocess):
+@dataclass(unsafe_hash=True, frozen=True)
+class ClassificationState(ProcessState):
- def __init__(self, multi_label: bool = False, save_path: Optional[str] = None):
- super().__init__(save_path=save_path)
- self.multi_label = multi_label
-
- def per_sample_transform(self, samples: Any) -> List[Any]:
- if self.multi_label:
- return F.sigmoid(samples).tolist()
- else:
- return torch.argmax(samples, -1).tolist()
+ labels: Optional[List[str]]
class ClassificationTask(Task):
- postprocess_cls = ClassificationPostprocess
-
- def __init__(self, *args, postprocess: Optional[Preprocess] = None, **kwargs):
- super().__init__(*args, postprocess=postprocess or self.postprocess_cls(), **kwargs)
+ def __init__(
+ self,
+ *args,
+ serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(*args, serializer=serializer or Classes(), **kwargs)
def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor:
if getattr(self.hparams, "multi_label", False):
return F.sigmoid(x)
return F.softmax(x, -1)
+
+
+class ClassificationSerializer(Serializer):
+ """A base class for classification serializers.
+
+ Args:
+ multi_label: If true, treats outputs as multi label logits.
+ """
+
+ def __init__(self, multi_label: bool = False):
+ super().__init__()
+
+ self._mutli_label = multi_label
+
+ @property
+ def multi_label(self) -> bool:
+ return self._mutli_label
+
+
+class Logits(ClassificationSerializer):
+ """A :class:`.Serializer` which simply converts the model outputs (assumed to be logits) to a list."""
+
+ def serialize(self, sample: Any) -> Any:
+ return sample.tolist()
+
+
+class Probabilities(ClassificationSerializer):
+ """A :class:`.Serializer` which applies a softmax to the model outputs (assumed to be logits) and converts to a
+ list."""
+
+ def serialize(self, sample: Any) -> Any:
+ if self.multi_label:
+ return torch.sigmoid(sample).tolist()
+ return torch.softmax(sample, -1).tolist()
+
+
+class Classes(ClassificationSerializer):
+ """A :class:`.Serializer` which applies an argmax to the model outputs (either logits or probabilities) and
+ converts to a list."""
+
+ def __init__(self, multi_label: bool = False, threshold: float = 0.5):
+ super().__init__(multi_label)
+
+ self.threshold = threshold
+
+ def serialize(self, sample: Any) -> Union[int, List[int]]:
+ if self.multi_label:
+ one_hot = (sample.sigmoid() > self.threshold).int().tolist()
+ result = []
+ for index, value in enumerate(one_hot):
+ if value == 1:
+ result.append(index)
+ return result
+ return torch.argmax(sample, -1).tolist()
+
+
+class Labels(Classes):
+ """A :class:`.Serializer` which converts the model outputs (either logits or probabilities) to the label of the
+ argmax classification.
+
+ Args:
+ labels: A list of labels, assumed to map the class index to the label for that class. If ``labels`` is not
+ provided, will attempt to get them from the :class:`.ClassificationState`.
+ """
+
+ def __init__(self, labels: Optional[List[str]] = None, multi_label: bool = False):
+ super().__init__(multi_label=multi_label)
+ self._labels = labels
+
+ def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]:
+ labels = None
+
+ if self._labels is not None:
+ labels = self._labels
+ else:
+ state = self.get_state(ClassificationState)
+ if state is not None:
+ labels = state.labels
+
+ classes = super().serialize(sample)
+
+ if labels is not None:
+ if self.multi_label:
+ return [labels[cls] for cls in classes]
+ return labels[classes]
+ else:
+ rank_zero_warn(
+ "No ClassificationState was found, this serializer will act as a Classes serializer.", UserWarning
+ )
+ return classes
diff --git a/flash/core/model.py b/flash/core/model.py
index 02dc367932..b2bb555816 100644
--- a/flash/core/model.py
+++ b/flash/core/model.py
@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
-import inspect
-from copy import deepcopy
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union
import torch
@@ -29,7 +27,8 @@
from flash.core.registry import FlashRegistry
from flash.core.schedulers import _SCHEDULERS_REGISTRY
from flash.core.utils import get_callable_dict
-from flash.data.data_pipeline import DataPipeline, Postprocess, Preprocess
+from flash.data.data_pipeline import DataPipeline
+from flash.data.process import Postprocess, Preprocess, Serializer, SerializerMapping
def predict_context(func: Callable) -> Callable:
@@ -80,8 +79,9 @@ def __init__(
scheduler_kwargs: Optional[Dict[str, Any]] = None,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
learning_rate: float = 5e-5,
- preprocess: Preprocess = None,
- postprocess: Postprocess = None,
+ preprocess: Optional[Preprocess] = None,
+ postprocess: Optional[Postprocess] = None,
+ serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
):
super().__init__()
if model is not None:
@@ -97,8 +97,12 @@ def __init__(
# TODO: should we save more? Bug on some regarding yaml if we save metrics
self.save_hyperparameters("learning_rate", "optimizer")
- self._preprocess = preprocess
- self._postprocess = postprocess
+ self._preprocess: Optional[Preprocess] = preprocess
+ self._postprocess: Optional[Postprocess] = postprocess
+ self._serializer: Optional[Serializer] = None
+
+ # Explicitly set the serializer to call the setter
+ self.serializer = serializer
def step(self, batch: Any, batch_idx: int) -> Any:
"""
@@ -197,22 +201,27 @@ def configure_finetune_callback(self) -> List[Callback]:
def _resolve(
old_preprocess: Optional[Preprocess],
old_postprocess: Optional[Postprocess],
+ old_serializer: Optional[Serializer],
new_preprocess: Optional[Preprocess],
new_postprocess: Optional[Postprocess],
- ) -> Tuple[Optional[Preprocess], Optional[Postprocess]]:
- """Resolves the correct :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` to use,
- choosing ``new_*`` if it is not None or a base class
- (:class:`~flash.data.process.Preprocess` or :class:`~flash.data.process.Postprocess`)
- and ``old_*`` otherwise.
+ new_serializer: Optional[Serializer],
+ ) -> Tuple[Optional[Preprocess], Optional[Postprocess], Optional[Serializer]]:
+ """Resolves the correct :class:`~flash.data.process.Preprocess`, :class:`~flash.data.process.Postprocess`, and
+ :class:`~flash.data.process.Serializer` to use, choosing ``new_*`` if it is not None or a base class
+ (:class:`~flash.data.process.Preprocess`, :class:`~flash.data.process.Postprocess`, or
+ :class:`~flash.data.process.Serializer`) and ``old_*`` otherwise.
Args:
old_preprocess: :class:`~flash.data.process.Preprocess` to be overridden.
old_postprocess: :class:`~flash.data.process.Postprocess` to be overridden.
+ old_serializer: :class:`~flash.data.process.Serializer` to be overridden.
new_preprocess: :class:`~flash.data.process.Preprocess` to override with.
new_postprocess: :class:`~flash.data.process.Postprocess` to override with.
+ new_serializer: :class:`~flash.data.process.Serializer` to override with.
Returns:
- The resolved :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`.
+ The resolved :class:`~flash.data.process.Preprocess`, :class:`~flash.data.process.Postprocess`, and
+ :class:`~flash.data.process.Serializer`.
"""
preprocess = old_preprocess
if new_preprocess is not None and type(new_preprocess) != Preprocess:
@@ -222,7 +231,23 @@ def _resolve(
if new_postprocess is not None and type(new_postprocess) != Postprocess:
postprocess = new_postprocess
- return preprocess, postprocess
+ serializer = old_serializer
+ if new_serializer is not None and type(new_serializer) != Serializer:
+ serializer = new_serializer
+
+ return preprocess, postprocess, serializer
+
+ @property
+ def serializer(self) -> Optional[Serializer]:
+ """The current :class:`.Serializer` associated with this model. If this property was set to a mapping
+ (e.g. ``.serializer = {'output1': SerializerOne()}``) then this will be a :class:`.MappingSerializer`."""
+ return self._serializer
+
+ @serializer.setter
+ def serializer(self, serializer: Union[Serializer, Mapping[str, Serializer]]):
+ if isinstance(serializer, Mapping):
+ serializer = SerializerMapping(serializer)
+ self._serializer = serializer
def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> Optional[DataPipeline]:
"""Build a :class:`.DataPipeline` incorporating available
@@ -241,32 +266,45 @@ def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> O
Returns:
The fully resolved :class:`.DataPipeline`.
"""
- preprocess, postprocess = None, None
+ preprocess, postprocess, serializer = None, None, None
# Datamodule
if self.datamodule is not None and getattr(self.datamodule, 'data_pipeline', None) is not None:
preprocess = getattr(self.datamodule.data_pipeline, '_preprocess_pipeline', None)
postprocess = getattr(self.datamodule.data_pipeline, '_postprocess_pipeline', None)
+ serializer = getattr(self.datamodule.data_pipeline, '_serializer', None)
elif self.trainer is not None and hasattr(
self.trainer, 'datamodule'
) and getattr(self.trainer.datamodule, 'data_pipeline', None) is not None:
preprocess = getattr(self.trainer.datamodule.data_pipeline, '_preprocess_pipeline', None)
postprocess = getattr(self.trainer.datamodule.data_pipeline, '_postprocess_pipeline', None)
+ serializer = getattr(self.trainer.datamodule.data_pipeline, '_serializer', None)
# Defaults / task attributes
- preprocess, postprocess = Task._resolve(preprocess, postprocess, self._preprocess, self._postprocess)
+ preprocess, postprocess, serializer = Task._resolve(
+ preprocess,
+ postprocess,
+ serializer,
+ self._preprocess,
+ self._postprocess,
+ self.serializer,
+ )
# Datapipeline
if data_pipeline is not None:
- preprocess, postprocess = Task._resolve(
+ preprocess, postprocess, serializer = Task._resolve(
preprocess,
postprocess,
+ serializer,
getattr(data_pipeline, '_preprocess_pipeline', None),
getattr(data_pipeline, '_postprocess_pipeline', None),
+ getattr(data_pipeline, '_serializer', None),
)
- return DataPipeline(preprocess, postprocess)
+ data_pipeline = DataPipeline(preprocess, postprocess, serializer)
+ data_pipeline.initialize()
+ return data_pipeline
@property
def data_pipeline(self) -> DataPipeline:
@@ -276,13 +314,23 @@ def data_pipeline(self) -> DataPipeline:
@data_pipeline.setter
def data_pipeline(self, data_pipeline: Optional[DataPipeline]) -> None:
- self._preprocess, self._postprocess = Task._resolve(
+ self._preprocess, self._postprocess, self.serializer = Task._resolve(
self._preprocess,
self._postprocess,
+ self.serializer,
getattr(data_pipeline, '_preprocess_pipeline', None),
getattr(data_pipeline, '_postprocess_pipeline', None),
+ getattr(data_pipeline, '_serializer', None),
)
+ @property
+ def preprocess(self) -> Preprocess:
+ return getattr(self.data_pipeline, '_preprocess_pipeline', None)
+
+ @property
+ def postprocess(self) -> Postprocess:
+ return getattr(self.data_pipeline, '_postprocess_pipeline', None)
+
def on_train_dataloader(self) -> None:
if self.data_pipeline is not None:
self.data_pipeline._detach_from_model(self, RunningStage.TRAINING)
diff --git a/flash/data/batch.py b/flash/data/batch.py
index 3758d78a66..ea6ce1e9ca 100644
--- a/flash/data/batch.py
+++ b/flash/data/batch.py
@@ -189,6 +189,7 @@ def __init__(
uncollate_fn: Callable,
per_batch_transform: Callable,
per_sample_transform: Callable,
+ serializer: Optional[Callable],
save_fn: Optional[Callable] = None,
save_per_sample: bool = False
):
@@ -196,13 +197,14 @@ def __init__(
self.uncollate_fn = convert_to_modules(uncollate_fn)
self.per_batch_transform = convert_to_modules(per_batch_transform)
self.per_sample_transform = convert_to_modules(per_sample_transform)
+ self.serializer = convert_to_modules(serializer)
self.save_fn = convert_to_modules(save_fn)
self.save_per_sample = convert_to_modules(save_per_sample)
def forward(self, batch: Sequence[Any]):
uncollated = self.uncollate_fn(self.per_batch_transform(batch))
- final_preds = type(uncollated)([self.per_sample_transform(sample) for sample in uncollated])
+ final_preds = type(uncollated)([self.serializer(self.per_sample_transform(sample)) for sample in uncollated])
if self.save_fn:
if self.save_per_sample:
diff --git a/flash/data/callback.py b/flash/data/callback.py
index df8ad91600..a479a6e59e 100644
--- a/flash/data/callback.py
+++ b/flash/data/callback.py
@@ -102,7 +102,7 @@ def from_inputs(
test_data: Any,
predict_data: Any) -> "CustomDataModule":
- preprocess = cls.preprocess_cls()
+ preprocess = CustomPreprocess()
return cls.from_load_data_inputs(
train_load_data_input=train_data,
diff --git a/flash/data/data_module.py b/flash/data/data_module.py
index c8986ad024..fea3e7bbc3 100644
--- a/flash/data/data_module.py
+++ b/flash/data/data_module.py
@@ -13,7 +13,7 @@
# limitations under the License.
import os
import platform
-from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
import pytorch_lightning as pl
import torch
@@ -37,9 +37,8 @@ class DataModule(pl.LightningDataModule):
train_dataset: Dataset for training. Defaults to None.
val_dataset: Dataset for validating model performance during training. Defaults to None.
test_dataset: Dataset to test model performance. Defaults to None.
- predict_dataset: Dataset to predict model performance. Defaults to None.
+ predict_dataset: Dataset for predicting. Defaults to None.
num_workers: The number of workers to use for parallelized loading. Defaults to None.
- predict_ds: Dataset for predicting. Defaults to None.
batch_size: The batch size to be used by the DataLoader. Defaults to 1.
num_workers: The number of workers to use for parallelized loading.
Defaults to None which equals the number of available CPU threads,
@@ -81,7 +80,10 @@ def __init__(
# TODO: figure out best solution for setting num_workers
if num_workers is None:
- num_workers = 0 if platform.system() == "Darwin" else os.cpu_count()
+ if platform.system() == "Darwin" or platform.system() == "Windows":
+ num_workers = 0
+ else:
+ num_workers = os.cpu_count()
self.num_workers = num_workers
self._preprocess: Optional[Preprocess] = None
@@ -310,16 +312,18 @@ def autogenerate_dataset(
or from the provided ``whole_data_load_fn``, ``per_sample_load_fn`` functions directly
"""
+ preprocess = getattr(data_pipeline, '_preprocess_pipeline', None)
+
if whole_data_load_fn is None:
whole_data_load_fn = getattr(
- cls.preprocess_cls,
- DataPipeline._resolve_function_hierarchy('load_data', cls.preprocess_cls, running_stage, Preprocess)
+ preprocess,
+ DataPipeline._resolve_function_hierarchy('load_data', preprocess, running_stage, Preprocess)
)
if per_sample_load_fn is None:
per_sample_load_fn = getattr(
- cls.preprocess_cls,
- DataPipeline._resolve_function_hierarchy('load_sample', cls.preprocess_cls, running_stage, Preprocess)
+ preprocess,
+ DataPipeline._resolve_function_hierarchy('load_sample', preprocess, running_stage, Preprocess)
)
if use_iterable_auto_dataset:
return IterableAutoDataset(
@@ -424,6 +428,7 @@ def from_load_data_inputs(
val_load_data_input: Optional[Any] = None,
test_load_data_input: Optional[Any] = None,
predict_load_data_input: Optional[Any] = None,
+ data_fetcher: BaseDataFetcher = None,
preprocess: Optional[Preprocess] = None,
postprocess: Optional[Postprocess] = None,
use_iterable_auto_dataset: bool = False,
@@ -453,7 +458,7 @@ def from_load_data_inputs(
else:
data_pipeline = cls(**kwargs).data_pipeline
- data_fetcher: BaseDataFetcher = cls.configure_data_fetcher()
+ data_fetcher: BaseDataFetcher = data_fetcher or cls.configure_data_fetcher()
data_fetcher.attach_to_preprocess(data_pipeline._preprocess_pipeline)
diff --git a/flash/data/data_pipeline.py b/flash/data/data_pipeline.py
index fe75404f1c..72389b4996 100644
--- a/flash/data/data_pipeline.py
+++ b/flash/data/data_pipeline.py
@@ -19,20 +19,47 @@
import torch
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
from pytorch_lightning.trainer.states import RunningStage
-from pytorch_lightning.utilities import imports
+from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data import DataLoader, IterableDataset
from torch.utils.data._utils.collate import default_collate, default_convert
from flash.data.auto_dataset import AutoDataset, IterableAutoDataset
from flash.data.batch import _PostProcessor, _PreProcessor, _Sequential
-from flash.data.process import Postprocess, Preprocess
+from flash.data.process import Postprocess, Preprocess, ProcessState, Serializer
from flash.data.utils import _POSTPROCESS_FUNCS, _PREPROCESS_FUNCS, _STAGES_PREFIX
if TYPE_CHECKING:
from flash.core.model import Task
+class DataPipelineState:
+ """A class to store and share all process states once a :class:`.DataPipeline` has been initialized."""
+
+ def __init__(self):
+ self._state: Dict[Type[ProcessState], ProcessState] = {}
+ self._initialized = False
+
+ def set_state(self, state: ProcessState):
+ """Add the given :class:`.ProcessState` to the :class:`.DataPipelineState`."""
+
+ if not self._initialized:
+ self._state[type(state)] = state
+ else:
+ rank_zero_warn(
+ f"Attempted to add a state ({state}) after the data pipeline has already been initialized. This will"
+ " only have an effect when a new data pipeline is created.", UserWarning
+ )
+
+ def get_state(self, state_type: Type[ProcessState]) -> Optional[ProcessState]:
+ """Get the :class:`.ProcessState` of the given type from the :class:`.DataPipelineState`."""
+
+ if state_type in self._state:
+ return self._state[state_type]
+ else:
+ return None
+
+
class DataPipeline:
"""
DataPipeline holds the engineering logic to connect
@@ -59,12 +86,29 @@ class CustomPostprocess(Postprocess):
PREPROCESS_FUNCS: Set[str] = _PREPROCESS_FUNCS
POSTPROCESS_FUNCS: Set[str] = _POSTPROCESS_FUNCS
- def __init__(self, preprocess: Optional[Preprocess] = None, postprocess: Optional[Postprocess] = None) -> None:
+ def __init__(
+ self,
+ preprocess: Optional[Preprocess] = None,
+ postprocess: Optional[Postprocess] = None,
+ serializer: Optional[Serializer] = None,
+ ) -> None:
self._preprocess_pipeline = preprocess or Preprocess()
self._postprocess_pipeline = postprocess or Postprocess()
- self._postprocessor = None
+
+ self._serializer = serializer or Serializer()
+
self._running_stage = None
+ def initialize(self):
+ """Creates the :class:`.DataPipelineState` and gives the reference to the: :class:`.Preprocess`,
+ :class:`.Postprocess`, and :class:`.Serializer`. Once this has been called, any attempt to add new state will
+ give a warning."""
+ data_pipeline_state = DataPipelineState()
+ self._preprocess_pipeline.attach_data_pipeline_state(data_pipeline_state)
+ self._postprocess_pipeline.attach_data_pipeline_state(data_pipeline_state)
+ self._serializer.attach_data_pipeline_state(data_pipeline_state)
+ data_pipeline_state._initialized = True
+
@staticmethod
def _is_overriden(method_name: str, process_obj, super_obj: Any, prefix: Optional[str] = None) -> bool:
"""
@@ -79,11 +123,6 @@ def _is_overriden(method_name: str, process_obj, super_obj: Any, prefix: Optiona
return getattr(process_obj, current_method_name).__code__ != getattr(super_obj, method_name).__code__
- @property
- def preprocess_state(self):
- if self._preprocess_pipeline:
- return self._preprocess_pipeline.state
-
@classmethod
def _is_overriden_recursive(
cls, method_name: str, process_obj, super_obj: Any, prefix: Optional[str] = None
@@ -182,7 +221,7 @@ def _create_collate_preprocessors(
)
collate_in_worker_from_transform: Optional[bool] = getattr(
- preprocess, f"_{prefix}_collate_in_worker_from_transform"
+ preprocess, f"_{prefix}_collate_in_worker_from_transform", None
)
if (
@@ -374,6 +413,7 @@ def _create_uncollate_postprocessors(self, stage: RunningStage) -> _PostProcesso
getattr(postprocess, func_names["uncollate"]),
getattr(postprocess, func_names["per_batch_transform"]),
getattr(postprocess, func_names["per_sample_transform"]),
+ serializer=self._serializer,
save_fn=save_fn,
save_per_sample=save_per_sample
)
diff --git a/flash/data/process.py b/flash/data/process.py
index 542ae8f3dc..128fb01b5c 100644
--- a/flash/data/process.py
+++ b/flash/data/process.py
@@ -13,7 +13,7 @@
# limitations under the License.
import os
from dataclasses import dataclass
-from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union
+from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, TYPE_CHECKING, TypeVar, Union
import torch
from pytorch_lightning.trainer.states import RunningStage
@@ -26,11 +26,46 @@
from flash.data.callback import FlashCallback
from flash.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX, convert_to_modules
+if TYPE_CHECKING:
+ from flash.data.data_pipeline import DataPipelineState
+
+
+@dataclass(unsafe_hash=True, frozen=True)
+class ProcessState:
+ """
+ Base class for all process states
+ """
+ pass
+
+
+STATE_TYPE = TypeVar('STATE_TYPE', bound=ProcessState)
+
class Properties:
- _running_stage: Optional[RunningStage] = None
- _current_fn: Optional[str] = None
+ def __init__(self):
+ super().__init__()
+
+ self._running_stage: Optional[RunningStage] = None
+ self._current_fn: Optional[str] = None
+ self._data_pipeline_state: Optional['DataPipelineState'] = None
+ self._state: Dict[Type[ProcessState], ProcessState] = {}
+
+ def get_state(self, state_type: Type[STATE_TYPE]) -> Optional[STATE_TYPE]:
+ if self._data_pipeline_state is not None:
+ return self._data_pipeline_state.get_state(state_type)
+ else:
+ return None
+
+ def set_state(self, state: ProcessState):
+ self._state[type(state)] = state
+ if self._data_pipeline_state is not None:
+ self._data_pipeline_state.set_state(state)
+
+ def attach_data_pipeline_state(self, data_pipeline_state: 'DataPipelineState'):
+ self._data_pipeline_state = data_pipeline_state
+ for state in self._state.values():
+ self._data_pipeline_state.set_state(state)
@property
def current_fn(self) -> Optional[str]:
@@ -93,14 +128,6 @@ def validating(self, val: bool) -> None:
self._running_stage = None
-@dataclass(unsafe_hash=True, frozen=True)
-class PreprocessState:
- """
- Base class for all preprocess states
- """
- pass
-
-
class Preprocess(Properties, Module):
"""
The :class:`~flash.data.process.Preprocess` encapsulates
@@ -351,10 +378,6 @@ def current_transform(self) -> Callable:
else:
return self._identity
- @classmethod
- def from_state(cls, state: PreprocessState) -> 'Preprocess':
- return cls(**vars(state))
-
@property
def callbacks(self) -> List['FlashCallback']:
if not hasattr(self, "_callbacks"):
@@ -483,3 +506,57 @@ def _save_data(self, data: Any) -> None:
def _save_sample(self, sample: Any) -> None:
self.save_sample(sample, self.format_sample_save_path(self._save_path))
+
+
+class Serializer(Properties):
+ """A :class:`.Serializer` encapsulates a single ``serialize`` method which is used to convert the model ouptut into
+ the desired output format when predicting."""
+
+ def __init__(self):
+ super().__init__()
+ self._is_enabled = True
+
+ def enable(self):
+ """Enable serialization."""
+ self._is_enabled = True
+
+ def disable(self):
+ """Disable serialization."""
+ self._is_enabled = False
+
+ def serialize(self, sample: Any) -> Any:
+ """Serialize the given sample into the desired output format.
+
+ Args:
+ sample: The output from the :class:`.Postprocess`.
+
+ Returns:
+ The serialized output.
+ """
+ return sample
+
+ def __call__(self, sample: Any) -> Any:
+ if self._is_enabled:
+ return self.serialize(sample)
+ else:
+ return sample
+
+
+class SerializerMapping(Serializer):
+ """If the model output is a dictionary, then the :class:`.SerializerMapping` enables each entry in the dictionary
+ to be passed to it's own :class:`.Serializer`."""
+
+ def __init__(self, serializers: Mapping[str, Serializer]):
+ super().__init__()
+
+ self._serializers = serializers
+
+ def serialize(self, sample: Any) -> Any:
+ if isinstance(sample, Mapping):
+ return {key: serializer.serialize(sample[key]) for key, serializer in self._serializers.items()}
+ else:
+ raise ValueError("The model output must be a mapping when using a SerializerMapping.")
+
+ def attach_data_pipeline_state(self, data_pipeline_state: 'DataPipelineState'):
+ for serializer in self._serializers.values():
+ serializer.attach_data_pipeline_state(data_pipeline_state)
diff --git a/flash/tabular/classification/data/data.py b/flash/tabular/classification/data/data.py
index 58f583e524..d148896ac4 100644
--- a/flash/tabular/classification/data/data.py
+++ b/flash/tabular/classification/data/data.py
@@ -11,77 +11,32 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from dataclasses import dataclass
-from typing import Any, Callable, Dict, List, Optional, Type, Union
+from typing import Dict, List, Optional, Union
import numpy as np
import pandas as pd
from pandas.core.frame import DataFrame
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from sklearn.model_selection import train_test_split
+from torch.utils.data import Dataset
+from flash.core.classification import ClassificationState
from flash.data.auto_dataset import AutoDataset
from flash.data.data_module import DataModule
-from flash.data.process import Preprocess, PreprocessState
+from flash.data.process import Preprocess
from flash.tabular.classification.data.dataset import (
_compute_normalization,
- _dfs_to_samples,
_generate_codes,
- _impute,
_pre_transform,
_to_cat_vars_numpy,
_to_num_vars_numpy,
- PandasDataset,
)
-@dataclass(unsafe_hash=True, frozen=True)
-class TabularState(PreprocessState):
- cat_cols: List[str] # categorical columns used for training
- num_cols: List[str] # numerical columns used for training
- target_col: str # target column name used for training
- mean: DataFrame # mean DataFrame for categorical columsn on train DataFrame
- std: DataFrame # std DataFrame for categorical columsn on train DataFrame
- codes: Dict # codes for numerical columns used for training
- target_codes: Dict # target codes for target used for training
- num_classes: int # number of classes used for training
- is_regression: bool # whether the task was a is_regression
-
-
class TabularPreprocess(Preprocess):
def __init__(
self,
- cat_cols: List[str],
- num_cols: List[str],
- target_col: str,
- mean: DataFrame,
- std: DataFrame,
- codes: Dict,
- target_codes: Dict,
- num_classes: int,
- is_regression: bool = False,
- ):
- super().__init__()
- self.cat_cols = cat_cols
- self.num_cols = num_cols
- self.target_col = target_col
- self.mean = mean
- self.std = std
- self.codes = codes
- self.target_codes = target_codes
- self.num_classes = num_classes
- self.is_regression = is_regression
-
- @property
- def state(self) -> TabularState:
- return TabularState(
- self.cat_cols, self.num_cols, self.target_col, self.mean, self.std, self.codes, self.target_codes,
- self.num_classes, self.is_regression
- )
-
- @staticmethod
- def generate_state(
train_df: DataFrame,
val_df: Optional[DataFrame],
test_df: Optional[DataFrame],
@@ -90,13 +45,10 @@ def generate_state(
num_cols: List[str],
cat_cols: List[str],
is_regression: bool,
- preprocess_state: Optional[TabularState] = None
):
- if preprocess_state is not None:
- return preprocess_state
-
+ super().__init__()
if train_df is None:
- raise MisconfigurationException("train_df is required to compute the preprocess state")
+ raise MisconfigurationException("train_df is required to instantiate the TabularPreprocess")
dfs = [train_df]
@@ -110,7 +62,9 @@ def generate_state(
dfs += [predict_df]
mean, std = _compute_normalization(dfs[0], num_cols)
- num_classes = len(dfs[0][target_col].unique())
+ classes = dfs[0][target_col].unique()
+ self.set_state(ClassificationState(classes))
+ num_classes = len(classes)
if dfs[0][target_col].dtype == object:
# if the target_col is a category, not an int
target_codes = _generate_codes(dfs, [target_col])
@@ -118,17 +72,15 @@ def generate_state(
target_codes = None
codes = _generate_codes(dfs, cat_cols)
- return TabularState(
- cat_cols,
- num_cols,
- target_col,
- mean,
- std,
- codes,
- target_codes,
- num_classes,
- is_regression,
- )
+ self.cat_cols = cat_cols
+ self.num_cols = num_cols
+ self.target_col = target_col
+ self.mean = mean
+ self.std = std
+ self.codes = codes
+ self.target_codes = target_codes
+ self.num_classes = num_classes
+ self.is_regression = is_regression
def common_load_data(self, df: DataFrame, dataset: AutoDataset):
# impute_data
@@ -162,29 +114,41 @@ class TabularData(DataModule):
preprocess_cls = TabularPreprocess
- @property
- def preprocess_state(self) -> PreprocessState:
- return self._preprocess.state
+ def __init__(
+ self,
+ train_dataset: Optional[Dataset] = None,
+ val_dataset: Optional[Dataset] = None,
+ test_dataset: Optional[Dataset] = None,
+ predict_dataset: Optional[Dataset] = None,
+ batch_size: int = 1,
+ num_workers: Optional[int] = 0,
+ ) -> None:
+ super().__init__(
+ train_dataset,
+ val_dataset,
+ test_dataset,
+ predict_dataset,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ )
- @preprocess_state.setter
- def preprocess_state(self, preprocess_state):
- self._preprocess = self.preprocess_cls.from_state(preprocess_state)
+ self._preprocess: Optional[Preprocess] = None
@property
def codes(self) -> Dict[str, str]:
- return self.preprocess_state.codes
+ return self._preprocess.codes
@property
def num_classes(self) -> int:
- return self.preprocess_state.num_classes
+ return self._preprocess.num_classes
@property
def cat_cols(self) -> Optional[List[str]]:
- return self.preprocess_state.cat_cols
+ return self._preprocess.cat_cols
@property
def num_cols(self) -> Optional[List[str]]:
- return self.preprocess_state.num_cols
+ return self._preprocess.num_cols
@property
def num_features(self) -> int:
@@ -204,8 +168,7 @@ def from_csv(
num_workers: Optional[int] = None,
val_size: Optional[float] = None,
test_size: Optional[float] = None,
- preprocess_cls: Optional[Type[Preprocess]] = None,
- preprocess_state: Optional[TabularState] = None,
+ preprocess: Optional[Preprocess] = None,
**pandas_kwargs,
):
"""Creates a TextClassificationData object from pandas DataFrames.
@@ -223,8 +186,7 @@ def from_csv(
or 0 for Darwin platform.
val_size: Float between 0 and 1 to create a validation dataset from train dataset.
test_size: Float between 0 and 1 to create a test dataset from train validation.
- preprocess_cls: Preprocess class to be used within this DataModule DataPipeline.
- preprocess_state: Used to store the train statistics.
+ preprocess: Preprocess to be used within this DataModule DataPipeline.
Returns:
TabularData: The constructed data module.
@@ -250,8 +212,7 @@ def from_csv(
num_workers,
val_size,
test_size,
- preprocess_state=preprocess_state,
- preprocess_cls=preprocess_cls,
+ preprocess=preprocess,
)
@property
@@ -306,8 +267,7 @@ def from_df(
val_size: float = None,
test_size: float = None,
is_regression: bool = False,
- preprocess_state: Optional[TabularState] = None,
- preprocess_cls: Optional[Type[Preprocess]] = None,
+ preprocess: Optional[Preprocess] = None,
):
"""Creates a TabularData object from pandas DataFrames.
@@ -324,6 +284,7 @@ def from_df(
or 0 for Darwin platform.
val_size: Float between 0 and 1 to create a validation dataset from train dataset.
test_size: Float between 0 and 1 to create a test dataset from train validation.
+ preprocess: Preprocess to be used within this DataModule DataPipeline.
Returns:
TabularData: The constructed data module.
@@ -336,9 +297,7 @@ def from_df(
train_df, val_df, test_df = cls._split_dataframe(train_df, val_df, test_df, val_size, test_size)
- preprocess_cls = preprocess_cls or cls.preprocess_cls
-
- preprocess_state = preprocess_cls.generate_state(
+ preprocess = preprocess or cls.preprocess_cls(
train_df,
val_df,
test_df,
@@ -347,9 +306,7 @@ def from_df(
numerical_cols,
categorical_cols,
is_regression,
- preprocess_state=preprocess_state
)
- preprocess: Preprocess = preprocess_cls.from_state(preprocess_state)
return cls.from_load_data_inputs(
train_load_data_input=train_df,
diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py
index b1ac5b8990..b83f2a5194 100644
--- a/flash/text/classification/data.py
+++ b/flash/text/classification/data.py
@@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
-from dataclasses import dataclass
from functools import partial
-from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, Union
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from datasets import DatasetDict, load_dataset
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -22,33 +21,29 @@
from transformers import AutoTokenizer, default_data_collator
from transformers.modeling_outputs import SequenceClassifierOutput
-from flash.core.classification import ClassificationPostprocess
+from flash.core.classification import ClassificationState
from flash.data.auto_dataset import AutoDataset
from flash.data.data_module import DataModule
-from flash.data.process import Preprocess, PreprocessState
-
-
-@dataclass(unsafe_hash=True, frozen=True)
-class TextClassificationState(PreprocessState):
- label_to_class_mapping: Dict[str, int]
+from flash.data.process import Postprocess, Preprocess
class TextClassificationPreprocess(Preprocess):
def __init__(
self,
- tokenizer: AutoTokenizer,
input: str,
+ backbone: str,
max_length: int,
target: str,
filetype: str,
- label_to_class_mapping: Dict[str, int],
+ train_file: Optional[str],
+ label_to_class_mapping: Optional[Dict[str, int]],
):
"""
This class contains the preprocessing logic for text classification
Args:
- tokenizer: Hugging Face Tokenizer.
+ # tokenizer: Hugging Face Tokenizer. # TODO: Add back a tokenizer argument and make backbone optional?
input: The field storing the text to be classified.
max_length: Maximum number of tokens within a single sentence.
target: The field storing the class id of the associated text.
@@ -61,7 +56,16 @@ def __init__(
"""
super().__init__()
- self.tokenizer = tokenizer
+
+ if label_to_class_mapping is None:
+ if train_file is not None:
+ label_to_class_mapping = self.get_label_to_class_mapping(train_file, target, filetype)
+ else:
+ raise MisconfigurationException(
+ "Either ``label_to_class_mapping`` or ``train_file`` needs to be provided"
+ )
+
+ self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True)
self.input = input
self.filetype = filetype
self.max_length = max_length
@@ -77,9 +81,10 @@ def __init__(
padding="max_length"
)
- @property
- def state(self):
- return TextClassificationState(self.label_to_class_mapping)
+ class_to_label_mapping = ['CLASS_UNKNOWN'] * (max(self.label_to_class_mapping.values()) + 1)
+ for label, cls in self.label_to_class_mapping.items():
+ class_to_label_mapping[cls] = label
+ self.set_state(ClassificationState(class_to_label_mapping))
def per_batch_transform(self, batch: Any) -> Any:
if "labels" not in batch:
@@ -112,11 +117,11 @@ def _transform_label(self, ex: Dict[str, str]):
return ex
@staticmethod
- def generate_state(file: str, target: str, filetype: str) -> TextClassificationState:
+ def get_label_to_class_mapping(file: str, target: str, filetype: str) -> Dict[str, int]:
data_files = {'train': file}
dataset_dict = load_dataset(filetype, data_files=data_files)
label_to_class_mapping = {v: k for k, v in enumerate(list(sorted(list(set(dataset_dict['train'][target])))))}
- return TextClassificationState(label_to_class_mapping)
+ return label_to_class_mapping
def load_data(
self,
@@ -172,7 +177,7 @@ def predict_load_data(self, sample: Any, dataset: AutoDataset):
raise MisconfigurationException("Currently, we support only list of sentences")
-class TextClassificationPostProcess(ClassificationPostprocess):
+class TextClassificationPostProcess(Postprocess):
def per_batch_transform(self, batch: Any) -> Any:
if isinstance(batch, SequenceClassifierOutput):
@@ -182,52 +187,13 @@ def per_batch_transform(self, batch: Any) -> Any:
class TextClassificationData(DataModule):
"""Data Module for text classification tasks"""
+
preprocess_cls = TextClassificationPreprocess
postprocess_cls = TextClassificationPostProcess
- target: Optional[str] = None
-
- @property
- def preprocess_state(self) -> TextClassificationState:
- return self._preprocess.state
@property
def num_classes(self) -> int:
- return len(self.preprocess_state.label_to_class_mapping)
-
- @classmethod
- def instantiate_preprocess(
- cls,
- train_file: Optional[str],
- input: str,
- target: str,
- filetype: str,
- backbone: str,
- max_length: int,
- label_to_class_mapping: Optional[dict] = None,
- preprocess_state: Optional[TextClassificationState] = None,
- preprocess_cls: Optional[Type[Preprocess]] = None,
- ):
- if label_to_class_mapping is None:
- preprocess_cls = preprocess_cls or cls.preprocess_cls
- if train_file is not None:
- preprocess_state = preprocess_cls.generate_state(train_file, target, filetype)
- else:
- if preprocess_state is None:
- raise MisconfigurationException(
- "Either ``preprocess_state`` or ``train_file`` needs to be provided"
- )
- label_to_class_mapping = preprocess_state.label_to_class_mapping
-
- preprocess_cls = preprocess_cls or cls.preprocess_cls
-
- return preprocess_cls(
- AutoTokenizer.from_pretrained(backbone, use_fast=True),
- input,
- max_length,
- target,
- filetype,
- label_to_class_mapping,
- )
+ return len(self._preprocess.label_to_class_mapping)
@classmethod
def from_files(
@@ -244,8 +210,8 @@ def from_files(
label_to_class_mapping: Optional[dict] = None,
batch_size: int = 16,
num_workers: Optional[int] = None,
- preprocess_state: Optional[TextClassificationState] = None,
- preprocess_cls: Optional[Type[Preprocess]] = None,
+ preprocess: Optional[Preprocess] = None,
+ postprocess: Optional[Postprocess] = None,
) -> 'TextClassificationData':
"""Creates a TextClassificationData object from files.
@@ -273,18 +239,18 @@ def from_files(
cat_cols=["account_type"])
"""
- preprocess = cls.instantiate_preprocess(
- train_file,
+ preprocess = preprocess or cls.preprocess_cls(
input,
- target,
- filetype,
backbone,
max_length,
+ target,
+ filetype,
+ train_file,
label_to_class_mapping,
- preprocess_state,
- preprocess_cls,
)
+ postprocess = postprocess or cls.postprocess_cls()
+
return cls.from_load_data_inputs(
train_load_data_input=train_file,
val_load_data_input=val_file,
@@ -292,7 +258,8 @@ def from_files(
predict_load_data_input=predict_file,
batch_size=batch_size,
num_workers=num_workers,
- preprocess=preprocess
+ preprocess=preprocess,
+ postprocess=postprocess,
)
@classmethod
@@ -303,10 +270,11 @@ def from_file(
backbone="bert-base-cased",
filetype="csv",
max_length: int = 128,
- preprocess_state: Optional[TextClassificationState] = None,
label_to_class_mapping: Optional[dict] = None,
batch_size: int = 16,
num_workers: Optional[int] = None,
+ preprocess: Optional[Preprocess] = None,
+ postprocess: Optional[Postprocess] = None,
) -> 'TextClassificationData':
"""Creates a TextClassificationData object from files.
@@ -334,5 +302,6 @@ def from_file(
label_to_class_mapping=label_to_class_mapping,
batch_size=batch_size,
num_workers=num_workers,
- preprocess_state=preprocess_state,
+ preprocess=preprocess,
+ postprocess=postprocess,
)
diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py
index 7fa4e926c7..56d85e8ce5 100644
--- a/flash/text/classification/model.py
+++ b/flash/text/classification/model.py
@@ -13,7 +13,7 @@
# limitations under the License.
import os
import warnings
-from typing import Callable, Mapping, Sequence, Type, Union
+from typing import Callable, Mapping, Optional, Sequence, Type, Union
import torch
from torchmetrics import Accuracy
@@ -21,6 +21,7 @@
from transformers.modeling_outputs import SequenceClassifierOutput
from flash.core.classification import ClassificationTask
+from flash.data.process import Serializer
class TextClassifier(ClassificationTask):
@@ -41,6 +42,7 @@ def __init__(
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
metrics: Union[Callable, Mapping, Sequence, None] = [Accuracy()],
learning_rate: float = 1e-3,
+ serializer: Optional[Serializer] = None,
):
self.save_hyperparameters()
@@ -56,6 +58,7 @@ def __init__(
optimizer=optimizer,
metrics=metrics,
learning_rate=learning_rate,
+ serializer=serializer,
)
self.model = BertForSequenceClassification.from_pretrained(backbone, num_labels=num_classes)
diff --git a/flash/text/seq2seq/core/data.py b/flash/text/seq2seq/core/data.py
index 9757c79516..a0f76544b3 100644
--- a/flash/text/seq2seq/core/data.py
+++ b/flash/text/seq2seq/core/data.py
@@ -13,7 +13,7 @@
# limitations under the License.
import os
from functools import partial
-from typing import Any, Callable, Dict, List, Optional, Type, Union
+from typing import Any, Callable, Dict, List, Optional, Union
import datasets
import torch
@@ -23,7 +23,7 @@
from transformers import AutoTokenizer, default_data_collator
from flash.data.data_module import DataModule
-from flash.data.process import Preprocess
+from flash.data.process import Postprocess, Preprocess
class Seq2SeqPreprocess(Preprocess):
@@ -121,45 +121,6 @@ class Seq2SeqData(DataModule):
preprocess_cls = Seq2SeqPreprocess
- @classmethod
- def instantiate_preprocess(
- cls,
- tokenizer: AutoTokenizer,
- input: str,
- filetype: str,
- target: str,
- max_source_length: int,
- max_target_length: int,
- padding: int,
- preprocess_cls: Optional[Type[Preprocess]] = None
- ) -> Preprocess:
- """
- This function is used to instantiate the ``Seq2SeqPreprocess`` preprocess.
-
- Args:
- tokenizer: Path to training data.
- input: The field storing the source translation text.
- filetype: ``csv`` or ``json`` File
- target: The field storing the target translation text.
- backbone: Tokenizer backbone to use, can use any HuggingFace tokenizer.
- max_source_length: Maximum length of the source text. Any text longer will be truncated.
- max_target_length: Maximum length of the target text. Any text longer will be truncated.
- padding: Padding strategy for batches. Default is pad to maximum length.
- preprocess_cls: Preprocess cls
- """
-
- preprocess_cls = preprocess_cls or cls.preprocess_cls
-
- return preprocess_cls(
- tokenizer=tokenizer,
- input=input,
- filetype=filetype,
- target=target,
- max_source_length=max_source_length,
- max_target_length=max_target_length,
- padding=padding,
- )
-
@classmethod
def from_files(
cls,
@@ -176,7 +137,8 @@ def from_files(
padding: Union[str, bool] = 'max_length',
batch_size: int = 32,
num_workers: Optional[int] = None,
- preprocess_cls: Optional[Type[Preprocess]] = None,
+ preprocess: Optional[Preprocess] = None,
+ postprocess: Optional[Postprocess] = None,
):
"""Creates a Seq2SeqData object from files.
Args:
@@ -204,7 +166,7 @@ def from_files(
cat_cols=["account_type"])
"""
tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True)
- preprocess = cls.instantiate_preprocess(
+ preprocess = preprocess or cls.preprocess_cls(
tokenizer,
input,
filetype,
@@ -212,7 +174,6 @@ def from_files(
max_source_length,
max_target_length,
padding,
- preprocess_cls=preprocess_cls
)
return cls.from_load_data_inputs(
@@ -222,7 +183,8 @@ def from_files(
predict_load_data_input=predict_file,
batch_size=batch_size,
num_workers=num_workers,
- preprocess=preprocess
+ preprocess=preprocess,
+ postprocess=postprocess,
)
@classmethod
@@ -238,7 +200,8 @@ def from_file(
padding: Union[str, bool] = 'max_length',
batch_size: int = 32,
num_workers: Optional[int] = None,
- preprocess_cls: Optional[Type[Preprocess]] = None,
+ preprocess: Optional[Preprocess] = None,
+ postprocess: Optional[Postprocess] = None,
):
"""Creates a TextClassificationData object from files.
Args:
@@ -269,5 +232,6 @@ def from_file(
padding=padding,
batch_size=batch_size,
num_workers=num_workers,
- preprocess_cls=preprocess_cls,
+ preprocess=preprocess,
+ postprocess=postprocess,
)
diff --git a/flash/text/seq2seq/summarization/data.py b/flash/text/seq2seq/summarization/data.py
index b6ecfc05df..bcdd2a2ff6 100644
--- a/flash/text/seq2seq/summarization/data.py
+++ b/flash/text/seq2seq/summarization/data.py
@@ -39,13 +39,6 @@ class SummarizationData(Seq2SeqData):
preprocess_cls = Seq2SeqPreprocess
postprocess_cls = SummarizationPostprocess
- @classmethod
- def instantiate_postprocess(
- cls, tokenizer: AutoTokenizer, postprocess_cls: Optional[Type[Postprocess]] = None
- ) -> Postprocess:
- postprocess_cls = postprocess_cls or cls.postprocess_cls
- return postprocess_cls(tokenizer)
-
@classmethod
def from_files(
cls,
@@ -62,8 +55,8 @@ def from_files(
padding: Union[str, bool] = 'max_length',
batch_size: int = 16,
num_workers: Optional[int] = None,
- preprocess_cls: Optional[Type[Preprocess]] = None,
- postprocess_cls: Optional[Type[Postprocess]] = None,
+ preprocess: Optional[Preprocess] = None,
+ postprocess: Optional[Postprocess] = None,
):
"""Creates a SummarizationData object from files.
@@ -95,7 +88,8 @@ def from_files(
"""
tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True)
- preprocess = cls.instantiate_preprocess(
+
+ preprocess = preprocess or cls.preprocess_cls(
tokenizer,
input,
filetype,
@@ -103,10 +97,9 @@ def from_files(
max_source_length,
max_target_length,
padding,
- preprocess_cls=preprocess_cls
)
- postprocess = cls.instantiate_postprocess(tokenizer, postprocess_cls=postprocess_cls)
+ postprocess = postprocess or cls.postprocess_cls(tokenizer)
return cls.from_load_data_inputs(
train_load_data_input=train_file,
@@ -132,6 +125,8 @@ def from_file(
padding: Union[str, bool] = 'longest',
batch_size: int = 16,
num_workers: Optional[int] = None,
+ preprocess: Optional[Preprocess] = None,
+ postprocess: Optional[Postprocess] = None,
):
"""Creates a SummarizationData object from files.
@@ -163,5 +158,7 @@ def from_file(
max_target_length=max_target_length,
padding=padding,
batch_size=batch_size,
- num_workers=num_workers
+ num_workers=num_workers,
+ preprocess=preprocess,
+ postprocess=postprocess,
)
diff --git a/flash/text/seq2seq/translation/data.py b/flash/text/seq2seq/translation/data.py
index 46bd9ebe8e..940bae7af8 100644
--- a/flash/text/seq2seq/translation/data.py
+++ b/flash/text/seq2seq/translation/data.py
@@ -11,9 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional, Type, Union
+from typing import Optional, Union
-from flash.data.process import Preprocess
+from flash.data.process import Postprocess, Preprocess
from flash.text.seq2seq.core.data import Seq2SeqData
@@ -36,7 +36,8 @@ def from_files(
padding: Union[str, bool] = 'max_length',
batch_size: int = 8,
num_workers: Optional[int] = None,
- preprocess_cls: Optional[Type[Preprocess]] = None
+ preprocess: Optional[Preprocess] = None,
+ postprocess: Optional[Postprocess] = None,
):
"""Creates a TranslateData object from files.
@@ -82,7 +83,8 @@ def from_files(
padding=padding,
batch_size=batch_size,
num_workers=num_workers,
- preprocess_cls=preprocess_cls
+ preprocess=preprocess,
+ postprocess=postprocess,
)
@classmethod
@@ -98,6 +100,8 @@ def from_file(
padding: Union[str, bool] = 'longest',
batch_size: int = 8,
num_workers: Optional[int] = None,
+ preprocess: Optional[Preprocess] = None,
+ postprocess: Optional[Postprocess] = None,
):
"""Creates a TranslationData object from files.
@@ -129,5 +133,7 @@ def from_file(
max_target_length=max_target_length,
padding=padding,
batch_size=batch_size,
- num_workers=num_workers
+ num_workers=num_workers,
+ preprocess=preprocess,
+ postprocess=postprocess,
)
diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py
index db912b4aa6..fc017f0350 100644
--- a/flash/vision/classification/data.py
+++ b/flash/vision/classification/data.py
@@ -13,20 +13,21 @@
# limitations under the License.
import os
import pathlib
-from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union
+from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import torch
import torchvision
from PIL import Image
-from pytorch_lightning.utilities.exceptions import MisconfigurationException
+from pytorch_lightning.trainer.states import RunningStage
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data._utils.collate import default_collate
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, make_dataset
+from flash.core.classification import ClassificationState
from flash.data.auto_dataset import AutoDataset
+from flash.data.callback import BaseDataFetcher
from flash.data.data_module import DataModule
-from flash.data.data_pipeline import DataPipeline
from flash.data.process import Preprocess
from flash.utils.imports import _KORNIA_AVAILABLE
@@ -40,6 +41,19 @@
class ImageClassificationPreprocess(Preprocess):
to_tensor = torchvision.transforms.ToTensor()
+ image_size = (196, 196)
+
+ def __init__(
+ self,
+ train_transform: Optional[Union[Dict[str, Callable]]] = None,
+ val_transform: Optional[Union[Dict[str, Callable]]] = None,
+ test_transform: Optional[Union[Dict[str, Callable]]] = None,
+ predict_transform: Optional[Union[Dict[str, Callable]]] = None,
+ ):
+ train_transform, val_transform, test_transform, predict_transform = self._resolve_transforms(
+ train_transform, val_transform, test_transform, predict_transform
+ )
+ super().__init__(train_transform, val_transform, test_transform, predict_transform)
@staticmethod
def _find_classes(dir: str) -> Tuple:
@@ -73,8 +87,77 @@ def _get_predicting_files(samples: Union[Sequence, str]) -> List[str]:
return files
+ def default_train_transforms(self):
+ image_size = self.image_size
+ if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1":
+ # Better approach as all transforms are applied on tensor directly
+ return {
+ "to_tensor_transform": torchvision.transforms.ToTensor(),
+ "post_tensor_transform": nn.Sequential(K.RandomResizedCrop(image_size), K.RandomHorizontalFlip()),
+ "per_batch_transform_on_device": nn.Sequential(
+ K.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])),
+ )
+ }
+ else:
+ from torchvision import transforms as T # noqa F811
+ return {
+ "pre_tensor_transform": nn.Sequential(T.RandomResizedCrop(image_size), T.RandomHorizontalFlip()),
+ "to_tensor_transform": torchvision.transforms.ToTensor(),
+ "post_tensor_transform": T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
+ }
+
+ def default_val_transforms(self):
+ image_size = self.image_size
+ if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1":
+ # Better approach as all transforms are applied on tensor directly
+ return {
+ "to_tensor_transform": torchvision.transforms.ToTensor(),
+ "post_tensor_transform": nn.Sequential(K.RandomResizedCrop(image_size)),
+ "per_batch_transform_on_device": nn.Sequential(
+ K.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])),
+ )
+ }
+ else:
+ from torchvision import transforms as T # noqa F811
+ return {
+ "pre_tensor_transform": T.Compose([T.RandomResizedCrop(image_size)]),
+ "to_tensor_transform": torchvision.transforms.ToTensor(),
+ "post_tensor_transform": T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
+ }
+
+ def _resolve_transforms(
+ self,
+ train_transform: Optional[Union[str, Dict]] = 'default',
+ val_transform: Optional[Union[str, Dict]] = 'default',
+ test_transform: Optional[Union[str, Dict]] = 'default',
+ predict_transform: Optional[Union[str, Dict]] = 'default',
+ ):
+
+ if not train_transform or train_transform == 'default':
+ train_transform = self.default_train_transforms()
+
+ if not val_transform or val_transform == 'default':
+ val_transform = self.default_val_transforms()
+
+ if not test_transform or test_transform == 'default':
+ test_transform = self.default_val_transforms()
+
+ if not predict_transform or predict_transform == 'default':
+ predict_transform = self.default_val_transforms()
+
+ return (
+ self._check_transforms(train_transform, RunningStage.TRAINING),
+ self._check_transforms(val_transform, RunningStage.VALIDATING),
+ self._check_transforms(test_transform, RunningStage.TESTING),
+ self._check_transforms(predict_transform, RunningStage.PREDICTING),
+ )
+
@classmethod
- def _load_data_dir(cls, data: Any, dataset: Optional[AutoDataset] = None) -> List[str]:
+ def _load_data_dir(
+ cls,
+ data: Any,
+ dataset: Optional[AutoDataset] = None,
+ ) -> Tuple[Optional[List[str]], List[Tuple[str, int]]]:
if isinstance(data, list):
dataset.num_classes = len(data)
out = []
@@ -85,11 +168,11 @@ def _load_data_dir(cls, data: Any, dataset: Optional[AutoDataset] = None) -> Lis
out.append([os.path.join(p, f), label])
elif os.path.isfile(p) and has_file_allowed_extension(p, IMG_EXTENSIONS):
out.append([p, label])
- return out
+ return None, out
else:
classes, class_to_idx = cls._find_classes(data)
dataset.num_classes = len(classes)
- return make_dataset(data, class_to_idx, IMG_EXTENSIONS, None)
+ return classes, make_dataset(data, class_to_idx, IMG_EXTENSIONS, None)
@classmethod
def _load_data_files_labels(cls, data: Any, dataset: Optional[AutoDataset] = None) -> Any:
@@ -103,11 +186,13 @@ def _load_data_files_labels(cls, data: Any, dataset: Optional[AutoDataset] = Non
return data
- @classmethod
- def load_data(cls, data: Any, dataset: Optional[AutoDataset] = None) -> Iterable:
+ def load_data(self, data: Any, dataset: Optional[AutoDataset] = None) -> Iterable:
if isinstance(data, (str, pathlib.Path, list)):
- return cls._load_data_dir(data=data, dataset=dataset)
- return cls._load_data_files_labels(data=data, dataset=dataset)
+ classes, data = self._load_data_dir(data=data, dataset=dataset)
+ state = ClassificationState(classes)
+ self.set_state(state)
+ return data
+ return self._load_data_files_labels(data=data, dataset=dataset)
@staticmethod
def load_sample(sample) -> Union[Image.Image, torch.Tensor, Tuple[Image.Image, torch.Tensor]]:
@@ -188,9 +273,6 @@ def per_batch_transform_on_device(self, sample: Any) -> Any:
class ImageClassificationData(DataModule):
"""Data module for image classification tasks."""
- preprocess_cls = ImageClassificationPreprocess
- image_size = (196, 196)
-
def __init__(
self,
train_dataset: Optional[Dataset] = None,
@@ -204,7 +286,7 @@ def __init__(
val_split: Optional[Union[float, int]] = None,
test_split: Optional[Union[float, int]] = None,
**kwargs,
- ) -> 'ImageClassificationData':
+ ) -> None:
"""Creates a ImageClassificationData object from lists of image filepaths and labels"""
if train_dataset is not None and train_split is not None or val_split is not None or test_split is not None:
@@ -236,60 +318,6 @@ def __init__(
if self._predict_ds:
self.set_dataset_attribute(self._predict_ds, 'num_classes', self.num_classes)
- @staticmethod
- def _check_transforms(transform: Dict[str, Union[nn.Module, Callable]]) -> Dict[str, Union[nn.Module, Callable]]:
- if transform and not isinstance(transform, Dict):
- raise MisconfigurationException(
- "Transform should be a dict. "
- f"Here are the available keys for your transforms: {DataPipeline.PREPROCESS_FUNCS}."
- )
- if "per_batch_transform" in transform and "per_sample_transform_on_device" in transform:
- raise MisconfigurationException(
- f'{transform}: `per_batch_transform` and `per_sample_transform_on_device` '
- f'are mutually exclusive.'
- )
- return transform
-
- @staticmethod
- def default_train_transforms():
- image_size = ImageClassificationData.image_size
- if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1":
- # Better approach as all transforms are applied on tensor directly
- return {
- "to_tensor_transform": torchvision.transforms.ToTensor(),
- "post_tensor_transform": nn.Sequential(K.RandomResizedCrop(image_size), K.RandomHorizontalFlip()),
- "per_batch_transform_on_device": nn.Sequential(
- K.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])),
- )
- }
- else:
- from torchvision import transforms as T # noqa F811
- return {
- "pre_tensor_transform": nn.Sequential(T.RandomResizedCrop(image_size), T.RandomHorizontalFlip()),
- "to_tensor_transform": torchvision.transforms.ToTensor(),
- "post_tensor_transform": T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
- }
-
- @staticmethod
- def default_val_transforms():
- image_size = ImageClassificationData.image_size
- if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1":
- # Better approach as all transforms are applied on tensor directly
- return {
- "to_tensor_transform": torchvision.transforms.ToTensor(),
- "post_tensor_transform": nn.Sequential(K.RandomResizedCrop(image_size)),
- "per_batch_transform_on_device": nn.Sequential(
- K.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])),
- )
- }
- else:
- from torchvision import transforms as T # noqa F811
- return {
- "pre_tensor_transform": T.Compose([T.RandomResizedCrop(image_size)]),
- "to_tensor_transform": torchvision.transforms.ToTensor(),
- "post_tensor_transform": T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
- }
-
@property
def num_classes(self) -> int:
if self._num_classes is None:
@@ -305,72 +333,6 @@ def _get_num_classes(self, dataset: torch.utils.data.Dataset):
return num_classes
- @classmethod
- def instantiate_preprocess(
- cls,
- train_transform: Dict[str, Union[nn.Module, Callable]],
- val_transform: Dict[str, Union[nn.Module, Callable]],
- test_transform: Dict[str, Union[nn.Module, Callable]],
- predict_transform: Dict[str, Union[nn.Module, Callable]],
- preprocess_cls: Type[Preprocess] = None,
- ) -> Preprocess:
- """
- This function is used to instantiate ImageClassificationData preprocess object.
-
- Args:
- train_transform: Train transforms for images.
- val_transform: Validation transforms for images.
- test_transform: Test transforms for images.
- predict_transform: Predict transforms for images.
- preprocess_cls: User provided preprocess_cls.
-
- Example::
-
- train_transform = {
- "per_sample_transform": T.Compose([
- T.RandomResizedCrop(224),
- T.RandomHorizontalFlip(),
- T.ToTensor(),
- T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
- ]),
- "per_batch_transform_on_device": nn.Sequential(K.RandomAffine(360), K.ColorJitter(0.2, 0.3, 0.2, 0.3))
- }
-
- """
- train_transform, val_transform, test_transform, predict_transform = cls._resolve_transforms(
- train_transform, val_transform, test_transform, predict_transform
- )
-
- preprocess_cls = preprocess_cls or cls.preprocess_cls
- preprocess: Preprocess = preprocess_cls(train_transform, val_transform, test_transform, predict_transform)
- return preprocess
-
- @classmethod
- def _resolve_transforms(
- cls,
- train_transform: Optional[Union[str, Dict]] = 'default',
- val_transform: Optional[Union[str, Dict]] = 'default',
- test_transform: Optional[Union[str, Dict]] = 'default',
- predict_transform: Optional[Union[str, Dict]] = 'default',
- ):
-
- if not train_transform or train_transform == 'default':
- train_transform = cls.default_train_transforms()
-
- if not val_transform or val_transform == 'default':
- val_transform = cls.default_val_transforms()
-
- if not test_transform or test_transform == 'default':
- test_transform = cls.default_val_transforms()
-
- if not predict_transform or predict_transform == 'default':
- predict_transform = cls.default_val_transforms()
-
- return (
- cls._check_transforms(train_transform), cls._check_transforms(val_transform),
- cls._check_transforms(test_transform), cls._check_transforms(predict_transform)
- )
-
@classmethod
def from_folders(
cls,
@@ -384,7 +346,8 @@ def from_folders(
predict_transform: Optional[Union[str, Dict]] = 'default',
batch_size: int = 4,
num_workers: Optional[int] = None,
- preprocess_cls: Optional[Type[Preprocess]] = None,
+ data_fetcher: BaseDataFetcher = None,
+ preprocess: Optional[Preprocess] = None,
**kwargs,
) -> 'DataModule':
"""
@@ -418,12 +381,11 @@ def from_folders(
>>> img_data = ImageClassificationData.from_folders("train/") # doctest: +SKIP
"""
- preprocess = cls.instantiate_preprocess(
+ preprocess = preprocess or ImageClassificationPreprocess(
train_transform,
val_transform,
test_transform,
predict_transform,
- preprocess_cls=preprocess_cls,
)
return cls.from_load_data_inputs(
@@ -433,6 +395,7 @@ def from_folders(
predict_load_data_input=predict_folder,
batch_size=batch_size,
num_workers=num_workers,
+ data_fetcher=data_fetcher,
preprocess=preprocess,
**kwargs,
)
@@ -454,7 +417,8 @@ def from_filepaths(
batch_size: int = 64,
num_workers: Optional[int] = None,
seed: Optional[int] = 42,
- preprocess_cls: Optional[Type[Preprocess]] = None,
+ data_fetcher: BaseDataFetcher = None,
+ preprocess: Optional[Preprocess] = None,
**kwargs,
) -> 'ImageClassificationData':
"""
@@ -475,10 +439,14 @@ def from_filepaths(
val_labels: Sequence of labels for validation dataset. Defaults to ``None``.
test_filepaths: String or sequence of file paths for test dataset. Defaults to ``None``.
test_labels: Sequence of labels for test dataset. Defaults to ``None``.
- train_transform: Transforms for training dataset. Defaults to ``default``,
- which loads imagenet transforms.
- val_transform: Transforms for validation and testing dataset.
- Defaults to ``default``, which loads imagenet transforms.
+ train_transform: Image transform to use for the train set. Defaults to ``default``, which loads imagenet
+ transforms.
+ val_transform: Image transform to use for the validation set. Defaults to ``default``, which loads
+ imagenet transforms.
+ test_transform: Image transform to use for the test set. Defaults to ``default``, which loads imagenet
+ transforms.
+ predict_transform: Image transform to use for the predict set. Defaults to ``default``, which loads imagenet
+ transforms.
batch_size: The batchsize to use for parallel loading. Defaults to ``64``.
num_workers: The number of workers to use for parallelized loading.
Defaults to ``None`` which equals the number of available CPU threads.
@@ -507,12 +475,11 @@ def from_filepaths(
else:
test_filepaths = [test_filepaths]
- preprocess = cls.instantiate_preprocess(
+ preprocess = preprocess or ImageClassificationPreprocess(
train_transform,
val_transform,
test_transform,
predict_transform,
- preprocess_cls=preprocess_cls,
)
return cls.from_load_data_inputs(
@@ -522,6 +489,7 @@ def from_filepaths(
predict_load_data_input=predict_filepaths,
batch_size=batch_size,
num_workers=num_workers,
+ data_fetcher=data_fetcher,
preprocess=preprocess,
seed=seed,
**kwargs
diff --git a/flash/vision/classification/model.py b/flash/vision/classification/model.py
index 5b6d9dca30..90cfc5ae15 100644
--- a/flash/vision/classification/model.py
+++ b/flash/vision/classification/model.py
@@ -12,16 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from types import FunctionType
-from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, Union
+from typing import Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, Union
import torch
from torch import nn
from torch.nn import functional as F
from torchmetrics import Accuracy
-from flash.core.classification import ClassificationTask
+from flash.core.classification import Classes, ClassificationTask
from flash.core.registry import FlashRegistry
+from flash.data.process import Preprocess, Serializer
from flash.vision.backbones import IMAGE_CLASSIFIER_BACKBONES
+from flash.vision.classification.data import ImageClassificationPreprocess
def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
@@ -79,6 +81,7 @@ def __init__(
metrics: Optional[Union[Callable, Mapping, Sequence, None]] = None,
learning_rate: float = 1e-3,
multi_label: bool = False,
+ serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
):
if metrics is None:
@@ -93,7 +96,7 @@ def __init__(
optimizer=optimizer,
metrics=metrics,
learning_rate=learning_rate,
- postprocess=self.postprocess_cls(multi_label)
+ serializer=serializer or Classes(multi_label=multi_label),
)
self.save_hyperparameters()
diff --git a/flash/vision/detection/data.py b/flash/vision/detection/data.py
index d08ac6cdef..6a56ff0093 100644
--- a/flash/vision/detection/data.py
+++ b/flash/vision/detection/data.py
@@ -180,19 +180,6 @@ class ObjectDetectionData(DataModule):
preprocess_cls = ObjectDetectionPreprocess
- @classmethod
- def instantiate_preprocess(
- cls,
- train_transform: Optional[Dict[str, Module]] = None,
- val_transform: Optional[Dict[str, Module]] = None,
- test_transform: Optional[Dict[str, Module]] = None,
- predict_transform: Optional[Dict[str, Module]] = None,
- preprocess_cls: Type[Preprocess] = None,
- ) -> Preprocess:
-
- preprocess_cls = preprocess_cls or cls.preprocess_cls
- return preprocess_cls(train_transform, val_transform, test_transform, predict_transform)
-
@classmethod
def from_coco(
cls,
@@ -208,12 +195,14 @@ def from_coco(
predict_transform: Optional[Dict[str, Module]] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
- preprocess_cls: Type[Preprocess] = None,
+ preprocess: Preprocess = None,
**kwargs
):
-
- preprocess = cls.instantiate_preprocess(
- train_transform, val_transform, predict_transform, predict_transform, preprocess_cls=preprocess_cls
+ preprocess = preprocess or cls.preprocess_cls(
+ train_transform,
+ val_transform,
+ test_transform,
+ predict_transform,
)
return cls.from_load_data_inputs(
diff --git a/flash/vision/embedding/model.py b/flash/vision/embedding/model.py
index f9fc3d85e2..f43dabcfaa 100644
--- a/flash/vision/embedding/model.py
+++ b/flash/vision/embedding/model.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Callable, Mapping, Optional, Sequence, Type, Union
+from typing import Callable, Mapping, Optional, Sequence, Type, Union
import torch
from pytorch_lightning.utilities.distributed import rank_zero_warn
@@ -66,9 +66,7 @@ def __init__(
optimizer=optimizer,
metrics=metrics,
learning_rate=learning_rate,
- preprocess=ImageClassificationPreprocess(
- predict_transform=ImageClassificationData.default_val_transforms(),
- )
+ preprocess=ImageClassificationPreprocess()
)
self.save_hyperparameters()
diff --git a/flash_examples/finetuning/image_classification.py b/flash_examples/finetuning/image_classification.py
index de19d76d05..4a93ec1785 100644
--- a/flash_examples/finetuning/image_classification.py
+++ b/flash_examples/finetuning/image_classification.py
@@ -16,6 +16,7 @@
import flash
from flash import Trainer
+from flash.core.classification import Labels
from flash.core.finetuning import FreezeUnfreeze
from flash.data.utils import download_data
from flash.vision import ImageClassificationData, ImageClassifier
@@ -56,6 +57,10 @@ def fn_resnet(pretrained: bool = True):
trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1))
# 7a. Predict what's on a few images! ants or bees?
+
+# Serialize predictions as lables, automatically inferred from the training data in part 2.
+model.serializer = Labels()
+
predictions = model.predict([
"data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
"data/hymenoptera_data/val/bees/590318879_68cf112861.jpg",
@@ -64,7 +69,10 @@ def fn_resnet(pretrained: bool = True):
print(predictions)
-datamodule = ImageClassificationData.from_folders(predict_folder="data/hymenoptera_data/predict/")
+datamodule = ImageClassificationData.from_folders(
+ predict_folder="data/hymenoptera_data/predict/",
+ preprocess=model.preprocess,
+)
# 7b. Or generate predictions with a whole folder!
predictions = Trainer().predict(model, datamodule=datamodule)
diff --git a/flash_examples/finetuning/image_classification_multi_label.py b/flash_examples/finetuning/image_classification_multi_label.py
new file mode 100644
index 0000000000..8bc558c04f
--- /dev/null
+++ b/flash_examples/finetuning/image_classification_multi_label.py
@@ -0,0 +1,105 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+from typing import List, Tuple
+
+import pandas as pd
+import torch
+
+import flash
+from flash.core.classification import Labels
+from flash.core.finetuning import FreezeUnfreeze
+from flash.data.utils import download_data
+from flash.vision import ImageClassificationData, ImageClassifier
+from flash.vision.classification.data import ImageClassificationPreprocess
+
+# 1. Download the data
+# This is a subset of the movie poster genre prediction data set from the paper
+# “Movie Genre Classification based on Poster Images with Deep Neural Networks” by Wei-Ta Chu and Hung-Jui Guo.
+# Please consider citing their paper if you use it. More here: https://www.cs.ccu.edu.tw/~wtchu/projects/MoviePoster/
+download_data("https://pl-flash-data.s3.amazonaws.com/movie_posters.zip", "data/")
+
+# 2. Load the data
+genres = [
+ "Action", "Adventure", "Animation", "Biography", "Comedy", "Crime", "Documentary", "Drama", "Family", "Fantasy",
+ "History", "Horror", "Music", "Musical", "Mystery", "N/A", "News", "Reality-TV", "Romance", "Sci-Fi", "Short",
+ "Sport", "Thriller", "War", "Western"
+]
+
+
+def load_data(data: str, root: str = 'data/movie_posters') -> Tuple[List[str], List[List[int]]]:
+ metadata = pd.read_csv(os.path.join(root, data, "metadata.csv"))
+
+ images = []
+ labels = []
+ for _, row in metadata.iterrows():
+ images.append(os.path.join(root, data, row['Id'] + ".jpg"))
+ labels.append([int(row[genre]) for genre in genres])
+
+ return images, labels
+
+
+ImageClassificationPreprocess.image_size = (128, 128)
+
+train_filepaths, train_labels = load_data('train')
+val_filepaths, val_labels = load_data('val')
+test_filepaths, test_labels = load_data('test')
+
+datamodule = ImageClassificationData.from_filepaths(
+ train_filepaths=train_filepaths,
+ train_labels=train_labels,
+ val_filepaths=val_filepaths,
+ val_labels=val_labels,
+ test_filepaths=test_filepaths,
+ test_labels=test_labels,
+ preprocess=ImageClassificationPreprocess(),
+)
+
+# 3. Build the model
+model = ImageClassifier(
+ backbone="resnet18",
+ num_classes=len(genres),
+ multi_label=True,
+)
+
+# 4. Create the trainer.
+trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1)
+
+# 5. Train the model
+trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1))
+
+# 6a. Predict what's on a few images!
+
+# Serialize predictions as labels.
+model.serializer = Labels(genres, multi_label=True)
+
+predictions = model.predict([
+ "data/movie_posters/val/tt0361500.jpg",
+ "data/movie_posters/val/tt0361748.jpg",
+ "data/movie_posters/val/tt0362478.jpg",
+])
+
+print(predictions)
+
+datamodule = ImageClassificationData.from_folders(
+ predict_folder="data/movie_posters/predict/",
+ preprocess=model.preprocess,
+)
+
+# 6b. Or generate predictions with a whole folder!
+predictions = trainer.predict(model, datamodule=datamodule)
+print(predictions)
+
+# 7. Save it!
+trainer.save_checkpoint("image_classification_multi_label_model.pt")
diff --git a/flash_examples/finetuning/text_classification.py b/flash_examples/finetuning/text_classification.py
index f5977ae113..efbcac71ea 100644
--- a/flash_examples/finetuning/text_classification.py
+++ b/flash_examples/finetuning/text_classification.py
@@ -25,7 +25,7 @@
test_file="data/imdb/test.csv",
input="review",
target="sentiment",
- batch_size=16
+ batch_size=16,
)
# 3. Build the model
diff --git a/flash_examples/predict/image_classification.py b/flash_examples/predict/image_classification.py
index fda4a5c71a..fe697b2963 100644
--- a/flash_examples/predict/image_classification.py
+++ b/flash_examples/predict/image_classification.py
@@ -30,6 +30,10 @@
print(predictions)
# 3b. Or generate predictions with a whole folder!
-datamodule = ImageClassificationData.from_folders(predict_folder="data/hymenoptera_data/predict/")
+datamodule = ImageClassificationData.from_folders(
+ predict_folder="data/hymenoptera_data/predict/",
+ preprocess=model.preprocess,
+)
+
predictions = Trainer().predict(model, datamodule=datamodule)
print(predictions)
diff --git a/flash_examples/predict/image_classification_multi_label.py b/flash_examples/predict/image_classification_multi_label.py
new file mode 100644
index 0000000000..64c8ff6faf
--- /dev/null
+++ b/flash_examples/predict/image_classification_multi_label.py
@@ -0,0 +1,65 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any
+
+import torchvision.transforms.functional as T
+from torchvision.utils import make_grid
+
+from flash import Trainer
+from flash.data.base_viz import BaseVisualization
+from flash.data.utils import download_data
+from flash.vision import ImageClassificationData, ImageClassifier
+
+# 1. Download the data
+# This is a subset of the movie poster genre prediction data set from the paper
+# “Movie Genre Classification based on Poster Images with Deep Neural Networks” by Wei-Ta Chu and Hung-Jui Guo.
+# Please consider citing their paper if you use it. More here: https://www.cs.ccu.edu.tw/~wtchu/projects/MoviePoster/
+download_data("https://pl-flash-data.s3.amazonaws.com/movie_posters.zip", "data/")
+
+
+# 2. Define our custom visualisation and datamodule
+class CustomViz(BaseVisualization):
+
+ def show_per_batch_transform(self, batch: Any, _) -> None:
+ images = batch[0]
+ image = make_grid(images, nrow=2)
+ image = T.to_pil_image(image, 'RGB')
+ image.show()
+
+
+# 3. Load the model from a checkpoint
+model = ImageClassifier.load_from_checkpoint(
+ "https://flash-weights.s3.amazonaws.com/image_classification_multi_label_model.pt",
+)
+
+# 4a. Predict the genres of a few movie posters!
+predictions = model.predict([
+ "data/movie_posters/val/tt0361500.jpg",
+ "data/movie_posters/val/tt0361748.jpg",
+ "data/movie_posters/val/tt0362478.jpg",
+])
+print(predictions)
+
+# 4b. Or generate predictions with a whole folder!
+datamodule = ImageClassificationData.from_folders(
+ predict_folder="data/movie_posters/predict/",
+ data_fetcher=CustomViz(),
+ preprocess=model.preprocess,
+)
+
+predictions = Trainer().predict(model, datamodule=datamodule)
+print(predictions)
+
+# 5. Show some data!
+datamodule.show_predict_batch()
diff --git a/flash_examples/predict/tabular_classification.py b/flash_examples/predict/tabular_classification.py
index 71094a5e9e..a874d1f99f 100644
--- a/flash_examples/predict/tabular_classification.py
+++ b/flash_examples/predict/tabular_classification.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from flash.core.classification import Labels
from flash.data.utils import download_data
from flash.tabular import TabularClassifier
@@ -20,6 +21,8 @@
# 2. Load the model from a checkpoint
model = TabularClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/tabular_classification_model.pt")
+model.serializer = Labels(['Did not survive', 'Survived'])
+
# 3. Generate predictions from a sheet file! Who would survive?
predictions = model.predict("data/titanic/titanic.csv")
print(predictions)
diff --git a/flash_examples/predict/text_classification.py b/flash_examples/predict/text_classification.py
index 00029e3fae..e81fd17c52 100644
--- a/flash_examples/predict/text_classification.py
+++ b/flash_examples/predict/text_classification.py
@@ -13,6 +13,7 @@
# limitations under the License.
from pytorch_lightning import Trainer
+from flash.core.classification import Labels
from flash.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier
@@ -22,12 +23,14 @@
# 2. Load the model from a checkpoint
model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt")
+model.serializer = Labels()
+
# 2a. Classify a few sentences! How was the movie?
predictions = model.predict([
"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
"The worst movie in the history of cinema.",
- "I come from Bulgaria where it 's almost impossible to have a tornado."
- "Very, very afraid"
+ "I come from Bulgaria where it 's almost impossible to have a tornado.",
+ "Very, very afraid.",
"This guy has done a great job with this movie!",
])
print(predictions)
@@ -37,7 +40,7 @@
predict_file="data/imdb/predict.csv",
input="review",
# use the same data pre-processing values we used to predict in 2a
- preprocess_state=model.data_pipeline.preprocess_state,
+ preprocess=model.preprocess,
)
predictions = Trainer().predict(model, datamodule=datamodule)
print(predictions)
diff --git a/flash_notebooks/custom_task_tutorial.ipynb b/flash_notebooks/custom_task_tutorial.ipynb
index 5c5c947405..93b214ff62 100644
--- a/flash_notebooks/custom_task_tutorial.ipynb
+++ b/flash_notebooks/custom_task_tutorial.ipynb
@@ -25,7 +25,7 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -36,7 +36,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -56,27 +56,9 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Global seed set to 42\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "42"
- ]
- },
- "execution_count": 7,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"seed_everything(42)"
]
@@ -92,7 +74,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -112,14 +94,12 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class LinearRegression(flash.Task):\n",
"\n",
- " postprocess_cls = CustomPostprocess\n",
- "\n",
" def __init__(self, num_inputs, learning_rate=0.001, metrics=None):\n",
" # what kind of model do we want?\n",
" model = nn.Linear(num_inputs, 1)\n",
@@ -136,6 +116,7 @@
" optimizer=optimizer,\n",
" metrics=metrics,\n",
" learning_rate=learning_rate,\n",
+ " postprocess=CustomPostprocess(),\n",
" )\n",
"\n",
" def forward(self, x):\n",
@@ -167,7 +148,7 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -193,18 +174,16 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class SklearnDataModule(flash.DataModule):\n",
"\n",
- " preprocess_cls = NumpyRegressionPreprocess\n",
- "\n",
" @classmethod\n",
" def from_dataset(cls, x: np.ndarray, y: np.ndarray, batch_size: int = 64, num_workers: int = 0):\n",
"\n",
- " preprocess = cls.preprocess_cls()\n",
+ " preprocess = NumpyRegressionPreprocess()\n",
"\n",
" x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.20, random_state=0)\n",
"\n",
@@ -387,7 +366,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.6.9"
+ "version": "3.8.5"
}
},
"nbformat": 4,
diff --git a/flash_notebooks/tabular_classification.ipynb b/flash_notebooks/tabular_classification.ipynb
index 3932ba7c09..8cbb8470a4 100644
--- a/flash_notebooks/tabular_classification.ipynb
+++ b/flash_notebooks/tabular_classification.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
- "id": "upper-receipt",
+ "id": "twelve-miracle",
"metadata": {},
"source": [
"\n",
@@ -12,7 +12,7 @@
},
{
"cell_type": "markdown",
- "id": "herbal-commissioner",
+ "id": "genuine-elephant",
"metadata": {},
"source": [
"In this notebook, we'll go over the basics of lightning Flash by training a TabularClassifier on [Titanic Dataset](https://www.kaggle.com/c/titanic).\n",
@@ -26,7 +26,7 @@
},
{
"cell_type": "markdown",
- "id": "married-failing",
+ "id": "sorted-dancing",
"metadata": {},
"source": [
"# Training"
@@ -35,7 +35,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "innocent-bhutan",
+ "id": "caring-appreciation",
"metadata": {},
"outputs": [],
"source": [
@@ -46,7 +46,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "expensive-chassis",
+ "id": "sexual-diabetes",
"metadata": {},
"outputs": [],
"source": [
@@ -59,7 +59,7 @@
},
{
"cell_type": "markdown",
- "id": "virtual-supplier",
+ "id": "boxed-harvest",
"metadata": {},
"source": [
"### 1. Download the data\n",
@@ -69,7 +69,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "documented-humanitarian",
+ "id": "backed-render",
"metadata": {},
"outputs": [],
"source": [
@@ -78,7 +78,7 @@
},
{
"cell_type": "markdown",
- "id": "german-grill",
+ "id": "young-arthritis",
"metadata": {},
"source": [
"### 2. Load the data\n",
@@ -90,7 +90,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "exempt-cholesterol",
+ "id": "ultimate-bunny",
"metadata": {},
"outputs": [],
"source": [
@@ -106,7 +106,7 @@
},
{
"cell_type": "markdown",
- "id": "mineral-remove",
+ "id": "brutal-hypothesis",
"metadata": {},
"source": [
"### 3. Build the model\n",
@@ -117,7 +117,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "functioning-compilation",
+ "id": "practical-perry",
"metadata": {},
"outputs": [],
"source": [
@@ -126,7 +126,7 @@
},
{
"cell_type": "markdown",
- "id": "practical-highland",
+ "id": "dietary-bowling",
"metadata": {},
"source": [
"### 4. Create the trainer. Run 10 times on data"
@@ -135,7 +135,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "pretty-layer",
+ "id": "integral-interface",
"metadata": {},
"outputs": [],
"source": [
@@ -144,7 +144,7 @@
},
{
"cell_type": "markdown",
- "id": "proprietary-mitchell",
+ "id": "liable-remains",
"metadata": {},
"source": [
"### 5. Train the model"
@@ -153,7 +153,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "advised-contact",
+ "id": "controversial-newcastle",
"metadata": {},
"outputs": [],
"source": [
@@ -162,7 +162,7 @@
},
{
"cell_type": "markdown",
- "id": "parental-norwegian",
+ "id": "fluid-franchise",
"metadata": {},
"source": [
"### 6. Test model"
@@ -171,7 +171,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "protective-scholar",
+ "id": "therapeutic-bidder",
"metadata": {},
"outputs": [],
"source": [
@@ -180,7 +180,7 @@
},
{
"cell_type": "markdown",
- "id": "operating-incident",
+ "id": "genuine-pilot",
"metadata": {},
"source": [
"### 7. Save it!"
@@ -189,7 +189,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "following-journalist",
+ "id": "alien-stand",
"metadata": {},
"outputs": [],
"source": [
@@ -198,7 +198,7 @@
},
{
"cell_type": "markdown",
- "id": "pointed-hunter",
+ "id": "conventional-travel",
"metadata": {},
"source": [
"# Predicting"
@@ -206,7 +206,7 @@
},
{
"cell_type": "markdown",
- "id": "homeless-warrior",
+ "id": "coated-insulation",
"metadata": {},
"source": [
"### 8. Load the model from a checkpoint\n",
@@ -217,7 +217,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "personalized-panel",
+ "id": "alpine-drilling",
"metadata": {},
"outputs": [],
"source": [
@@ -227,7 +227,7 @@
},
{
"cell_type": "markdown",
- "id": "soviet-theta",
+ "id": "painted-assistant",
"metadata": {},
"source": [
"### 9. Generate predictions from a sheet file! Who would survive?\n",
@@ -238,7 +238,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "minor-siemens",
+ "id": "located-cable",
"metadata": {},
"outputs": [],
"source": [
@@ -248,7 +248,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "martial-hundred",
+ "id": "realistic-infection",
"metadata": {},
"outputs": [],
"source": [
@@ -257,7 +257,7 @@
},
{
"cell_type": "markdown",
- "id": "portuguese-ordering",
+ "id": "classified-casino",
"metadata": {},
"source": [
"\n",
@@ -313,7 +313,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.6.8"
+ "version": "3.8.5"
}
},
"nbformat": 4,
diff --git a/tests/core/test_classification.py b/tests/core/test_classification.py
new file mode 100644
index 0000000000..fd4de14d7e
--- /dev/null
+++ b/tests/core/test_classification.py
@@ -0,0 +1,39 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+
+from flash.core.classification import Classes, Labels, Logits, Probabilities
+
+
+def test_classification_serializers():
+ example_output = torch.tensor([-0.1, 0.2, 0.3]) # 3 classes
+ labels = ['class_1', 'class_2', 'class_3']
+
+ assert torch.allclose(torch.tensor(Logits().serialize(example_output)), example_output)
+ assert torch.allclose(torch.tensor(Probabilities().serialize(example_output)), torch.softmax(example_output, -1))
+ assert Classes().serialize(example_output) == 2
+ assert Labels(labels).serialize(example_output) == 'class_3'
+
+
+def test_classification_serializers_multi_label():
+ example_output = torch.tensor([-0.1, 0.2, 0.3]) # 3 classes
+ labels = ['class_1', 'class_2', 'class_3']
+
+ assert torch.allclose(torch.tensor(Logits(multi_label=True).serialize(example_output)), example_output)
+ assert torch.allclose(
+ torch.tensor(Probabilities(multi_label=True).serialize(example_output)),
+ torch.sigmoid(example_output),
+ )
+ assert Classes(multi_label=True).serialize(example_output) == [1, 2]
+ assert Labels(labels, multi_label=True).serialize(example_output) == ['class_2', 'class_3']
diff --git a/tests/core/test_model.py b/tests/core/test_model.py
index 450b662dbd..4246da4f43 100644
--- a/tests/core/test_model.py
+++ b/tests/core/test_model.py
@@ -27,6 +27,7 @@
import flash
from flash.core.classification import ClassificationTask
+from flash.data.process import Postprocess
from flash.tabular import TabularClassifier
from flash.text import SummarizationTask, TextClassifier
from flash.utils.imports import _TRANSFORMERS_AVAILABLE
@@ -50,6 +51,11 @@ def __getitem__(self, index: int) -> Tensor:
return torch.rand(1, 28, 28)
+class DummyPostprocess(Postprocess):
+
+ pass
+
+
# ================================
@@ -116,10 +122,10 @@ def test_classification_task_trainer_predict(tmpdir):
def test_task_datapipeline_save(tmpdir):
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax())
train_dl = torch.utils.data.DataLoader(DummyDataset())
- task = ClassificationTask(model, F.nll_loss)
+ task = ClassificationTask(model, F.nll_loss, postprocess=DummyPostprocess())
# to check later
- task._postprocess.test = True
+ task.postprocess.test = True
# generate a checkpoint
trainer = pl.Trainer(
@@ -136,7 +142,7 @@ def test_task_datapipeline_save(tmpdir):
# load from file
task = ClassificationTask.load_from_checkpoint(path, model=model)
- assert task._postprocess.test
+ assert task.postprocess.test
@pytest.mark.parametrize(
diff --git a/tests/data/test_callbacks.py b/tests/data/test_callbacks.py
index df1bba9939..ad2e2bfb61 100644
--- a/tests/data/test_callbacks.py
+++ b/tests/data/test_callbacks.py
@@ -25,6 +25,7 @@
from flash.data.base_viz import BaseVisualization
from flash.data.callback import BaseDataFetcher
from flash.data.data_module import DataModule
+from flash.data.process import Preprocess
from flash.data.utils import _STAGES_PREFIX
from flash.vision import ImageClassificationData
@@ -57,7 +58,7 @@ def configure_data_fetcher():
@classmethod
def from_inputs(cls, train_data: Any, val_data: Any, test_data: Any, predict_data: Any) -> "CustomDataModule":
- preprocess = cls.preprocess_cls()
+ preprocess = Preprocess()
return cls.from_load_data_inputs(
train_load_data_input=train_data,
diff --git a/tests/data/test_data_pipeline.py b/tests/data/test_data_pipeline.py
index 9d15247cf2..147b7b37f9 100644
--- a/tests/data/test_data_pipeline.py
+++ b/tests/data/test_data_pipeline.py
@@ -665,14 +665,11 @@ def predict_step(self, batch, batch_idx, dataloader_idx):
return tensor([0, 0, 0])
-class CustomDataModule(DataModule):
-
- preprocess_cls = TestPreprocessTransformations
-
-
def test_datapipeline_transformations(tmpdir):
- datamodule = CustomDataModule.from_load_data_inputs(1, 1, 1, 1, batch_size=2, num_workers=0)
+ datamodule = DataModule.from_load_data_inputs(
+ 1, 1, 1, 1, batch_size=2, num_workers=0, preprocess=TestPreprocessTransformations()
+ )
assert datamodule.train_dataloader().dataset[0] == (0, 1, 2, 3)
batch = next(iter(datamodule.train_dataloader()))
@@ -683,8 +680,9 @@ def test_datapipeline_transformations(tmpdir):
with pytest.raises(MisconfigurationException, match="When ``to_tensor_transform``"):
batch = next(iter(datamodule.val_dataloader()))
- CustomDataModule.preprocess_cls = TestPreprocessTransformations2
- datamodule = CustomDataModule.from_load_data_inputs(1, 1, 1, 1, batch_size=2, num_workers=0)
+ datamodule = DataModule.from_load_data_inputs(
+ 1, 1, 1, 1, batch_size=2, num_workers=0, preprocess=TestPreprocessTransformations2()
+ )
batch = next(iter(datamodule.val_dataloader()))
assert torch.equal(batch["a"], tensor([0, 1]))
assert torch.equal(batch["b"], tensor([1, 2]))
diff --git a/tests/data/test_process.py b/tests/data/test_process.py
new file mode 100644
index 0000000000..eeba5acfe6
--- /dev/null
+++ b/tests/data/test_process.py
@@ -0,0 +1,103 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest.mock import Mock
+
+import pytest
+
+from flash.data.data_pipeline import DataPipelineState
+from flash.data.process import ProcessState, Properties, Serializer, SerializerMapping
+
+
+def test_properties_data_pipeline_state():
+ """Tests that ``get_state`` and ``set_state`` work for properties and that ``DataPipelineState`` is attached
+ correctly."""
+
+ class MyProcessState1(ProcessState):
+ pass
+
+ class MyProcessState2(ProcessState):
+ pass
+
+ class OtherProcessState(ProcessState):
+ pass
+
+ my_properties = Properties()
+ my_properties.set_state(MyProcessState1())
+ assert my_properties._state == {MyProcessState1: MyProcessState1()}
+ assert my_properties.get_state(OtherProcessState) is None
+
+ data_pipeline_state = DataPipelineState()
+ data_pipeline_state.set_state(OtherProcessState())
+ my_properties.attach_data_pipeline_state(data_pipeline_state)
+ assert my_properties.get_state(OtherProcessState) == OtherProcessState()
+
+ my_properties.set_state(MyProcessState2())
+ assert data_pipeline_state.get_state(MyProcessState2) == MyProcessState2()
+
+
+def test_serializer():
+ """Tests that ``Serializer`` can be enabled and disabled correctly."""
+
+ my_serializer = Serializer()
+
+ assert my_serializer.serialize('test') == 'test'
+ my_serializer.serialize = Mock()
+
+ my_serializer.disable()
+ assert my_serializer('test') == 'test'
+ my_serializer.serialize.assert_not_called()
+
+ my_serializer.enable()
+ my_serializer('test')
+ my_serializer.serialize.assert_called_once()
+
+
+def test_serializer_mapping():
+ """Tests that ``SerializerMapping`` correctly passes its inputs to the underlying serializers. Also checks that
+ state is retrieved / loaded correctly."""
+
+ serializer1 = Serializer()
+ serializer1.serialize = Mock(return_value='test1')
+
+ class Serializer1State(ProcessState):
+ pass
+
+ serializer2 = Serializer()
+ serializer2.serialize = Mock(return_value='test2')
+
+ class Serializer2State(ProcessState):
+ pass
+
+ serializer_mapping = SerializerMapping({'key1': serializer1, 'key2': serializer2})
+ assert serializer_mapping({'key1': 'serializer1', 'key2': 'serializer2'}) == {'key1': 'test1', 'key2': 'test2'}
+ serializer1.serialize.assert_called_once_with('serializer1')
+ serializer2.serialize.assert_called_once_with('serializer2')
+
+ with pytest.raises(ValueError, match='output must be a mapping'):
+ serializer_mapping('not a mapping')
+
+ serializer1_state = Serializer1State()
+ serializer2_state = Serializer2State()
+
+ serializer1.set_state(serializer1_state)
+ serializer2.set_state(serializer2_state)
+
+ data_pipeline_state = DataPipelineState()
+ serializer_mapping.attach_data_pipeline_state(data_pipeline_state)
+
+ assert serializer1._data_pipeline_state is data_pipeline_state
+ assert serializer2._data_pipeline_state is data_pipeline_state
+
+ assert data_pipeline_state.get_state(Serializer1State) is serializer1_state
+ assert data_pipeline_state.get_state(Serializer2State) is serializer2_state
diff --git a/tests/data/test_serialization.py b/tests/data/test_serialization.py
index bc35fc0eb4..fda5cb7643 100644
--- a/tests/data/test_serialization.py
+++ b/tests/data/test_serialization.py
@@ -57,11 +57,11 @@ def test_serialization_data_pipeline(tmpdir):
trainer.fit(model, dummy_data)
assert model.data_pipeline
- assert isinstance(model._preprocess, CustomPreprocess)
+ assert isinstance(model.preprocess, CustomPreprocess)
trainer.save_checkpoint(checkpoint_file)
loaded_model = CustomModel.load_from_checkpoint(checkpoint_file)
assert loaded_model.data_pipeline
- assert isinstance(loaded_model._preprocess, CustomPreprocess)
+ assert isinstance(loaded_model.preprocess, CustomPreprocess)
for file in os.listdir(tmpdir):
if file.endswith('.ckpt'):
os.remove(os.path.join(tmpdir, file))
diff --git a/tests/vision/classification/test_model.py b/tests/vision/classification/test_model.py
index 067fa994b8..0aa3ab1835 100644
--- a/tests/vision/classification/test_model.py
+++ b/tests/vision/classification/test_model.py
@@ -16,6 +16,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from flash import Trainer
+from flash.core.classification import Probabilities
from flash.vision import ImageClassifier
# ======== Mock functions ========
@@ -85,7 +86,7 @@ def test_multilabel(tmpdir):
num_classes = 4
ds = DummyMultiLabelDataset(num_classes)
- model = ImageClassifier(num_classes, multi_label=True)
+ model = ImageClassifier(num_classes, multi_label=True, serializer=Probabilities(multi_label=True))
train_dl = torch.utils.data.DataLoader(ds, batch_size=2)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.finetune(model, train_dl, strategy="freeze_unfreeze")