Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Adding support for loading datasets and visualizing model predictions via FiftyOne #360

Merged
merged 64 commits into from
Jun 15, 2021

Conversation

ehofesmann
Copy link
Contributor

@ehofesmann ehofesmann commented Jun 4, 2021

What does this PR do?

Integrates Lightning Flash with FiftyOne, the open source dataset and model analysis library!

Loading FiftyOne data into Flash

This PR adds FiftyOneDataSources for image/video classification, object detection, semantic segmentation, and image embedding tasks that load FiftyOne Datasets into Flash.

Loading Flash predictions into FiftyOne

This PR adds Serializer implementations that can convert classification/detection/segmentation model outputs into the appropriate FiftyOne label types so that they can be added to FiftyOne datasets and visualized.

Note

This PR requires a source install of FiftyOne on this branch voxel51/fiftyone#1059 in order to function.

git clone https://github.com/voxel51/fiftyone
cd fiftyone
git checkout --track origin/flash-video
bash install.bash

The above branch also contains a parallel integration that enables FiftyOne users to add predictions from any Flash model to their datasets 😄

Points of discussion

  1. It'd be great if these examples could be integrated into the Flash documentation/README in the appropriate places 😄

  2. The new FiftyoneDataSource classes introduced in this PR require a label_field argument to specify which field of the FiftyOne dataset should be used as the label field. To enable this, we added **data_source_kwargs to Flash's processor interface. Perhaps there's a better way to support this?

  3. When serializing object detections, Flash models seem to return bounding boxes in absolute coordinates, but FiftyOne expects bounding boxes in relative coordinates. Is it possible for FiftyOneDetectionLabels to access the dimensions of the current image when serialize() is called? Perhaps using set_state() as is done for class labels? The current implementation requires fiftyone.utils.flash.normalize_detections() to be manually called to convert to relative coordinates for import into FiftyOne, but it would be much cleaner if this could be done natively within FiftyOneDetectionLabels...

Basic patterns

The following subsections show the basic patterns enabled by this integration. See the next section for concrete examples of each task type.

Loading data from FiftyOne into Flash

FiftyOne users can load their datasets into Flash Data Sources via the pattern below:

from flash.image import ImageClassificationData

import fiftyone as fo

train_dataset = fo.Dataset.from_dir(
    "/path/to/train",
    fo.types.ImageClassificationDirectoryTree,
    label_field="ground_truth",
)

val_dataset = fo.Dataset.from_dir(
    "/path/to/val",
    fo.types.ImageClassificationDirectoryTree,
    label_field="ground_truth",
)

datamodule = ImageClassificationData.from_fiftyone(
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    label_field="ground_truth",
)

Visualizing Flash predictions in FiftyOne

Flash users can swap out the serializer on their model with the corresponding FiftyOne serializer for the task type, and then visualize their predictions in the FiftyOne App via the pattern below:

from flash import Trainer
from flash.core.classification import FiftyOneLabels
from flash.core.integrations.fiftyone import visualize
from flash.video import VideoClassificationData, VideoClassifier

classifier = VideoClassifier.load_from_checkpoint(...)

# Option 1: Generate predictions using a Trainer and datamodule
datamodule = VideoClassificationData.from_folders(
    predict_folder="/path/to/folder",
    ...
)
trainer = Trainer()
classifier.serializer = FiftyOneLabels(return_filepath=True)
predictions = trainer.predict(classifier, datamodule=datamodule)

session = visualize(predictions) # Launch FiftyOne

# Option 2: Generate predictions from model using filepaths
filepaths = ["list", "of", "filepaths"]
predictions = classifier.predict(filepaths)
classifier.serializer = FiftyOneLabels()

session = visualize(predictions, filepaths=filepaths) # Launch FiftyOne

Applying Flash models to FiftyOne datasets

In addition to this PR, voxel51/fiftyone#1059 adds a parallel integration in the FiftyOne library that enables FiftyOne users to add predictions from any Flash model to their datasets via the pattern below:

from flash.image import ObjectDetector

import fiftyone as fo
import fiftyone.zoo as foz

dataset = foz.load_zoo_dataset("quickstart", max_samples=10)

model = ObjectDetector.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/object_detection_model.pt")

dataset.apply_model(model, label_field="predictions")

session = fo.launch_app(dataset)

Task examples

The subsections below demonstrate both (a) FiftyOne dataset -> Flash, and (b) Flash predictions -> FiftyOne for each task type.

Video classification

from torch.utils.data.sampler import RandomSampler

import flash
from flash.core.classification import FiftyOneLabels
from flash.core.data.utils import download_data
from flash.video import VideoClassificationData, VideoClassifier

import fiftyone as fo

# 1. Download data
download_data("https://pl-flash-data.s3.amazonaws.com/kinetics.zip")

# 2. Load data into FiftyOne
# Here we use different datasets for each split, but you can also
# use views into the same dataset
train_dataset = fo.Dataset.from_dir(
    "data/kinetics/train",
    fo.types.VideoClassificationDirectoryTree,
    label_field="ground_truth",
    max_samples=5,
)

val_dataset = fo.Dataset.from_dir(
    "data/kinetics/val",
    fo.types.VideoClassificationDirectoryTree,
    label_field="ground_truth",
    max_samples=5,
)

predict_dataset = fo.Dataset.from_dir(
    "data/kinetics/predict",
    fo.types.VideoDirectory,
    max_samples=5,
)

# 3. Finetune a model
classifier = VideoClassifier.load_from_checkpoint(
  "https://flash-weights.s3.amazonaws.com/video_classification.pt",
  pretrained=False,
)

datamodule = VideoClassificationData.from_fiftyone(
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    predict_dataset=predict_dataset,
    label_field="ground_truth",
    batch_size=8,
    clip_sampler="uniform",
    clip_duration=1,
    video_sampler=RandomSampler,
    decode_audio=False,
    num_workers=8,
)

trainer = flash.Trainer(max_epochs=1, fast_dev_run=1)
trainer.finetune(classifier, datamodule=datamodule)
trainer.save_checkpoint("video_classification.pt")

# 4. Predict from checkpoint
classifier = VideoClassifier.load_from_checkpoint(
  "https://flash-weights.s3.amazonaws.com/video_classification.pt",
  pretrained=False,
)

classifier.serializer = FiftyOneLabels()

filepaths = predict_dataset.values("filepath")
predictions = classifier.predict(filepaths)

predict_dataset.set_values("predictions", predictions)

# 5. Visualize in FiftyOne App
session = fo.launch_app(predict_dataset)

Image classification

from itertools import chain

import fiftyone as fo
import fiftyone.zoo as foz

from flash import Trainer
from flash.core.classification import FiftyOneLabels
from flash.core.finetuning import FreezeUnfreeze
from flash.image import ImageClassificationData, ImageClassifier

# 1. Load your FiftyOne dataset
# Here we use views into one dataset, but you can also create a
# different dataset for each split
dataset = foz.load_zoo_dataset("cifar10", split="test", max_samples=40)
train_dataset = dataset.shuffle(seed=51)[:20]
test_dataset = dataset.shuffle(seed=51)[20:25]
val_dataset = dataset.shuffle(seed=51)[25:30]
predict_dataset = dataset.shuffle(seed=51)[30:40]

# 2. Load the Datamodule
datamodule = ImageClassificationData.from_fiftyone(
    train_dataset = train_dataset,
    test_dataset = test_dataset,
    val_dataset = val_dataset,
    predict_dataset = predict_dataset,
    label_field = "ground_truth",
    batch_size=4,
    num_workers=4,
)

# 3. Build the model
model = ImageClassifier(
    backbone="resnet18",
    num_classes=datamodule.num_classes,
    serializer=FiftyOneLabels(),
)

# 4. Create the trainer
trainer = Trainer(
    max_epochs=1,
    limit_train_batches=1,
    limit_val_batches=1,
)

# 5. Finetune the model
trainer.finetune(
    model,
    datamodule=datamodule,
    strategy=FreezeUnfreeze(unfreeze_epoch=1),
)

# 6. Save it!
trainer.save_checkpoint("image_classification_model.pt")

# 7. Generate predictions
model = ImageClassifier.load_from_checkpoint(
  "https://flash-weights.s3.amazonaws.com/image_classification_model.pt"
)
model.serializer = FiftyOneLabels()

predictions = trainer.predict(model, datamodule=datamodule)

predictions = list(chain.from_iterable(predictions)) # flatten batches

# 8. Add predictions to dataset and analyze
predict_dataset.set_values("flash_predictions", predictions)
session = fo.launch_app(view=predict_dataset)

Object detection

from itertools import chain

import fiftyone as fo
import fiftyone.zoo as foz

from flash import Trainer
from flash.image import ObjectDetectionData, ObjectDetector
from flash.image.detection.serialization import FiftyOneDetectionLabels

# 1. Load your FiftyOne dataset
# Here we use views into one dataset, but you can also create a
# different dataset for each split
dataset = foz.load_zoo_dataset("quickstart", max_samples=40)
train_dataset = dataset.shuffle(seed=51)[:20]
test_dataset = dataset.shuffle(seed=51)[20:25]
val_dataset = dataset.shuffle(seed=51)[25:30]
predict_dataset = dataset.shuffle(seed=51)[30:40]

# 2. Load the Datamodule
datamodule = ObjectDetectionData.from_fiftyone(
    train_dataset = train_dataset,
    test_dataset = test_dataset,
    val_dataset = val_dataset,
    predict_dataset = predict_dataset,
    label_field = "ground_truth",
    batch_size=4,
    num_workers=4,
)

# 3. Build the model
model = ObjectDetector(
    model="retinanet",
    num_classes=datamodule.num_classes,
    serializer=FiftyOneDetectionLabels(),
)

# 4. Create the trainer
trainer = Trainer(
    max_epochs=1,
    limit_train_batches=1,
    limit_val_batches=1,
)

# 5. Finetune the model
trainer.finetune(model, datamodule=datamodule)

# 6. Save it!
trainer.save_checkpoint("object_detection_model.pt")

# 7. Generate predictions
model = ObjectDetector.load_from_checkpoint(
  "https://flash-weights.s3.amazonaws.com/object_detection_model.pt"
)
model.serializer = FiftyOneDetectionLabels()

predictions = trainer.predict(model, datamodule=datamodule)

predictions = list(chain.from_iterable(predictions)) # flatten batches

# 8. Add predictions to dataset and analyze
predict_dataset.set_values("flash_predictions", predictions)
session = fo.launch_app(view=predict_dataset)

Semantic segmentation

from itertools import chain

import fiftyone as fo
import fiftyone.zoo as foz

from flash import Trainer
from flash.core.data.utils import download_data
from flash.image import SemanticSegmentation, SemanticSegmentationData
from flash.image.segmentation.serialization import FiftyOneSegmentationLabels

# 1. Load your FiftyOne dataset
# This is a Dataset with Semantic Segmentation Labels generated via CARLA
self-driving simulator.
# The data was generated as part of the Lyft Udacity Challenge.
# More info here:
https://www.kaggle.com/kumaresanmanickavelu/lyft-udacity-challenge
download_data(
  "https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip",
  "data/"
)

# Here we use views into one dataset, but you can also create a
# different dataset for each split
dataset = fo.Dataset.from_dir(
    dataset_dir = "data",
    data_path = "CameraRGB",
    labels_path = "CameraSeg",
    max_samples = 40,
    force_grayscale = True,
    dataset_type=fo.types.ImageSegmentationDirectory,
)
train_dataset = dataset.shuffle(seed=51)[:20]
test_dataset = dataset.shuffle(seed=51)[20:25]
val_dataset = dataset.shuffle(seed=51)[25:30]
predict_dataset = dataset.shuffle(seed=51)[30:40]

# 2. Load the Datamodule
datamodule = SemanticSegmentationData.from_fiftyone(
    train_dataset = train_dataset,
    test_dataset = test_dataset,
    val_dataset = val_dataset,
    predict_dataset = predict_dataset,
    label_field = "ground_truth",
    batch_size=4,
    num_workers=4,
    num_classes=21,
)

# 3. Build the model
model = SemanticSegmentation(
    backbone="resnet50",
    num_classes=datamodule.num_classes,
    serializer=FiftyOneSegmentationLabels(),
)

# 4. Create the trainer
trainer = Trainer(
    max_epochs=1,
    fast_dev_run=1,
)

# 5. Finetune the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 6. Save it!
trainer.save_checkpoint("semantic_segmentation_model.pt")

# 7. Generate predictions
model = ObjectDetector.load_from_checkpoint(
  "https://flash-weights.s3.amazonaws.com/semantic_segmentation_model.pt"
)
model.serializer = FiftyOneSegmentationLabels()

predictions = trainer.predict(model, datamodule=datamodule)

predictions = list(chain.from_iterable(predictions)) # flatten batches

# 8. Add predictions to dataset and analyze
predict_dataset.set_values("flash_predictions", predictions)
session = fo.launch_app(view=predict_dataset)

Image embeddings

import numpy as np
import torch

from flash.core.data.utils import download_data
from flash.image import ImageEmbedder

import fiftyone as fo
import fiftyone.brain as fob

# 1 Download data
download_data(
    "https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip"
)

# 2 Load data into FiftyOne
dataset = fo.Dataset.from_dir(
    "data/hymenoptera_data/test/",
    fo.types.ImageClassificationDirectoryTree,
)

# 3 Load model
embedder = ImageEmbedder(backbone="swav-imagenet", embedding_dim=128)

# 4 Generate embeddings
filepaths = dataset.values("filepath")
embeddings = np.stack(embedder.predict(filepaths))

