Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Refactor preprocess_cls to preprocess, add Serializer, add `Dat…
Browse files Browse the repository at this point in the history
…aPipelineState` (#229)

* Initial commit

* Initial commit

* Small fixes

* Small fixes

* Fix a small bug

* Update docs

* Update notebook

* Update finetuning image classification

* Updates

* Update docs and serializer mapping

* Fix missed merge conflicts

* Fix some broken tests

* Fix a test

* Fix a test

* Fix some tests

* Fix a test

* Update examples

* Update examples

* Pre-commit

* Pre-commit

* Update text classification

* Add a test

* Multi-label example initial commit

* Add predict example for multi_label

* Remove unused imports

* Update predict example

* Update examples

* Add multi-label Labels suport

* Update test_classification

* Update .gitignore

* Add some tests

* Fix broken test

* Update test_process

* Some docs updates

* Update docs

* Fix some tests

* Add back some process_cls

* Add types

* Update docs/source/general/data.rst

Co-authored-by: thomas chaton <thomas@grid.ai>

* Update flash/data/process.py

Co-authored-by: Edgar Riba <edgar.riba@gmail.com>

* Add comment

* Update example

* Remove state checkpoint not needed

* Fix doctest

* Update image_classification example

* Update following fix

* Fix num-workers in test_examples

* Update example predict

* Better fix for windows error

Co-authored-by: thomas chaton <thomas@grid.ai>
Co-authored-by: Edgar Riba <edgar.riba@gmail.com>
  • Loading branch information
3 people authored Apr 22, 2021
1 parent b7436c4 commit 1ab7346
Show file tree
Hide file tree
Showing 38 changed files with 1,228 additions and 563 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.DS_Store
.lock
lightning_logs

Expand Down Expand Up @@ -149,3 +150,4 @@ xsum
coco128
wmt_en_ro
kinetics
movie_posters
12 changes: 6 additions & 6 deletions docs/source/custom_task.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,13 @@ We will define a custom ``NumpyDataModule`` class subclassing :class:`~flash.dat
This ``NumpyDataModule`` class will provide a ``from_xy_dataset`` helper ``classmethod`` to instantiate
:class:`~flash.data.data_module.DataModule` from x, y numpy arrays.

Here is how it would look like:
Here is how it would look:

Example::

x, y = ...
preprocess_cls = ...
datamodule = NumpyDataModule.from_xy_dataset(x, y, preprocess_cls)
preprocess = ...
datamodule = NumpyDataModule.from_xy_dataset(x, y, preprocess)

Here is the ``NumpyDataModule`` implementation:

Expand All @@ -140,12 +140,12 @@ Example::
cls,
x: ND,
y: ND,
preprocess_cls: Preprocess = NumpyPreprocess,
preprocess: Preprocess = None,
batch_size: int = 64,
num_workers: int = 0
):

preprocess = preprocess_cls()
preprocess = preprocess or NumpyPreprocess()

x_train, x_test, y_train, y_test = train_test_split(
x, y, test_size=.20, random_state=0)
Expand Down Expand Up @@ -180,7 +180,7 @@ It allows the user much more granular control over their data processing flow.

.. note::

Why introducing :class:`~flash.data.process.Preprocess` ?
Why introduce :class:`~flash.data.process.Preprocess` ?

The :class:`~flash.data.process.Preprocess` object reduces the engineering overhead to make inference on raw data or
to deploy the model in production environnement compared to traditional
Expand Down
45 changes: 33 additions & 12 deletions docs/source/general/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ Here are common terms you need to be familiar with:
The :class:`~flash.data.process.Preprocess` hooks covers from data-loading to model forwarding.
* - :class:`~flash.data.process.Postprocess`
- The :class:`~flash.data.process.Postprocess` provides a simple hook-based API to encapsulate your post-processing logic.
The :class:`~flash.data.process.Postprocess` hooks covers from model outputs to predictions export.
The :class:`~flash.data.process.Postprocess` hooks cover from model outputs to predictions export.
* - :class:`~flash.data.process.Serializer`
- The :class:`~flash.data.process.Serializer` provides a single ``serialize`` method that is used to convert model outputs (after the :class:`~flash.data.process.Postprocess`) to the desired output format during prediction.

*******************************************
How to use out-of-the-box flashdatamodules
Expand All @@ -49,7 +51,9 @@ However, after model training, it requires a lot of engineering overhead to make
Usually, extra processing logic should be added to bridge the gap between training data and raw data.

The :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` classes can be used to
store the data as well as the preprocessing and postprocessing transforms.
store the data as well as the preprocessing and postprocessing transforms. The :class:`~flash.data.process.Serializer`
class provides the logic for converting :class:`~flash.data.process.Postprocess` outputs to the desired predict format
(e.g. classes, labels, probabilites, etc.).

By providing a series of hooks that can be overridden with custom data processing logic,
the user has much more granular control over their data processing flow.
Expand Down Expand Up @@ -122,7 +126,7 @@ Example::
train_folder="data/hymenoptera_data/train/",
val_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
preprocess_cls=CustomImageClassificationPreprocess
preprocess=CustomImageClassificationPreprocess(),
)


Expand Down Expand Up @@ -157,7 +161,7 @@ Example::
val_folder="./data/val",
test_folder="./data/test",
predict_folder="./data/predict",
preprocess=preprocess
preprocess=preprocess,
)

model = ImageClassifier(...)
Expand Down Expand Up @@ -190,6 +194,7 @@ Example::
**kwargs
):

# Set a custom ``Preprocess`` if none was provided
preprocess = preprocess or cls.preprocess_cls()

# {stage}_load_data_input will be given to your
Expand Down Expand Up @@ -291,6 +296,18 @@ ___________
:members:


----------

.. _serializer:

Serializer
___________


.. autoclass:: flash.data.process.Serializer
:members:


----------

.. _datapipeline:
Expand Down Expand Up @@ -414,16 +431,18 @@ Example::
predictions = lightning_module(data)


Postprocess
___________
Postprocess and Serializer
__________________________


Once the predictions have been generated by the Flash :class:`~flash.core.model.Task`.
The Flash :class:`~flash.data.data_pipeline.DataPipeline` will behind the scenes execute the :class:`~flash.data.process.Postprocess` hooks.
Once the predictions have been generated by the Flash :class:`~flash.core.model.Task`, the Flash
:class:`~flash.data.data_pipeline.DataPipeline` will execute the :class:`~flash.data.process.Postprocess` hooks and the
:class:`~flash.data.process.Serializer` behind the scenes.

First, the ``per_batch_transform`` hooks will be applied on the batch predictions.
Then the ``uncollate`` will split the batch into individual predictions.
Finally, the ``per_sample_transform`` will be applied on each prediction.
First, the :meth:`~flash.data.process.Postprocess.per_batch_transform` hooks will be applied on the batch predictions.
Then, the :meth:`~flash.data.process.Postprocess.uncollate` will split the batch into individual predictions.
Next, the :meth:`~flash.data.process.Postprocess.per_sample_transform` will be applied on each prediction.
Finally, the :meth:`~flash.data.process.Serializer.serialize` method will be called to serialize the predictions.

.. note:: The transform can be applied either on device or ``CPU``.

Expand All @@ -438,7 +457,9 @@ Example::

samples = uncollate(batch)

return [per_sample_transform(sample) for sample in samples]
samples = [per_sample_transform(sample) for sample in samples]
# only if serializers are enabled.
return [serialize(sample) for sample in samples]

predictions = lightning_module(data)
return uncollate_fn(predictions)
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Lightning Flash
reference/task
reference/image_classification
reference/image_embedder
reference/multi_label_classification
reference/summarization
reference/text_classification
reference/tabular_classification
Expand Down
212 changes: 212 additions & 0 deletions docs/source/reference/multi_label_classification.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@

.. _multi_label_classification:

################################
Multi-label Image Classification
################################

********
The task
********
Multi-label classification is the task of assigning a number of labels from a fixed set to each data point, which can be in any modality. In this example, we will look at the task of trying to predict the movie genres from an image of the movie poster.

------

********
The data
********
The data we will use in this example is a subset of the awesome movie poster genre prediction data set from the paper "Movie Genre Classification based on Poster Images with Deep Neural Networks" by Wei-Ta Chu and Hung-Jui Guo, resized to 128 by 128.
Take a look at their paper (and please consider citing their paper if you use the data) here: `www.cs.ccu.edu.tw/~wtchu/projects/MoviePoster/ <https://www.cs.ccu.edu.tw/~wtchu/projects/MoviePoster/>`_.

------

*********
Inference
*********

The :class:`~flash.vision.ImageClassifier` is already pre-trained on `ImageNet <http://www.image-net.org/>`_, a dataset of over 14 million images.

We can use the :class:`~flash.vision.ImageClassifier` model (pretrained on our data) for inference on any string sequence using :func:`~flash.vision.ImageClassifier.predict`.
We can also add a simple visualisation by extending :class:`~flash.data.base_viz.BaseVisualization`, like this:

.. code-block:: python
# import our libraries
from typing import Any
import torchvision.transforms.functional as T
from torchvision.utils import make_grid
from flash import Trainer
from flash.data.base_viz import BaseVisualization
from flash.data.utils import download_data
from flash.vision import ImageClassificationData, ImageClassifier
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/movie_posters.zip", "data/")
# 2. Define our custom visualisation and datamodule
class CustomViz(BaseVisualization):
def show_per_batch_transform(self, batch: Any, _):
images = batch[0]
image = make_grid(images, nrow=2)
image = T.to_pil_image(image, 'RGB')
image.show()
# 3. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/image_classification_multi_label_model.pt",
)
# 4a. Predict the genres of a few movie posters!
predictions = model.predict([
"data/movie_posters/val/tt0361500.jpg",
"data/movie_posters/val/tt0361748.jpg",
"data/movie_posters/val/tt0362478.jpg",
])
print(predictions)
# 4b. Or generate predictions with a whole folder!
datamodule = ImageClassificationData.from_folders(
predict_folder="data/movie_posters/predict/",
data_fetcher=CustomViz(),
preprocess=model.preprocess,
)
predictions = Trainer().predict(model, datamodule=datamodule)
print(predictions)
# 5. Show some data!
datamodule.show_predict_batch()
For more advanced inference options, see :ref:`predictions`.

------

**********
Finetuning
**********

Now let's look at how we can finetune a model on the movie poster data.
Once we download the data using :func:`~flash.data.download_data`, all we need is the train data and validation data folders to create the :class:`~flash.vision.ImageClassificationData`.

.. note:: The dataset contains ``train`` and ``validation`` folders, and then each folder contains images and a ``metadata.csv`` which stores the labels.

.. code-block::
movie_posters
├── train
│ ├── metadata.csv
│ ├── tt0084058.jpg
│ ├── tt0084867.jpg
│ ...
└── val
├── metadata.csv
├── tt0200465.jpg
├── tt0326965.jpg
...
The ``metadata.csv`` files in each folder contain our labels, so we need to create a function (``load_data``) to extract the list of images and associated labels:

.. code-block:: python
# import our libraries
import os
from typing import List, Tuple
import pandas as pd
import torch
genres = [
"Action", "Adventure", "Animation", "Biography", "Comedy", "Crime", "Documentary", "Drama", "Family", "Fantasy", "History", "Horror", "Music", "Musical", "Mystery", "N/A", "News", "Reality-TV", "Romance", "Sci-Fi", "Short", "Sport", "Thriller", "War", "Western"
]
def load_data(data: str, root: str = 'data/movie_posters') -> Tuple[List[str], List[List[int]]]:
metadata = pd.read_csv(os.path.join(root, data, "metadata.csv"))
images = []
labels = []
for _, row in metadata.iterrows():
images.append(os.path.join(root, data, row['Id'] + ".jpg"))
labels.append([int(row[genre]) for genre in genres])
return images, labels
Our :class:`~flash.data.process.Preprocess` overrides the :meth:`~flash.data.process.Preprocess.load_data` method to create an iterable of image paths and label tensors. The :class:`~flash.vision.classification.data.ImageClassificationPreprocess` then handles loading and augmenting the images for us!
Now all we need is three lines of code to build to train our task!

.. note:: We need set `multi_label=True` in both our :class:`~flash.Task` and our :class:`~flash.data.process.Serializer` to use a binary cross entropy loss and to process outputs correctly.

.. code-block:: python
import flash
from flash.core.classification import Labels
from flash.core.finetuning import FreezeUnfreeze
from flash.data.utils import download_data
from flash.vision import ImageClassificationData, ImageClassifier
from flash.vision.classification.data import ImageClassificationPreprocess
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/movie_posters.zip", "data/")
# 2. Load the data
ImageClassificationPreprocess.image_size = (128, 128)
train_filepaths, train_labels = load_data('train')
val_filepaths, val_labels = load_data('val')
test_filepaths, test_labels = load_data('test')
datamodule = ImageClassificationData.from_filepaths(
train_filepaths=train_filepaths,
train_labels=train_labels,
val_filepaths=val_filepaths,
val_labels=val_labels,
test_filepaths=test_filepaths,
test_labels=test_labels,
preprocess=ImageClassificationPreprocess(),
)
# 3. Build the model
model = ImageClassifier(
backbone="resnet18",
num_classes=len(genres),
multi_label=True,
)
# 4. Create the trainer.
trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1)
# 5. Train the model
trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1))
# 6a. Predict what's on a few images!
# Serialize predictions as labels.
model.serializer = Labels(genres, multi_label=True)
predictions = model.predict([
"data/movie_posters/val/tt0361500.jpg",
"data/movie_posters/val/tt0361748.jpg",
"data/movie_posters/val/tt0362478.jpg",
])
print(predictions)
datamodule = ImageClassificationData.from_folders(
predict_folder="data/movie_posters/predict/",
preprocess=model.preprocess,
)
# 6b. Or generate predictions with a whole folder!
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)
# 7. Save it!
trainer.save_checkpoint("image_classification_multi_label_model.pt")
------

For more backbone options, see :ref:`image_classification`.
Loading

0 comments on commit 1ab7346

Please sign in to comment.