Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add semantic segmentation task #239

Merged
merged 58 commits into from
May 10, 2021
Merged
Show file tree
Hide file tree
Changes from 51 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
e2f5f20
semantic segmentation skeleton
edgarriba Apr 22, 2021
f3ce4c7
expose and add smoke tests for preproces and datamodule
edgarriba Apr 23, 2021
1ef1b40
data module connections working
edgarriba Apr 23, 2021
7f17fb2
preprocess not crashing(wip)
edgarriba Apr 23, 2021
7d9d46c
implement segmentation sequential
edgarriba Apr 26, 2021
498e278
implement torchvision backbone model
edgarriba Apr 26, 2021
56fa4d5
model working
edgarriba Apr 26, 2021
950252e
implement labels mapping
edgarriba Apr 26, 2021
6a75245
add map labels tests
edgarriba Apr 26, 2021
7a7f855
from filepaths training test not crashing
edgarriba Apr 26, 2021
def1ea0
non working visualiser
edgarriba Apr 27, 2021
ed17eb0
fix visualiser
edgarriba Apr 27, 2021
3eb6417
training working
edgarriba Apr 27, 2021
d529d9e
training not crashing
edgarriba Apr 27, 2021
13095e6
cleanup example and move serializer to core
edgarriba Apr 28, 2021
2f9ede5
cleanup model code, tests and docs
edgarriba Apr 28, 2021
dc9b2b8
move transforms apart
edgarriba Apr 28, 2021
e767a53
implement ApplytransformsToKey augmentations
edgarriba Apr 28, 2021
f268b62
relative path
edgarriba Apr 28, 2021
99b99f0
fix load from pretrained and add resnet 101
edgarriba Apr 28, 2021
d1a91fd
create segmentation keys enum
edgarriba Apr 28, 2021
7343887
sync with master and fix val_split
edgarriba Apr 28, 2021
febe7f0
move apart segmentation backbones
edgarriba Apr 29, 2021
3891920
Merge branch 'master' into feat/segmentation
edgarriba Apr 29, 2021
248145b
fix tests
edgarriba Apr 29, 2021
6d635db
fix tests
edgarriba Apr 29, 2021
ca97034
fix tests
edgarriba Apr 29, 2021
da83251
fix memory leak issues
edgarriba Apr 29, 2021
f1e76f9
Merge branch 'master' into feat/segmentation
edgarriba Apr 29, 2021
2ef8c88
undo function filtering
edgarriba Apr 29, 2021
87a92f3
fix import
edgarriba Apr 29, 2021
73d462b
more fixes for memory leaks
edgarriba Apr 29, 2021
8b971a4
add segmentation to docs
edgarriba Apr 30, 2021
69358a6
add inference example
edgarriba Apr 30, 2021
caabfb6
add image to docs and update with AdamW
edgarriba Apr 30, 2021
78301ff
Merge branch 'master' into feat/segmentation
ethanwharris May 7, 2021
e8e92d1
Make pretrained arg kwarg
ethanwharris May 7, 2021
cf430f3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 7, 2021
74ce6dc
Data sources initial commit
ethanwharris May 7, 2021
df2b989
Update transforms
ethanwharris May 7, 2021
bb95b8f
Updates
ethanwharris May 7, 2021
3a8842d
Fixes
ethanwharris May 8, 2021
3596e16
Fix tests
ethanwharris May 8, 2021
859a0ef
Fixes
ethanwharris May 8, 2021
268b4c8
Fixes
ethanwharris May 8, 2021
0b67769
Merge branch 'master' into feat/segmentation
ethanwharris May 8, 2021
091e50d
Merge branch 'master' into feat/segmentation
ethanwharris May 10, 2021
4aa3716
Add tests
ethanwharris May 10, 2021
f50c200
Update docs/source/reference/semantic_segmentation.rst
ethanwharris May 10, 2021
11ed7c5
Update docs/source/reference/semantic_segmentation.rst
ethanwharris May 10, 2021
0fc3581
Add a check
ethanwharris May 10, 2021
1875967
Move KorniaParallelTransforms and add docstring
ethanwharris May 10, 2021
e9dee30
implement quick test for segmentation labels
edgarriba May 10, 2021
0049197
add small assertion tests
edgarriba May 10, 2021
4c75774
Rename test_serialisation.py to test_serialization.py
ethanwharris May 10, 2021
2d76f37
Switch to exception
ethanwharris May 10, 2021
5f254b6
Fix
ethanwharris May 10, 2021
8745191
Fixes
ethanwharris May 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,5 @@ wmt_en_ro
action_youtube_naudio
kinetics
movie_posters
CameraRGB
CameraSeg
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Lightning Flash
reference/translation
reference/object_detection
reference/video_classification

