From b08c4653b0551d3d943387fe73bf30a1f1798497 Mon Sep 17 00:00:00 2001 From: Anthonios Partheniou Date: Fri, 20 Sep 2024 17:26:13 +0000 Subject: [PATCH] feat: Add OperationsRestAsyncTransport to support long running operations --- google/api_core/operations_v1/__init__.py | 14 + .../operations_v1/transports/__init__.py | 14 + .../api_core/operations_v1/transports/base.py | 12 +- .../operations_v1/transports/rest_asyncio.py | 504 ++++++++++++++++++ noxfile.py | 2 +- .../test_operations_rest_client.py | 49 +- 6 files changed, 580 insertions(+), 15 deletions(-) create mode 100644 google/api_core/operations_v1/transports/rest_asyncio.py diff --git a/google/api_core/operations_v1/__init__.py b/google/api_core/operations_v1/__init__.py index 8b75426b9..d4c35ecac 100644 --- a/google/api_core/operations_v1/__init__.py +++ b/google/api_core/operations_v1/__init__.py @@ -21,9 +21,23 @@ from google.api_core.operations_v1.operations_client import OperationsClient from google.api_core.operations_v1.transports.rest import OperationsRestTransport +try: + from google.api_core.operations_v1.transports.rest_asyncio import ( + OperationsRestAsyncTransport, + ) + + HAS_ASYNC_TRANSPORT = True +except ImportError as e: + # Don't raise an exception if `OperationsRestAsyncTransport` cannot be imported + # so that OperationsRestTransport can still be imported + HAS_ASYNC_TRANSPORT = False + __all__ = [ "AbstractOperationsClient", "OperationsAsyncClient", "OperationsClient", "OperationsRestTransport", ] + +if HAS_ASYNC_TRANSPORT: + __all__.append("OperationsRestAsyncTransport") diff --git a/google/api_core/operations_v1/transports/__init__.py b/google/api_core/operations_v1/transports/__init__.py index df53e15e4..4619d74dc 100644 --- a/google/api_core/operations_v1/transports/__init__.py +++ b/google/api_core/operations_v1/transports/__init__.py @@ -18,12 +18,26 @@ from .base import OperationsTransport from .rest import OperationsRestTransport +try: + from .rest_asyncio import OperationsRestAsyncTransport + + HAS_ASYNC_TRANSPORT = True +except ImportError as e: + # Don't raise an exception if `OperationsRestAsyncTransport` cannot be imported + # so that OperationsRestTransport can still be imported + HAS_ASYNC_TRANSPORT = False # Compile a registry of transports. _transport_registry = OrderedDict() _transport_registry["rest"] = OperationsRestTransport +if HAS_ASYNC_TRANSPORT: + _transport_registry["rest_asyncio"] = OperationsRestAsyncTransport + __all__ = ( "OperationsTransport", "OperationsRestTransport", ) + +if HAS_ASYNC_TRANSPORT: + __all__.append("OperationsRestAsyncTransport") diff --git a/google/api_core/operations_v1/transports/base.py b/google/api_core/operations_v1/transports/base.py index fb1d4fc9e..78382a528 100644 --- a/google/api_core/operations_v1/transports/base.py +++ b/google/api_core/operations_v1/transports/base.py @@ -20,6 +20,7 @@ from google.api_core import exceptions as core_exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore +from google.api_core import retry_async as retries_async # type: ignore from google.api_core import version import google.auth # type: ignore from google.auth import credentials as ga_credentials # type: ignore @@ -115,12 +116,13 @@ def __init__( # Save the credentials. self._credentials = credentials - def _prep_wrapped_messages(self, client_info): + def _prep_wrapped_messages(self, client_info, is_async=False): # Precompute the wrapped methods. + retry_class = retries_async.AsyncRetry if is_async else retries.Retry self._wrapped_methods = { self.list_operations: gapic_v1.method.wrap_method( self.list_operations, - default_retry=retries.Retry( + default_retry=retry_class( initial=0.5, maximum=10.0, multiplier=2.0, @@ -135,7 +137,7 @@ def _prep_wrapped_messages(self, client_info): ), self.get_operation: gapic_v1.method.wrap_method( self.get_operation, - default_retry=retries.Retry( + default_retry=retry_class( initial=0.5, maximum=10.0, multiplier=2.0, @@ -150,7 +152,7 @@ def _prep_wrapped_messages(self, client_info): ), self.delete_operation: gapic_v1.method.wrap_method( self.delete_operation, - default_retry=retries.Retry( + default_retry=retry_class( initial=0.5, maximum=10.0, multiplier=2.0, @@ -165,7 +167,7 @@ def _prep_wrapped_messages(self, client_info): ), self.cancel_operation: gapic_v1.method.wrap_method( self.cancel_operation, - default_retry=retries.Retry( + default_retry=retry_class( initial=0.5, maximum=10.0, multiplier=2.0, diff --git a/google/api_core/operations_v1/transports/rest_asyncio.py b/google/api_core/operations_v1/transports/rest_asyncio.py new file mode 100644 index 000000000..74fb53762 --- /dev/null +++ b/google/api_core/operations_v1/transports/rest_asyncio.py @@ -0,0 +1,504 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import re +from typing import Callable, Dict, Optional, Sequence, Tuple, Union + +from requests import __version__ as requests_version + +try: + import google.auth.aio.transport +except ImportError as e: # pragma: NO COVER + raise ImportError( + "google-auth>=2.35.0 is required to use OperationsRestAsyncTransport." + ) from e + +from google.api_core import exceptions as core_exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import path_template # type: ignore +from google.api_core import rest_helpers # type: ignore +from google.api_core import retry_async as retries # type: ignore +from google.auth.aio import credentials as ga_credentials # type: ignore +from google.auth.aio.transport.sessions import AsyncAuthorizedSession # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +from google.protobuf import json_format # type: ignore +import google.protobuf + +import grpc +from .base import DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO, OperationsTransport + +PROTOBUF_VERSION = google.protobuf.__version__ + +OptionalRetry = Union[retries.AsyncRetry, object] + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, + grpc_version=None, + rest_version=requests_version, +) + + +class OperationsRestAsyncTransport(OperationsTransport): + """REST Async backend transport for Operations. + + Manages async long-running operations with an API service. + + When an API method normally takes long time to complete, it can be + designed to return [Operation][google.api_core.operations_v1.Operation] to the + client, and the client can use this interface to receive the real + response asynchronously by polling the operation resource, or pass + the operation resource to another API (such as Google Cloud Pub/Sub + API) to receive the response. Any API service that returns + long-running operations should implement the ``Operations`` + interface so developers can have a consistent client experience. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + """ + + def __init__( + self, + *, + host: str = "longrunning.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + http_options: Optional[Dict] = None, + path_prefix: str = "v1", + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + http_options: a dictionary of http_options for transcoding, to override + the defaults from operations.proto. Each method has an entry + with the corresponding http rules as value. + path_prefix: path prefix (usually represents API version). Set to + "v1" by default. + + """ + # Run the base constructor + # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. + # TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the + # credentials object + maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) + if maybe_url_match is None: + raise ValueError( + f"Unexpected hostname structure: {host}" + ) # pragma: NO COVER + + url_match_items = maybe_url_match.groupdict() + + host = f"{url_scheme}://{host}" if not url_match_items["scheme"] else host + + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + ) + # TODO Add support for default_host + self._session = AsyncAuthorizedSession(self._credentials) + # TODO REMOVE MTLS CODE + # if client_cert_source_for_mtls: + # self._session.configure_mtls_channel(client_cert_source_for_mtls) + self._prep_wrapped_messages(client_info, is_async=True) + self._http_options = http_options or {} + self._path_prefix = path_prefix + + async def _list_operations( + self, + request: operations_pb2.ListOperationsRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + compression: Optional[grpc.Compression] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Call the list operations method over HTTP. + + Args: + request (~.operations_pb2.ListOperationsRequest): + The request object. The request message for + [Operations.ListOperations][google.api_core.operations_v1.Operations.ListOperations]. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operations_pb2.ListOperationsResponse: + The response message for + [Operations.ListOperations][google.api_core.operations_v1.Operations.ListOperations]. + + """ + + http_options = [ + { + "method": "get", + "uri": "/{}/{{name=**}}/operations".format(self._path_prefix), + }, + ] + if "google.longrunning.Operations.ListOperations" in self._http_options: + http_options = self._http_options[ + "google.longrunning.Operations.ListOperations" + ] + + request_kwargs = self._convert_protobuf_message_to_dict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params_request = operations_pb2.ListOperationsRequest() + json_format.ParseDict(transcoded_request["query_params"], query_params_request) + query_params = json_format.MessageToDict( + query_params_request, + preserving_proto_field_name=False, + use_integers_for_enums=False, + ) + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = await getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + api_response = operations_pb2.ListOperationsResponse() + json_format.Parse(response.content, api_response, ignore_unknown_fields=False) + return api_response + + async def _get_operation( + self, + request: operations_pb2.GetOperationRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + compression: Optional[grpc.Compression] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + r"""Call the get operation method over HTTP. + + Args: + request (~.operations_pb2.GetOperationRequest): + The request object. The request message for + [Operations.GetOperation][google.api_core.operations_v1.Operations.GetOperation]. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operations_pb2.Operation: + This resource represents a long- + running operation that is the result of a + network API call. + + """ + + http_options = [ + { + "method": "get", + "uri": "/{}/{{name=**/operations/*}}".format(self._path_prefix), + }, + ] + if "google.longrunning.Operations.GetOperation" in self._http_options: + http_options = self._http_options[ + "google.longrunning.Operations.GetOperation" + ] + + request_kwargs = self._convert_protobuf_message_to_dict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params_request = operations_pb2.GetOperationRequest() + json_format.ParseDict(transcoded_request["query_params"], query_params_request) + query_params = json_format.MessageToDict( + query_params_request, + preserving_proto_field_name=False, + use_integers_for_enums=False, + ) + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = await getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + api_response = operations_pb2.Operation() + json_format.Parse( + await response.read(), api_response, ignore_unknown_fields=False + ) + return api_response + + async def _delete_operation( + self, + request: operations_pb2.DeleteOperationRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + compression: Optional[grpc.Compression] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> empty_pb2.Empty: + r"""Call the delete operation method over HTTP. + + Args: + request (~.operations_pb2.DeleteOperationRequest): + The request object. The request message for + [Operations.DeleteOperation][google.api_core.operations_v1.Operations.DeleteOperation]. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + + http_options = [ + { + "method": "delete", + "uri": "/{}/{{name=**/operations/*}}".format(self._path_prefix), + }, + ] + if "google.longrunning.Operations.DeleteOperation" in self._http_options: + http_options = self._http_options[ + "google.longrunning.Operations.DeleteOperation" + ] + + request_kwargs = self._convert_protobuf_message_to_dict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params_request = operations_pb2.DeleteOperationRequest() + json_format.ParseDict(transcoded_request["query_params"], query_params_request) + query_params = json_format.MessageToDict( + query_params_request, + preserving_proto_field_name=False, + use_integers_for_enums=False, + ) + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = await getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + return empty_pb2.Empty() + + async def _cancel_operation( + self, + request: operations_pb2.CancelOperationRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + compression: Optional[grpc.Compression] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> empty_pb2.Empty: + r"""Call the cancel operation method over HTTP. + + Args: + request (~.operations_pb2.CancelOperationRequest): + The request object. The request message for + [Operations.CancelOperation][google.api_core.operations_v1.Operations.CancelOperation]. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + + http_options = [ + { + "method": "post", + "uri": "/{}/{{name=**/operations/*}}:cancel".format(self._path_prefix), + "body": "*", + }, + ] + if "google.longrunning.Operations.CancelOperation" in self._http_options: + http_options = self._http_options[ + "google.longrunning.Operations.CancelOperation" + ] + + request_kwargs = self._convert_protobuf_message_to_dict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + + # Jsonify the request body + body_request = operations_pb2.CancelOperationRequest() + json_format.ParseDict(transcoded_request["body"], body_request) + body = json_format.MessageToDict( + body_request, + preserving_proto_field_name=False, + use_integers_for_enums=False, + ) + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params_request = operations_pb2.CancelOperationRequest() + json_format.ParseDict(transcoded_request["query_params"], query_params_request) + query_params = json_format.MessageToDict( + query_params_request, + preserving_proto_field_name=False, + use_integers_for_enums=False, + ) + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = await getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + return empty_pb2.Empty() + + def _convert_protobuf_message_to_dict( + self, message: google.protobuf.message.Message + ): + r"""Converts protobuf message to a dictionary. + + When the dictionary is encoded to JSON, it conforms to proto3 JSON spec. + + Args: + message(google.protobuf.message.Message): The protocol buffers message + instance to serialize. + + Returns: + A dict representation of the protocol buffer message. + """ + # For backwards compatibility with protobuf 3.x 4.x + # Remove once support for protobuf 3.x and 4.x is dropped + # https://github.com/googleapis/python-api-core/issues/643 + if PROTOBUF_VERSION[0:2] in ["3.", "4."]: + result = json_format.MessageToDict( + message, + preserving_proto_field_name=True, + including_default_value_fields=True, # type: ignore # backward compatibility + ) + else: + result = json_format.MessageToDict( + message, + preserving_proto_field_name=True, + always_print_fields_with_no_presence=True, + ) + + return result + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], operations_pb2.ListOperationsResponse + ]: + return self._list_operations + + @property + def get_operation( + self, + ) -> Callable[[operations_pb2.GetOperationRequest], operations_pb2.Operation]: + return self._get_operation + + @property + def delete_operation( + self, + ) -> Callable[[operations_pb2.DeleteOperationRequest], empty_pb2.Empty]: + return self._delete_operation + + @property + def cancel_operation( + self, + ) -> Callable[[operations_pb2.CancelOperationRequest], empty_pb2.Empty]: + return self._cancel_operation + + +__all__ = ("OperationsRestTransport",) diff --git a/noxfile.py b/noxfile.py index 3fc4a722c..f4235a26e 100644 --- a/noxfile.py +++ b/noxfile.py @@ -147,7 +147,7 @@ def default(session, install_grpc=True, prerelease=False, install_auth_aio=False if install_auth_aio: session.install( - "google-auth @ git+https://git@github.com/googleapis/google-auth-library-python@8833ad6f92c3300d6645355994c7db2356bd30ad" + "google-auth[aiohttp]" ) # Print out package versions of dependencies diff --git a/tests/unit/operations_v1/test_operations_rest_client.py b/tests/unit/operations_v1/test_operations_rest_client.py index 4ab4f1f75..8f4507dd3 100644 --- a/tests/unit/operations_v1/test_operations_rest_client.py +++ b/tests/unit/operations_v1/test_operations_rest_client.py @@ -33,6 +33,15 @@ from google.api_core.operations_v1 import transports import google.auth from google.auth import credentials as ga_credentials +from google.auth.aio import credentials as ga_credentials_async + +try: + import aiohttp + + AIOHTTP_INSTALLED = True +except ImportError: + AIOHTTP_INSTALLED = False + from google.auth.exceptions import MutualTLSChannelError from google.longrunning import operations_pb2 from google.oauth2 import service_account @@ -125,11 +134,14 @@ def test_operations_client_from_service_account_info(client_class): @pytest.mark.parametrize( - "transport_class,transport_name", [(transports.OperationsRestTransport, "rest")] + "transport_class", + [ + transports.OperationsRestTransport, + # TODO: Add support for service account credentials + # transports.OperationsRestAsyncTransport, + ], ) -def test_operations_client_service_account_always_use_jwt( - transport_class, transport_name -): +def test_operations_client_service_account_always_use_jwt(transport_class): with mock.patch.object( service_account.Credentials, "with_always_use_jwt_access", create=True ) as use_jwt: @@ -712,11 +724,22 @@ def test_transport_instance(): assert client.transport is transport -@pytest.mark.parametrize("transport_class", [transports.OperationsRestTransport]) -def test_transport_adc(transport_class): +@pytest.mark.parametrize( + "transport_class,credentials", + [ + (transports.OperationsRestTransport, ga_credentials.AnonymousCredentials()), + ( + transports.OperationsRestAsyncTransport, + ga_credentials_async.AnonymousCredentials(), + ), + ], +) +def test_transport_adc(transport_class, credentials): + if not AIOHTTP_INSTALLED: + pytest.skip("aiohttp extra is required") # Test default credentials are used if not provided. with mock.patch.object(google.auth, "default") as adc: - adc.return_value = (ga_credentials.AnonymousCredentials(), None) + adc.return_value = (credentials, None) transport_class() adc.assert_called_once() @@ -800,12 +823,20 @@ def test_operations_auth_adc(): ) -def test_operations_http_transport_client_cert_source_for_mtls(): +@pytest.mark.parametrize( + "transport_class", + [ + transports.OperationsRestTransport, + # TODO + # transports.OperationsRestAsyncTransport, + ], +) +def test_operations_http_transport_client_cert_source_for_mtls(transport_class): cred = ga_credentials.AnonymousCredentials() with mock.patch( "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" ) as mock_configure_mtls_channel: - transports.OperationsRestTransport( + transport_class( credentials=cred, client_cert_source_for_mtls=client_cert_source_callback ) mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback)