This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 212
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add instance segmentation and keypoint detection to flash zero (#672)
* Add instance segmentation and keypoint detection to flash zero * Add instance segmentation and keypoint detection to flash zero * Add docs * Uodate CHANGELOG.md * Fixes
- Loading branch information
1 parent
741a838
commit 67b227f
Showing
14 changed files
with
236 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from functools import partial | ||
from typing import Callable, Optional | ||
|
||
from flash.core.utilities.flash_cli import FlashCLI | ||
from flash.core.utilities.imports import _ICEDATA_AVAILABLE, requires_extras | ||
from flash.image import InstanceSegmentation, InstanceSegmentationData | ||
|
||
if _ICEDATA_AVAILABLE: | ||
import icedata | ||
|
||
__all__ = ["instance_segmentation"] | ||
|
||
|
||
@requires_extras("image") | ||
def from_pets( | ||
val_split: float = 0.1, | ||
batch_size: int = 4, | ||
num_workers: Optional[int] = None, | ||
parser: Optional[Callable] = None, | ||
**preprocess_kwargs, | ||
) -> InstanceSegmentationData: | ||
"""Downloads and loads the pets data set from icedata.""" | ||
data_dir = icedata.pets.load_data() | ||
|
||
if parser is None: | ||
parser = partial(icedata.pets.parser, mask=True) | ||
|
||
return InstanceSegmentationData.from_folders( | ||
train_folder=data_dir, | ||
val_split=val_split, | ||
batch_size=batch_size, | ||
num_workers=num_workers, | ||
parser=parser, | ||
**preprocess_kwargs, | ||
) | ||
|
||
|
||
def instance_segmentation(): | ||
"""Segment object instances in images.""" | ||
cli = FlashCLI( | ||
InstanceSegmentation, | ||
InstanceSegmentationData, | ||
default_datamodule_builder=from_pets, | ||
default_arguments={ | ||
"trainer.max_epochs": 3, | ||
}, | ||
) | ||
|
||
cli.trainer.save_checkpoint("instance_segmentation_model.pt") | ||
|
||
|
||
if __name__ == "__main__": | ||
instance_segmentation() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Callable, Optional | ||
|
||
from flash.core.utilities.flash_cli import FlashCLI | ||
from flash.core.utilities.imports import _ICEDATA_AVAILABLE, requires_extras | ||
from flash.image import KeypointDetectionData, KeypointDetector | ||
|
||
if _ICEDATA_AVAILABLE: | ||
import icedata | ||
|
||
__all__ = ["keypoint_detection"] | ||
|
||
|
||
@requires_extras("image") | ||
def from_biwi( | ||
val_split: float = 0.1, | ||
batch_size: int = 4, | ||
num_workers: Optional[int] = None, | ||
parser: Optional[Callable] = None, | ||
**preprocess_kwargs, | ||
) -> KeypointDetectionData: | ||
"""Downloads and loads the BIWI data set from icedata.""" | ||
data_dir = icedata.biwi.load_data() | ||
|
||
if parser is None: | ||
parser = icedata.biwi.parser | ||
|
||
return KeypointDetectionData.from_folders( | ||
train_folder=data_dir, | ||
val_split=val_split, | ||
batch_size=batch_size, | ||
num_workers=num_workers, | ||
parser=parser, | ||
**preprocess_kwargs, | ||
) | ||
|
||
|
||
def keypoint_detection(): | ||
"""Detect keypoints in images.""" | ||
cli = FlashCLI( | ||
KeypointDetector, | ||
KeypointDetectionData, | ||
default_datamodule_builder=from_biwi, | ||
default_arguments={ | ||
"model.num_keypoints": 1, | ||
"trainer.max_epochs": 3, | ||
}, | ||
) | ||
|
||
cli.trainer.save_checkpoint("keypoint_detection_model.pt") | ||
|
||
|
||
if __name__ == "__main__": | ||
keypoint_detection() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from unittest import mock | ||
|
||
import pytest | ||
|
||
from flash.__main__ import main | ||
from tests.helpers.utils import _IMAGE_TESTING | ||
|
||
|
||
@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") | ||
def test_cli(): | ||
cli_args = ["flash", "instance_segmentation", "--trainer.fast_dev_run", "True"] | ||
with mock.patch("sys.argv", cli_args): | ||
try: | ||
main() | ||
except SystemExit: | ||
pass |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from unittest import mock | ||
|
||
import pytest | ||
|
||
from flash.__main__ import main | ||
from tests.helpers.utils import _IMAGE_TESTING | ||
|
||
|
||
@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") | ||
def test_cli(): | ||
cli_args = ["flash", "keypoint_detection", "--trainer.fast_dev_run", "True"] | ||
with mock.patch("sys.argv", cli_args): | ||
try: | ||
main() | ||
except SystemExit: | ||
pass |