diff --git a/vision/google/cloud/vision/_gax.py b/vision/google/cloud/vision/_gax.py index 44a55e0f09e5..fd4dda68f9e2 100644 --- a/vision/google/cloud/vision/_gax.py +++ b/vision/google/cloud/vision/_gax.py @@ -14,6 +14,8 @@ """GAX Client for interacting with the Google Cloud Vision API.""" +from google.gax.errors import RetryError + from google.cloud.gapic.vision.v1 import image_annotator_client from google.cloud.proto.vision.v1 import image_annotator_pb2 @@ -66,7 +68,12 @@ def annotate(self, images=None, requests_pb=None): requests = requests_pb annotator_client = self._annotator_client - responses = annotator_client.batch_annotate_images(requests).responses + try: + api_result = annotator_client.batch_annotate_images(requests) + responses = api_result.responses + except RetryError as exception: + raise exception.cause.exception() + return [Annotations.from_pb(response) for response in responses] diff --git a/vision/unit_tests/test__gax.py b/vision/unit_tests/test__gax.py index ee73cbdce0b2..5ca790063cef 100644 --- a/vision/unit_tests/test__gax.py +++ b/vision/unit_tests/test__gax.py @@ -221,6 +221,32 @@ def test_annotate_with_pb_requests_results(self): self.assertIsInstance(annotation, Annotations) gax_api._annotator_client.batch_annotate_images.assert_called() + def test_handle_retry_error(self): + from grpc._channel import _Rendezvous + from google.gax.errors import GaxError + from google.gax.errors import RetryError + from google.cloud.proto.vision.v1 import image_annotator_pb2 + + client = mock.Mock(spec_set=['_credentials']) + request = image_annotator_pb2.AnnotateImageRequest() + + # In the case of a RetryError being raised, we want to get to the root + # exception. i.e. RetryError.cause.exception() + mock_rendezvous = mock.Mock(spec=_Rendezvous) + mock_rendezvous.exception.return_value = GaxError('gax') + + with mock.patch('google.cloud.vision._gax.image_annotator_client.' + 'ImageAnnotatorClient'): + gax_api = self._make_one(client) + + gax_api._annotator_client = mock.Mock( + spec_set=['batch_annotate_images']) + + gax_api._annotator_client.batch_annotate_images = mock.Mock( + side_effect=RetryError('retry', cause=mock_rendezvous)) + with self.assertRaises(GaxError): + gax_api.annotate(requests_pb=[request]) + class Test__to_gapic_feature(unittest.TestCase): def _call_fut(self, feature):