diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 21ac8fbd45..254234c8fd 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -137,7 +137,7 @@ jobs: run: | sudo apt-get install libsndfile1 pip install matplotlib - pip install '.[image]' --pre --upgrade + pip install '.[audio,image]' --pre --upgrade - name: Cache datasets uses: actions/cache@v2 diff --git a/.gitignore b/.gitignore index 8f9c8b29a2..9ab9838b44 100644 --- a/.gitignore +++ b/.gitignore @@ -161,7 +161,7 @@ jigsaw_toxic_comments flash_examples/serve/tabular_classification/data logs/cache/* flash_examples/data -flash_examples/cli/*/data +flash_examples/checkpoints timit/ urban8k_images/ __MACOSX diff --git a/CHANGELOG.md b/CHANGELOG.md index a27635e797..7674cd349c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,6 +40,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added option to pass a `resolver` to the `from_csv` and `from_pandas` methods of `ImageClassificationData`, which is used to resolve filenames given IDs ([#651](https://github.com/PyTorchLightning/lightning-flash/pull/651)) +- Added integration with IceVision for the `ObjectDetector` ([#608](https://github.com/PyTorchLightning/lightning-flash/pull/608)) + +- Added keypoint detection task ([#608](https://github.com/PyTorchLightning/lightning-flash/pull/608)) + +- Added instance segmentation task ([#608](https://github.com/PyTorchLightning/lightning-flash/pull/608)) + ### Changed - Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560)) @@ -48,6 +54,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed the behaviour of the `sampler` argument of the `DataModule` to take a `Sampler` type rather than instantiated object ([#651](https://github.com/PyTorchLightning/lightning-flash/pull/651)) +- Changed arguments to `ObjectDetector`, use `head` instead of `model` and append `_fpn` to the backbone name instead of the `fpn` argument ([#608](https://github.com/PyTorchLightning/lightning-flash/pull/608)) + ### Fixed - Fixed a bug where serve sanity checking would not be triggered using the latest PyTorchLightning version ([#493](https://github.com/PyTorchLightning/lightning-flash/pull/493)) diff --git a/docs/source/api/core.rst b/docs/source/api/core.rst index 5b8674c37a..1b80d0e2c1 100644 --- a/docs/source/api/core.rst +++ b/docs/source/api/core.rst @@ -7,6 +7,17 @@ flash.core :local: :backlinks: top +flash.core.adapter +__________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~flash.core.adapter.Adapter + ~flash.core.adapter.AdapterTask + flash.core.classification _________________________ @@ -56,6 +67,8 @@ ________________ ~flash.core.model.BenchmarkConvergenceCI ~flash.core.model.CheckDependenciesMeta + ~flash.core.model.ModuleWrapperBase + ~flash.core.model.DatasetProcessor ~flash.core.model.Task flash.core.registry diff --git a/docs/source/api/image.rst b/docs/source/api/image.rst index 0877655db8..34d44164a8 100644 --- a/docs/source/api/image.rst +++ b/docs/source/api/image.rst @@ -31,8 +31,8 @@ ______________ classification.transforms.default_transforms classification.transforms.train_default_transforms -Detection -_________ +Object Detection +________________ .. autosummary:: :toctree: generated/ @@ -42,21 +42,37 @@ _________ ~detection.model.ObjectDetector ~detection.data.ObjectDetectionData - detection.data.COCODataSource + detection.data.FiftyOneParser detection.data.ObjectDetectionFiftyOneDataSource detection.data.ObjectDetectionPreprocess - detection.finetuning.ObjectDetectionFineTuning - detection.model.ObjectDetector detection.serialization.DetectionLabels detection.serialization.FiftyOneDetectionLabels +Keypoint Detection +__________________ + .. autosummary:: :toctree: generated/ :nosignatures: - :template: + :template: classtemplate.rst + + ~keypoint_detection.model.KeypointDetector + ~keypoint_detection.data.KeypointDetectionData + + keypoint_detection.data.KeypointDetectionPreprocess + +Instance Segmentation +_____________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~instance_segmentation.model.InstanceSegmentation + ~instance_segmentation.data.InstanceSegmentationData - detection.transforms.collate - detection.transforms.default_transforms + instance_segmentation.data.InstanceSegmentationPreprocess Embedding _________ diff --git a/docs/source/index.rst b/docs/source/index.rst index 05293b3d76..95c7e2933f 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -37,6 +37,8 @@ Lightning Flash reference/image_classification_multi_label reference/image_embedder reference/object_detection + reference/keypoint_detection + reference/instance_segmentation reference/semantic_segmentation reference/style_transfer reference/video_classification diff --git a/docs/source/reference/instance_segmentation.rst b/docs/source/reference/instance_segmentation.rst new file mode 100644 index 0000000000..75408dc3fa --- /dev/null +++ b/docs/source/reference/instance_segmentation.rst @@ -0,0 +1,31 @@ + +.. _instance_segmentation: + +##################### +Instance Segmentation +##################### + +******** +The Task +******** + +Instance segmentation is the task of segmenting objects images and determining their associated classes. + +The :class:`~flash.image.instance_segmentation.model.InstanceSegmentation` and :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData` classes internally rely on `IceVision `_. + +------ + +******* +Example +******* + +Let's look at instance segmentation with `The Oxford-IIIT Pet Dataset `_ from `IceData `_. +Once we've downloaded the data, we can create the :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData`. +We select a ``mask_rcnn`` with a ``resnet18_fpn`` backbone to use for our :class:`~flash.image.instance_segmentation.model.InstanceSegmentation` and fine-tune on the pets data. +We then use the trained :class:`~flash.image.instance_segmentation.model.InstanceSegmentation` for inference. +Finally, we save the model. +Here's the full example: + +.. literalinclude:: ../../../flash_examples/instance_segmentation.py + :language: python + :lines: 14- diff --git a/docs/source/reference/keypoint_detection.rst b/docs/source/reference/keypoint_detection.rst new file mode 100644 index 0000000000..76fd0dcdf5 --- /dev/null +++ b/docs/source/reference/keypoint_detection.rst @@ -0,0 +1,31 @@ + +.. _keypoint_detection: + +################## +Keypoint Detection +################## + +******** +The Task +******** + +Keypoint detection is the task of identifying keypoints in images and their associated classes. + +The :class:`~flash.image.keypoint_detection.model.KeypointDetector` and :class:`~flash.image.keypoint_detection.data.KeypointDetectionData` classes internally rely on `IceVision `_. + +------ + +******* +Example +******* + +Let's look at keypoint detection with `BIWI Sample Keypoints (center of face) `_ from `IceData `_. +Once we've downloaded the data, we can create the :class:`~flash.image.keypoint_detection.data.KeypointDetectionData`. +We select a ``keypoint_rcnn`` with a ``resnet18_fpn`` backbone to use for our :class:`~flash.image.keypoint_detection.model.KeypointDetector` and fine-tune on the BIWI data. +We then use the trained :class:`~flash.image.keypoint_detection.model.KeypointDetector` for inference. +Finally, we save the model. +Here's the full example: + +.. literalinclude:: ../../../flash_examples/keypoint_detection.py + :language: python + :lines: 14- diff --git a/docs/source/reference/object_detection.rst b/docs/source/reference/object_detection.rst index d0e2baf74d..0bf34c07c3 100644 --- a/docs/source/reference/object_detection.rst +++ b/docs/source/reference/object_detection.rst @@ -11,6 +11,8 @@ The Task Object detection is the task of identifying objects in images and their associated classes and bounding boxes. +The :class:`~flash.image.detection.model.ObjectDetector` and :class:`~flash.image.detection.data.ObjectDetectionData` classes internally rely on `IceVision `_. + ------ ******* diff --git a/flash/__about__.py b/flash/__about__.py index e57715c058..eab8629bc9 100644 --- a/flash/__about__.py +++ b/flash/__about__.py @@ -1,4 +1,4 @@ -__version__ = "0.4.1dev" +__version__ = "0.5.0dev" __author__ = "PyTorchLightning et al." __author_email__ = "name@pytorchlightning.ai" __license__ = "Apache-2.0" diff --git a/flash/core/adapter.py b/flash/core/adapter.py new file mode 100644 index 0000000000..c7557b1977 --- /dev/null +++ b/flash/core/adapter.py @@ -0,0 +1,162 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import abstractmethod +from typing import Any, Callable, Optional + +from torch import nn +from torch.utils.data import DataLoader, Sampler + +import flash +from flash.core.data.auto_dataset import BaseAutoDataset +from flash.core.model import DatasetProcessor, ModuleWrapperBase, Task + + +class Adapter(DatasetProcessor, ModuleWrapperBase, nn.Module): + """The ``Adapter`` is a lightweight interface that can be used to encapsulate the logic from a particular + provider within a :class:`~flash.core.model.Task`.""" + + @classmethod + @abstractmethod + def from_task(cls, task: "flash.Task", **kwargs) -> "Adapter": + """Instantiate the adapter from the given :class:`~flash.core.model.Task`. + + This includes resolution / creation of backbones / heads and any other provider specific options. + """ + + def forward(self, x: Any) -> Any: + pass + + def training_step(self, batch: Any, batch_idx: int) -> Any: + pass + + def validation_step(self, batch: Any, batch_idx: int) -> None: + pass + + def test_step(self, batch: Any, batch_idx: int) -> None: + pass + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + pass + + def training_epoch_end(self, outputs) -> None: + pass + + def validation_epoch_end(self, outputs) -> None: + pass + + def test_epoch_end(self, outputs) -> None: + pass + + +class AdapterTask(Task): + """The ``AdapterTask`` is a :class:`~flash.core.model.Task` which wraps an :class:`~flash.core.adapter.Adapter` + and forwards all of the hooks. + + Args: + adapter: The :class:`~flash.core.adapter.Adapter` to wrap. + kwargs: Keyword arguments to be passed to the base :class:`~flash.core.model.Task`. + """ + + def __init__(self, adapter: Adapter, **kwargs): + super().__init__(**kwargs) + + self.adapter = adapter + + @property + def backbone(self) -> nn.Module: + return self.adapter.backbone + + def forward(self, x: Any) -> Any: + return self.adapter.forward(x) + + def training_step(self, batch: Any, batch_idx: int) -> Any: + return self.adapter.training_step(batch, batch_idx) + + def validation_step(self, batch: Any, batch_idx: int) -> None: + return self.adapter.validation_step(batch, batch_idx) + + def test_step(self, batch: Any, batch_idx: int) -> None: + return self.adapter.test_step(batch, batch_idx) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + return self.adapter.predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) + + def training_epoch_end(self, outputs) -> None: + return self.adapter.training_epoch_end(outputs) + + def validation_epoch_end(self, outputs) -> None: + return self.adapter.validation_epoch_end(outputs) + + def test_epoch_end(self, outputs) -> None: + return self.adapter.test_epoch_end(outputs) + + def process_train_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = True, + sampler: Optional[Sampler] = None, + ) -> DataLoader: + return self.adapter.process_train_dataset( + dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler + ) + + def process_val_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = False, + sampler: Optional[Sampler] = None, + ) -> DataLoader: + return self.adapter.process_val_dataset( + dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler + ) + + def process_test_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = False, + sampler: Optional[Sampler] = None, + ) -> DataLoader: + return self.adapter.process_test_dataset( + dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler + ) + + def process_predict_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int = 1, + num_workers: int = 0, + pin_memory: bool = False, + collate_fn: Callable = lambda x: x, + shuffle: bool = False, + drop_last: bool = True, + sampler: Optional[Sampler] = None, + ) -> DataLoader: + return self.adapter.process_predict_dataset( + dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler + ) diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index 02ef13e86e..d1ebac04a8 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -377,7 +377,7 @@ def _predict_dataloader(self) -> DataLoader: pin_memory = True if isinstance(getattr(self, "trainer", None), pl.Trainer): - return self.trainer.lightning_module.process_test_dataset( + return self.trainer.lightning_module.process_predict_dataset( predict_ds, batch_size=batch_size, num_workers=self.num_workers, diff --git a/flash/core/data/data_pipeline.py b/flash/core/data/data_pipeline.py index 4c707ef8c2..d00618ff05 100644 --- a/flash/core/data/data_pipeline.py +++ b/flash/core/data/data_pipeline.py @@ -164,8 +164,10 @@ def _identity(samples: Sequence[Any]) -> Sequence[Any]: def deserialize_processor(self) -> _DeserializeProcessor: return self._create_collate_preprocessors(RunningStage.PREDICTING)[0] - def worker_preprocessor(self, running_stage: RunningStage, is_serving: bool = False) -> _Preprocessor: - return self._create_collate_preprocessors(running_stage, is_serving=is_serving)[1] + def worker_preprocessor( + self, running_stage: RunningStage, collate_fn: Optional[Callable] = None, is_serving: bool = False + ) -> _Preprocessor: + return self._create_collate_preprocessors(running_stage, collate_fn=collate_fn, is_serving=is_serving)[1] def device_preprocessor(self, running_stage: RunningStage) -> _Preprocessor: return self._create_collate_preprocessors(running_stage)[2] diff --git a/flash/core/integrations/icevision/__init__.py b/flash/core/integrations/icevision/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flash/core/integrations/icevision/adapter.py b/flash/core/integrations/icevision/adapter.py new file mode 100644 index 0000000000..af95da9a52 --- /dev/null +++ b/flash/core/integrations/icevision/adapter.py @@ -0,0 +1,202 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +from typing import Any, Callable, Dict, List, Optional + +from torch.utils.data import DataLoader, Sampler + +from flash.core.adapter import Adapter +from flash.core.data.auto_dataset import BaseAutoDataset +from flash.core.data.data_source import DefaultDataKeys +from flash.core.integrations.icevision.transforms import to_icevision_record +from flash.core.model import Task +from flash.core.utilities.imports import _ICEVISION_AVAILABLE +from flash.core.utilities.url_error import catch_url_error + +if _ICEVISION_AVAILABLE: + from icevision.metrics import COCOMetric + from icevision.metrics import Metric as IceVisionMetric +else: + COCOMetric = object + + +class SimpleCOCOMetric(COCOMetric): + def finalize(self) -> Dict[str, float]: + logs = super().finalize() + return { + "Precision (IoU=0.50:0.95,area=all)": logs["AP (IoU=0.50:0.95) area=all"], + "Recall (IoU=0.50:0.95,area=all,maxDets=100)": logs["AR (IoU=0.50:0.95) area=all maxDets=100"], + } + + +class IceVisionAdapter(Adapter): + """The ``IceVisionAdapter`` is an :class:`~flash.core.adapter.Adapter` for integrating with IceVision.""" + + required_extras: str = "image" + + def __init__(self, model_type, model, icevision_adapter, backbone): + super().__init__() + + self.model_type = model_type + self.model = model + self.icevision_adapter = icevision_adapter + self.backbone = backbone + + @classmethod + @catch_url_error + def from_task( + cls, + task: Task, + num_classes: int, + backbone: str, + head: str, + pretrained: bool = True, + metrics: Optional["IceVisionMetric"] = None, + image_size: Optional = None, + **kwargs, + ) -> Adapter: + metadata = task.heads.get(head, with_metadata=True) + backbones = metadata["metadata"]["backbones"] + backbone_config = backbones.get(backbone)(pretrained) + model_type, model, icevision_adapter, backbone = metadata["fn"]( + backbone_config, + num_classes, + image_size=image_size, + **kwargs, + ) + icevision_adapter = icevision_adapter(model=model, metrics=metrics) + return cls(model_type, model, icevision_adapter, backbone) + + @staticmethod + def _collate_fn(collate_fn, samples, metadata: Optional[List[Dict[str, Any]]] = None): + metadata = metadata or [None] * len(samples) + return collate_fn( + [to_icevision_record({**sample, DefaultDataKeys.METADATA: m}) for sample, m in zip(samples, metadata)] + ) + + def process_train_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Optional[Callable] = None, + shuffle: bool = False, + drop_last: bool = False, + sampler: Optional[Sampler] = None, + ) -> DataLoader: + data_loader = self.model_type.train_dl( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) + data_loader.collate_fn = functools.partial(self._collate_fn, data_loader.collate_fn) + return data_loader + + def process_val_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Optional[Callable] = None, + shuffle: bool = False, + drop_last: bool = False, + sampler: Optional[Sampler] = None, + ) -> DataLoader: + data_loader = self.model_type.valid_dl( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) + data_loader.collate_fn = functools.partial(self._collate_fn, data_loader.collate_fn) + return data_loader + + def process_test_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Optional[Callable] = None, + shuffle: bool = False, + drop_last: bool = False, + sampler: Optional[Sampler] = None, + ) -> DataLoader: + data_loader = self.model_type.valid_dl( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) + data_loader.collate_fn = functools.partial(self._collate_fn, data_loader.collate_fn) + return data_loader + + def process_predict_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int = 1, + num_workers: int = 0, + pin_memory: bool = False, + collate_fn: Callable = lambda x: x, + shuffle: bool = False, + drop_last: bool = True, + sampler: Optional[Sampler] = None, + ) -> DataLoader: + data_loader = self.model_type.infer_dl( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) + data_loader.collate_fn = functools.partial(self._collate_fn, data_loader.collate_fn) + return data_loader + + def training_step(self, batch, batch_idx) -> Any: + return self.icevision_adapter.training_step(batch, batch_idx) + + def validation_step(self, batch, batch_idx): + return self.icevision_adapter.validation_step(batch, batch_idx) + + def test_step(self, batch, batch_idx): + return self.icevision_adapter.validation_step(batch, batch_idx) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + return self(batch) + + def forward(self, batch: Any) -> Any: + return self.model_type.predict_from_dl(self.model, [batch], show_pbar=False) + + def training_epoch_end(self, outputs) -> None: + return self.icevision_adapter.training_epoch_end(outputs) + + def validation_epoch_end(self, outputs) -> None: + return self.icevision_adapter.validation_epoch_end(outputs) + + def test_epoch_end(self, outputs) -> None: + return self.icevision_adapter.validation_epoch_end(outputs) diff --git a/flash/core/integrations/icevision/backbones.py b/flash/core/integrations/icevision/backbones.py new file mode 100644 index 0000000000..dd30d3be56 --- /dev/null +++ b/flash/core/integrations/icevision/backbones.py @@ -0,0 +1,63 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from inspect import getmembers + +from torch import nn + +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _ICEVISION_AVAILABLE + +if _ICEVISION_AVAILABLE: + from icevision.backbones import BackboneConfig + + +def icevision_model_adapter(model_type): + class IceVisionModelAdapter(model_type.lightning.ModelAdapter): + def log(self, name, value, **kwargs): + if "prog_bar" not in kwargs: + kwargs["prog_bar"] = True + return super().log(name.split("/")[-1], value, **kwargs) + + return IceVisionModelAdapter + + +def load_icevision(adapter, model_type, backbone, num_classes, **kwargs): + model = model_type.model(backbone=backbone, num_classes=num_classes, **kwargs) + + backbone = nn.Module() + params = model.param_groups()[0] + for i, param in enumerate(params): + backbone.register_parameter(f"backbone_{i}", param) + + return model_type, model, adapter(model_type), backbone + + +def load_icevision_ignore_image_size(adapter, model_type, backbone, num_classes, image_size=None, **kwargs): + return load_icevision(adapter, model_type, backbone, num_classes, **kwargs) + + +def load_icevision_with_image_size(adapter, model_type, backbone, num_classes, image_size=None, **kwargs): + kwargs["img_size"] = image_size + return load_icevision(adapter, model_type, backbone, num_classes, **kwargs) + + +def get_backbones(model_type): + _BACKBONES = FlashRegistry("backbones") + + for backbone_name, backbone_config in getmembers(model_type.backbones, lambda x: isinstance(x, BackboneConfig)): + _BACKBONES( + backbone_config, + name=backbone_name, + ) + return _BACKBONES diff --git a/flash/core/integrations/icevision/data.py b/flash/core/integrations/icevision/data.py new file mode 100644 index 0000000000..80ce622616 --- /dev/null +++ b/flash/core/integrations/icevision/data.py @@ -0,0 +1,79 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type + +import numpy as np + +from flash.core.data.data_source import DefaultDataKeys +from flash.core.integrations.icevision.transforms import from_icevision_record +from flash.core.utilities.imports import _ICEVISION_AVAILABLE +from flash.image.data import ImagePathsDataSource + +if _ICEVISION_AVAILABLE: + from icevision.core.record import BaseRecord + from icevision.core.record_components import ClassMapRecordComponent, ImageRecordComponent, tasks + from icevision.data.data_splitter import SingleSplitSplitter + from icevision.parsers.parser import Parser + + +class IceVisionPathsDataSource(ImagePathsDataSource): + def predict_load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: + return super().predict_load_data(data, dataset) + + def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: + record = sample[DefaultDataKeys.INPUT].load() + return from_icevision_record(record) + + def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: + sample = super().load_sample(sample) + image = np.array(sample[DefaultDataKeys.INPUT]) + record = BaseRecord([ImageRecordComponent()]) + + record.set_img(image) + record.add_component(ClassMapRecordComponent(task=tasks.detection)) + return from_icevision_record(record) + + +class IceVisionParserDataSource(IceVisionPathsDataSource): + def __init__(self, parser: Optional[Type["Parser"]] = None): + super().__init__() + self.parser = parser + + def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: + root, ann_file = data + + if self.parser is not None: + parser = self.parser(ann_file, root) + dataset.num_classes = len(parser.class_map) + records = parser.parse(data_splitter=SingleSplitSplitter()) + return [{DefaultDataKeys.INPUT: record} for record in records[0]] + else: + raise ValueError("The parser type must be provided") + + +class IceDataParserDataSource(IceVisionPathsDataSource): + def __init__(self, parser: Optional[Callable] = None): + super().__init__() + self.parser = parser + + def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: + root = data + + if self.parser is not None: + parser = self.parser(root) + dataset.num_classes = len(parser.class_map) + records = parser.parse(data_splitter=SingleSplitSplitter()) + return [{DefaultDataKeys.INPUT: record} for record in records[0]] + else: + raise ValueError("The parser must be provided") diff --git a/flash/core/integrations/icevision/transforms.py b/flash/core/integrations/icevision/transforms.py new file mode 100644 index 0000000000..c5a5968160 --- /dev/null +++ b/flash/core/integrations/icevision/transforms.py @@ -0,0 +1,198 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, Tuple + +from torch import nn + +from flash.core.data.data_source import DefaultDataKeys +from flash.core.utilities.imports import _ICEVISION_AVAILABLE, requires_extras + +if _ICEVISION_AVAILABLE: + from icevision.core import tasks + from icevision.core.bbox import BBox + from icevision.core.keypoints import KeyPoints + from icevision.core.mask import EncodedRLEs, MaskArray + from icevision.core.record import BaseRecord + from icevision.core.record_components import ( + BBoxesRecordComponent, + ClassMapRecordComponent, + FilepathRecordComponent, + ImageRecordComponent, + InstancesLabelsRecordComponent, + KeyPointsRecordComponent, + MasksRecordComponent, + RecordIDRecordComponent, + ) + from icevision.tfms import A + + +def to_icevision_record(sample: Dict[str, Any]): + record = BaseRecord([]) + + metadata = sample.get(DefaultDataKeys.METADATA, None) or {} + + if "image_id" in metadata: + record_id_component = RecordIDRecordComponent() + record_id_component.set_record_id(metadata["image_id"]) + + component = ClassMapRecordComponent(tasks.detection) + component.set_class_map(metadata.get("class_map", None)) + record.add_component(component) + + if "labels" in sample[DefaultDataKeys.TARGET]: + labels_component = InstancesLabelsRecordComponent() + labels_component.add_labels_by_id(sample[DefaultDataKeys.TARGET]["labels"]) + record.add_component(labels_component) + + if "bboxes" in sample[DefaultDataKeys.TARGET]: + bboxes = [ + BBox.from_xywh(bbox["xmin"], bbox["ymin"], bbox["width"], bbox["height"]) + for bbox in sample[DefaultDataKeys.TARGET]["bboxes"] + ] + component = BBoxesRecordComponent() + component.set_bboxes(bboxes) + record.add_component(component) + + if "masks" in sample[DefaultDataKeys.TARGET]: + mask_array = MaskArray(sample[DefaultDataKeys.TARGET]["masks"]) + component = MasksRecordComponent() + component.set_masks(mask_array) + record.add_component(component) + + if "keypoints" in sample[DefaultDataKeys.TARGET]: + keypoints = [] + + for keypoints_list, keypoints_metadata in zip( + sample[DefaultDataKeys.TARGET]["keypoints"], sample[DefaultDataKeys.TARGET]["keypoints_metadata"] + ): + xyv = [] + for keypoint in keypoints_list: + xyv.extend((keypoint["x"], keypoint["y"], keypoint["visible"])) + + keypoints.append(KeyPoints.from_xyv(xyv, keypoints_metadata)) + component = KeyPointsRecordComponent() + component.set_keypoints(keypoints) + record.add_component(component) + + if isinstance(sample[DefaultDataKeys.INPUT], str): + input_component = FilepathRecordComponent() + input_component.set_filepath(sample[DefaultDataKeys.INPUT]) + else: + if "filepath" in metadata: + input_component = FilepathRecordComponent() + input_component.filepath = metadata["filepath"] + else: + input_component = ImageRecordComponent() + input_component.composite = record + input_component.set_img(sample[DefaultDataKeys.INPUT]) + record.add_component(input_component) + + return record + + +def from_icevision_record(record: "BaseRecord"): + sample = { + DefaultDataKeys.METADATA: { + "image_id": record.record_id, + } + } + + if record.img is not None: + sample[DefaultDataKeys.INPUT] = record.img + filepath = getattr(record, "filepath", None) + if filepath is not None: + sample[DefaultDataKeys.METADATA]["filepath"] = filepath + elif record.filepath is not None: + sample[DefaultDataKeys.INPUT] = record.filepath + + sample[DefaultDataKeys.TARGET] = {} + + if hasattr(record.detection, "bboxes"): + sample[DefaultDataKeys.TARGET]["bboxes"] = [] + for bbox in record.detection.bboxes: + bbox_list = list(bbox.xywh) + bbox_dict = { + "xmin": bbox_list[0], + "ymin": bbox_list[1], + "width": bbox_list[2], + "height": bbox_list[3], + } + sample[DefaultDataKeys.TARGET]["bboxes"].append(bbox_dict) + + if hasattr(record.detection, "masks"): + masks = record.detection.masks + + if isinstance(masks, EncodedRLEs): + masks = masks.to_mask(record.height, record.width) + + if isinstance(masks, MaskArray): + sample[DefaultDataKeys.TARGET]["masks"] = masks.data + else: + raise RuntimeError("Masks are expected to be a MaskArray or EncodedRLEs.") + + if hasattr(record.detection, "keypoints"): + keypoints = record.detection.keypoints + + sample[DefaultDataKeys.TARGET]["keypoints"] = [] + sample[DefaultDataKeys.TARGET]["keypoints_metadata"] = [] + + for keypoint in keypoints: + keypoints_list = [] + for x, y, v in keypoint.xyv: + keypoints_list.append( + { + "x": x, + "y": y, + "visible": v, + } + ) + sample[DefaultDataKeys.TARGET]["keypoints"].append(keypoints_list) + + # TODO: Unpack keypoints_metadata + sample[DefaultDataKeys.TARGET]["keypoints_metadata"].append(keypoint.metadata) + + if getattr(record.detection, "label_ids", None) is not None: + sample[DefaultDataKeys.TARGET]["labels"] = list(record.detection.label_ids) + + if getattr(record.detection, "class_map", None) is not None: + sample[DefaultDataKeys.METADATA]["class_map"] = record.detection.class_map + + return sample + + +class IceVisionTransformAdapter(nn.Module): + def __init__(self, transform): + super().__init__() + self.transform = transform + + def forward(self, x): + record = to_icevision_record(x) + record = self.transform(record) + return from_icevision_record(record) + + +@requires_extras("image") +def default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: + """The default transforms from IceVision.""" + return { + "pre_tensor_transform": IceVisionTransformAdapter(A.Adapter([*A.resize_and_pad(image_size), A.Normalize()])), + } + + +@requires_extras("image") +def train_default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: + """The default augmentations from IceVision.""" + return { + "pre_tensor_transform": IceVisionTransformAdapter(A.Adapter([*A.aug_tfms(size=image_size), A.Normalize()])), + } diff --git a/flash/core/model.py b/flash/core/model.py index 059089b299..282a3130e0 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -13,6 +13,7 @@ # limitations under the License. import functools import inspect +import pickle from abc import ABCMeta from copy import deepcopy from importlib import import_module @@ -21,9 +22,10 @@ import pytorch_lightning as pl import torch import torchmetrics -from pytorch_lightning import LightningModule +from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import Callback from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn from torch.optim.lr_scheduler import _LRScheduler @@ -50,6 +52,173 @@ from flash.core.utilities.imports import requires_extras +class ModuleWrapperBase: + """The ``ModuleWrapperBase`` is a base for classes which wrap a ``LightningModule`` or an instance of + ``ModuleWrapperBase``. + + This class ensures that trainer attributes are forwarded to any wrapped or nested + ``LightningModule`` instances so that nested calls to ``.log`` are handled correctly. The ``ModuleWrapperBase`` is + also stateful, meaning that a :class:`~flash.core.data.data_pipeline.DataPipelineState` can be attached. Attached + state will be forwarded to any nested ``ModuleWrapperBase`` instances. + """ + + def __init__(self): + super().__init__() + + self._children = [] + + # TODO: create enum values to define what are the exact states + self._data_pipeline_state: Optional[DataPipelineState] = None + + # model own internal state shared with the data pipeline. + self._state: Dict[Type[ProcessState], ProcessState] = {} + + def __setattr__(self, key, value): + if isinstance(value, (LightningModule, ModuleWrapperBase)): + self._children.append(key) + patched_attributes = ["_current_fx_name", "_current_hook_fx_name", "_results", "_data_pipeline_state"] + if isinstance(value, Trainer) or key in patched_attributes: + if hasattr(self, "_children"): + for child in self._children: + setattr(getattr(self, child), key, value) + super().__setattr__(key, value) + + def get_state(self, state_type): + if state_type in self._state: + return self._state[state_type] + if self._data_pipeline_state is not None: + return self._data_pipeline_state.get_state(state_type) + return None + + def set_state(self, state: ProcessState): + self._state[type(state)] = state + if self._data_pipeline_state is not None: + self._data_pipeline_state.set_state(state) + + def attach_data_pipeline_state(self, data_pipeline_state: "DataPipelineState"): + for state in self._state.values(): + data_pipeline_state.set_state(state) + for child in self._children: + child = getattr(self, child) + if hasattr(child, "attach_data_pipeline_state"): + child.attach_data_pipeline_state(data_pipeline_state) + + +class DatasetProcessor: + """The ``DatasetProcessor`` mixin provides hooks for classes which need custom logic for producing the data + loaders for each running stage given the corresponding dataset.""" + + def _process_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = True, + sampler: Optional[Sampler] = None, + ) -> DataLoader: + return DataLoader( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + collate_fn=collate_fn, + ) + + def process_train_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = True, + sampler: Optional[Sampler] = None, + ) -> DataLoader: + return self._process_dataset( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) + + def process_val_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = False, + sampler: Optional[Sampler] = None, + ) -> DataLoader: + return self._process_dataset( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) + + def process_test_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = True, + sampler: Optional[Sampler] = None, + ) -> DataLoader: + return self._process_dataset( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) + + def process_predict_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int = 1, + num_workers: int = 0, + pin_memory: bool = False, + collate_fn: Callable = None, + shuffle: bool = False, + drop_last: bool = True, + sampler: Optional[Sampler] = None, + ) -> DataLoader: + return self._process_dataset( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) + + class BenchmarkConvergenceCI(Callback): def __init__(self): self.history = [] @@ -98,7 +267,7 @@ def __new__(mcs, *args, **kwargs): return result -class Task(LightningModule, metaclass=CheckDependenciesMeta): +class Task(DatasetProcessor, ModuleWrapperBase, LightningModule, metaclass=CheckDependenciesMeta): """A general Task. Args: @@ -150,28 +319,10 @@ def __init__( self._postprocess: Optional[Postprocess] = postprocess self._serializer: Optional[Serializer] = None - # TODO: create enum values to define what are the exact states - self._data_pipeline_state: Optional[DataPipelineState] = None - - # model own internal state shared with the data pipeline. - self._state: Dict[Type[ProcessState], ProcessState] = {} - # Explicitly set the serializer to call the setter self.deserializer = deserializer self.serializer = serializer - self._children = [] - - def __setattr__(self, key, value): - if isinstance(value, LightningModule): - self._children.append(key) - patched_attributes = ["_current_fx_name", "_current_hook_fx_name", "_results"] - if isinstance(value, pl.Trainer) or key in patched_attributes: - if hasattr(self, "_children"): - for child in self._children: - setattr(getattr(self, child), key, value) - super().__setattr__(key, value) - def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any: """The training/validation/test step. @@ -262,8 +413,9 @@ def predict( data_pipeline = self.build_data_pipeline(data_source or "default", deserializer, data_pipeline) dataset = data_pipeline.data_source.generate_dataset(x, running_stage) - x = list(self.process_predict_dataset(dataset, convert_to_dataloader=False)) - x = data_pipeline.worker_preprocessor(running_stage)(x) + dataloader = self.process_predict_dataset(dataset) + x = list(dataloader.dataset) + x = data_pipeline.worker_preprocessor(running_stage, collate_fn=dataloader.collate_fn)(x) # todo (tchaton): Remove this when sync with Lightning master. if len(inspect.signature(self.transfer_batch_to_device).parameters) == 3: x = self.transfer_batch_to_device(x, self.device, 0) @@ -539,7 +691,11 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # This may be an issue since here we create the same problems with pickle as in # https://pytorch.org/docs/stable/notes/serialization.html if self.data_pipeline is not None and "data_pipeline" not in checkpoint: - checkpoint["data_pipeline"] = self.data_pipeline + try: + pickle.dumps(self.data_pipeline) # TODO: DataPipeline not always pickleable + checkpoint["data_pipeline"] = self.data_pipeline + except AttributeError: + rank_zero_warn("DataPipeline couldn't be added to the checkpoint.") if self._data_pipeline_state is not None and "_data_pipeline_state" not in checkpoint: checkpoint["_data_pipeline_state"] = self._data_pipeline_state super().on_save_checkpoint(checkpoint) @@ -552,11 +708,27 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: self._data_pipeline_state = checkpoint["_data_pipeline_state"] @classmethod - def available_backbones(cls) -> List[str]: - registry: Optional[FlashRegistry] = getattr(cls, "backbones", None) - if registry is None: - return [] - return registry.available_keys() + def available_backbones(cls, head: Optional[str] = None) -> Union[Dict[str, List[str]], List[str]]: + if head is None: + registry: Optional[FlashRegistry] = getattr(cls, "backbones", None) + if registry is not None: + return registry.available_keys() + heads = cls.available_heads() + else: + heads = [head] + + result = {} + for head in heads: + metadata = cls.heads.get(head, with_metadata=True)["metadata"] + if "backbones" in metadata: + backbones = metadata["backbones"].available_keys() + else: + backbones = cls.available_backbones() + result[head] = backbones + + if len(result) == 1: + result = next(iter(result.values())) + return result @classmethod def available_heads(cls) -> List[str]: @@ -697,134 +869,3 @@ def serve(self, host: str = "127.0.0.1", port: int = 8000, sanity_check: bool = composition = Composition(predict=comp, TESTING=flash._IS_TESTING) composition.serve(host=host, port=port) return composition - - def get_state(self, state_type): - if state_type in self._state: - return self._state[state_type] - if self._data_pipeline_state is not None: - return self._data_pipeline_state.get_state(state_type) - return None - - def set_state(self, state: ProcessState): - self._state[type(state)] = state - if self._data_pipeline_state is not None: - self._data_pipeline_state.set_state(state) - - def attach_data_pipeline_state(self, data_pipeline_state: "DataPipelineState"): - for state in self._state.values(): - data_pipeline_state.set_state(state) - - def _process_dataset( - self, - dataset: BaseAutoDataset, - batch_size: int, - num_workers: int, - pin_memory: bool, - collate_fn: Callable, - shuffle: bool = False, - drop_last: bool = True, - sampler: Optional[Sampler] = None, - convert_to_dataloader: bool = True, - ) -> DataLoader: - if convert_to_dataloader: - return DataLoader( - dataset, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=pin_memory, - shuffle=shuffle, - drop_last=drop_last, - collate_fn=collate_fn, - sampler=sampler, - ) - return dataset - - def process_train_dataset( - self, - dataset: BaseAutoDataset, - batch_size: int, - num_workers: int, - pin_memory: bool, - collate_fn: Callable, - shuffle: bool = False, - drop_last: bool = True, - sampler: Optional[Sampler] = None, - ) -> DataLoader: - return self._process_dataset( - dataset, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=pin_memory, - collate_fn=collate_fn, - shuffle=shuffle, - drop_last=drop_last, - sampler=sampler, - ) - - def process_val_dataset( - self, - dataset: BaseAutoDataset, - batch_size: int, - num_workers: int, - pin_memory: bool, - collate_fn: Callable, - shuffle: bool = False, - drop_last: bool = False, - sampler: Optional[Sampler] = None, - ) -> DataLoader: - return self._process_dataset( - dataset, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=pin_memory, - collate_fn=collate_fn, - shuffle=shuffle, - drop_last=drop_last, - sampler=sampler, - ) - - def process_test_dataset( - self, - dataset: BaseAutoDataset, - batch_size: int, - num_workers: int, - pin_memory: bool, - collate_fn: Callable, - shuffle: bool = False, - drop_last: bool = False, - sampler: Optional[Sampler] = None, - ) -> DataLoader: - return self._process_dataset( - dataset, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=pin_memory, - collate_fn=collate_fn, - shuffle=shuffle, - drop_last=drop_last, - sampler=sampler, - ) - - def process_predict_dataset( - self, - dataset: BaseAutoDataset, - batch_size: int = 1, - num_workers: int = 0, - pin_memory: bool = False, - collate_fn: Callable = lambda x: x, - shuffle: bool = False, - drop_last: bool = False, - sampler: Optional[Sampler] = None, - convert_to_dataloader: bool = True, - ) -> Union[DataLoader, BaseAutoDataset]: - return self._process_dataset( - dataset, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=pin_memory, - collate_fn=collate_fn, - shuffle=shuffle, - drop_last=drop_last, - sampler=sampler, - convert_to_dataloader=convert_to_dataloader, - ) diff --git a/flash/core/registry.py b/flash/core/registry.py index e35e3e3379..1f97f2a664 100644 --- a/flash/core/registry.py +++ b/flash/core/registry.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial -from types import FunctionType +import functools +from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union from pytorch_lightning.utilities import rank_zero_info @@ -21,6 +21,33 @@ _REGISTERED_FUNCTION = Dict[str, Any] +@dataclass +class Provider: + + name: str + url: str + + def __str__(self): + return f"{self.name} ({self.url})" + + +def print_provider_info(name, providers, func): + if not isinstance(providers, List): + providers = [providers] + providers = list(providers) + if len(providers) > 1: + providers[-2] = f"{str(providers[-2])} and {str(providers[-1])}" + providers = providers[:-1] + message = f"Using '{name}' provided by {', '.join(str(provider) for provider in providers)}." + + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank_zero_info(message) + return func(*args, **kwargs) + + return wrapper + + class FlashRegistry: """This class is used to register function or :class:`functools.partial` class to a registry.""" @@ -75,14 +102,18 @@ def _register_function( override: bool = False, metadata: Optional[Dict[str, Any]] = None, ): - if not isinstance(fn, FunctionType) and not isinstance(fn, partial): - raise MisconfigurationException(f"You can only register a function, found: {fn}") + if not callable(fn): + raise MisconfigurationException(f"You can only register a callable, found: {fn}") name = name or fn.__name__ if self._verbose: rank_zero_info(f"Registering: {fn.__name__} function with name: {name} and metadata: {metadata}") + if "providers" in metadata: + providers = metadata["providers"] + fn = print_provider_info(name, providers, fn) + item = {"fn": fn, "name": name, "metadata": metadata or {}} matching_index = self._find_matching_index(item) @@ -102,12 +133,20 @@ def _find_matching_index(self, item: _REGISTERED_FUNCTION) -> Optional[int]: return idx def __call__( - self, fn: Optional[Callable[..., Any]] = None, name: Optional[str] = None, override: bool = False, **metadata + self, + fn: Optional[Callable[..., Any]] = None, + name: Optional[str] = None, + override: bool = False, + providers: Optional[Union[Provider, List[Provider]]] = None, + **metadata, ) -> Callable: """This function is used to register new functions to the registry along their metadata. Functions can be filtered using metadata using the ``get`` function. """ + if providers is not None: + metadata["providers"] = providers + if fn is not None: self._register_function(fn=fn, name=name, override=override, metadata=metadata) return fn diff --git a/flash/core/serve/core.py b/flash/core/serve/core.py index e05717212a..563c0d580e 100644 --- a/flash/core/serve/core.py +++ b/flash/core/serve/core.py @@ -83,7 +83,7 @@ def __call__(self, *args, **kwargs): class Servable: - """Wrapper around a model object to enable serving at scale. + """ModuleWrapperBase around a model object to enable serving at scale. Create a ``Servable`` from either (LM, LOCATION) or (LOCATION,) diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index 9c542ecb23..1a4837c68b 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -95,6 +95,7 @@ def _compare_version(package: str, op, version) -> bool: _ROUGE_SCORE_AVAILABLE = _module_available("rouge_score") _SENTENCEPIECE_AVAILABLE = _module_available("sentencepiece") _DATASETS_AVAILABLE = _module_available("datasets") +_ICEVISION_AVAILABLE = _module_available("icevision") if Version: _TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0") @@ -117,6 +118,7 @@ def _compare_version(package: str, op, version) -> bool: _KORNIA_AVAILABLE, _PYSTICHE_AVAILABLE, _SEGMENTATION_MODELS_AVAILABLE, + _ICEVISION_AVAILABLE, ] ) _SERVE_AVAILABLE = _FASTAPI_AVAILABLE and _PYDANTIC_AVAILABLE and _CYTOOLZ_AVAILABLE and _UVICORN_AVAILABLE @@ -171,6 +173,10 @@ def requires_extras(extras: Union[str, List]): ) +def example_requires(extras: Union[str, List[str]]): + return requires_extras(extras)(lambda: None)() + + def lazy_import(module_name, callback=None): """Returns a proxy module object that will lazily import the given module the first time it is used. diff --git a/flash/image/detection/finetuning.py b/flash/core/utilities/providers.py similarity index 54% rename from flash/image/detection/finetuning.py rename to flash/core/utilities/providers.py index 7294be86f4..ff464e690c 100644 --- a/flash/image/detection/finetuning.py +++ b/flash/core/utilities/providers.py @@ -11,17 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +from flash.core.registry import Provider -from flash.core.finetuning import FlashBaseFinetuning - - -class ObjectDetectionFineTuning(FlashBaseFinetuning): - """Freezes the backbone during Detector training.""" - - def __init__(self, train_bn: bool = True) -> None: - super().__init__(train_bn=train_bn) - - def freeze_before_training(self, pl_module: pl.LightningModule) -> None: - model = pl_module.model - self.freeze(modules=model.backbone, train_bn=self.train_bn) +_ICEVISION = Provider("airctic/IceVision", "https://github.com/airctic/icevision") +_TORCHVISION = Provider("PyTorch/torchvision", "https://github.com/pytorch/vision") +_ULTRALYTICS = Provider("Ultralytics/YOLOV5", "https://github.com/ultralytics/yolov5") +_MMDET = Provider("OpenMMLab/MMDetection", "https://github.com/open-mmlab/mmdetection") +_EFFDET = Provider("rwightman/efficientdet-pytorch", "https://github.com/rwightman/efficientdet-pytorch") diff --git a/flash/core/utilities/url_error.py b/flash/core/utilities/url_error.py index 83559131c9..6f0d28676a 100644 --- a/flash/core/utilities/url_error.py +++ b/flash/core/utilities/url_error.py @@ -23,6 +23,9 @@ def wrapper(*args, pretrained=False, **kwargs): try: return fn(*args, pretrained=pretrained, **kwargs) except urllib.error.URLError: + # Hack for icevision/efficientdet to work without internet access + if "efficientdet" in kwargs.get("head", ""): + kwargs["pretrained_backbone"] = False result = fn(*args, pretrained=False, **kwargs) rank_zero_warn( "Failed to download pretrained weights for the selected backbone. The backbone has been created with" diff --git a/flash/image/__init__.py b/flash/image/__init__.py index 352cbaff8e..b3ac7f10b6 100644 --- a/flash/image/__init__.py +++ b/flash/image/__init__.py @@ -1,4 +1,3 @@ -from flash.image.backbones import OBJ_DETECTION_BACKBONES # noqa: F401 from flash.image.classification import ( # noqa: F401 ImageClassificationData, ImageClassificationPreprocess, @@ -7,6 +6,8 @@ from flash.image.classification.backbones import IMAGE_CLASSIFIER_BACKBONES # noqa: F401 from flash.image.detection import ObjectDetectionData, ObjectDetector # noqa: F401 from flash.image.embedding import ImageEmbedder # noqa: F401 +from flash.image.instance_segmentation import InstanceSegmentation, InstanceSegmentationData # noqa: F401 +from flash.image.keypoint_detection import KeypointDetectionData, KeypointDetector # noqa: F401 from flash.image.segmentation import ( # noqa: F401 SemanticSegmentation, SemanticSegmentationData, diff --git a/flash/image/backbones.py b/flash/image/backbones.py deleted file mode 100644 index 82bb8dc8a6..0000000000 --- a/flash/image/backbones.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from functools import partial -from typing import Tuple - -from torch import nn - -from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _TORCHVISION_AVAILABLE -from flash.core.utilities.url_error import catch_url_error - -if _TORCHVISION_AVAILABLE: - from torchvision.models.detection.backbone_utils import resnet_fpn_backbone - -RESNET_MODELS = ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d", "resnext101_32x8d"] - -OBJ_DETECTION_BACKBONES = FlashRegistry("backbones") - -if _TORCHVISION_AVAILABLE: - - def _fn_resnet_fpn( - model_name: str, - pretrained: bool = True, - trainable_layers: bool = True, - **kwargs, - ) -> Tuple[nn.Module, int]: - backbone = resnet_fpn_backbone(model_name, pretrained=pretrained, trainable_layers=trainable_layers, **kwargs) - return backbone, 256 - - for model_name in RESNET_MODELS: - OBJ_DETECTION_BACKBONES( - fn=catch_url_error(partial(_fn_resnet_fpn, model_name)), - name=model_name, - package="torchvision", - type="resnet-fpn", - ) diff --git a/flash/image/detection/backbones.py b/flash/image/detection/backbones.py new file mode 100644 index 0000000000..c3e9d5cfad --- /dev/null +++ b/flash/image/detection/backbones.py @@ -0,0 +1,122 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +from typing import Optional + +import torch + +from flash.core.adapter import Adapter +from flash.core.integrations.icevision.adapter import IceVisionAdapter, SimpleCOCOMetric +from flash.core.integrations.icevision.backbones import ( + get_backbones, + icevision_model_adapter, + load_icevision_ignore_image_size, + load_icevision_with_image_size, +) +from flash.core.model import Task +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _module_available, _TORCHVISION_AVAILABLE +from flash.core.utilities.providers import _EFFDET, _ICEVISION, _MMDET, _TORCHVISION, _ULTRALYTICS + +if _ICEVISION_AVAILABLE: + from icevision import models as icevision_models + from icevision.metrics import COCOMetricType + from icevision.metrics import Metric as IceVisionMetric + +OBJECT_DETECTION_HEADS = FlashRegistry("heads") + + +class IceVisionObjectDetectionAdapter(IceVisionAdapter): + @classmethod + def from_task( + cls, + task: Task, + num_classes: int, + backbone: str = "resnet18_fpn", + head: str = "retinanet", + pretrained: bool = True, + metrics: Optional["IceVisionMetric"] = None, + image_size: Optional = None, + **kwargs, + ) -> Adapter: + return super().from_task( + task, + num_classes=num_classes, + backbone=backbone, + head=head, + pretrained=pretrained, + metrics=metrics or [SimpleCOCOMetric(COCOMetricType.bbox)], + image_size=image_size, + **kwargs, + ) + + +if _ICEVISION_AVAILABLE: + if _TORCHVISION_AVAILABLE: + for model_type in [icevision_models.torchvision.retinanet, icevision_models.torchvision.faster_rcnn]: + OBJECT_DETECTION_HEADS( + partial(load_icevision_ignore_image_size, icevision_model_adapter, model_type), + model_type.__name__.split(".")[-1], + backbones=get_backbones(model_type), + adapter=IceVisionObjectDetectionAdapter, + providers=[_ICEVISION, _TORCHVISION], + ) + + if _module_available("yolov5"): + model_type = icevision_models.ultralytics.yolov5 + OBJECT_DETECTION_HEADS( + partial(load_icevision_with_image_size, icevision_model_adapter, model_type), + model_type.__name__.split(".")[-1], + backbones=get_backbones(model_type), + adapter=IceVisionObjectDetectionAdapter, + providers=[_ICEVISION, _ULTRALYTICS], + ) + + if _module_available("mmdet"): + for model_type in [ + icevision_models.mmdet.faster_rcnn, + icevision_models.mmdet.retinanet, + icevision_models.mmdet.fcos, + icevision_models.mmdet.sparse_rcnn, + ]: + OBJECT_DETECTION_HEADS( + partial(load_icevision_ignore_image_size, icevision_model_adapter, model_type), + f"mmdet_{model_type.__name__.split('.')[-1]}", + backbones=get_backbones(model_type), + adapter=IceVisionObjectDetectionAdapter, + providers=[_ICEVISION, _MMDET], + ) + + if _module_available("effdet"): + + def _icevision_effdet_model_adapter(model_type): + class IceVisionEffdetModelAdapter(icevision_model_adapter(model_type)): + def validation_step(self, batch, batch_idx): + images = batch[0][0] + batch[0][1]["img_scale"] = torch.ones_like(images[:, 0, 0, 0]).unsqueeze(1) + batch[0][1]["img_size"] = ( + (torch.ones_like(images[:, 0, 0, 0]) * images[0].shape[-1]).unsqueeze(1).repeat(1, 2) + ) + return super().validation_step(batch, batch_idx) + + return IceVisionEffdetModelAdapter + + model_type = icevision_models.ross.efficientdet + OBJECT_DETECTION_HEADS( + partial(load_icevision_with_image_size, _icevision_effdet_model_adapter, model_type), + model_type.__name__.split(".")[-1], + backbones=get_backbones(model_type), + adapter=IceVisionObjectDetectionAdapter, + providers=[_ICEVISION, _EFFDET], + ) diff --git a/flash/image/detection/data.py b/flash/image/detection/data.py index d19ec4f2e3..d75ff23430 100644 --- a/flash/image/detection/data.py +++ b/flash/image/detection/data.py @@ -11,25 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -from typing import Any, Callable, Dict, Optional, Sequence, Tuple, TYPE_CHECKING +from typing import Any, Callable, Dict, Hashable, Optional, Sequence, Tuple, TYPE_CHECKING from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule -from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources, FiftyOneDataSource +from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources, FiftyOneDataSource from flash.core.data.process import Preprocess -from flash.core.utilities.imports import ( - _COCO_AVAILABLE, - _FIFTYONE_AVAILABLE, - _TORCHVISION_AVAILABLE, - lazy_import, - requires, +from flash.core.integrations.icevision.data import ( + IceDataParserDataSource, + IceVisionParserDataSource, + IceVisionPathsDataSource, ) -from flash.image.data import ImagePathsDataSource -from flash.image.detection.transforms import default_transforms - -if _COCO_AVAILABLE: - from pycocotools.coco import COCO +from flash.core.integrations.icevision.transforms import default_transforms +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _ICEVISION_AVAILABLE, lazy_import, requires SampleCollection = None if _FIFTYONE_AVAILABLE: @@ -39,159 +33,105 @@ else: foc, fol = None, None -if _TORCHVISION_AVAILABLE: - from torchvision.datasets.folder import default_loader +if _ICEVISION_AVAILABLE: + from icevision.core import BBox, ClassMap, IsCrowdsRecordComponent, ObjectDetectionRecord + from icevision.data import SingleSplitSplitter + from icevision.parsers import COCOBBoxParser, Parser, VIABBoxParser, VOCBBoxParser + from icevision.utils import ImgSize +else: + Parser = object -class COCODataSource(DataSource[Tuple[str, str]]): - @requires("pycocotools") - def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: - root, ann_file = data +class FiftyOneParser(Parser): + def __init__(self, data, class_map, label_field, iscrowd): + template_record = ObjectDetectionRecord() + template_record.add_component(IsCrowdsRecordComponent()) + super().__init__(template_record=template_record) - coco = COCO(ann_file) + data = data + label_field = label_field + iscrowd = iscrowd - categories = coco.loadCats(coco.getCatIds()) - if categories: - dataset.num_classes = categories[-1]["id"] + 1 + self.data = [] + self.class_map = class_map - img_ids = list(sorted(coco.imgs.keys())) - paths = coco.loadImgs(img_ids) + for fp, w, h, sample_labs, sample_boxes, sample_iscrowd in zip( + data.values("filepath"), + data.values("metadata.width"), + data.values("metadata.height"), + data.values(label_field + ".detections.label"), + data.values(label_field + ".detections.bounding_box"), + data.values(label_field + ".detections." + iscrowd), + ): + for lab, box, iscrowd in zip(sample_labs, sample_boxes, sample_iscrowd): + self.data.append((fp, w, h, lab, box, iscrowd)) - data = [] + def __iter__(self) -> Any: + return iter(self.data) - for img_id, path in zip(img_ids, paths): - path = path["file_name"] + def __len__(self) -> int: + return len(self.data) - ann_ids = coco.getAnnIds(imgIds=img_id) - annotations = coco.loadAnns(ann_ids) + def record_id(self, o) -> Hashable: + return o[0] - boxes, labels, areas, iscrowd = [], [], [], [] + def parse_fields(self, o, record, is_new): + fp, w, h, lab, box, iscrowd = o - # Ref: https://github.com/pytorch/vision/blob/master/references/detection/coco_utils.py - if self.training and all(any(o <= 1 for o in obj["bbox"][2:]) for obj in annotations): - continue + if iscrowd is None: + iscrowd = 0 - for obj in annotations: - xmin = obj["bbox"][0] - ymin = obj["bbox"][1] - xmax = xmin + obj["bbox"][2] - ymax = ymin + obj["bbox"][3] + if is_new: + record.set_filepath(fp) + record.set_img_size(ImgSize(width=w, height=h)) + record.detection.set_class_map(self.class_map) - bbox = [xmin, ymin, xmax, ymax] - keep = (bbox[3] > bbox[1]) & (bbox[2] > bbox[0]) - if keep: - boxes.append(bbox) - labels.append(obj["category_id"]) - areas.append(obj["area"]) - iscrowd.append(obj["iscrowd"]) + box = self._reformat_bbox(*box, w, h) - data.append( - dict( - input=os.path.join(root, path), - target=dict( - boxes=boxes, - labels=labels, - image_id=img_id, - area=areas, - iscrowd=iscrowd, - ), - ) - ) - return data + record.detection.add_bboxes([BBox.from_xyxy(*box)]) + record.detection.add_labels([lab]) + record.detection.add_iscrowds([iscrowd]) - def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: - filepath = sample[DefaultDataKeys.INPUT] - img = default_loader(filepath) - sample[DefaultDataKeys.INPUT] = img - w, h = img.size # WxH - sample[DefaultDataKeys.METADATA] = { - "filepath": filepath, - "size": (h, w), - } - return sample + @staticmethod + def _reformat_bbox(xmin, ymin, box_w, box_h, img_w, img_h): + xmin *= img_w + ymin *= img_h + box_w *= img_w + box_h *= img_h + xmax = xmin + box_w + ymax = ymin + box_h + output_bbox = [xmin, ymin, xmax, ymax] + return output_bbox -class ObjectDetectionFiftyOneDataSource(FiftyOneDataSource): +class ObjectDetectionFiftyOneDataSource(IceVisionPathsDataSource, FiftyOneDataSource): def __init__(self, label_field: str = "ground_truth", iscrowd: str = "iscrowd"): - super().__init__(label_field=label_field) + super().__init__() + self.label_field = label_field self.iscrowd = iscrowd @property + @requires("fiftyone") def label_cls(self): return fol.Detections + @requires("fiftyone") def load_data(self, data: SampleCollection, dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: self._validate(data) data.compute_metadata() - - filepaths = data.values("filepath") - widths = data.values("metadata.width") - heights = data.values("metadata.height") - labels = data.values(self.label_field + ".detections.label") - bboxes = data.values(self.label_field + ".detections.bounding_box") - iscrowds = data.values(self.label_field + ".detections." + self.iscrowd) - classes = self._get_classes(data) - class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} - if dataset is not None: - dataset.num_classes = len(classes) + class_map = ClassMap(classes) + dataset.num_classes = len(class_map) - output_data = [] - img_id = 1 - for fp, w, h, sample_labs, sample_boxes, sample_iscrowd in zip( - filepaths, widths, heights, labels, bboxes, iscrowds - ): - output_boxes = [] - output_labs = [] - output_iscrowd = [] - output_areas = [] - for lab, box, iscrowd in zip(sample_labs, sample_boxes, sample_iscrowd): - output_box, output_area = self._reformat_bbox(box[0], box[1], box[2], box[3], w, h) - output_areas.append(output_area) - output_labs.append(class_to_idx[lab]) - output_boxes.append(output_box) - if iscrowd is None: - iscrowd = 0 - output_iscrowd.append(iscrowd) - output_data.append( - dict( - input=fp, - target=dict( - boxes=output_boxes, - labels=output_labs, - image_id=img_id, - area=output_areas, - iscrowd=output_iscrowd, - ), - ) - ) - img_id += 1 - - return output_data + parser = FiftyOneParser(data, class_map, self.label_field, self.iscrowd) + records = parser.parse(data_splitter=SingleSplitSplitter()) + return [{DefaultDataKeys.INPUT: record} for record in records[0]] @staticmethod - def load_sample(sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: - filepath = sample[DefaultDataKeys.INPUT] - img = default_loader(filepath) - sample[DefaultDataKeys.INPUT] = img - w, h = img.size # WxH - sample[DefaultDataKeys.METADATA] = { - "filepath": filepath, - "size": (h, w), - } - return sample - - @staticmethod - def _reformat_bbox(xmin, ymin, box_w, box_h, img_w, img_h): - xmin *= img_w - ymin *= img_h - box_w *= img_w - box_h *= img_h - xmax = xmin + box_w - ymax = ymin + box_h - output_bbox = [xmin, ymin, xmax, ymax] - return output_bbox, box_w * box_h + @requires("fiftyone") + def predict_load_data(data: SampleCollection, dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: + return [{DefaultDataKeys.INPUT: f} for f in data.values("filepath")] class ObjectDetectionPreprocess(Preprocess): @@ -201,22 +141,30 @@ def __init__( val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, + image_size: Tuple[int, int] = (128, 128), + parser: Optional[Callable] = None, **data_source_kwargs: Any, ): + self.image_size = image_size + super().__init__( train_transform=train_transform, val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, data_sources={ + "coco": IceVisionParserDataSource(parser=COCOBBoxParser), + "via": IceVisionParserDataSource(parser=VIABBoxParser), + "voc": IceVisionParserDataSource(parser=VOCBBoxParser), + DefaultDataSources.FILES: IceVisionPathsDataSource(), + DefaultDataSources.FOLDERS: IceDataParserDataSource(parser=parser), DefaultDataSources.FIFTYONE: ObjectDetectionFiftyOneDataSource(**data_source_kwargs), - DefaultDataSources.FILES: ImagePathsDataSource(), - DefaultDataSources.FOLDERS: ImagePathsDataSource(), - "coco": COCODataSource(), }, default_data_source=DefaultDataSources.FILES, ) + self._default_collate = self._identity + def get_state_dict(self) -> Dict[str, Any]: return {**self.transforms} @@ -225,7 +173,10 @@ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): return cls(**state_dict) def default_transforms(self) -> Optional[Dict[str, Callable]]: - return default_transforms() + return default_transforms(self.image_size) + + def train_default_transforms(self) -> Optional[Dict[str, Callable]]: + return default_transforms(self.image_size) class ObjectDetectionData(DataModule): @@ -233,7 +184,6 @@ class ObjectDetectionData(DataModule): preprocess_cls = ObjectDetectionPreprocess @classmethod - @requires("pycocotools") def from_coco( cls, train_folder: Optional[str] = None, @@ -242,9 +192,11 @@ def from_coco( val_ann_file: Optional[str] = None, test_folder: Optional[str] = None, test_ann_file: Optional[str] = None, + predict_folder: Optional[str] = None, train_transform: Optional[Dict[str, Callable]] = None, val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, @@ -253,7 +205,7 @@ def from_coco( **preprocess_kwargs: Any, ): """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders - and corresponding target folders. + and annotation files in the COCO format. Args: train_folder: The folder containing the train data. @@ -262,12 +214,15 @@ def from_coco( val_ann_file: The COCO format annotation file. test_folder: The folder containing the test data. test_ann_file: The COCO format annotation file. + predict_folder: The folder containing the predict data. train_transform: The dictionary of transforms to use during training which maps :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the @@ -284,7 +239,7 @@ def from_coco( Examples:: - data_module = SemanticSegmentationData.from_coco( + data_module = ObjectDetectionData.from_coco( train_folder="train_folder", train_ann_file="annotations.json", ) @@ -294,9 +249,169 @@ def from_coco( (train_folder, train_ann_file) if train_folder else None, (val_folder, val_ann_file) if val_folder else None, (test_folder, test_ann_file) if test_folder else None, + predict_folder, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) + + @classmethod + def from_voc( + cls, + train_folder: Optional[str] = None, + train_ann_file: Optional[str] = None, + val_folder: Optional[str] = None, + val_ann_file: Optional[str] = None, + test_folder: Optional[str] = None, + test_ann_file: Optional[str] = None, + predict_folder: Optional[str] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ): + """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders + and annotation files in the VOC format. + + Args: + train_folder: The folder containing the train data. + train_ann_file: The COCO format annotation file. + val_folder: The folder containing the validation data. + val_ann_file: The COCO format annotation file. + test_folder: The folder containing the test data. + test_ann_file: The COCO format annotation file. + predict_folder: The folder containing the predict data. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + will be constructed and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = ObjectDetectionData.from_voc( + train_folder="train_folder", + train_ann_file="annotations.json", + ) + """ + return cls.from_data_source( + "voc", + (train_folder, train_ann_file) if train_folder else None, + (val_folder, val_ann_file) if val_folder else None, + (test_folder, test_ann_file) if test_folder else None, + predict_folder, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) + + @classmethod + def from_via( + cls, + train_folder: Optional[str] = None, + train_ann_file: Optional[str] = None, + val_folder: Optional[str] = None, + val_ann_file: Optional[str] = None, + test_folder: Optional[str] = None, + test_ann_file: Optional[str] = None, + predict_folder: Optional[str] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ): + """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders + and annotation files in the VIA format. + + Args: + train_folder: The folder containing the train data. + train_ann_file: The COCO format annotation file. + val_folder: The folder containing the validation data. + val_ann_file: The COCO format annotation file. + test_folder: The folder containing the test data. + test_ann_file: The COCO format annotation file. + predict_folder: The folder containing the predict data. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + will be constructed and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = ObjectDetectionData.from_via( + train_folder="train_folder", + train_ann_file="annotations.json", + ) + """ + return cls.from_data_source( + "via", + (train_folder, train_ann_file) if train_folder else None, + (val_folder, val_ann_file) if val_folder else None, + (test_folder, test_ann_file) if test_folder else None, + predict_folder, train_transform=train_transform, val_transform=val_transform, test_transform=test_transform, + predict_transform=predict_transform, data_fetcher=data_fetcher, preprocess=preprocess, val_split=val_split, diff --git a/flash/image/detection/model.py b/flash/image/detection/model.py index 320f64bbee..c2bcd606f6 100644 --- a/flash/image/detection/model.py +++ b/flash/image/detection/model.py @@ -11,53 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union +from typing import Any, Dict, List, Mapping, Optional, Type, Union import torch -from torch import nn, tensor from torch.optim import Optimizer -from flash.core.data.data_source import DefaultDataKeys +from flash.core.adapter import AdapterTask from flash.core.data.process import Serializer -from flash.core.model import Task from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _TORCHVISION_AVAILABLE -from flash.image.backbones import OBJ_DETECTION_BACKBONES -from flash.image.detection.finetuning import ObjectDetectionFineTuning -from flash.image.detection.serialization import DetectionLabels +from flash.image.detection.backbones import OBJECT_DETECTION_HEADS -if _TORCHVISION_AVAILABLE: - import torchvision - from torchvision.models.detection.faster_rcnn import FasterRCNN, FastRCNNPredictor - from torchvision.models.detection.retinanet import RetinaNet, RetinaNetHead - from torchvision.models.detection.rpn import AnchorGenerator - from torchvision.ops import box_iou - _models = { - "fasterrcnn": torchvision.models.detection.fasterrcnn_resnet50_fpn, - "retinanet": torchvision.models.detection.retinanet_resnet50_fpn, - } - -else: - AnchorGenerator = None - - -def _evaluate_iou(target, pred): - """Evaluate intersection over union (IOU) for target from dataset and output prediction from model.""" - if pred["boxes"].shape[0] == 0: - # no box detected, 0 IOU - return tensor(0.0, device=pred["boxes"].device) - return box_iou(target["boxes"], pred["boxes"]).diag().mean() - - -class ObjectDetector(Task): +class ObjectDetector(AdapterTask): """The ``ObjectDetector`` is a :class:`~flash.Task` for detecting objects in images. For more details, see :ref:`object_detection`. Args: num_classes: the number of classes for detection, including background model: a string of :attr`_models`. Defaults to 'fasterrcnn'. - backbone: Pretained backbone CNN architecture. Constructs a model with a + backbone: Pretrained backbone CNN architecture. Constructs a model with a ResNet-50-FPN backbone when no backbone is specified. fpn: If True, creates a Feature Pyramind Network on top of Resnet based CNNs. pretrained: if true, returns a model pre-trained on COCO train2017 @@ -74,144 +46,40 @@ class ObjectDetector(Task): """ - backbones: FlashRegistry = OBJ_DETECTION_BACKBONES + heads: FlashRegistry = OBJECT_DETECTION_HEADS required_extras: str = "image" def __init__( self, num_classes: int, - model: str = "fasterrcnn", - backbone: Optional[str] = None, - fpn: bool = True, + backbone: Optional[str] = "resnet18_fpn", + head: Optional[str] = "retinanet", pretrained: bool = True, - pretrained_backbone: bool = True, - trainable_backbone_layers: int = 3, - anchor_generator: Optional[Type["AnchorGenerator"]] = None, - loss=None, - metrics: Union[Callable, nn.Module, Mapping, Sequence, None] = None, - optimizer: Type[Optimizer] = torch.optim.AdamW, - learning_rate: float = 1e-3, + optimizer: Type[Optimizer] = torch.optim.Adam, + learning_rate: float = 5e-4, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, **kwargs: Any, ): self.save_hyperparameters() - if model in _models: - model = ObjectDetector.get_model( - model, - num_classes, - backbone, - fpn, - pretrained, - pretrained_backbone, - trainable_backbone_layers, - anchor_generator, - **kwargs, - ) - else: - ValueError(f"{model} is not supported yet.") + metadata = self.heads.get(head, with_metadata=True) + adapter = metadata["metadata"]["adapter"].from_task( + self, + num_classes=num_classes, + backbone=backbone, + head=head, + pretrained=pretrained, + **kwargs, + ) super().__init__( - model=model, - loss_fn=loss, - metrics=metrics, + adapter, learning_rate=learning_rate, optimizer=optimizer, - serializer=serializer or DetectionLabels(), + serializer=serializer, ) - @staticmethod - def get_model( - model_name, - num_classes, - backbone, - fpn, - pretrained, - pretrained_backbone, - trainable_backbone_layers, - anchor_generator, - **kwargs, - ): - if backbone is None: - # Constructs a model with a ResNet-50-FPN backbone when no backbone is specified. - if model_name == "fasterrcnn": - model = _models[model_name]( - pretrained=pretrained, - pretrained_backbone=pretrained_backbone, - trainable_backbone_layers=trainable_backbone_layers, - ) - in_features = model.roi_heads.box_predictor.cls_score.in_features - head = FastRCNNPredictor(in_features, num_classes) - model.roi_heads.box_predictor = head - else: - model = _models[model_name](pretrained=pretrained, pretrained_backbone=pretrained_backbone) - model.head = RetinaNetHead( - in_channels=model.backbone.out_channels, - num_anchors=model.head.classification_head.num_anchors, - num_classes=num_classes, - **kwargs, - ) - else: - backbone_model, num_features = ObjectDetector.backbones.get(backbone)( - pretrained=pretrained_backbone, - trainable_layers=trainable_backbone_layers, - **kwargs, - ) - backbone_model.out_channels = num_features - if anchor_generator is None: - anchor_generator = ( - AnchorGenerator(sizes=((32, 64, 128, 256, 512),), aspect_ratios=((0.5, 1.0, 2.0),)) - if not hasattr(backbone_model, "fpn") - else None - ) - - if model_name == "fasterrcnn": - model = FasterRCNN(backbone_model, num_classes=num_classes, rpn_anchor_generator=anchor_generator) - else: - model = RetinaNet(backbone_model, num_classes=num_classes, anchor_generator=anchor_generator) - return model - - def forward(self, x: List[torch.Tensor]) -> Any: - return self.model(x) - - def training_step(self, batch, batch_idx) -> Any: - """The training step. - - Overrides ``Task.training_step`` - """ - images, targets = batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET] - targets = [dict(t.items()) for t in targets] - - # fasterrcnn takes both images and targets for training, returns loss_dict - loss_dict = self.model(images, targets) - loss = sum(loss_dict.values()) - self.log_dict({f"train_{k}": v for k, v in loss_dict.items()}, on_step=True, on_epoch=True, prog_bar=True) - return loss - - def validation_step(self, batch, batch_idx): - images, targets = batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET] - # fasterrcnn takes only images for eval() mode - outs = self(images) - iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, outs)]).mean() - self.log("val_iou", iou) - - def test_step(self, batch, batch_idx): - images, targets = batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET] - # fasterrcnn takes only images for eval() mode - outs = self(images) - iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, outs)]).mean() - self.log("test_iou", iou) - - def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - images = batch[DefaultDataKeys.INPUT] - batch[DefaultDataKeys.PREDS] = self(images) - return batch - - def configure_finetune_callback(self): - return [ObjectDetectionFineTuning(train_bn=True)] - def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: """This function is used only for debugging usage with CI.""" - # todo (tchaton) Improve convergence - # history[-1]["val_iou"] + # todo diff --git a/flash/image/detection/transforms.py b/flash/image/detection/transforms.py deleted file mode 100644 index 5179f1f8a7..0000000000 --- a/flash/image/detection/transforms.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Callable, Dict, Sequence - -import torch -from torch import nn - -from flash.core.data.transforms import ApplyToKeys -from flash.core.utilities.imports import _TORCHVISION_AVAILABLE - -if _TORCHVISION_AVAILABLE: - import torchvision - - -def collate(samples: Sequence[Dict[str, Any]]) -> Dict[str, Sequence[Any]]: - return {key: [sample[key] for sample in samples] for key in samples[0]} - - -def default_transforms() -> Dict[str, Callable]: - """The default transforms for object detection: convert the image and targets to a tensor, collate the - batch.""" - return { - "to_tensor_transform": nn.Sequential( - ApplyToKeys("input", torchvision.transforms.ToTensor()), - ApplyToKeys( - "target", - nn.Sequential( - ApplyToKeys("boxes", torch.as_tensor), - ApplyToKeys("labels", torch.as_tensor), - ApplyToKeys("image_id", torch.as_tensor), - ApplyToKeys("area", torch.as_tensor), - ApplyToKeys("iscrowd", torch.as_tensor), - ), - ), - ), - "collate": collate, - } diff --git a/flash/image/instance_segmentation/__init__.py b/flash/image/instance_segmentation/__init__.py new file mode 100644 index 0000000000..c5659822c8 --- /dev/null +++ b/flash/image/instance_segmentation/__init__.py @@ -0,0 +1,2 @@ +from flash.image.instance_segmentation.data import InstanceSegmentationData # noqa: F401 +from flash.image.instance_segmentation.model import InstanceSegmentation # noqa: F401 diff --git a/flash/image/instance_segmentation/backbones.py b/flash/image/instance_segmentation/backbones.py new file mode 100644 index 0000000000..9811d6fa78 --- /dev/null +++ b/flash/image/instance_segmentation/backbones.py @@ -0,0 +1,81 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +from typing import Optional + +from flash.core.adapter import Adapter +from flash.core.integrations.icevision.adapter import IceVisionAdapter, SimpleCOCOMetric +from flash.core.integrations.icevision.backbones import ( + get_backbones, + icevision_model_adapter, + load_icevision_ignore_image_size, +) +from flash.core.model import Task +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _module_available, _TORCHVISION_AVAILABLE +from flash.core.utilities.providers import _ICEVISION, _MMDET, _TORCHVISION + +if _ICEVISION_AVAILABLE: + from icevision import models as icevision_models + from icevision.metrics import COCOMetricType + from icevision.metrics import Metric as IceVisionMetric + +INSTANCE_SEGMENTATION_HEADS = FlashRegistry("heads") + + +class IceVisionInstanceSegmentationAdapter(IceVisionAdapter): + @classmethod + def from_task( + cls, + task: Task, + num_classes: int, + backbone: str = "resnet18_fpn", + head: str = "mask_rcnn", + pretrained: bool = True, + metrics: Optional["IceVisionMetric"] = None, + image_size: Optional = None, + **kwargs, + ) -> Adapter: + return super().from_task( + task, + num_classes=num_classes, + backbone=backbone, + head=head, + pretrained=pretrained, + metrics=metrics or [SimpleCOCOMetric(COCOMetricType.mask)], + image_size=image_size, + **kwargs, + ) + + +if _ICEVISION_AVAILABLE: + if _TORCHVISION_AVAILABLE: + model_type = icevision_models.torchvision.mask_rcnn + INSTANCE_SEGMENTATION_HEADS( + partial(load_icevision_ignore_image_size, icevision_model_adapter, model_type), + model_type.__name__.split(".")[-1], + backbones=get_backbones(model_type), + adapter=IceVisionInstanceSegmentationAdapter, + providers=[_ICEVISION, _TORCHVISION], + ) + + if _module_available("mmdet"): + model_type = icevision_models.mmdet.mask_rcnn + INSTANCE_SEGMENTATION_HEADS( + partial(load_icevision_ignore_image_size, icevision_model_adapter, model_type), + f"mmdet_{model_type.__name__.split('.')[-1]}", + backbones=get_backbones(model_type), + adapter=IceVisionInstanceSegmentationAdapter, + providers=[_ICEVISION, _MMDET], + ) diff --git a/flash/image/instance_segmentation/data.py b/flash/image/instance_segmentation/data.py new file mode 100644 index 0000000000..b67e606683 --- /dev/null +++ b/flash/image/instance_segmentation/data.py @@ -0,0 +1,234 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, Optional, Tuple + +from flash.core.data.callback import BaseDataFetcher +from flash.core.data.data_module import DataModule +from flash.core.data.data_source import DefaultDataSources +from flash.core.data.process import Preprocess +from flash.core.integrations.icevision.data import ( + IceDataParserDataSource, + IceVisionParserDataSource, + IceVisionPathsDataSource, +) +from flash.core.integrations.icevision.transforms import default_transforms +from flash.core.utilities.imports import _ICEVISION_AVAILABLE + +if _ICEVISION_AVAILABLE: + from icevision.parsers import COCOMaskParser, VOCMaskParser + + +class InstanceSegmentationPreprocess(Preprocess): + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + image_size: Tuple[int, int] = (128, 128), + parser: Optional[Callable] = None, + ): + self.image_size = image_size + + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + "coco": IceVisionParserDataSource(parser=COCOMaskParser), + "voc": IceVisionParserDataSource(parser=VOCMaskParser), + DefaultDataSources.FILES: IceVisionPathsDataSource(), + DefaultDataSources.FOLDERS: IceDataParserDataSource(parser=parser), + }, + default_data_source=DefaultDataSources.FILES, + ) + + self._default_collate = self._identity + + def get_state_dict(self) -> Dict[str, Any]: + return {**self.transforms} + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): + return cls(**state_dict) + + def default_transforms(self) -> Optional[Dict[str, Callable]]: + return default_transforms(self.image_size) + + def train_default_transforms(self) -> Optional[Dict[str, Callable]]: + return default_transforms(self.image_size) + + +class InstanceSegmentationData(DataModule): + + preprocess_cls = InstanceSegmentationPreprocess + + @classmethod + def from_coco( + cls, + train_folder: Optional[str] = None, + train_ann_file: Optional[str] = None, + val_folder: Optional[str] = None, + val_ann_file: Optional[str] = None, + test_folder: Optional[str] = None, + test_ann_file: Optional[str] = None, + predict_folder: Optional[str] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ): + """Creates a :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData` object from the + given data folders and annotation files in the COCO format. + + Args: + train_folder: The folder containing the train data. + train_ann_file: The COCO format annotation file. + val_folder: The folder containing the validation data. + val_ann_file: The COCO format annotation file. + test_folder: The folder containing the test data. + test_ann_file: The COCO format annotation file. + predict_folder: The folder containing the predict data. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + will be constructed and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = InstanceSegmentationData.from_coco( + train_folder="train_folder", + train_ann_file="annotations.json", + ) + """ + return cls.from_data_source( + "coco", + (train_folder, train_ann_file) if train_folder else None, + (val_folder, val_ann_file) if val_folder else None, + (test_folder, test_ann_file) if test_folder else None, + predict_folder, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) + + @classmethod + def from_voc( + cls, + train_folder: Optional[str] = None, + train_ann_file: Optional[str] = None, + val_folder: Optional[str] = None, + val_ann_file: Optional[str] = None, + test_folder: Optional[str] = None, + test_ann_file: Optional[str] = None, + predict_folder: Optional[str] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ): + """Creates a :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData` object from the + given data folders and annotation files in the VOC format. + + Args: + train_folder: The folder containing the train data. + train_ann_file: The COCO format annotation file. + val_folder: The folder containing the validation data. + val_ann_file: The COCO format annotation file. + test_folder: The folder containing the test data. + test_ann_file: The COCO format annotation file. + predict_folder: The folder containing the predict data. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + will be constructed and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = InstanceSegmentationData.from_voc( + train_folder="train_folder", + train_ann_file="annotations.json", + ) + """ + return cls.from_data_source( + "voc", + (train_folder, train_ann_file) if train_folder else None, + (val_folder, val_ann_file) if val_folder else None, + (test_folder, test_ann_file) if test_folder else None, + predict_folder, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) diff --git a/flash/image/instance_segmentation/model.py b/flash/image/instance_segmentation/model.py new file mode 100644 index 0000000000..52f2706554 --- /dev/null +++ b/flash/image/instance_segmentation/model.py @@ -0,0 +1,85 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, List, Mapping, Optional, Type, Union + +import torch +from torch.optim import Optimizer + +from flash.core.adapter import AdapterTask +from flash.core.data.process import Serializer +from flash.core.registry import FlashRegistry +from flash.image.instance_segmentation.backbones import INSTANCE_SEGMENTATION_HEADS + + +class InstanceSegmentation(AdapterTask): + """The ``InstanceSegmentation`` is a :class:`~flash.Task` for detecting objects in images. For more details, see + :ref:`object_detection`. + + Args: + num_classes: the number of classes for detection, including background + model: a string of :attr`_models`. Defaults to 'fasterrcnn'. + backbone: Pretained backbone CNN architecture. Constructs a model with a + ResNet-50-FPN backbone when no backbone is specified. + fpn: If True, creates a Feature Pyramind Network on top of Resnet based CNNs. + pretrained: if true, returns a model pre-trained on COCO train2017 + pretrained_backbone: if true, returns a model with backbone pre-trained on Imagenet + trainable_backbone_layers: number of trainable resnet layers starting from final block. + Only applicable for `fasterrcnn`. + loss: the function(s) to update the model with. Has no effect for torchvision detection models. + metrics: The provided metrics. All metrics here will be logged to progress bar and the respective logger. + Changing this argument currently has no effect. + optimizer: The optimizer to use for training. Can either be the actual class or the class name. + pretrained: Whether the model from torchvision should be loaded with it's pretrained weights. + Has no effect for custom models. + learning_rate: The learning rate to use for training + + """ + + heads: FlashRegistry = INSTANCE_SEGMENTATION_HEADS + + required_extras: str = "image" + + def __init__( + self, + num_classes: int, + backbone: Optional[str] = "resnet18_fpn", + head: Optional[str] = "mask_rcnn", + pretrained: bool = True, + optimizer: Type[Optimizer] = torch.optim.Adam, + learning_rate: float = 5e-4, + serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + **kwargs: Any, + ): + self.save_hyperparameters() + + metadata = self.heads.get(head, with_metadata=True) + adapter = metadata["metadata"]["adapter"].from_task( + self, + num_classes=num_classes, + backbone=backbone, + head=head, + pretrained=pretrained, + **kwargs, + ) + + super().__init__( + adapter, + learning_rate=learning_rate, + optimizer=optimizer, + serializer=serializer, + ) + + def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: + """This function is used only for debugging usage with CI.""" + # todo diff --git a/flash/image/keypoint_detection/__init__.py b/flash/image/keypoint_detection/__init__.py new file mode 100644 index 0000000000..d397086e24 --- /dev/null +++ b/flash/image/keypoint_detection/__init__.py @@ -0,0 +1,2 @@ +from flash.image.keypoint_detection.data import KeypointDetectionData # noqa: F401 +from flash.image.keypoint_detection.model import KeypointDetector # noqa: F401 diff --git a/flash/image/keypoint_detection/backbones.py b/flash/image/keypoint_detection/backbones.py new file mode 100644 index 0000000000..72334761f2 --- /dev/null +++ b/flash/image/keypoint_detection/backbones.py @@ -0,0 +1,72 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +from typing import Optional + +from flash.core.adapter import Adapter +from flash.core.integrations.icevision.adapter import IceVisionAdapter +from flash.core.integrations.icevision.backbones import ( + get_backbones, + icevision_model_adapter, + load_icevision_ignore_image_size, +) +from flash.core.model import Task +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _TORCHVISION_AVAILABLE +from flash.core.utilities.providers import _ICEVISION, _TORCHVISION + +if _ICEVISION_AVAILABLE: + from icevision import models as icevision_models + from icevision.metrics import Metric as IceVisionMetric + +KEYPOINT_DETECTION_HEADS = FlashRegistry("heads") + + +class IceVisionKeypointDetectionAdapter(IceVisionAdapter): + @classmethod + def from_task( + cls, + task: Task, + num_keypoints: int, + num_classes: int = 2, + backbone: str = "resnet18_fpn", + head: str = "keypoint_rcnn", + pretrained: bool = True, + metrics: Optional["IceVisionMetric"] = None, + image_size: Optional = None, + **kwargs, + ) -> Adapter: + return super().from_task( + task, + num_keypoints=num_keypoints, + num_classes=num_classes, + backbone=backbone, + head=head, + pretrained=pretrained, + metrics=metrics, + image_size=image_size, + **kwargs, + ) + + +if _ICEVISION_AVAILABLE: + if _TORCHVISION_AVAILABLE: + model_type = icevision_models.torchvision.keypoint_rcnn + KEYPOINT_DETECTION_HEADS( + partial(load_icevision_ignore_image_size, icevision_model_adapter, model_type), + model_type.__name__.split(".")[-1], + backbones=get_backbones(model_type), + adapter=IceVisionKeypointDetectionAdapter, + providers=[_ICEVISION, _TORCHVISION], + ) diff --git a/flash/image/keypoint_detection/data.py b/flash/image/keypoint_detection/data.py new file mode 100644 index 0000000000..48e4b06a44 --- /dev/null +++ b/flash/image/keypoint_detection/data.py @@ -0,0 +1,154 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, Optional, Tuple + +from flash.core.data.callback import BaseDataFetcher +from flash.core.data.data_module import DataModule +from flash.core.data.data_source import DefaultDataSources +from flash.core.data.process import Preprocess +from flash.core.integrations.icevision.data import ( + IceDataParserDataSource, + IceVisionParserDataSource, + IceVisionPathsDataSource, +) +from flash.core.integrations.icevision.transforms import default_transforms +from flash.core.utilities.imports import _ICEVISION_AVAILABLE + +if _ICEVISION_AVAILABLE: + from icevision.parsers import COCOKeyPointsParser + + +class KeypointDetectionPreprocess(Preprocess): + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + image_size: Tuple[int, int] = (128, 128), + parser: Optional[Callable] = None, + ): + self.image_size = image_size + + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + "coco": IceVisionParserDataSource(parser=COCOKeyPointsParser), + DefaultDataSources.FILES: IceVisionPathsDataSource(), + DefaultDataSources.FOLDERS: IceDataParserDataSource(parser=parser), + }, + default_data_source=DefaultDataSources.FILES, + ) + + self._default_collate = self._identity + + def get_state_dict(self) -> Dict[str, Any]: + return {**self.transforms} + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): + return cls(**state_dict) + + def default_transforms(self) -> Optional[Dict[str, Callable]]: + return default_transforms(self.image_size) + + def train_default_transforms(self) -> Optional[Dict[str, Callable]]: + return default_transforms(self.image_size) + + +class KeypointDetectionData(DataModule): + + preprocess_cls = KeypointDetectionPreprocess + + @classmethod + def from_coco( + cls, + train_folder: Optional[str] = None, + train_ann_file: Optional[str] = None, + val_folder: Optional[str] = None, + val_ann_file: Optional[str] = None, + test_folder: Optional[str] = None, + test_ann_file: Optional[str] = None, + predict_folder: Optional[str] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ): + """Creates a :class:`~flash.image.keypoint_detection.data.KeypointDetectionData` object from the given data + folders and annotation files in the COCO format. + + Args: + train_folder: The folder containing the train data. + train_ann_file: The COCO format annotation file. + val_folder: The folder containing the validation data. + val_ann_file: The COCO format annotation file. + test_folder: The folder containing the test data. + test_ann_file: The COCO format annotation file. + predict_folder: The folder containing the predict data. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + will be constructed and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = KeypointDetectionData.from_coco( + train_folder="train_folder", + train_ann_file="annotations.json", + ) + """ + return cls.from_data_source( + "coco", + (train_folder, train_ann_file) if train_folder else None, + (val_folder, val_ann_file) if val_folder else None, + (test_folder, test_ann_file) if test_folder else None, + predict_folder, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) diff --git a/flash/image/keypoint_detection/model.py b/flash/image/keypoint_detection/model.py new file mode 100644 index 0000000000..b85177d083 --- /dev/null +++ b/flash/image/keypoint_detection/model.py @@ -0,0 +1,87 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, List, Mapping, Optional, Type, Union + +import torch +from torch.optim import Optimizer + +from flash.core.adapter import AdapterTask +from flash.core.data.process import Serializer +from flash.core.registry import FlashRegistry +from flash.image.keypoint_detection.backbones import KEYPOINT_DETECTION_HEADS + + +class KeypointDetector(AdapterTask): + """The ``ObjectDetector`` is a :class:`~flash.Task` for detecting objects in images. For more details, see + :ref:`object_detection`. + + Args: + num_classes: the number of classes for detection, including background + model: a string of :attr`_models`. Defaults to 'fasterrcnn'. + backbone: Pretained backbone CNN architecture. Constructs a model with a + ResNet-50-FPN backbone when no backbone is specified. + fpn: If True, creates a Feature Pyramind Network on top of Resnet based CNNs. + pretrained: if true, returns a model pre-trained on COCO train2017 + pretrained_backbone: if true, returns a model with backbone pre-trained on Imagenet + trainable_backbone_layers: number of trainable resnet layers starting from final block. + Only applicable for `fasterrcnn`. + loss: the function(s) to update the model with. Has no effect for torchvision detection models. + metrics: The provided metrics. All metrics here will be logged to progress bar and the respective logger. + Changing this argument currently has no effect. + optimizer: The optimizer to use for training. Can either be the actual class or the class name. + pretrained: Whether the model from torchvision should be loaded with it's pretrained weights. + Has no effect for custom models. + learning_rate: The learning rate to use for training + + """ + + heads: FlashRegistry = KEYPOINT_DETECTION_HEADS + + required_extras: str = "image" + + def __init__( + self, + num_keypoints: int, + num_classes: int = 2, + backbone: Optional[str] = "resnet18_fpn", + head: Optional[str] = "keypoint_rcnn", + pretrained: bool = True, + optimizer: Type[Optimizer] = torch.optim.Adam, + learning_rate: float = 5e-4, + serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + **kwargs: Any, + ): + self.save_hyperparameters() + + metadata = self.heads.get(head, with_metadata=True) + adapter = metadata["metadata"]["adapter"].from_task( + self, + num_keypoints=num_keypoints, + num_classes=num_classes, + backbone=backbone, + head=head, + pretrained=pretrained, + **kwargs, + ) + + super().__init__( + adapter, + learning_rate=learning_rate, + optimizer=optimizer, + serializer=serializer, + ) + + def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: + """This function is used only for debugging usage with CI.""" + # todo diff --git a/flash/pointcloud/detection/data.py b/flash/pointcloud/detection/data.py index 8931cf26b8..40349b8653 100644 --- a/flash/pointcloud/detection/data.py +++ b/flash/pointcloud/detection/data.py @@ -6,7 +6,7 @@ from flash.core.data.data_module import DataModule from flash.core.data.data_source import BaseDataFormat, DataSource, DefaultDataKeys, DefaultDataSources from flash.core.data.process import Deserializer, Preprocess -from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE +from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE, requires_extras if _POINTCLOUD_AVAILABLE: from flash.pointcloud.detection.open3d_ml.data_sources import ( @@ -14,7 +14,7 @@ PointCloudObjectDetectorFoldersDataSource, ) else: - PointCloudObjectDetectorFoldersDataSource = object() + PointCloudObjectDetectorFoldersDataSource = object class PointCloudObjectDetectionDataFormat: KITTI = None @@ -44,6 +44,7 @@ def load_sample(self, index: int, dataset: Optional[Any] = None) -> Any: class PointCloudObjectDetectorPreprocess(Preprocess): + @requires_extras("pointcloud") def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, diff --git a/flash/pointcloud/detection/model.py b/flash/pointcloud/detection/model.py index b17adb67ba..155126d785 100644 --- a/flash/pointcloud/detection/model.py +++ b/flash/pointcloud/detection/model.py @@ -163,8 +163,7 @@ def _process_dataset( shuffle: bool = False, drop_last: bool = True, sampler: Optional[Sampler] = None, - convert_to_dataloader: bool = True, - ) -> Union[DataLoader, BaseAutoDataset]: + ) -> DataLoader: if not _POINTCLOUD_AVAILABLE: raise ModuleNotFoundError("Please, run `pip install flash[pointcloud]`.") @@ -172,17 +171,13 @@ def _process_dataset( dataset.preprocess_fn = self.model.preprocess dataset.transform_fn = self.model.transform - if convert_to_dataloader: - return DataLoader( - dataset, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=pin_memory, - collate_fn=collate_fn, - shuffle=shuffle, - drop_last=drop_last, - sampler=sampler, - ) - - else: - return dataset + return DataLoader( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) diff --git a/flash/pointcloud/segmentation/model.py b/flash/pointcloud/segmentation/model.py index 7098aea98e..9342a61758 100644 --- a/flash/pointcloud/segmentation/model.py +++ b/flash/pointcloud/segmentation/model.py @@ -192,8 +192,7 @@ def _process_dataset( shuffle: bool = False, drop_last: bool = True, sampler: Optional[Sampler] = None, - convert_to_dataloader: bool = True, - ) -> Union[DataLoader, BaseAutoDataset]: + ) -> DataLoader: if not _POINTCLOUD_AVAILABLE: raise ModuleNotFoundError("Please, run `pip install flash[pointcloud]`.") @@ -207,20 +206,16 @@ def _process_dataset( use_cache=False, ) - if convert_to_dataloader: - return DataLoader( - dataset, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=pin_memory, - collate_fn=collate_fn, - shuffle=shuffle, - drop_last=drop_last, - sampler=sampler, - ) - - else: - return dataset + return DataLoader( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) def configure_finetune_callback(self) -> List[Callback]: return [PointCloudSegmentationFinetuning()] diff --git a/flash_examples/graph_classification.py b/flash_examples/graph_classification.py index 68c01e700e..4519f70c33 100644 --- a/flash_examples/graph_classification.py +++ b/flash_examples/graph_classification.py @@ -14,13 +14,12 @@ import torch import flash -from flash.core.utilities.imports import _TORCH_GEOMETRIC_AVAILABLE +from flash.core.utilities.imports import example_requires from flash.graph import GraphClassificationData, GraphClassifier -if _TORCH_GEOMETRIC_AVAILABLE: - from torch_geometric.datasets import TUDataset -else: - raise ModuleNotFoundError("Please, pip install -e '.[graph]'") +example_requires("graph") + +from torch_geometric.datasets import TUDataset # noqa: E402 # 1. Create the DataModule dataset = TUDataset(root="data", name="KKI") diff --git a/flash_examples/instance_segmentation.py b/flash_examples/instance_segmentation.py new file mode 100644 index 0000000000..16e5699d14 --- /dev/null +++ b/flash_examples/instance_segmentation.py @@ -0,0 +1,56 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +import flash +from flash.core.utilities.imports import example_requires +from flash.image import InstanceSegmentation, InstanceSegmentationData + +example_requires("image") + +import icedata # noqa: E402 + +# 1. Create the DataModule +data_dir = icedata.pets.load_data() + +datamodule = InstanceSegmentationData.from_folders( + train_folder=data_dir, + val_split=0.1, + image_size=128, + parser=partial(icedata.pets.parser, mask=True), +) + +# 2. Build the task +model = InstanceSegmentation( + head="mask_rcnn", + backbone="resnet18_fpn", + num_classes=datamodule.num_classes, +) + +# 3. Create the trainer and finetune the model +trainer = flash.Trainer(max_epochs=1) +trainer.finetune(model, datamodule=datamodule, strategy="freeze") + +# 4. Detect objects in a few images! +predictions = model.predict( + [ + str(data_dir / "images/yorkshire_terrier_9.jpg"), + str(data_dir / "images/english_cocker_spaniel_1.jpg"), + str(data_dir / "images/scottish_terrier_1.jpg"), + ] +) +print(predictions) + +# 5. Save the model! +trainer.save_checkpoint("instance_segmentation_model.pt") diff --git a/flash_examples/keypoint_detection.py b/flash_examples/keypoint_detection.py new file mode 100644 index 0000000000..731f0a8125 --- /dev/null +++ b/flash_examples/keypoint_detection.py @@ -0,0 +1,55 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import flash +from flash.core.utilities.imports import example_requires +from flash.image import KeypointDetectionData, KeypointDetector + +example_requires("image") + +import icedata # noqa: E402 + +# 1. Create the DataModule +data_dir = icedata.biwi.load_data() + +datamodule = KeypointDetectionData.from_folders( + train_folder=data_dir, + val_split=0.1, + image_size=128, + parser=icedata.biwi.parser, +) + +# 2. Build the task +model = KeypointDetector( + head="keypoint_rcnn", + backbone="resnet18_fpn", + num_keypoints=1, + num_classes=datamodule.num_classes, +) + +# 3. Create the trainer and finetune the model +trainer = flash.Trainer(max_epochs=1) +trainer.finetune(model, datamodule=datamodule, strategy="freeze") + +# 4. Detect objects in a few images! +predictions = model.predict( + [ + str(data_dir / "biwi_sample/images/0.jpg"), + str(data_dir / "biwi_sample/images/1.jpg"), + str(data_dir / "biwi_sample/images/10.jpg"), + ] +) +print(predictions) + +# 5. Save the model! +trainer.save_checkpoint("object_detection_model.pt") diff --git a/flash_examples/object_detection.py b/flash_examples/object_detection.py index 790193e67c..1a5dddbce9 100644 --- a/flash_examples/object_detection.py +++ b/flash_examples/object_detection.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch - import flash from flash.core.data.utils import download_data from flash.image import ObjectDetectionData, ObjectDetector @@ -25,15 +23,15 @@ train_folder="data/coco128/images/train2017/", train_ann_file="data/coco128/annotations/instances_train2017.json", val_split=0.1, - batch_size=2, + image_size=128, ) # 2. Build the task -model = ObjectDetector(model="retinanet", num_classes=datamodule.num_classes) +model = ObjectDetector(head="efficientdet", backbone="d0", num_classes=datamodule.num_classes, image_size=128) # 3. Create the trainer and finetune the model -trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) -trainer.finetune(model, datamodule=datamodule) +trainer = flash.Trainer(max_epochs=1) +trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Detect objects in a few images! predictions = model.predict( diff --git a/requirements/datatype_image.txt b/requirements/datatype_image.txt index 3be9ed638d..aa9fe14c15 100644 --- a/requirements/datatype_image.txt +++ b/requirements/datatype_image.txt @@ -5,3 +5,6 @@ Pillow>=7.2 kornia>=0.5.1,<0.5.4 pystiche==1.* segmentation-models-pytorch +icevision>=0.8 +icedata +effdet diff --git a/requirements/datatype_image_extras.txt b/requirements/datatype_image_extras.txt index 7e7370035f..f61e3f9c25 100644 --- a/requirements/datatype_image_extras.txt +++ b/requirements/datatype_image_extras.txt @@ -1,3 +1,2 @@ matplotlib -pycocotools>=2.0.2 ; python_version >= "3.7" fiftyone diff --git a/tests/core/data/test_callback.py b/tests/core/data/test_callback.py index e9b6b853a2..5db55dee08 100644 --- a/tests/core/data/test_callback.py +++ b/tests/core/data/test_callback.py @@ -23,8 +23,9 @@ from flash.core.trainer import Trainer +@mock.patch("pickle.dumps") # need to mock pickle or we get pickle error @mock.patch("torch.save") # need to mock torch.save or we get pickle error -def test_flash_callback(_, tmpdir): +def test_flash_callback(_, __, tmpdir): """Test the callback hook system for fit.""" callback_mock = MagicMock() diff --git a/tests/core/test_model.py b/tests/core/test_model.py index a94861c2be..23c08d96a0 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -28,6 +28,7 @@ from torch.utils.data import DataLoader import flash +from flash.core.adapter import Adapter from flash.core.classification import ClassificationTask from flash.core.data.process import DefaultPreprocess, Postprocess from flash.core.utilities.imports import _PIL_AVAILABLE, _TABULAR_AVAILABLE, _TEXT_AVAILABLE @@ -118,6 +119,30 @@ def __init__(self, child): super().__init__(Parent(child)) +class BasicAdapter(Adapter): + def __init__(self, child): + super().__init__() + + self.child = child + + def training_step(self, batch, batch_idx): + return self.child.training_step(batch, batch_idx) + + def validation_step(self, batch, batch_idx): + return self.child.validation_step(batch, batch_idx) + + def test_step(self, batch, batch_idx): + return self.child.test_step(batch, batch_idx) + + def forward(self, x): + return self.child(x) + + +class AdapterParent(Parent): + def __init__(self, child): + super().__init__(BasicAdapter(child)) + + # ================================ @@ -133,7 +158,7 @@ def test_classificationtask_train(tmpdir: str, metrics: Any): assert "test_nll_loss" in result[0] -@pytest.mark.parametrize("task", [Parent, GrandParent]) +@pytest.mark.parametrize("task", [Parent, GrandParent, AdapterParent]) def test_nested_tasks(tmpdir, task): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax()) train_dl = torch.utils.data.DataLoader(DummyDataset()) @@ -259,7 +284,7 @@ def test_available_backbones(): class Foo(ImageClassifier): backbones = None - assert Foo.available_backbones() == [] + assert Foo.available_backbones() == {} def test_optimization(tmpdir): diff --git a/tests/core/test_registry.py b/tests/core/test_registry.py index 674a3a4616..a230b869c0 100644 --- a/tests/core/test_registry.py +++ b/tests/core/test_registry.py @@ -27,8 +27,8 @@ def test_registry_raises(): def my_model(nc_input=5, nc_output=6): return nn.Linear(nc_input, nc_output), nc_input, nc_output - with pytest.raises(MisconfigurationException, match="You can only register a function, found: Linear"): - backbones(nn.Linear(1, 1), name="foo") + with pytest.raises(MisconfigurationException, match="You can only register a callable, found: 3"): + backbones(3, name="foo") backbones(my_model, name="foo", override=True) diff --git a/tests/image/detection/test_data.py b/tests/image/detection/test_data.py index 2c5b670671..50ce9fb196 100644 --- a/tests/image/detection/test_data.py +++ b/tests/image/detection/test_data.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import json import os from pathlib import Path @@ -135,15 +148,13 @@ def test_image_detector_data_from_coco(tmpdir): train_folder, coco_ann_path = _create_synth_coco_dataset(tmpdir) - datamodule = ObjectDetectionData.from_coco(train_folder=train_folder, train_ann_file=coco_ann_path, batch_size=1) + datamodule = ObjectDetectionData.from_coco( + train_folder=train_folder, train_ann_file=coco_ann_path, batch_size=1, image_size=128 + ) data = next(iter(datamodule.train_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] - - assert len(imgs) == 1 - assert imgs[0].shape == (3, 1080, 1920) - assert len(labels) == 1 - assert list(labels[0].keys()) == ["boxes", "labels", "image_id", "area", "iscrowd"] + sample = data[0] + assert sample[DefaultDataKeys.INPUT].shape == (128, 128, 3) assert datamodule.val_dataloader() is None assert datamodule.test_dataloader() is None @@ -157,23 +168,17 @@ def test_image_detector_data_from_coco(tmpdir): test_ann_file=coco_ann_path, batch_size=1, num_workers=0, + image_size=128, ) data = next(iter(datamodule.val_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] - assert len(imgs) == 1 - assert imgs[0].shape == (3, 1080, 1920) - assert len(labels) == 1 - assert list(labels[0].keys()) == ["boxes", "labels", "image_id", "area", "iscrowd"] + sample = data[0] + assert sample[DefaultDataKeys.INPUT].shape == (128, 128, 3) data = next(iter(datamodule.test_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] - - assert len(imgs) == 1 - assert imgs[0].shape == (3, 1080, 1920) - assert len(labels) == 1 - assert list(labels[0].keys()) == ["boxes", "labels", "image_id", "area", "iscrowd"] + sample = data[0] + assert sample[DefaultDataKeys.INPUT].shape == (128, 128, 3) @pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") @@ -182,15 +187,11 @@ def test_image_detector_data_from_fiftyone(tmpdir): train_dataset = _create_synth_fiftyone_dataset(tmpdir) - datamodule = ObjectDetectionData.from_fiftyone(train_dataset=train_dataset, batch_size=1) + datamodule = ObjectDetectionData.from_fiftyone(train_dataset=train_dataset, batch_size=1, image_size=128) data = next(iter(datamodule.train_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] - - assert len(imgs) == 1 - assert imgs[0].shape == (3, 1080, 1920) - assert len(labels) == 1 - assert list(labels[0].keys()) == ["boxes", "labels", "image_id", "area", "iscrowd"] + sample = data[0] + assert sample[DefaultDataKeys.INPUT].shape == (128, 128, 3) assert datamodule.val_dataloader() is None assert datamodule.test_dataloader() is None @@ -201,20 +202,13 @@ def test_image_detector_data_from_fiftyone(tmpdir): test_dataset=train_dataset, batch_size=1, num_workers=0, + image_size=128, ) data = next(iter(datamodule.val_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] - - assert len(imgs) == 1 - assert imgs[0].shape == (3, 1080, 1920) - assert len(labels) == 1 - assert list(labels[0].keys()) == ["boxes", "labels", "image_id", "area", "iscrowd"] + sample = data[0] + assert sample[DefaultDataKeys.INPUT].shape == (128, 128, 3) data = next(iter(datamodule.test_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] - - assert len(imgs) == 1 - assert imgs[0].shape == (3, 1080, 1920) - assert len(labels) == 1 - assert list(labels[0].keys()) == ["boxes", "labels", "image_id", "area", "iscrowd"] + sample = data[0] + assert sample[DefaultDataKeys.INPUT].shape == (128, 128, 3) diff --git a/tests/image/detection/test_data_model_integration.py b/tests/image/detection/test_data_model_integration.py index 51895a601c..1a9d47b9f0 100644 --- a/tests/image/detection/test_data_model_integration.py +++ b/tests/image/detection/test_data_model_integration.py @@ -20,6 +20,7 @@ from flash.core.utilities.imports import _COCO_AVAILABLE, _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE, _PIL_AVAILABLE from flash.image import ObjectDetector from flash.image.detection import ObjectDetectionData +from tests.helpers.utils import _IMAGE_TESTING if _PIL_AVAILABLE: from PIL import Image @@ -33,19 +34,18 @@ from tests.image.detection.test_data import _create_synth_fiftyone_dataset -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") -@pytest.mark.skipif(not _COCO_AVAILABLE, reason="pycocotools is not installed for testing") -@pytest.mark.parametrize(["model", "backbone"], [("fasterrcnn", "resnet18")]) -def test_detection(tmpdir, model, backbone): +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.parametrize(["head", "backbone"], [("retinanet", "resnet18_fpn")]) +def test_detection(tmpdir, head, backbone): train_folder, coco_ann_path = _create_synth_coco_dataset(tmpdir) data = ObjectDetectionData.from_coco(train_folder=train_folder, train_ann_file=coco_ann_path, batch_size=1) - model = ObjectDetector(model=model, backbone=backbone, num_classes=data.num_classes) + model = ObjectDetector(head=head, backbone=backbone, num_classes=data.num_classes) trainer = flash.Trainer(fast_dev_run=True, gpus=torch.cuda.device_count()) - trainer.finetune(model, data) + trainer.finetune(model, data, strategy="freeze") test_image_one = os.fspath(tmpdir / "test_one.png") test_image_two = os.fspath(tmpdir / "test_two.png") @@ -59,17 +59,17 @@ def test_detection(tmpdir, model, backbone): @pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing") -@pytest.mark.parametrize(["model", "backbone"], [("fasterrcnn", "resnet18")]) -def test_detection_fiftyone(tmpdir, model, backbone): +@pytest.mark.parametrize(["head", "backbone"], [("retinanet", "resnet18_fpn")]) +def test_detection_fiftyone(tmpdir, head, backbone): train_dataset = _create_synth_fiftyone_dataset(tmpdir) data = ObjectDetectionData.from_fiftyone(train_dataset=train_dataset, batch_size=1) - model = ObjectDetector(model=model, backbone=backbone, num_classes=data.num_classes) + model = ObjectDetector(head=head, backbone=backbone, num_classes=data.num_classes) trainer = flash.Trainer(fast_dev_run=True, gpus=torch.cuda.device_count()) - trainer.finetune(model, data) + trainer.finetune(model, data, strategy="freeze") test_image_one = os.fspath(tmpdir / "test_one.png") test_image_two = os.fspath(tmpdir / "test_two.png") diff --git a/tests/image/detection/test_model.py b/tests/image/detection/test_model.py index cae495794a..f3ed0dc445 100644 --- a/tests/image/detection/test_model.py +++ b/tests/image/detection/test_model.py @@ -11,21 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os +import random import re from unittest import mock +import numpy as np import pytest import torch from pytorch_lightning import Trainer -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import Dataset from flash.__main__ import main from flash.core.data.data_source import DefaultDataKeys -from flash.core.utilities.imports import _COCO_AVAILABLE, _IMAGE_AVAILABLE +from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_AVAILABLE from flash.image import ObjectDetector from tests.helpers.utils import _IMAGE_TESTING +if _ICEVISION_AVAILABLE: + from icevision.data import Prediction + def collate_fn(samples): return {key: [sample[key] for sample in samples] for key in samples[0]} @@ -46,13 +50,25 @@ def _random_bbox(self): c, h, w = self.img_shape xs = torch.randint(w - 1, (2,)) ys = torch.randint(h - 1, (2,)) - return [min(xs), min(ys), max(xs) + 1, max(ys) + 1] + return {"xmin": min(xs), "ymin": min(ys), "width": max(xs) - min(xs) + 1, "height": max(ys) - min(ys) + 1} def __getitem__(self, idx): - img = torch.rand(self.img_shape) - boxes = torch.tensor([self._random_bbox() for _ in range(self.num_boxes)]) - labels = torch.randint(self.num_classes, (self.num_boxes,)) - return {DefaultDataKeys.INPUT: img, DefaultDataKeys.TARGET: {"boxes": boxes, "labels": labels}} + sample = {} + + img = np.random.rand(*self.img_shape).astype(np.float32) + + sample[DefaultDataKeys.INPUT] = img + + sample[DefaultDataKeys.TARGET] = { + "bboxes": [], + "labels": [], + } + + for i in range(self.num_boxes): + sample[DefaultDataKeys.TARGET]["bboxes"].append(self._random_bbox()) + sample[DefaultDataKeys.TARGET]["labels"].append(random.randint(0, self.num_classes - 1)) + + return sample @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") @@ -61,45 +77,45 @@ def test_init(): model.eval() batch_size = 2 - ds = DummyDetectionDataset((3, 224, 224), 1, 2, 10) - dl = DataLoader(ds, collate_fn=collate_fn, batch_size=batch_size) + ds = DummyDetectionDataset((128, 128, 3), 1, 2, 10) + dl = model.process_predict_dataset(ds, batch_size=batch_size) data = next(iter(dl)) - img = data[DefaultDataKeys.INPUT] - out = model(img) + out = model(data) assert len(out) == batch_size - assert {"boxes", "labels", "scores"} <= out[0].keys() + assert all(isinstance(res, Prediction) for res in out) -@pytest.mark.parametrize("model", ["fasterrcnn", "retinanet"]) +@pytest.mark.parametrize("head", ["faster_rcnn", "retinanet"]) @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") -def test_training(tmpdir, model): - model = ObjectDetector(num_classes=2, model=model, pretrained=False, pretrained_backbone=False) - ds = DummyDetectionDataset((3, 224, 224), 1, 2, 10) - dl = DataLoader(ds, collate_fn=collate_fn) +def test_training(tmpdir, head): + model = ObjectDetector(num_classes=2, head=head, pretrained=False) + ds = DummyDetectionDataset((128, 128, 3), 1, 2, 10) + dl = model.process_train_dataset(ds, 2, 0, False, None) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.fit(model, dl) -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") -def test_jit(tmpdir): - path = os.path.join(tmpdir, "test.pt") - - model = ObjectDetector(2) - model.eval() - - model = torch.jit.script(model) # torch.jit.trace doesn't work with torchvision RCNN - - torch.jit.save(model, path) - model = torch.jit.load(path) - - out = model([torch.rand(3, 32, 32)]) - - # torchvision RCNN always returns a (Losses, Detections) tuple in scripting - out = out[1] - - assert {"boxes", "labels", "scores"} <= out[0].keys() +# TODO: resolve JIT issues +# @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +# def test_jit(tmpdir): +# path = os.path.join(tmpdir, "test.pt") +# +# model = ObjectDetector(2) +# model.eval() +# +# model = torch.jit.script(model) # torch.jit.trace doesn't work with torchvision RCNN +# +# torch.jit.save(model, path) +# model = torch.jit.load(path) +# +# out = model([torch.rand(3, 32, 32)]) +# +# # torchvision RCNN always returns a (Losses, Detections) tuple in scripting +# out = out[1] +# +# assert {"boxes", "labels", "scores"} <= out[0].keys() @pytest.mark.skipif(_IMAGE_AVAILABLE, reason="image libraries are installed.") @@ -109,7 +125,7 @@ def test_load_from_checkpoint_dependency_error(): @pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") -@pytest.mark.skipif(not _COCO_AVAILABLE, reason="pycocotools is not installed for testing.") +@pytest.mark.skipif(not _ICEVISION_AVAILABLE, reason="icevision is not installed.") def test_cli(): cli_args = ["flash", "object_detection", "--trainer.fast_dev_run", "True"] with mock.patch("sys.argv", cli_args): diff --git a/tests/image/test_backbones.py b/tests/image/test_backbones.py index 88888988fd..c751426c76 100644 --- a/tests/image/test_backbones.py +++ b/tests/image/test_backbones.py @@ -14,21 +14,18 @@ import urllib.error import pytest -from pytorch_lightning.utilities import _TORCHVISION_AVAILABLE -from flash.core.utilities.imports import _TIMM_AVAILABLE from flash.core.utilities.url_error import catch_url_error from flash.image.classification.backbones import IMAGE_CLASSIFIER_BACKBONES +from tests.helpers.utils import _IMAGE_TESTING @pytest.mark.parametrize( ["backbone", "expected_num_features"], [ - pytest.param("resnet34", 512, marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")), - pytest.param("mobilenetv2_100", 1280, marks=pytest.mark.skipif(not _TIMM_AVAILABLE, reason="No timm")), - pytest.param( - "mobilenet_v2", 1280, marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision") - ), + pytest.param("resnet34", 512, marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="No torchvision")), + pytest.param("mobilenetv2_100", 1280, marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="No timm")), + pytest.param("mobilenet_v2", 1280, marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="No torchvision")), ], ) def test_image_classifier_backbones_registry(backbone, expected_num_features): @@ -45,11 +42,9 @@ def test_image_classifier_backbones_registry(backbone, expected_num_features): "resnet50", "supervised", 2048, - marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision"), - ), - pytest.param( - "resnet50", "simclr", 2048, marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision") + marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="No torchvision"), ), + pytest.param("resnet50", "simclr", 2048, marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="No torchvision")), ], ) def test_pretrained_weights_registry(backbone, pretrained, expected_num_features):