diff --git a/flash/__main__.py b/flash/__main__.py index 1f1eba0580..1f521bb2a8 100644 --- a/flash/__main__.py +++ b/flash/__main__.py @@ -44,6 +44,7 @@ def wrapper(cli_args): "flash.graph.classification", "flash.image.classification", "flash.image.detection", + "flash.image.face_detection", "flash.image.instance_segmentation", "flash.image.keypoint_detection", "flash.image.segmentation", diff --git a/flash/image/classification/integrations/baal/data.py b/flash/image/classification/integrations/baal/data.py index 12524bb170..954b937a25 100644 --- a/flash/image/classification/integrations/baal/data.py +++ b/flash/image/classification/integrations/baal/data.py @@ -141,7 +141,10 @@ def train_dataloader(self) -> "DataLoader": if self.has_labelled_data and self.val_split: self.val_dataloader = self._val_dataloader - return self.labelled.train_dataloader() + if self.has_labelled_data: + return self.labelled.train_dataloader() + # Return a dummy dataloader, will be replaced by the loop + return DataLoader(["dummy"]) def _val_dataloader(self) -> "DataLoader": self.labelled._val_input = train_val_split(self._dataset, self.val_split)[1] diff --git a/flash/image/classification/integrations/baal/loop.py b/flash/image/classification/integrations/baal/loop.py index fa4fad7e60..1497f6be42 100644 --- a/flash/image/classification/integrations/baal/loop.py +++ b/flash/image/classification/integrations/baal/loop.py @@ -185,6 +185,7 @@ def _reset_dataloader_for_stage(self, running_state: RunningStage): if is_overridden(dataloader_name, self.trainer.datamodule) else None ) + if dataloader: if _PL_GREATER_EQUAL_1_5_0: setattr( diff --git a/flash/image/face_detection/cli.py b/flash/image/face_detection/cli.py new file mode 100644 index 0000000000..59f9530700 --- /dev/null +++ b/flash/image/face_detection/cli.py @@ -0,0 +1,54 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from flash.core.utilities.flash_cli import FlashCLI +from flash.image.face_detection.data import FaceDetectionData +from flash.image.face_detection.model import FaceDetector + +__all__ = ["face_detection"] + + +def from_fddb( + batch_size: int = 1, + **data_module_kwargs, +) -> FaceDetectionData: + """Downloads and loads the FDDB data set.""" + import fastface as ff + + train_dataset = ff.dataset.FDDBDataset(source_dir="data/", phase="train") + val_dataset = ff.dataset.FDDBDataset(source_dir="data/", phase="val") + + return FaceDetectionData.from_datasets( + train_dataset=train_dataset, + val_dataset=val_dataset, + batch_size=batch_size, + **data_module_kwargs, + ) + + +def face_detection(): + """Detect faces in images.""" + cli = FlashCLI( + FaceDetector, + FaceDetectionData, + default_datamodule_builder=from_fddb, + default_arguments={ + "trainer.max_epochs": 3, + }, + ) + + cli.trainer.save_checkpoint("face_detection_model.pt") + + +if __name__ == "__main__": + face_detection() diff --git a/flash/image/face_detection/model.py b/flash/image/face_detection/model.py index 9be3f51dcc..ba009ffce3 100644 --- a/flash/image/face_detection/model.py +++ b/flash/image/face_detection/model.py @@ -71,7 +71,7 @@ def __init__( learning_rate=learning_rate, optimizer=optimizer, lr_scheduler=lr_scheduler, - output_transform=FaceDetectionOutputTransform, + output_transform=FaceDetectionOutputTransform(), ) @staticmethod diff --git a/flash/image/face_detection/output_transform.py b/flash/image/face_detection/output_transform.py index 0308d82dfb..2ee34502ad 100644 --- a/flash/image/face_detection/output_transform.py +++ b/flash/image/face_detection/output_transform.py @@ -37,6 +37,5 @@ def per_batch_transform(batch: Any) -> Any: # preds: list of torch.Tensor(N, 5) as x1, y1, x2, y2, score preds = [preds[preds[:, 5] == batch_idx, :5] for batch_idx in range(len(preds))] preds = ff.utils.preprocess.adjust_results(preds, scales, paddings) - batch[DataKeys.PREDS] = preds - return batch + return preds diff --git a/requirements/datatype_image_extras.txt b/requirements/datatype_image_extras.txt index cd957f0e62..c00028b4cf 100644 --- a/requirements/datatype_image_extras.txt +++ b/requirements/datatype_image_extras.txt @@ -11,3 +11,6 @@ structlog==21.1.0 # remove when baal resolved its dependency. baal fastface fairscale + +# pin PL for testing, remove when fastface is updated +pytorch-lightning<1.5.0 diff --git a/tests/image/face_detection/test_model.py b/tests/image/face_detection/test_model.py index a60d14d2cf..d57b8b8590 100644 --- a/tests/image/face_detection/test_model.py +++ b/tests/image/face_detection/test_model.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock + import pytest -import torch import flash +from flash.__main__ import main from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _FASTFACE_AVAILABLE from flash.image import FaceDetectionData, FaceDetector @@ -42,15 +44,6 @@ def test_fastface_training(): trainer.predict(model, datamodule=datamodule) -@pytest.mark.skipif(not _FASTFACE_AVAILABLE, reason="fastface not installed.") -def test_fastface_forward(): - model = FaceDetector(model="lffd_slim") - mock_batch = torch.randn(2, 3, 256, 256) - - # test model forward (tests: _prepare_batch, logits_to_preds, _output_transform from ff) - model(mock_batch) - - @pytest.mark.skipif(not _FASTFACE_AVAILABLE, reason="fastface not installed.") def test_fastface_backbones_registry(): backbones = FACE_DETECTION_BACKBONES.available_keys() @@ -59,3 +52,13 @@ def test_fastface_backbones_registry(): backbone, _ = FACE_DETECTION_BACKBONES.get("lffd_original")(pretrained=False) assert isinstance(backbone, LFFD) + + +@pytest.mark.skipif(not _FASTFACE_AVAILABLE, reason="fastface not installed.") +def test_cli(): + cli_args = ["flash", "face_detection", "--trainer.fast_dev_run", "True"] + with mock.patch("sys.argv", cli_args): + try: + main() + except SystemExit: + pass