Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add face detection CLI (#1170)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Feb 14, 2022
1 parent 3c84e94 commit 4c62482
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 14 deletions.
1 change: 1 addition & 0 deletions flash/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def wrapper(cli_args):
"flash.graph.classification",
"flash.image.classification",
"flash.image.detection",
"flash.image.face_detection",
"flash.image.instance_segmentation",
"flash.image.keypoint_detection",
"flash.image.segmentation",
Expand Down
5 changes: 4 additions & 1 deletion flash/image/classification/integrations/baal/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,10 @@ def train_dataloader(self) -> "DataLoader":
if self.has_labelled_data and self.val_split:
self.val_dataloader = self._val_dataloader

return self.labelled.train_dataloader()
if self.has_labelled_data:
return self.labelled.train_dataloader()
# Return a dummy dataloader, will be replaced by the loop
return DataLoader(["dummy"])

def _val_dataloader(self) -> "DataLoader":
self.labelled._val_input = train_val_split(self._dataset, self.val_split)[1]
Expand Down
1 change: 1 addition & 0 deletions flash/image/classification/integrations/baal/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def _reset_dataloader_for_stage(self, running_state: RunningStage):
if is_overridden(dataloader_name, self.trainer.datamodule)
else None
)

if dataloader:
if _PL_GREATER_EQUAL_1_5_0:
setattr(
Expand Down
54 changes: 54 additions & 0 deletions flash/image/face_detection/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from flash.core.utilities.flash_cli import FlashCLI
from flash.image.face_detection.data import FaceDetectionData
from flash.image.face_detection.model import FaceDetector

__all__ = ["face_detection"]


def from_fddb(
batch_size: int = 1,
**data_module_kwargs,
) -> FaceDetectionData:
"""Downloads and loads the FDDB data set."""
import fastface as ff

train_dataset = ff.dataset.FDDBDataset(source_dir="data/", phase="train")
val_dataset = ff.dataset.FDDBDataset(source_dir="data/", phase="val")

return FaceDetectionData.from_datasets(
train_dataset=train_dataset,
val_dataset=val_dataset,
batch_size=batch_size,
**data_module_kwargs,
)


def face_detection():
"""Detect faces in images."""
cli = FlashCLI(
FaceDetector,
FaceDetectionData,
default_datamodule_builder=from_fddb,
default_arguments={
"trainer.max_epochs": 3,
},
)

cli.trainer.save_checkpoint("face_detection_model.pt")


if __name__ == "__main__":
face_detection()
2 changes: 1 addition & 1 deletion flash/image/face_detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(
learning_rate=learning_rate,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
output_transform=FaceDetectionOutputTransform,
output_transform=FaceDetectionOutputTransform(),
)

@staticmethod
Expand Down
3 changes: 1 addition & 2 deletions flash/image/face_detection/output_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,5 @@ def per_batch_transform(batch: Any) -> Any:
# preds: list of torch.Tensor(N, 5) as x1, y1, x2, y2, score
preds = [preds[preds[:, 5] == batch_idx, :5] for batch_idx in range(len(preds))]
preds = ff.utils.preprocess.adjust_results(preds, scales, paddings)
batch[DataKeys.PREDS] = preds

return batch
return preds
3 changes: 3 additions & 0 deletions requirements/datatype_image_extras.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@ structlog==21.1.0 # remove when baal resolved its dependency.
baal
fastface
fairscale

# pin PL for testing, remove when fastface is updated
pytorch-lightning<1.5.0
23 changes: 13 additions & 10 deletions tests/image/face_detection/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock

import pytest
import torch

import flash
from flash.__main__ import main
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _FASTFACE_AVAILABLE
from flash.image import FaceDetectionData, FaceDetector
Expand Down Expand Up @@ -42,15 +44,6 @@ def test_fastface_training():
trainer.predict(model, datamodule=datamodule)


@pytest.mark.skipif(not _FASTFACE_AVAILABLE, reason="fastface not installed.")
def test_fastface_forward():
model = FaceDetector(model="lffd_slim")
mock_batch = torch.randn(2, 3, 256, 256)

# test model forward (tests: _prepare_batch, logits_to_preds, _output_transform from ff)
model(mock_batch)


@pytest.mark.skipif(not _FASTFACE_AVAILABLE, reason="fastface not installed.")
def test_fastface_backbones_registry():
backbones = FACE_DETECTION_BACKBONES.available_keys()
Expand All @@ -59,3 +52,13 @@ def test_fastface_backbones_registry():

backbone, _ = FACE_DETECTION_BACKBONES.get("lffd_original")(pretrained=False)
assert isinstance(backbone, LFFD)


@pytest.mark.skipif(not _FASTFACE_AVAILABLE, reason="fastface not installed.")
def test_cli():
cli_args = ["flash", "face_detection", "--trainer.fast_dev_run", "True"]
with mock.patch("sys.argv", cli_args):
try:
main()
except SystemExit:
pass

0 comments on commit 4c62482

Please sign in to comment.