Skip to content

Commit

Permalink
fix: disambiguate method names that are reserved in transport classes
Browse files Browse the repository at this point in the history
In addition to the method specific stubs they provide, the generated
transports expose other methods to e.g. create a gRPC channel.
This presents the opportunity for a naming collision if an API has a
CreateChannel method.

This PR disambiguates colliding method names at the transport
level. Client level method names are unchanged for ergonomic reasons.
  • Loading branch information
software-dov committed Feb 3, 2022
1 parent e527089 commit ae9a5c8
Show file tree
Hide file tree
Showing 20 changed files with 627 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):

# Wrap the RPC method; this adds retry and timeout information,
# and friendly error handling.
rpc = self._transport._wrapped_methods[self._transport.{{ method.name|snake_case}}]
rpc = self._transport._wrapped_methods[self._transport.{{ method.transport_safe_name|snake_case}}]
{% if method.field_headers %}

# Certain fields should be provided within the metadata header;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ class {{ service.name }}Transport(abc.ABC):
# Precompute the wrapped methods.
self._wrapped_methods = {
{% for method in service.methods.values() %}
self.{{ method.name|snake_case }}: gapic_v1.method.wrap_method(
self.{{ method.name|snake_case }},
self.{{ method.transport_safe_name|snake_case }}: gapic_v1.method.wrap_method(
self.{{ method.transport_safe_name|snake_case }},
{% if method.retry %}
default_retry=retries.Retry(
{% if method.retry.initial_backoff %}initial={{ method.retry.initial_backoff }},{% endif %}
Expand Down Expand Up @@ -160,7 +160,7 @@ class {{ service.name }}Transport(abc.ABC):
{% for method in service.methods.values() %}

@property
def {{ method.name|snake_case }}(self) -> Callable[
def {{ method.transport_safe_name|snake_case }}(self) -> Callable[
[{{ method.input.ident }}],
Union[
{{ method.output.ident }},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
{% for method in service.methods.values() %}

@property
def {{ method.name|snake_case }}(self) -> Callable[
def {{ method.transport_safe_name|snake_case }}(self) -> Callable[
[{{ method.input.ident }}],
{{ method.output.ident }}]:
r"""Return a callable for the{{ ' ' }}
Expand All @@ -269,13 +269,13 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
# the request.
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if '{{ method.name|snake_case }}' not in self._stubs:
self._stubs['{{ method.name|snake_case }}'] = self.grpc_channel.{{ method.grpc_stub_type }}(
if '{{ method.transport_safe_name|snake_case }}' not in self._stubs:
self._stubs['{{ method.transport_safe_name|snake_case }}'] = self.grpc_channel.{{ method.grpc_stub_type }}(
'/{{ '.'.join(method.meta.address.package) }}.{{ service.name }}/{{ method.name }}',
request_serializer={{ method.input.ident }}.{% if method.input.ident.python_import.module.endswith('_pb2') %}SerializeToString{% else %}serialize{% endif %},
response_deserializer={{ method.output.ident }}.{% if method.output.ident.python_import.module.endswith('_pb2') %}FromString{% else %}deserialize{% endif %},
)
return self._stubs['{{ method.name|snake_case }}']
return self._stubs['{{ method.transport_safe_name|snake_case }}']
{% endfor %}

{% if opts.add_iam_methods %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,12 +396,12 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% for method in service.methods.values()|sort(attribute="name") %}

@property
def {{method.name | snake_case}}(self) -> Callable[
def {{method.transport_safe_name | snake_case}}(self) -> Callable[
[{{method.input.ident}}],
{{method.output.ident}}]:
stub = self._STUBS.get("{{method.name | snake_case}}")
stub = self._STUBS.get("{{method.transport_safe_name | snake_case}}")
if not stub:
stub = self._STUBS["{{method.name | snake_case}}"] = self._{{method.name}}(self._session, self._host, self._interceptor)
stub = self._STUBS["{{method.transport_safe_name | snake_case}}"] = self._{{method.name}}(self._session, self._host, self._interceptor)

# The return type is fine, but mypy isn't sophisticated enough to determine what's going on here.
# In C++ this would require a dynamic_cast
Expand Down
2 changes: 1 addition & 1 deletion gapic/ads-templates/setup.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ setuptools.setup(
install_requires=(
{# TODO(dovs): remove when 1.x deprecation is complete #}
{% if 'rest' in opts.transport %}
'google-api-core[grpc] >= 2.1.0, < 3.0.0dev',
'google-api-core[grpc] >= 2.4.0, < 3.0.0dev',
{% else %}
'google-api-core[grpc] >= 1.28.0, < 3.0.0dev',
{% endif %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def test_{{ method_name }}(request_type, transport: str = 'grpc'):

# Mock the actual call within the gRPC stub, and fake the request.
with mock.patch.object(
type(client.transport.{{ method.name|snake_case }}),
type(client.transport.{{ method.transport_safe_name|snake_case }}),
'__call__') as call:
# Designate an appropriate return value for the call.
{% if method.void %}
Expand Down Expand Up @@ -571,7 +571,7 @@ def test_{{ method_name }}_empty_call():

# Mock the actual call within the gRPC stub, and fake the request.
with mock.patch.object(
type(client.transport.{{ method.name|snake_case }}),
type(client.transport.{{ method.transport_safe_name|snake_case }}),
'__call__') as call:
client.{{ method_name }}()
call.assert_called()
Expand Down Expand Up @@ -600,7 +600,7 @@ def test_{{ method_name }}_field_headers():

# Mock the actual call within the gRPC stub, and fake the request.
with mock.patch.object(
type(client.transport.{{ method.name|snake_case }}),
type(client.transport.{{ method.transport_safe_name|snake_case }}),
'__call__') as call:
{% if method.void %}
call.return_value = None
Expand Down Expand Up @@ -638,7 +638,7 @@ def test_{{ method_name }}_from_dict_foreign():
)
# Mock the actual call within the gRPC stub, and fake the request.
with mock.patch.object(
type(client.transport.{{ method.name|snake_case }}),
type(client.transport.{{ method.transport_safe_name|snake_case }}),
'__call__') as call:
# Designate an appropriate return value for the call.
{% if method.void %}
Expand Down Expand Up @@ -668,7 +668,7 @@ def test_{{ method_name }}_flattened():

# Mock the actual call within the gRPC stub, and fake the request.
with mock.patch.object(
type(client.transport.{{ method.name|snake_case }}),
type(client.transport.{{ method.transport_safe_name|snake_case }}),
'__call__') as call:
# Designate an appropriate return value for the call.
{% if method.void %}
Expand Down Expand Up @@ -746,7 +746,7 @@ def test_{{ method_name }}_pager(transport_name: str = "grpc"):

# Mock the actual call within the gRPC stub, and fake the request.
with mock.patch.object(
type(client.transport.{{ method.name|snake_case }}),
type(client.transport.{{ method.transport_safe_name|snake_case }}),
'__call__') as call:
# Set the response to a series of pages.
call.side_effect = (
Expand Down Expand Up @@ -808,7 +808,7 @@ def test_{{ method_name }}_pages(transport_name: str = "grpc"):

# Mock the actual call within the gRPC stub, and fake the request.
with mock.patch.object(
type(client.transport.{{ method.name|snake_case }}),
type(client.transport.{{ method.transport_safe_name|snake_case }}),
'__call__') as call:
# Set the response to a series of pages.
{% if method.paged_result_field.map%}
Expand Down Expand Up @@ -1184,7 +1184,7 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
def test_{{ method_name }}_rest_unset_required_fields():
transport = transports.{{ service.rest_transport_name }}(credentials=ga_credentials.AnonymousCredentials)

unset_fields = transport.{{ method.name|snake_case }}._get_unset_required_fields({})
unset_fields = transport.{{ method.transport_safe_name|snake_case }}._get_unset_required_fields({})
assert set(unset_fields) == (set(({% for param in method.query_params|sort %}"{{ param|camel_case }}", {% endfor %})) & set(({% for param in method.input.required_fields %}"{{param.name|camel_case}}", {% endfor %})))


Expand Down Expand Up @@ -1645,7 +1645,7 @@ def test_{{ service.name|snake_case }}_base_transport():
# raise NotImplementedError.
methods = (
{% for method in service.methods.values() %}
'{{ method.name|snake_case }}',
'{{ method.transport_safe_name|snake_case }}',
{% endfor %}
{% if opts.add_iam_methods %}
'set_iam_policy',
Expand Down
13 changes: 13 additions & 0 deletions gapic/schema/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,9 +945,22 @@ class Method:
default_factory=metadata.Metadata,
)

# These names conflict with other methods in the transport.
# We don't want to disambiguate the names at the client level
# because the disambiguated name is less convenient and user friendly.
TRANSPORT_UNSAFE_NAMES: ClassVar[FrozenSet] = frozenset({
"CreateChannel",
"GrpcChannel",
"OperationsClient",
})

def __getattr__(self, name):
return getattr(self.method_pb, name)

@property
def transport_safe_name(self) -> str:
return f"{self.name}_" if self.name in self.TRANSPORT_UNSAFE_NAMES else self.name

@property
def is_operation_polling_method(self):
return self.output.is_extended_operation and self.options.Extensions[ex_ops_pb2.operation_polling_method]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,15 @@ class {{ service.async_client_name }}:

@classmethod
def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[ClientOptions] = None):
"""Return the API endpoint and client cert source for mutual TLS.
"""Return the API endpoint and client cert source for mutual TLS.

The client cert source is determined in the following order:
(1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the
client cert source is None.
(2) if `client_options.client_cert_source` is provided, use the provided one; if the
default client cert source exists, use the default one; otherwise the client cert
source is None.

The API endpoint is determined in the following order:
(1) if `client_options.api_endpoint` if provided, use the provided one.
(2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the
Expand All @@ -118,7 +118,7 @@ class {{ service.async_client_name }}:
Returns:
Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the
client cert source to use.

Raises:
google.auth.exceptions.MutualTLSChannelError: If any errors happen.
"""
Expand Down Expand Up @@ -302,7 +302,7 @@ class {{ service.async_client_name }}:
# Wrap the RPC method; this adds retry and timeout information,
# and friendly error handling.
rpc = gapic_v1.method_async.wrap_method(
self._client._transport.{{ method.name|snake_case }},
self._client._transport.{{ method.transport_safe_name|snake_case }},
{% if method.retry %}
default_retry=retries.Retry(
{% if method.retry.initial_backoff %}initial={{ method.retry.initial_backoff }},{% endif %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):

# Wrap the RPC method; this adds retry and timeout information,
# and friendly error handling.
rpc = self._transport._wrapped_methods[self._transport.{{ method.name|snake_case}}]
rpc = self._transport._wrapped_methods[self._transport.{{ method.transport_safe_name|snake_case}}]

{% if method.explicit_routing %}
header_params = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ class {{ service.name }}Transport(abc.ABC):
# Precompute the wrapped methods.
self._wrapped_methods = {
{% for method in service.methods.values() %}
self.{{ method.name|snake_case }}: gapic_v1.method.wrap_method(
self.{{ method.name|snake_case }},
self.{{ method.transport_safe_name|snake_case }}: gapic_v1.method.wrap_method(
self.{{ method.transport_safe_name|snake_case }},
{% if method.retry %}
default_retry=retries.Retry(
{% if method.retry.initial_backoff %}initial={{ method.retry.initial_backoff }},{% endif %}
Expand Down Expand Up @@ -160,7 +160,7 @@ class {{ service.name }}Transport(abc.ABC):
{% for method in service.methods.values() %}

@property
def {{ method.name|snake_case }}(self) -> Callable[
def {{ method.transport_safe_name|snake_case }}(self) -> Callable[
[{{ method.input.ident }}],
Union[
{{ method.output.ident }},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
{% for method in service.methods.values() %}

@property
def {{ method.name|snake_case }}(self) -> Callable[
def {{ method.transport_safe_name|snake_case }}(self) -> Callable[
[{{ method.input.ident }}],
{{ method.output.ident }}]:
r"""Return a callable for the{{ ' ' }}
Expand All @@ -269,13 +269,13 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
# the request.
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if '{{ method.name|snake_case }}' not in self._stubs:
self._stubs['{{ method.name|snake_case }}'] = self.grpc_channel.{{ method.grpc_stub_type }}(
if '{{ method.transport_safe_name|snake_case }}' not in self._stubs:
self._stubs['{{ method.transport_safe_name|snake_case }}'] = self.grpc_channel.{{ method.grpc_stub_type }}(
'/{{ '.'.join(method.meta.address.package) }}.{{ service.name }}/{{ method.name }}',
request_serializer={{ method.input.ident }}.{% if method.input.ident.python_import.module.endswith('_pb2') %}SerializeToString{% else %}serialize{% endif %},
response_deserializer={{ method.output.ident }}.{% if method.output.ident.python_import.module.endswith('_pb2') %}FromString{% else %}deserialize{% endif %},
)
return self._stubs['{{ method.name|snake_case }}']
return self._stubs['{{ method.transport_safe_name|snake_case }}']
{% endfor %}

{% if opts.add_iam_methods %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
{% for method in service.methods.values() %}

@property
def {{ method.name|snake_case }}(self) -> Callable[
def {{ method.transport_safe_name|snake_case }}(self) -> Callable[
[{{ method.input.ident }}],
Awaitable[{{ method.output.ident }}]]:
r"""Return a callable for the{{ ' ' }}
Expand All @@ -270,13 +270,13 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
# the request.
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if '{{ method.name|snake_case }}' not in self._stubs:
self._stubs['{{ method.name|snake_case }}'] = self.grpc_channel.{{ method.grpc_stub_type }}(
if '{{ method.transport_safe_name|snake_case }}' not in self._stubs:
self._stubs['{{ method.transport_safe_name|snake_case }}'] = self.grpc_channel.{{ method.grpc_stub_type }}(
'/{{ '.'.join(method.meta.address.package) }}.{{ service.name }}/{{ method.name }}',
request_serializer={{ method.input.ident }}.{% if method.input.ident.python_import.module.endswith('_pb2') %}SerializeToString{% else %}serialize{% endif %},
response_deserializer={{ method.output.ident }}.{% if method.output.ident.python_import.module.endswith('_pb2') %}FromString{% else %}deserialize{% endif %},
)
return self._stubs['{{ method.name|snake_case }}']
return self._stubs['{{ method.transport_safe_name|snake_case }}']
{% endfor %}

{% if opts.add_iam_methods %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,12 +396,12 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% for method in service.methods.values()|sort(attribute="name") %}

@property
def {{method.name | snake_case}}(self) -> Callable[
def {{method.transport_safe_name|snake_case}}(self) -> Callable[
[{{method.input.ident}}],
{{method.output.ident}}]:
stub = self._STUBS.get("{{method.name | snake_case}}")
stub = self._STUBS.get("{{method.transport_safe_name|snake_case}}")
if not stub:
stub = self._STUBS["{{method.name | snake_case}}"] = self._{{method.name}}(self._session, self._host, self._interceptor)
stub = self._STUBS["{{method.transport_safe_name|snake_case}}"] = self._{{method.name}}(self._session, self._host, self._interceptor)

# The return type is fine, but mypy isn't sophisticated enough to determine what's going on here.
# In C++ this would require a dynamic_cast
Expand Down
Loading

0 comments on commit ae9a5c8

Please sign in to comment.