# 5 Visualize in FiftyOne App
results = fob.compute_visualization(dataset, embeddings=embeddings)

session = fo.launch_app(dataset)

plot = results.visualize(labels="ground_truth.label")
plot.show()

Before submitting

  • (This PR was discussed face-to-face) Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests? [not needed for typos/docs]
  • Did you verify new and existing tests pass locally with your changes?
  • If you made a notable change (that affects users), did you update the CHANGELOG?

PR review

  • Is this pull request ready for review? (if not, please submit in draft mode)

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@ehofesmann
Copy link
Contributor Author

Alright! This PR is ready for you guys to take a final pass over @ethanwharris @tchaton

There is one thing I need your assistance with. In the integration docs I added a video and an image. They would likely be best hosted by you, but I am just unsure what your process is for hosting documentation media.

They currently link here:
https://user-images.githubusercontent.com/21222883/121972505-45114b00-cd49-11eb-9ef5-9a69fd90bf59.png
https://voxel51.com/images/fiftyone_long_sizzle.mp4

Copy link
Collaborator

@ethanwharris ethanwharris left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ehofesmann Awesome! Just a few comments, I've added your assets to our S3 and recommended changes with the new links 😃 Could you also update with master and get the tests to pass? Then will be ready to merge

CHANGELOG.md Outdated Show resolved Hide resolved
docs/source/integrations/fiftyone.rst Outdated Show resolved Hide resolved
docs/source/integrations/fiftyone.rst Outdated Show resolved Hide resolved
flash/video/classification/data.py Outdated Show resolved Hide resolved
flash_examples/finetuning/semantic_segmentation.py Outdated Show resolved Hide resolved
test.py Outdated Show resolved Hide resolved
docs/source/integrations/fiftyone.rst Outdated Show resolved Hide resolved
docs/source/integrations/fiftyone.rst Outdated Show resolved Hide resolved
docs/source/integrations/fiftyone.rst Outdated Show resolved Hide resolved
@mergify mergify bot removed the has conflicts label Jun 15, 2021
@ethanwharris
Copy link
Collaborator

Hey @ehofesmann, just an FYI, you're getting some test failures due to an unrelated dependency issue. I'm working on a fix, I'll let you know when it's in master 😃

@ehofesmann
Copy link
Contributor Author

Hey @ehofesmann, just an FYI, you're getting some test failures due to an unrelated dependency issue. I'm working on a fix, I'll let you know when it's in master

Gotcha! I was racking my brain trying to figure out what I broke. I'll wait until it's in master 👍

@ethanwharris
Copy link
Collaborator

@ehofesmann If you update to master the CI should be working properly again 😃

Copy link
Collaborator

@ethanwharris ethanwharris left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests are passing, LGTM 😃

@ehofesmann
Copy link
Contributor Author

Tests are passing, LGTM

Awesome! We're just taking a final pass over everything. I'll let you know once we're done and then it's all yours.

@ehofesmann
Copy link
Contributor Author

@ethanwharris and @tchaton

It would be great to make it easier for users to find these visualization capabilities. What do you guys think about me merging in this section to the README?

voxel51#4

The proposed content:

Visualization

Predictions from image and video tasks can be visualized through our integration with FiftyOne
allowing you to better understand and analyze how your model is performing.

from flash.core.data.utils import download_data
from flash.core.integrations.fiftyone import visualize
from flash.image import ObjectDetector
from flash.image.detection.serialization import FiftyOneDetectionLabels

# 1. Download the data
# Dataset Credit: https://www.kaggle.com/ultralytics/coco128
download_data(
    "https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip",
    "data/",
)

# 2. Load the model from a checkpoint and use the FiftyOne serializer
model = ObjectDetector.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/object_detection_model.pt"
)
model.serializer = FiftyOneDetectionLabels()

# 3. Detect the object on the images
filepaths = [
    "data/coco128/images/train2017/000000000025.jpg",
    "data/coco128/images/train2017/000000000520.jpg",
    "data/coco128/images/train2017/000000000532.jpg",
]
predictions = model.predict(filepaths)

# 4. Visualize predictions
session = visualize(predictions, filepaths=filepaths)

root = Path(__file__).parent.parent.parent


@mock.patch.dict(os.environ, {"FLASH_TESTING": "1"})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the same file as tests/core/test_integrations.py

@tchaton tchaton merged commit 25d6633 into Lightning-Universe:master Jun 15, 2021
@tchaton
Copy link
Contributor

tchaton commented Jun 15, 2021

Hey @ehofesmann , awesome work with this integration.

@ethanwharris ethanwharris mentioned this pull request Jun 16, 2021
8 tasks
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants