Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Docstrings for from_fiftyone (#1136)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Jan 26, 2022
1 parent fd1a2e4 commit 76d6816
Show file tree
Hide file tree
Showing 10 changed files with 555 additions and 61 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where loading Seq2Seq data for prediction would not work if the target field was not present ([#1128](https://github.com/PyTorchLightning/lightning-flash/pull/1128))

- Fixed a bug where `from_fiftyone` classmethods did not work correctly with a `predict_dataset` ([#1136](https://github.com/PyTorchLightning/lightning-flash/pull/1136))

- Fixed a bug where the `labels` property would return `None` when using `ObjectDetectionData.from_fiftyone` ([#1136](https://github.com/PyTorchLightning/lightning-flash/pull/1136))

### Removed

- Removed the `Seq2SeqData` base class (use `TranslationData` or `SummarizationData` directly) ([#1128](https://github.com/PyTorchLightning/lightning-flash/pull/1128))
Expand Down
4 changes: 4 additions & 0 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,9 @@ def _import_module(self):

# Global variables used for testing purposes (e.g. to only run doctests in the correct CI job)
_IMAGE_TESTING = _IMAGE_AVAILABLE
_IMAGE_EXTRAS_TESTING = False # Not for normal use
_VIDEO_TESTING = _VIDEO_AVAILABLE
_VIDEO_EXTRAS_TESTING = False # Not for normal use
_TABULAR_TESTING = _TABULAR_AVAILABLE
_TEXT_TESTING = _TEXT_AVAILABLE
_SERVE_TESTING = _SERVE_AVAILABLE
Expand All @@ -288,7 +290,9 @@ def _import_module(self):
if "FLASH_TEST_TOPIC" in os.environ:
topic = os.environ["FLASH_TEST_TOPIC"]
_IMAGE_TESTING = topic == "image"
_IMAGE_EXTRAS_TESTING = topic == "image,image_extras"
_VIDEO_TESTING = topic == "video"
_VIDEO_EXTRAS_TESTING = topic == "video,video_extras"
_TABULAR_TESTING = topic == "tabular"
_TEXT_TESTING = topic == "text"
_SERVE_TESTING = topic == "serve"
Expand Down
119 changes: 110 additions & 9 deletions flash/image/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@
from flash.core.data.utilities.paths import PATH_TYPE
from flash.core.integrations.labelstudio.input import _parse_labelstudio_arguments, LabelStudioImageClassificationInput
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_TESTING, _MATPLOTLIB_AVAILABLE, Image, requires
from flash.core.utilities.imports import (
_FIFTYONE_AVAILABLE,
_IMAGE_EXTRAS_TESTING,
_IMAGE_TESTING,
_MATPLOTLIB_AVAILABLE,
Image,
requires,
)
from flash.core.utilities.stages import RunningStage
from flash.image.classification.input import (
ImageClassificationCSVInput,
Expand All @@ -45,15 +52,26 @@
else:
SampleCollection = None

# Skip doctests if requirements aren't available
if not _IMAGE_TESTING:
__doctest_skip__ = ["ImageClassificationData", "ImageClassificationData.*"]

if _MATPLOTLIB_AVAILABLE:
import matplotlib.pyplot as plt
else:
plt = None

# Skip doctests if requirements aren't available
__doctest_skip__ = []
if not _IMAGE_TESTING:
__doctest_skip__ += [
"ImageClassificationData",
"ImageClassificationData.from_files",
"ImageClassificationData.from_folders",
"ImageClassificationData.from_numpy",
"ImageClassificationData.from_tensors",
"ImageClassificationData.from_data_frame",
"ImageClassificationData.from_csv",
]
if not _IMAGE_EXTRAS_TESTING:
__doctest_skip__ += ["ImageClassificationData.from_fiftyone"]


class ImageClassificationData(DataModule):
"""The ``ImageClassificationData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of
Expand Down Expand Up @@ -784,18 +802,101 @@ def from_fiftyone(
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs,
) -> "ImageClassificationData":
"""Load the :class:`~flash.image.classification.data.ImageClassificationData` from FiftyOne
``SampleCollection`` objects.
The supported file extensions are: ``.jpg``, ``.jpeg``, ``.png``, ``.ppm``, ``.bmp``, ``.pgm``, ``.tif``,
``.tiff``, ``.webp``, and ``.npy``.
The targets will be extracted from the ``label_field`` in the ``SampleCollection`` objects and can be in any
of our :ref:`supported classification target formats <formatting_classification_targets>`.
To learn how to customize the transforms applied for each stage, read our
:ref:`customizing transforms guide <customizing_transforms>`.
Args:
train_dataset: The ``SampleCollection`` to use when training.
val_dataset: The ``SampleCollection`` to use when validating.
test_dataset: The ``SampleCollection`` to use when testing.
predict_dataset: The ``SampleCollection`` to use when predicting.
label_field: The field in the ``SampleCollection`` objects containing the targets.
train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training.
val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating.
test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing.
predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when
predicting.
input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data.
transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms.
data_module_kwargs: Additional keyword arguments to provide to the
:class:`~flash.core.data.data_module.DataModule` constructor.
Returns:
The constructed :class:`~flash.image.classification.data.ImageClassificationData`.
Examples
________
.. testsetup::
>>> from PIL import Image
>>> rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8"))
>>> _ = [rand_image.save(f"image_{i}.png") for i in range(1, 4)]
>>> _ = [rand_image.save(f"predict_image_{i}.png") for i in range(1, 4)]
.. doctest::
>>> import fiftyone as fo
>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> train_dataset = fo.Dataset.from_images(
... ["image_1.png", "image_2.png", "image_3.png"]
... ) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
<BLANKLINE>
...
>>> samples = [train_dataset[filepath] for filepath in train_dataset.values("filepath")]
>>> for sample, label in zip(samples, ["cat", "dog", "cat"]):
... sample["ground_truth"] = fo.Classification(label=label)
... sample.save()
...
>>> predict_dataset = fo.Dataset.from_images(
... ["predict_image_1.png", "predict_image_2.png", "predict_image_3.png"]
... ) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
<BLANKLINE>
...
>>> datamodule = ImageClassificationData.from_fiftyone(
... train_dataset=train_dataset,
... predict_dataset=predict_dataset,
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels
['cat', 'dog']
>>> model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Training...
>>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Predicting...
.. testcleanup::
>>> import os
>>> _ = [os.remove(f"image_{i}.png") for i in range(1, 4)]
>>> _ = [os.remove(f"predict_image_{i}.png") for i in range(1, 4)]
"""

ds_kw = dict(
data_pipeline_state=DataPipelineState(),
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
label_field=label_field,
)

return cls(
input_cls(RunningStage.TRAINING, train_dataset, transform=train_transform, **ds_kw),
input_cls(RunningStage.VALIDATING, val_dataset, transform=val_transform, **ds_kw),
input_cls(RunningStage.TESTING, test_dataset, transform=test_transform, **ds_kw),
input_cls(
RunningStage.TRAINING, train_dataset, transform=train_transform, label_field=label_field, **ds_kw
),
input_cls(RunningStage.VALIDATING, val_dataset, transform=val_transform, label_field=label_field, **ds_kw),
input_cls(RunningStage.TESTING, test_dataset, transform=test_transform, label_field=label_field, **ds_kw),
input_cls(RunningStage.PREDICTING, predict_dataset, transform=predict_transform, **ds_kw),
**data_module_kwargs,
)
Expand Down
3 changes: 1 addition & 2 deletions flash/image/classification/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,8 @@ def load_data(self, sample_collection: SampleCollection, label_field: str = "gro

return super().load_data(filepaths, targets)

@staticmethod
@requires("fiftyone")
def predict_load_data(data: SampleCollection) -> List[Dict[str, Any]]:
def predict_load_data(self, data: SampleCollection) -> List[Dict[str, Any]]:
return super().load_data(data.values("filepath"))


Expand Down
102 changes: 95 additions & 7 deletions flash/image/detection/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from flash.core.data.utilities.sort import sorted_alphanumeric
from flash.core.integrations.icevision.data import IceVisionInput
from flash.core.integrations.icevision.transforms import IceVisionInputTransform
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _ICEVISION_AVAILABLE, requires
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _ICEVISION_AVAILABLE, _IMAGE_EXTRAS_TESTING, requires
from flash.core.utilities.stages import RunningStage
from flash.core.utilities.types import INPUT_TRANSFORM_TYPE
from flash.image.detection.input import ObjectDetectionFiftyOneInput
Expand All @@ -40,7 +40,7 @@
Parser = object

# Skip doctests if requirements aren't available
if not _ICEVISION_AVAILABLE:
if not _IMAGE_EXTRAS_TESTING:
__doctest_skip__ = ["ObjectDetectionData", "ObjectDetectionData.*"]


Expand Down Expand Up @@ -598,26 +598,114 @@ def from_fiftyone(
val_dataset: Optional[SampleCollection] = None,
test_dataset: Optional[SampleCollection] = None,
predict_dataset: Optional[SampleCollection] = None,
label_field: str = "ground_truth",
iscrowd: str = "iscrowd",
train_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
val_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
test_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
label_field: str = "ground_truth",
iscrowd: str = "iscrowd",
input_cls: Type[Input] = ObjectDetectionFiftyOneInput,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any,
) -> "ObjectDetectionData":
"""Load the :class:`~flash.image.detection.data.ObjectDetectionData` from FiftyOne ``SampleCollection``
objects.
Targets will be extracted from the ``label_field`` in the ``SampleCollection`` objects.
To learn how to customize the transforms applied for each stage, read our
:ref:`customizing transforms guide <customizing_transforms>`.
Args:
train_dataset: The ``SampleCollection`` to use when training.
val_dataset: The ``SampleCollection`` to use when validating.
test_dataset: The ``SampleCollection`` to use when testing.
predict_dataset: The ``SampleCollection`` to use when predicting.
label_field: The field in the ``SampleCollection`` objects containing the targets.
iscrowd: The field in the ``SampleCollection`` objects containing the ``iscrowd`` annotation (if required).
train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training.
val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating.
test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing.
predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when
predicting.
input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data.
transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms.
data_module_kwargs: Additional keyword arguments to provide to the
:class:`~flash.core.data.data_module.DataModule` constructor.
Returns:
The constructed :class:`~flash.image.detection.data.ObjectDetectionData`.
Examples
________
.. testsetup::
>>> import numpy as np
>>> from PIL import Image
>>> rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8"))
>>> _ = [rand_image.save(f"image_{i}.png") for i in range(1, 4)]
>>> _ = [rand_image.save(f"predict_image_{i}.png") for i in range(1, 4)]
.. doctest::
>>> import numpy as np
>>> import fiftyone as fo
>>> from flash import Trainer
>>> from flash.image import ObjectDetector, ObjectDetectionData
>>> train_dataset = fo.Dataset.from_images(
... ["image_1.png", "image_2.png", "image_3.png"]
... ) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
<BLANKLINE>
...
>>> samples = [train_dataset[filepath] for filepath in train_dataset.values("filepath")]
>>> for sample, label, bounding_box in zip(
... samples,
... ["cat", "dog", "cat"],
... [[0.1, 0.2, 0.15, 0.3], [0.2, 0.3, 0.3, 0.4], [0.1, 0.2, 0.15, 0.45]],
... ):
... sample["ground_truth"] = fo.Detections(
... detections=[fo.Detection(label=label, bounding_box=bounding_box)],
... )
... sample.save()
...
>>> predict_dataset = fo.Dataset.from_images(
... ["predict_image_1.png", "predict_image_2.png", "predict_image_3.png"]
... ) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
<BLANKLINE>
...
>>> datamodule = ObjectDetectionData.from_fiftyone(
... train_dataset=train_dataset,
... predict_dataset=predict_dataset,
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... ) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
<BLANKLINE>
...
>>> datamodule.num_classes
3
>>> datamodule.labels
['background', 'cat', 'dog']
>>> model = ObjectDetector(num_classes=datamodule.num_classes)
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Training...
>>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Predicting...
.. testcleanup::
>>> import os
>>> _ = [os.remove(f"image_{i}.png") for i in range(1, 4)]
>>> _ = [os.remove(f"predict_image_{i}.png") for i in range(1, 4)]
"""

ds_kw = dict(data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs)

return cls(
input_cls(RunningStage.TRAINING, train_dataset, label_field, iscrowd, transform=train_transform, **ds_kw),
input_cls(RunningStage.VALIDATING, val_dataset, label_field, iscrowd, transform=val_transform, **ds_kw),
input_cls(RunningStage.TESTING, test_dataset, label_field, iscrowd, transform=test_transform, **ds_kw),
input_cls(
RunningStage.PREDICTING, predict_dataset, label_field, iscrowd, transform=predict_transform, **ds_kw
),
input_cls(RunningStage.PREDICTING, predict_dataset, transform=predict_transform, **ds_kw),
**data_module_kwargs,
)

Expand Down
3 changes: 3 additions & 0 deletions flash/image/detection/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from typing import Any, Dict, Hashable, Sequence

from flash.core.data.io.classification_input import ClassificationState
from flash.core.data.io.input import DataKeys
from flash.core.integrations.fiftyone.utils import FiftyOneLabelUtilities
from flash.core.integrations.icevision.data import IceVisionInput
Expand Down Expand Up @@ -110,6 +111,8 @@ def load_data(
classes = label_utilities.get_classes(sample_collection)
class_map = ClassMap(classes)
self.num_classes = len(class_map)
self.labels = [class_map.get_by_id(i) for i in range(self.num_classes)]
self.set_state(ClassificationState(self.labels))

parser = FiftyOneParser(sample_collection, class_map, label_field, iscrowd)
records = parser.parse(data_splitter=SingleSplitSplitter())
Expand Down
Loading

0 comments on commit 76d6816

Please sign in to comment.