Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Docstring for object detection data from fityone
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Jan 25, 2022
1 parent ccbdf28 commit 1d3cebc
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 5 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where `from_fiftyone` classmethods did not work correctly with a `predict_dataset` ([#1136](https://github.com/PyTorchLightning/lightning-flash/pull/1136))

- Fixed a bug where the `labels` property would return `None` when using `ObjectDetectionData.from_fiftyone` ([#1136](https://github.com/PyTorchLightning/lightning-flash/pull/1136))

### Removed

- Removed the `Seq2SeqData` base class (use `TranslationData` or `SummarizationData` directly) ([#1128](https://github.com/PyTorchLightning/lightning-flash/pull/1128))
Expand Down
95 changes: 90 additions & 5 deletions flash/image/detection/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,26 +598,111 @@ def from_fiftyone(
val_dataset: Optional[SampleCollection] = None,
test_dataset: Optional[SampleCollection] = None,
predict_dataset: Optional[SampleCollection] = None,
label_field: str = "ground_truth",
iscrowd: str = "iscrowd",
train_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
val_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
test_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
label_field: str = "ground_truth",
iscrowd: str = "iscrowd",
input_cls: Type[Input] = ObjectDetectionFiftyOneInput,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any,
) -> "ObjectDetectionData":
"""Load the :class:`~flash.image.detection.data.ObjectDetectionData` from FiftyOne ``SampleCollection``
objects.
Targets will be extracted from the ``label_field`` in the ``SampleCollection`` objects.
To learn how to customize the transforms applied for each stage, read our
:ref:`customizing transforms guide <customizing_transforms>`.
Args:
train_dataset: The ``SampleCollection`` to use when training.
val_dataset: The ``SampleCollection`` to use when validating.
test_dataset: The ``SampleCollection`` to use when testing.
predict_dataset: The ``SampleCollection`` to use when predicting.
label_field: The field in the ``SampleCollection`` objects containing the targets.
iscrowd: The field in the ``SampleCollection`` objects containing the ``iscrowd`` annotation (if required).
train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training.
val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating.
test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing.
predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when
predicting.
input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data.
transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms.
data_module_kwargs: Additional keyword arguments to provide to the
:class:`~flash.core.data.data_module.DataModule` constructor.
Returns:
The constructed :class:`~flash.image.detection.data.ObjectDetectionData`.
Examples
________
.. testsetup::
>>> import numpy as np
>>> from PIL import Image
>>> rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8"))
>>> _ = [rand_image.save(f"image_{i}.png") for i in range(1, 4)]
>>> _ = [rand_image.save(f"predict_image_{i}.png") for i in range(1, 4)]
.. doctest::
>>> import numpy as np
>>> import fiftyone as fo
>>> from flash import Trainer
>>> from flash.image import ObjectDetector, ObjectDetectionData
>>> train_dataset = fo.Dataset.from_images(
... ["image_1.png", "image_2.png", "image_3.png"]
... ) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
33%...
>>> samples = [train_dataset[filepath] for filepath in train_dataset.values("filepath")]
>>> for sample, label, bounding_box in zip(
... samples,
... ["cat", "dog", "cat"],
... [[0.1, 0.2, 0.15, 0.3], [0.2, 0.3, 0.3, 0.4], [0.1, 0.2, 0.15, 0.45]],
... ):
... sample["ground_truth"] = fo.Detections(
... detections=[fo.Detection(label=label, bounding_box=bounding_box)],
... )
... sample.save()
...
>>> predict_dataset = fo.Dataset.from_images(
... ["predict_image_1.png", "predict_image_2.png", "predict_image_3.png"]
... ) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
33%...
>>> datamodule = ObjectDetectionData.from_fiftyone(
... train_dataset=train_dataset,
... predict_dataset=predict_dataset,
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... ) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Computing...
>>> datamodule.num_classes
3
>>> datamodule.labels
['background', 'cat', 'dog']
>>> model = ObjectDetector(num_classes=datamodule.num_classes)
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Training...
>>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Predicting...
.. testcleanup::
>>> import os
>>> _ = [os.remove(f"image_{i}.png") for i in range(1, 4)]
>>> _ = [os.remove(f"predict_image_{i}.png") for i in range(1, 4)]
"""

ds_kw = dict(data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs)

return cls(
input_cls(RunningStage.TRAINING, train_dataset, label_field, iscrowd, transform=train_transform, **ds_kw),
input_cls(RunningStage.VALIDATING, val_dataset, label_field, iscrowd, transform=val_transform, **ds_kw),
input_cls(RunningStage.TESTING, test_dataset, label_field, iscrowd, transform=test_transform, **ds_kw),
input_cls(
RunningStage.PREDICTING, predict_dataset, label_field, iscrowd, transform=predict_transform, **ds_kw
),
input_cls(RunningStage.PREDICTING, predict_dataset, transform=predict_transform, **ds_kw),
**data_module_kwargs,
)

Expand Down
3 changes: 3 additions & 0 deletions flash/image/detection/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from typing import Any, Dict, Hashable, Sequence

from flash.core.data.io.classification_input import ClassificationState
from flash.core.data.io.input import DataKeys
from flash.core.integrations.fiftyone.utils import FiftyOneLabelUtilities
from flash.core.integrations.icevision.data import IceVisionInput
Expand Down Expand Up @@ -110,6 +111,8 @@ def load_data(
classes = label_utilities.get_classes(sample_collection)
class_map = ClassMap(classes)
self.num_classes = len(class_map)
self.labels = [class_map.get_by_id(i) for i in range(self.num_classes)]
self.set_state(ClassificationState(self.labels))

parser = FiftyOneParser(sample_collection, class_map, label_field, iscrowd)
records = parser.parse(data_splitter=SingleSplitSplitter())
Expand Down
3 changes: 3 additions & 0 deletions flash/image/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,9 @@ def from_fiftyone(
predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when
predicting.
input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data.
num_classes: The number of segmentation classes.
labels_map: An optional mapping from class to RGB tuple indicating the colour to use when visualizing masks.
If not provided, a random mapping will be used.
transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms.
data_module_kwargs: Additional keyword arguments to provide to the
:class:`~flash.core.data.data_module.DataModule` constructor.
Expand Down

0 comments on commit 1d3cebc

Please sign in to comment.