reference/semantic_segmentation

.. toctree::
:maxdepth: 1
Expand Down
151 changes: 151 additions & 0 deletions docs/source/reference/semantic_segmentation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@

.. _semantinc_segmentation:

######################
Semantinc Segmentation
######################

********
The task
********
Semantic segmentation, or image segmentation, is the task of performing classification at a pixel-level, meaning each pixel will associated to a given class. The model output shape is ``(batch_size, num_classes, heigh, width)``.

See more: https://paperswithcode.com/task/semantic-segmentation

.. raw:: html

<p>
<a href="https://i2.wp.com/syncedreview.com/wp-content/uploads/2019/12/image-9-1.png" >
<img src="https://i2.wp.com/syncedreview.com/wp-content/uploads/2019/12/image-9-1.png"/>
</a>
</p>

------

*********
Inference
*********

A :class:`~flash.vision.SemanticSegmentation` `fcn_resnet50` pre-trained on `CARLA <http://carla.org/>`_ simulator is provided for the inference example.


Use the :class:`~flash.vision.SemanticSegmentation` pretrained model for inference on any string sequence using :func:`~flash.vision.SemanticSegmentation.predict`:

.. code-block:: python

# import our libraries
from flash.data.utils import download_data
from flash.vision import SemanticSegmentation
from flash.vision.segmentation.serialization import SegmentationLabels

# 1. Download the data
download_data(
"https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip",
"data/"
)

# 2. Load the model from a checkpoint
model = SemanticSegmentation.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/semantic_segmentation_model.pt"
)
model.serializer = SegmentationLabels(visualize=True)

# 3. Predict what's on a few images and visualize!
predictions = model.predict([
'data/CameraRGB/F61-1.png',
'data/CameraRGB/F62-1.png',
'data/CameraRGB/F63-1.png',
])

For more advanced inference options, see :ref:`predictions`.

------

**********
Finetuning
**********

you now want to customise your model with new data using the same dataset.
Once we download the data using :func:`~flash.data.download_data`, all we need is the train data and validation data folders to create the :class:`~flash.vision.SemanticSegmentationData`.

.. note:: the dataset is structured in a way that each sample (an image and its corresponding labels) is stored in separated directories but keeping the same filename.

.. code-block::

data
├── CameraRGB
│ ├── F61-1.png
│ ├── F61-2.png
│ ...
└── CameraSeg
├── F61-1.png
├── F61-2.png
...


Now all we need is three lines of code to build to train our task!

.. code-block:: python

import flash
from flash.data.utils import download_data
from flash.vision import SemanticSegmentation, SemanticSegmentationData
from flash.vision.segmentation.serialization import SegmentationLabels

# 1. Download the data
download_data(
"https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip",
"data/"
)

# 2.1 Load the data
datamodule = SemanticSegmentationData.from_folders(
train_folder="data/CameraRGB",
train_target_folder="data/CameraSeg",
batch_size=4,
val_split=0.3,
image_size=(200, 200), # (600, 800)
)

# 2.2 Visualise the samples
labels_map = SegmentationLabels.create_random_labels_map(num_classes=21)
datamodule.set_labels_map(labels_map)
datamodule.show_train_batch(["load_sample", "post_tensor_transform"])

# 3. Build the model
model = SemanticSegmentation(backbone="torchvision/fcn_resnet50", num_classes=21)

# 4. Create the trainer.
trainer = flash.Trainer(max_epochs=1)

# 5. Train the model
trainer.finetune(model, datamodule=datamodule, strategy='freeze')

# 7. Save it!
trainer.save_checkpoint("semantic_segmentation_model.pt")

------

*************
API reference
*************

.. _segmentation:

SemanticSegmentation
--------------------

.. autoclass:: flash.vision.SemanticSegmentation
:members:
:exclude-members: forward

.. _segmentation_data:

SemanticSegmentationData
------------------------

.. autoclass:: flash.vision.SemanticSegmentationData

.. automethod:: flash.vision.SemanticSegmentationData.from_folders

