Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIR] Add TorchDetectionPredictor #32064

Closed
bveeramani opened this issue Jan 30, 2023 · 0 comments · Fixed by #32199
Closed

[AIR] Add TorchDetectionPredictor #32064

bveeramani opened this issue Jan 30, 2023 · 0 comments · Fixed by #32199
Assignees
Labels
enhancement Request for new feature and/or capability

Comments

@bveeramani
Copy link
Member

Description

Add a class TorchDetectionPredictor that works with TorchVision detection models.

Use case

TorchPredictor doesn't work with TorchVision detection models because they return lists of dictionaries. So, you need to extend TorchPredictor:

class CustomTorchPredictor(TorchPredictor):
    def _predict_numpy(
        self, data: np.ndarray, dtype: torch.dtype
    ) -> Dict[str, np.ndarray]:
        inputs = [torch.as_tensor(image) for image in data["image"]]
        assert all(image.dim() == 3 for image in inputs)
        outputs = self.call_model(inputs)
        predictions = collections.defaultdict(list)
        for output in outputs:
            for key, value in output.items():
                predictions[key].append(value.cpu().detach().numpy())
        for key, value in predictions.items():
            predictions[key] = _create_possibly_ragged_ndarray(value)
        predictions = {"pred_" + key: value for key, value in predictions.items()}
        return predictions
@bveeramani bveeramani added enhancement Request for new feature and/or capability air labels Jan 30, 2023
@bveeramani bveeramani self-assigned this Jan 30, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Request for new feature and/or capability
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant