Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[Feat] Add PointCloud ObjectDetection #600

Merged
merged 28 commits into from
Jul 16, 2021
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
2b8c5c9
wip
tchaton Jul 15, 2021
dc975fe
wip
tchaton Jul 16, 2021
bd98e1b
wip
tchaton Jul 16, 2021
cde483e
add tests
tchaton Jul 16, 2021
abdfb56
add docs
tchaton Jul 16, 2021
bb13ec7
update changelog
tchaton Jul 16, 2021
203ab2d
update
tchaton Jul 16, 2021
2b183fd
update
tchaton Jul 16, 2021
72dd0de
update
tchaton Jul 16, 2021
b6d5aa1
update
tchaton Jul 16, 2021
a4b7d09
update
tchaton Jul 16, 2021
c2b1913
update
tchaton Jul 16, 2021
196989a
update
tchaton Jul 16, 2021
4959772
update
tchaton Jul 16, 2021
c06ff5b
update
tchaton Jul 16, 2021
8326f8c
update
tchaton Jul 16, 2021
f4d1729
Merge branch 'master' into pointcloud_obj
mergify[bot] Jul 16, 2021
6c6f678
update
tchaton Jul 16, 2021
67896fd
Merge branch 'pointcloud_obj' of https://github.com/PyTorchLightning/…
tchaton Jul 16, 2021
b3d4f2d
Update tests/pointcloud/detection/test_data.py
ethanwharris Jul 16, 2021
30cc226
Apply suggestions from code review
ethanwharris Jul 16, 2021
ac7d52b
Update tests/pointcloud/detection/test_data.py
ethanwharris Jul 16, 2021
aa4b9c1
Update tests/pointcloud/detection/test_data.py
ethanwharris Jul 16, 2021
ac88b88
Update tests/pointcloud/detection/test_data.py
ethanwharris Jul 16, 2021
8b1ad92
Update tests/pointcloud/detection/test_data.py
ethanwharris Jul 16, 2021
4e94b3e
resolve test
tchaton Jul 16, 2021
f3d0272
Merge branch 'pointcloud_obj' of https://github.com/PyTorchLightning/…
tchaton Jul 16, 2021
b3813a4
Update tests/pointcloud/detection/test_data.py
ethanwharris Jul 16, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `PointCloudSegmentation` Task ([#566](https://github.com/PyTorchLightning/lightning-flash/pull/566))

- Added `PointCloudObjectDetection` Task ([#600](https://github.com/PyTorchLightning/lightning-flash/pull/600))

- Added a `GraphClassifier` task ([#73](https://github.com/PyTorchLightning/lightning-flash/pull/73))

- Added the option to pass `pretrained` as a string to `SemanticSegmentation` to change pretrained weights to load from `segmentation-models.pytorch` ([#587](https://github.com/PyTorchLightning/lightning-flash/pull/587))
Expand Down
16 changes: 16 additions & 0 deletions docs/source/api/pointcloud.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,19 @@ ____________
segmentation.data.PointCloudSegmentationPreprocess
segmentation.data.PointCloudSegmentationFoldersDataSource
segmentation.data.PointCloudSegmentationDatasetDataSource


Object Detection
________________

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

~detection.model.PointCloudObjectDetector
~detection.data.PointCloudObjectDetectorData

detection.data.PointCloudObjectDetectorPreprocess
detection.data.PointCloudObjectDetectorFoldersDataSource
detection.data.PointCloudObjectDetectorDatasetDataSource
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Lightning Flash
:caption: Point Cloud

reference/pointcloud_segmentation
reference/pointcloud_object_detection

.. toctree::
:maxdepth: 1
Expand Down
82 changes: 82 additions & 0 deletions docs/source/reference/pointcloud_object_detection.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@

.. _pointcloud_object_detection:

############################
Point Cloud Object Detection
############################

********
The Task
********

A Point Cloud is a set of data points in space, usually describes by ``x``, ``y`` and ``z`` coordinates.

PointCloud Object Detection is the task of identifying 3D objects in point clouds and their associated classes and 3D bounding boxes.

The current integration builds on top `Open3D-ML <https://github.com/intel-isl/Open3D-ML>`_.

------

*******
Example
*******

Let's look at an example using a data set generated from the `KITTI Vision Benchmark <http://www.semantic-kitti.org/dataset.html>`_.
The data are a tiny subset of the original dataset and contains sequences of point clouds.

The data contains:
* one folder for scans
* one folder for scan calibrations
* one folder for labels
* a meta.yaml file describing the classes and their official associated color map.

Here's the structure:

.. code-block::

data
├── meta.yaml
├── train
│ ├── scans
| | ├── 00000.bin
| | ├── 00001.bin
| | ...
│ ├── calibs
| | ├── 00000.txt
| | ├── 00001.txt
| | ...
│ ├── labels
| | ├── 00000.txt
| | ├── 00001.txt
│ ...
├── val
│ ...
├── predict
├── scans
| ├── 00000.bin
| ├── 00001.bin
|
├── calibs
| ├── 00000.txt
| ├── 00001.txt
├── meta.yaml



Learn more: http://www.semantic-kitti.org/dataset.html


Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.image.detection.data.PointCloudObjectDetectorData`.
We select a pre-trained ``randlanet_semantic_kitti`` backbone for our :class:`~flash.image.detection.model.PointCloudObjectDetector` task.
We then use the trained :class:`~flash.image.detection.model.PointCloudObjectDetector` for inference.
Finally, we save the model.
Here's the full example:

.. literalinclude:: ../../../flash_examples/pointcloud_detection.py
:language: python
:lines: 14-



.. image:: https://raw.githubusercontent.com/intel-isl/Open3D-ML/master/docs/images/visualizer_BoundingBoxes.png
:width: 100%
8 changes: 8 additions & 0 deletions flash/core/data/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,14 @@ def __hash__(self) -> int:
return hash(self.value)


class BaseDataFormat(LightningEnum):

pass
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def __hash__(self) -> int:
return hash(self.value)


class MockDataset:
"""The ``MockDataset`` catches any metadata that is attached through ``__setattr__``. This is passed to
:meth:`~flash.core.data.data_source.DataSource.load_data` so that attributes can be set on the generated
Expand Down
18 changes: 18 additions & 0 deletions flash/core/data/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,24 @@
from flash.core.data.properties import ProcessState


@dataclass(unsafe_hash=True, frozen=True)
class PreTensorTransform(ProcessState):

transform: Optional[Callable] = None


@dataclass(unsafe_hash=True, frozen=True)
class ToTensorTransform(ProcessState):

transform: Optional[Callable] = None


@dataclass(unsafe_hash=True, frozen=True)
class PostTensorTransform(ProcessState):

transform: Optional[Callable] = None


@dataclass(unsafe_hash=True, frozen=True)
class CollateFn(ProcessState):

Expand Down
17 changes: 14 additions & 3 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,21 +188,32 @@ def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any:
losses = {name: l_fn(y_hat, y) for name, l_fn in self.loss_fn.items()}
logs = {}
y_hat = self.to_metrics_format(output["y_hat"])

logs = {}

for name, metric in metrics.items():
if isinstance(metric, torchmetrics.metric.Metric):
metric(y_hat, y)
logs[name] = metric # log the metric itself if it is of type Metric
else:
logs[name] = metric(y_hat, y)
logs.update(losses)

if len(losses.values()) > 1:
logs["total_loss"] = sum(losses.values())
return logs["total_loss"], logs
output["loss"] = list(losses.values())[0]
output["logs"] = logs

output["loss"] = self.compute_loss(losses)
output["logs"] = self.compute_logs(logs, losses)
output["y"] = y
return output

def compute_loss(self, losses: Dict[str, torch.Tensor]) -> torch.Tensor:
return list(losses.values())[0]

def compute_logs(self, logs: Dict[str, Any], losses: Dict[str, torch.Tensor]):
logs.update(losses)
return logs

@staticmethod
def apply_filtering(y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""This function is used to filter some labels or predictions which aren't conform."""
Expand Down
3 changes: 2 additions & 1 deletion flash/pointcloud/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from flash.pointcloud.detection.data import PointCloudObjectDetectorData # noqa: F401
from flash.pointcloud.detection.model import PointCloudObjectDetector # noqa: F401
from flash.pointcloud.segmentation.data import PointCloudSegmentationData # noqa: F401
from flash.pointcloud.segmentation.model import PointCloudSegmentation # noqa: F401
from flash.pointcloud.segmentation.open3d_ml.app import launch_app # noqa: F401
3 changes: 3 additions & 0 deletions flash/pointcloud/detection/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from flash.pointcloud.detection.data import PointCloudObjectDetectorData # noqa: F401
from flash.pointcloud.detection.model import PointCloudObjectDetector # noqa: F401
from flash.pointcloud.detection.open3d_ml.app import launch_app # noqa: F401
19 changes: 19 additions & 0 deletions flash/pointcloud/detection/backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from flash.core.registry import FlashRegistry
from flash.pointcloud.detection.open3d_ml.backbones import register_open_3d_ml

POINTCLOUD_OBJECT_DETECTION_BACKBONES = FlashRegistry("backbones")

register_open_3d_ml(POINTCLOUD_OBJECT_DETECTION_BACKBONES)
Loading