.. autoclass:: flash.vision.SemanticSegmentationPreprocess
3 changes: 2 additions & 1 deletion flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def __init__(
def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor:
if getattr(self.hparams, "multi_label", False):
return torch.sigmoid(x)
return torch.softmax(x, -1)
# we'll assume that the data always comes as `(B, C, ...)`
return torch.softmax(x, dim=1)


class ClassificationSerializer(Serializer):
Expand Down
6 changes: 6 additions & 0 deletions flash/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,12 @@ def __init__(
self._per_batch_transform_context = CurrentFuncContext(f"per_batch_transform{extension}", preprocess)

def forward(self, samples: Sequence[Any]) -> Any:
# we create a new dict to prevent from potential memory leaks
# assuming that the dictionary samples are stored in between and
# potentially modified before the transforms are applied.
if isinstance(samples, dict):
samples = dict(samples.items())

with self._current_stage_context:

if self.apply_per_sample_transform:
Expand Down
2 changes: 2 additions & 0 deletions flash/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ def _show_batch(self, stage: str, func_names: Union[str, List[str]], reset: bool
_ = next(iter_dataloader)
data_fetcher: BaseVisualization = self.data_fetcher
data_fetcher._show(stage, func_names)
if reset:
self.data_fetcher.batches[stage] = {}

def show_train_batch(self, hooks_names: Union[str, List[str]] = 'load_sample', reset: bool = True) -> None:
"""This function is used to visualize a batch from the train dataloader."""
Expand Down
11 changes: 7 additions & 4 deletions flash/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,18 @@ def __init__(self, keys: Union[str, Sequence[str]], *args):
self.keys = keys

def forward(self, x: Mapping[str, Any]) -> Mapping[str, Any]:
inputs = [x[key] for key in filter(lambda key: key in x, self.keys)]
keys = list(filter(lambda key: key in x, self.keys))
inputs = [x[key] for key in keys]
if len(inputs) > 0:
outputs = super().forward(*inputs)
if not isinstance(outputs, tuple):
if len(inputs) == 1:
inputs = inputs[0]
outputs = super().forward(inputs)
if not isinstance(outputs, Sequence):
outputs = (outputs, )

result = {}
result.update(x)
for i, key in enumerate(self.keys):
for i, key in enumerate(keys):
result[key] = outputs[i]
return result
return x
1 change: 1 addition & 0 deletions flash/vision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from flash.vision.classification import ImageClassificationData, ImageClassificationPreprocess, ImageClassifier
from flash.vision.detection import ObjectDetectionData, ObjectDetector
from flash.vision.embedding import ImageEmbedder
from flash.vision.segmentation import SemanticSegmentation, SemanticSegmentationData, SemanticSegmentationPreprocess
2 changes: 1 addition & 1 deletion flash/vision/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def collate(self, samples: Sequence[Dict[str, Any]]) -> Any:
for key in sample.keys():
if torch.is_tensor(sample[key]):
sample[key] = sample[key].squeeze(0)
return default_collate(samples)
return super().collate(samples)

@property
def default_train_transforms(self) -> Optional[Dict[str, Callable]]:
Expand Down
2 changes: 2 additions & 0 deletions flash/vision/segmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from flash.vision.segmentation.data import SemanticSegmentationData, SemanticSegmentationPreprocess
from flash.vision.segmentation.model import SemanticSegmentation
36 changes: 36 additions & 0 deletions flash/vision/segmentation/backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch.nn as nn

from flash.core.registry import FlashRegistry
from flash.utils.imports import _TORCHVISION_AVAILABLE

if _TORCHVISION_AVAILABLE:
import torchvision

SEMANTIC_SEGMENTATION_BACKBONES = FlashRegistry("backbones")


@SEMANTIC_SEGMENTATION_BACKBONES(name="torchvision/fcn_resnet50")
def load_torchvision_fcn_resnet50(num_classes: int, pretrained: bool = True) -> nn.Module:
model = torchvision.models.segmentation.fcn_resnet50(pretrained=pretrained)
model.classifier[-1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))
return model


@SEMANTIC_SEGMENTATION_BACKBONES(name="torchvision/fcn_resnet101")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any library should we integrate there ? Like IceVision ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Torchvision should be good enough for now.
We already have heavy dependencies.

def load_torchvision_fcn_resnet101(num_classes: int, pretrained: bool = True) -> nn.Module:
model = torchvision.models.segmentation.fcn_resnet101(pretrained=pretrained)
model.classifier[-1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))
return model
Loading