Skip to content

Commit

Permalink
[AIR] Add TorchDetectionPredictor (#32199)
Browse files Browse the repository at this point in the history
TorchPredictor doesn't work with TorchVision detection models because they return List[Dict[str, torch.Tensor]] instead of torch.Tensor. This PR adds a TorchDetectionPredictor so users don't have to extend TorchPredictor themselves.

Signed-off-by: Balaji Veeramani <balaji@anyscale.com>
  • Loading branch information
bveeramani authored Feb 8, 2023
1 parent 468e606 commit 53260af
Show file tree
Hide file tree
Showing 6 changed files with 296 additions and 12 deletions.
14 changes: 12 additions & 2 deletions doc/source/train/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,21 @@ PyTorch
``TorchPredictor``
******************

.. automodule:: ray.train.torch
.. autoclass:: ray.train.torch.TorchPredictor
:members:
:exclude-members: TorchTrainer
:show-inheritance:

.. automethod:: __init__

``TorchDetectionPredictor``
***************************

.. autoclass:: ray.train.torch.TorchDetectionPredictor
:members:
:show-inheritance:

.. automethod:: __init__

Horovod
~~~~~~~

Expand Down
8 changes: 8 additions & 0 deletions python/ray/train/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,14 @@ py_test(
deps = [":train_lib", ":conftest"]
)

py_test(
name = "test_torch_detection_predictor",
size = "small",
srcs = ["tests/test_torch_detection_predictor.py"],
tags = ["team:ml", "exclusive", "ray_air", "gpu"],
deps = [":train_lib", ":conftest"]
)

py_test(
name = "test_torch_trainer",
size = "medium",
Expand Down
118 changes: 118 additions & 0 deletions python/ray/train/tests/test_torch_detection_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import numpy as np
import pytest
from torchvision import models

import ray
from ray.air.util.tensor_extensions.utils import create_ragged_ndarray
from ray.train.batch_predictor import BatchPredictor
from ray.train.torch import TorchCheckpoint, TorchDetectionPredictor


@pytest.fixture(name="predictor")
def predictor_fixture():
model = models.detection.maskrcnn_resnet50_fpn()
yield TorchDetectionPredictor(model=model)


@pytest.mark.parametrize(
"data",
[
np.zeros((1, 3, 32, 32), dtype=np.float32),
{"image": np.zeros((1, 3, 32, 32), dtype=np.float32)},
create_ragged_ndarray(
[
np.zeros((3, 32, 32), dtype=np.float32),
np.zeros((3, 64, 64), dtype=np.float32),
]
),
],
)
def test_predict(predictor, data):
predictions = predictor.predict(data)

assert all(len(value) == len(data) for value in predictions.values())
# Boxes should have shape `(# detections, 4)`.
assert all(boxes.ndim == 2 for boxes in predictions["pred_boxes"])
assert all(boxes.shape[-1] == 4 for boxes in predictions["pred_boxes"])
# Labels should have shape `(# detections,)`.
assert all(labels.ndim == 1 for labels in predictions["pred_labels"])
# Scores should have shape `(# detections,)`.
assert all(scores.ndim == 1 for scores in predictions["pred_scores"])


def test_predict_tensor_dataset():
model = models.detection.maskrcnn_resnet50_fpn()
checkpoint = TorchCheckpoint.from_model(model)
predictor = BatchPredictor.from_checkpoint(checkpoint, TorchDetectionPredictor)
dataset = ray.data.from_items([np.zeros((3, 32, 32), dtype=np.float32)])

predictions = predictor.predict(dataset)

# Boxes should have shape `(# detections, 4)`.
pred_boxes = [row["pred_boxes"] for row in predictions.take_all()]
assert all(boxes.ndim == 2 for boxes in pred_boxes)
assert all(boxes.shape[-1] == 4 for boxes in pred_boxes)
# Labels should have shape `(# detections,)`.
pred_labels = [row["pred_labels"] for row in predictions.take_all()]
assert all(labels.ndim == 1 for labels in pred_labels)
# Scores should have shape `(# detections,)`.
pred_scores = [row["pred_scores"] for row in predictions.take_all()]
assert all(scores.ndim == 1 for scores in pred_scores)


@pytest.mark.parametrize(
"items",
[
[{"image": np.zeros((3, 32, 32), dtype=np.float32)}],
[
{"image": np.zeros((3, 32, 32), dtype=np.float32)},
{"image": np.zeros((3, 64, 64), dtype=np.float32)},
],
],
)
def test_predict_tabular_dataset(items):
model = models.detection.maskrcnn_resnet50_fpn()
checkpoint = TorchCheckpoint.from_model(model)
predictor = BatchPredictor.from_checkpoint(checkpoint, TorchDetectionPredictor)
dataset = ray.data.from_items(items)

predictions = predictor.predict(dataset)

assert predictions.count() == len(items)
# Boxes should have shape `(# detections, 4)`.
pred_boxes = [row["pred_boxes"] for row in predictions.take_all()]
assert all(boxes.ndim == 2 for boxes in pred_boxes)
assert all(boxes.shape[-1] == 4 for boxes in pred_boxes)
# Labels should have shape `(# detections,)`.
pred_labels = [row["pred_labels"] for row in predictions.take_all()]
assert all(labels.ndim == 1 for labels in pred_labels)
# Scores should have shape `(# detections,)`.
pred_scores = [row["pred_scores"] for row in predictions.take_all()]
assert all(scores.ndim == 1 for scores in pred_scores)


def test_multi_column_batch_raises_value_error(predictor):
data = {
"image": np.zeros((2, 3, 32, 32), dtype=np.float32),
"boxes": np.zeros((2, 0, 4), dtype=np.float32),
"labels": np.zeros((2, 0), dtype=np.int64),
}
with pytest.raises(ValueError):
# `data` should only contain one key. Otherwise, `TorchDetectionPredictor`
# doesn't know which column contains the input images.
predictor.predict(data)


def test_invalid_dtype_raises_value_error(predictor):
data = np.zeros((1, 3, 32, 32), dtype=np.float32)
with pytest.raises(ValueError):
# `dtype` should be a single `torch.dtype`.
predictor.predict(data, dtype=np.float32)


if __name__ == "__main__":
import sys

import pytest

sys.exit(pytest.main(["-v", "-x", __file__]))
4 changes: 3 additions & 1 deletion python/ray/train/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
)
# isort: on

from ray.train.torch.torch_checkpoint import TorchCheckpoint
from ray.train.torch.config import TorchConfig
from ray.train.torch.torch_checkpoint import TorchCheckpoint
from ray.train.torch.torch_detection_predictor import TorchDetectionPredictor
from ray.train.torch.torch_predictor import TorchPredictor
from ray.train.torch.torch_trainer import TorchTrainer
from ray.train.torch.train_loop_utils import (
Expand All @@ -33,4 +34,5 @@
"backward",
"enable_reproducibility",
"TorchPredictor",
"TorchDetectionPredictor",
]
140 changes: 140 additions & 0 deletions python/ray/train/torch/torch_detection_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import collections
from typing import Dict, List, Optional, Union

import numpy as np
import torch

from ray.air.util.tensor_extensions.utils import create_ragged_ndarray
from ray.train._internal.dl_predictor import TensorDtype
from ray.train.torch.torch_predictor import TorchPredictor
from ray.util.annotations import PublicAPI


@PublicAPI(stability="alpha")
class TorchDetectionPredictor(TorchPredictor):
"""A predictor for TorchVision detection models.
Unlike other Torch models, instance segmentation models return
`List[Dict[str, Tensor]]`. This predictor extends :class:`TorchPredictor` to support
the non-standard outputs.
To learn more about instance segmentation models, read
`Instance segmentation models <https://pytorch.org/vision/main/auto_examples/plot_visualization_utils.html#instance-seg-output>`_.
Example:
.. testcode::
import numpy as np
from torchvision import models
from ray.train.torch import TorchDetectionPredictor
model = models.detection.fasterrcnn_resnet50_fpn_v2(pretrained=True)
predictor = TorchDetectionPredictor(model=model)
predictions = predictor.predict(np.zeros((4, 3, 32, 32), dtype=np.float32))
print(predictions.keys())
.. testoutput::
dict_keys(['pred_boxes', 'pred_labels', 'pred_scores'])
.. testcode::
import numpy as np
from torchvision import models
import ray
from ray.train.batch_predictor import BatchPredictor
from ray.train.torch import TorchCheckpoint, TorchDetectionPredictor
dataset = ray.data.from_items([{"image": np.zeros((3, 32, 32), dtype=np.float32)}])
model = models.detection.fasterrcnn_resnet50_fpn_v2(pretrained=True)
checkpoint = TorchCheckpoint.from_model(model)
predictor = BatchPredictor.from_checkpoint(checkpoint, TorchDetectionPredictor)
predictions = predictor.predict(dataset, feature_columns=["image"])
print(predictions.take(1))
.. testoutput::
[{'pred_boxes': array([], shape=(0, 4), dtype=float32), 'pred_labels': array([], dtype=int64), 'pred_scores': array([], dtype=float32)}]
""" # noqa: E501

def _predict_numpy(
self,
data: Union[np.ndarray, Dict[str, np.ndarray]],
dtype: Optional[Union[TensorDtype, Dict[str, TensorDtype]]],
) -> Dict[str, np.ndarray]:
if isinstance(data, dict) and len(data) != 1:
raise ValueError(
f"""Expected input to contain one key, but got {len(data)} instead.
If you're using `BatchPredictor`, pass a one-element list to
`feature_columns`.
---
predictor = BatchPredictor(checkpoint, TorchDetectionPredictor)
predictor.predict(dataset, feature_columns=["image"])
---
"""
)

if dtype is not None and not isinstance(dtype, torch.dtype):
raise ValueError(
"Expected `dtype` to be a `torch.dtype`, but got a "
f"{type(dtype).__name__} instead."
)

if isinstance(data, dict):
images = next(iter(data.values()))
else:
images = data

inputs = [
torch.as_tensor(image, dtype=dtype).to(self.device) for image in images
]
outputs = self.call_model(inputs)
outputs = _convert_outputs_to_ndarray_batch(outputs)
outputs = {"pred_" + key: value for key, value in outputs.items()}

return outputs


def _convert_outputs_to_ndarray_batch(
outputs: List[Dict[str, torch.Tensor]],
) -> Dict[str, np.ndarray]:
"""Batch detection model outputs.
TorchVision detection models return `List[Dict[Tensor]]`. Each `Dict` contain
'boxes', 'labels, and 'scores'.
>>> import torch
>>> from torchvision import models
>>> model = models.detection.fasterrcnn_resnet50_fpn_v2()
>>> model.eval() # doctest: +ELLIPSIS
FasterRCNN(...)
>>> outputs = model(torch.zeros((2, 3, 32, 32)))
>>> len(outputs)
2
>>> outputs[0].keys()
dict_keys(['boxes', 'labels', 'scores'])
This function batches values and returns a `Dict[str, np.ndarray]`.
>>> from ray.train.torch.torch_detection_predictor import _convert_outputs_to_ndarray_batch
>>> batch = _convert_outputs_to_ndarray_batch(outputs)
>>> batch.keys()
dict_keys(['boxes', 'labels', 'scores'])
>>> batch["boxes"].shape
(2,)
""" # noqa: E501
batch = collections.defaultdict(list)
for output in outputs:
for key, value in output.items():
batch[key].append(value.cpu().detach().numpy())
for key, value in batch.items():
batch[key] = create_ragged_ndarray(value)
return batch
24 changes: 15 additions & 9 deletions python/ray/train/torch/torch_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import numpy as np
import torch

from ray.util import log_once
from ray.train.predictor import DataBatchType
from ray.air.checkpoint import Checkpoint
from ray.air._internal.torch_utils import convert_ndarray_batch_to_torch_tensor_batch
from ray.train.torch.torch_checkpoint import TorchCheckpoint
from ray.air.checkpoint import Checkpoint
from ray.train._internal.dl_predictor import DLPredictor
from ray.train.predictor import DataBatchType
from ray.train.torch.torch_checkpoint import TorchCheckpoint
from ray.util import log_once
from ray.util.annotations import DeveloperAPI, PublicAPI

if TYPE_CHECKING:
Expand Down Expand Up @@ -38,12 +38,16 @@ def __init__(
):
self.model = model
self.model.eval()

# TODO (jiaodong): #26249 Use multiple GPU devices with sharded input
self.use_gpu = use_gpu

if use_gpu:
# Ensure input tensor and model live on GPU for GPU inference
self.model.to(torch.device("cuda"))
# TODO (jiaodong): #26249 Use multiple GPU devices with sharded input
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")

# Ensure input tensor and model live on the same device
self.model.to(self.device)

if (
not use_gpu
Expand Down Expand Up @@ -117,6 +121,8 @@ def call_model(
.. testcode::
from ray.train.torch import TorchPredictor
# List outputs are not supported by default TorchPredictor.
# So let's define a custom TorchPredictor and override call_model
class MyModel(torch.nn.Module):
Expand Down Expand Up @@ -231,7 +237,7 @@ def _arrays_to_tensors(
return convert_ndarray_batch_to_torch_tensor_batch(
numpy_arrays,
dtypes=dtype,
device="cuda" if self.use_gpu else None,
device=self.device,
)

def _tensor_to_array(self, tensor: torch.Tensor) -> np.ndarray:
Expand Down

0 comments on commit 53260af

Please sign in to comment.