-
Notifications
You must be signed in to change notification settings - Fork 6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[AIR] Add
TorchDetectionPredictor
(#32199)
TorchPredictor doesn't work with TorchVision detection models because they return List[Dict[str, torch.Tensor]] instead of torch.Tensor. This PR adds a TorchDetectionPredictor so users don't have to extend TorchPredictor themselves. Signed-off-by: Balaji Veeramani <balaji@anyscale.com>
- Loading branch information
1 parent
468e606
commit 53260af
Showing
6 changed files
with
296 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
118 changes: 118 additions & 0 deletions
118
python/ray/train/tests/test_torch_detection_predictor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import numpy as np | ||
import pytest | ||
from torchvision import models | ||
|
||
import ray | ||
from ray.air.util.tensor_extensions.utils import create_ragged_ndarray | ||
from ray.train.batch_predictor import BatchPredictor | ||
from ray.train.torch import TorchCheckpoint, TorchDetectionPredictor | ||
|
||
|
||
@pytest.fixture(name="predictor") | ||
def predictor_fixture(): | ||
model = models.detection.maskrcnn_resnet50_fpn() | ||
yield TorchDetectionPredictor(model=model) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"data", | ||
[ | ||
np.zeros((1, 3, 32, 32), dtype=np.float32), | ||
{"image": np.zeros((1, 3, 32, 32), dtype=np.float32)}, | ||
create_ragged_ndarray( | ||
[ | ||
np.zeros((3, 32, 32), dtype=np.float32), | ||
np.zeros((3, 64, 64), dtype=np.float32), | ||
] | ||
), | ||
], | ||
) | ||
def test_predict(predictor, data): | ||
predictions = predictor.predict(data) | ||
|
||
assert all(len(value) == len(data) for value in predictions.values()) | ||
# Boxes should have shape `(# detections, 4)`. | ||
assert all(boxes.ndim == 2 for boxes in predictions["pred_boxes"]) | ||
assert all(boxes.shape[-1] == 4 for boxes in predictions["pred_boxes"]) | ||
# Labels should have shape `(# detections,)`. | ||
assert all(labels.ndim == 1 for labels in predictions["pred_labels"]) | ||
# Scores should have shape `(# detections,)`. | ||
assert all(scores.ndim == 1 for scores in predictions["pred_scores"]) | ||
|
||
|
||
def test_predict_tensor_dataset(): | ||
model = models.detection.maskrcnn_resnet50_fpn() | ||
checkpoint = TorchCheckpoint.from_model(model) | ||
predictor = BatchPredictor.from_checkpoint(checkpoint, TorchDetectionPredictor) | ||
dataset = ray.data.from_items([np.zeros((3, 32, 32), dtype=np.float32)]) | ||
|
||
predictions = predictor.predict(dataset) | ||
|
||
# Boxes should have shape `(# detections, 4)`. | ||
pred_boxes = [row["pred_boxes"] for row in predictions.take_all()] | ||
assert all(boxes.ndim == 2 for boxes in pred_boxes) | ||
assert all(boxes.shape[-1] == 4 for boxes in pred_boxes) | ||
# Labels should have shape `(# detections,)`. | ||
pred_labels = [row["pred_labels"] for row in predictions.take_all()] | ||
assert all(labels.ndim == 1 for labels in pred_labels) | ||
# Scores should have shape `(# detections,)`. | ||
pred_scores = [row["pred_scores"] for row in predictions.take_all()] | ||
assert all(scores.ndim == 1 for scores in pred_scores) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"items", | ||
[ | ||
[{"image": np.zeros((3, 32, 32), dtype=np.float32)}], | ||
[ | ||
{"image": np.zeros((3, 32, 32), dtype=np.float32)}, | ||
{"image": np.zeros((3, 64, 64), dtype=np.float32)}, | ||
], | ||
], | ||
) | ||
def test_predict_tabular_dataset(items): | ||
model = models.detection.maskrcnn_resnet50_fpn() | ||
checkpoint = TorchCheckpoint.from_model(model) | ||
predictor = BatchPredictor.from_checkpoint(checkpoint, TorchDetectionPredictor) | ||
dataset = ray.data.from_items(items) | ||
|
||
predictions = predictor.predict(dataset) | ||
|
||
assert predictions.count() == len(items) | ||
# Boxes should have shape `(# detections, 4)`. | ||
pred_boxes = [row["pred_boxes"] for row in predictions.take_all()] | ||
assert all(boxes.ndim == 2 for boxes in pred_boxes) | ||
assert all(boxes.shape[-1] == 4 for boxes in pred_boxes) | ||
# Labels should have shape `(# detections,)`. | ||
pred_labels = [row["pred_labels"] for row in predictions.take_all()] | ||
assert all(labels.ndim == 1 for labels in pred_labels) | ||
# Scores should have shape `(# detections,)`. | ||
pred_scores = [row["pred_scores"] for row in predictions.take_all()] | ||
assert all(scores.ndim == 1 for scores in pred_scores) | ||
|
||
|
||
def test_multi_column_batch_raises_value_error(predictor): | ||
data = { | ||
"image": np.zeros((2, 3, 32, 32), dtype=np.float32), | ||
"boxes": np.zeros((2, 0, 4), dtype=np.float32), | ||
"labels": np.zeros((2, 0), dtype=np.int64), | ||
} | ||
with pytest.raises(ValueError): | ||
# `data` should only contain one key. Otherwise, `TorchDetectionPredictor` | ||
# doesn't know which column contains the input images. | ||
predictor.predict(data) | ||
|
||
|
||
def test_invalid_dtype_raises_value_error(predictor): | ||
data = np.zeros((1, 3, 32, 32), dtype=np.float32) | ||
with pytest.raises(ValueError): | ||
# `dtype` should be a single `torch.dtype`. | ||
predictor.predict(data, dtype=np.float32) | ||
|
||
|
||
if __name__ == "__main__": | ||
import sys | ||
|
||
import pytest | ||
|
||
sys.exit(pytest.main(["-v", "-x", __file__])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import collections | ||
from typing import Dict, List, Optional, Union | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from ray.air.util.tensor_extensions.utils import create_ragged_ndarray | ||
from ray.train._internal.dl_predictor import TensorDtype | ||
from ray.train.torch.torch_predictor import TorchPredictor | ||
from ray.util.annotations import PublicAPI | ||
|
||
|
||
@PublicAPI(stability="alpha") | ||
class TorchDetectionPredictor(TorchPredictor): | ||
"""A predictor for TorchVision detection models. | ||
Unlike other Torch models, instance segmentation models return | ||
`List[Dict[str, Tensor]]`. This predictor extends :class:`TorchPredictor` to support | ||
the non-standard outputs. | ||
To learn more about instance segmentation models, read | ||
`Instance segmentation models <https://pytorch.org/vision/main/auto_examples/plot_visualization_utils.html#instance-seg-output>`_. | ||
Example: | ||
.. testcode:: | ||
import numpy as np | ||
from torchvision import models | ||
from ray.train.torch import TorchDetectionPredictor | ||
model = models.detection.fasterrcnn_resnet50_fpn_v2(pretrained=True) | ||
predictor = TorchDetectionPredictor(model=model) | ||
predictions = predictor.predict(np.zeros((4, 3, 32, 32), dtype=np.float32)) | ||
print(predictions.keys()) | ||
.. testoutput:: | ||
dict_keys(['pred_boxes', 'pred_labels', 'pred_scores']) | ||
.. testcode:: | ||
import numpy as np | ||
from torchvision import models | ||
import ray | ||
from ray.train.batch_predictor import BatchPredictor | ||
from ray.train.torch import TorchCheckpoint, TorchDetectionPredictor | ||
dataset = ray.data.from_items([{"image": np.zeros((3, 32, 32), dtype=np.float32)}]) | ||
model = models.detection.fasterrcnn_resnet50_fpn_v2(pretrained=True) | ||
checkpoint = TorchCheckpoint.from_model(model) | ||
predictor = BatchPredictor.from_checkpoint(checkpoint, TorchDetectionPredictor) | ||
predictions = predictor.predict(dataset, feature_columns=["image"]) | ||
print(predictions.take(1)) | ||
.. testoutput:: | ||
[{'pred_boxes': array([], shape=(0, 4), dtype=float32), 'pred_labels': array([], dtype=int64), 'pred_scores': array([], dtype=float32)}] | ||
""" # noqa: E501 | ||
|
||
def _predict_numpy( | ||
self, | ||
data: Union[np.ndarray, Dict[str, np.ndarray]], | ||
dtype: Optional[Union[TensorDtype, Dict[str, TensorDtype]]], | ||
) -> Dict[str, np.ndarray]: | ||
if isinstance(data, dict) and len(data) != 1: | ||
raise ValueError( | ||
f"""Expected input to contain one key, but got {len(data)} instead. | ||
If you're using `BatchPredictor`, pass a one-element list to | ||
`feature_columns`. | ||
--- | ||
predictor = BatchPredictor(checkpoint, TorchDetectionPredictor) | ||
predictor.predict(dataset, feature_columns=["image"]) | ||
--- | ||
""" | ||
) | ||
|
||
if dtype is not None and not isinstance(dtype, torch.dtype): | ||
raise ValueError( | ||
"Expected `dtype` to be a `torch.dtype`, but got a " | ||
f"{type(dtype).__name__} instead." | ||
) | ||
|
||
if isinstance(data, dict): | ||
images = next(iter(data.values())) | ||
else: | ||
images = data | ||
|
||
inputs = [ | ||
torch.as_tensor(image, dtype=dtype).to(self.device) for image in images | ||
] | ||
outputs = self.call_model(inputs) | ||
outputs = _convert_outputs_to_ndarray_batch(outputs) | ||
outputs = {"pred_" + key: value for key, value in outputs.items()} | ||
|
||
return outputs | ||
|
||
|
||
def _convert_outputs_to_ndarray_batch( | ||
outputs: List[Dict[str, torch.Tensor]], | ||
) -> Dict[str, np.ndarray]: | ||
"""Batch detection model outputs. | ||
TorchVision detection models return `List[Dict[Tensor]]`. Each `Dict` contain | ||
'boxes', 'labels, and 'scores'. | ||
>>> import torch | ||
>>> from torchvision import models | ||
>>> model = models.detection.fasterrcnn_resnet50_fpn_v2() | ||
>>> model.eval() # doctest: +ELLIPSIS | ||
FasterRCNN(...) | ||
>>> outputs = model(torch.zeros((2, 3, 32, 32))) | ||
>>> len(outputs) | ||
2 | ||
>>> outputs[0].keys() | ||
dict_keys(['boxes', 'labels', 'scores']) | ||
This function batches values and returns a `Dict[str, np.ndarray]`. | ||
>>> from ray.train.torch.torch_detection_predictor import _convert_outputs_to_ndarray_batch | ||
>>> batch = _convert_outputs_to_ndarray_batch(outputs) | ||
>>> batch.keys() | ||
dict_keys(['boxes', 'labels', 'scores']) | ||
>>> batch["boxes"].shape | ||
(2,) | ||
""" # noqa: E501 | ||
batch = collections.defaultdict(list) | ||
for output in outputs: | ||
for key, value in output.items(): | ||
batch[key].append(value.cpu().detach().numpy()) | ||
for key, value in batch.items(): | ||
batch[key] = create_ragged_ndarray(value) | ||
return batch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters