Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: disambiguate method names that are reserved in transport classes #1187

Merged
merged 1 commit into from
Feb 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):

# Wrap the RPC method; this adds retry and timeout information,
# and friendly error handling.
rpc = self._transport._wrapped_methods[self._transport.{{ method.name|snake_case}}]
rpc = self._transport._wrapped_methods[self._transport.{{ method.transport_safe_name|snake_case}}]
{% if method.field_headers %}

# Certain fields should be provided within the metadata header;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ class {{ service.name }}Transport(abc.ABC):
# Precompute the wrapped methods.
self._wrapped_methods = {
{% for method in service.methods.values() %}
self.{{ method.name|snake_case }}: gapic_v1.method.wrap_method(
self.{{ method.name|snake_case }},
self.{{ method.transport_safe_name|snake_case }}: gapic_v1.method.wrap_method(
self.{{ method.transport_safe_name|snake_case }},
{% if method.retry %}
default_retry=retries.Retry(
{% if method.retry.initial_backoff %}initial={{ method.retry.initial_backoff }},{% endif %}
Expand Down Expand Up @@ -160,7 +160,7 @@ class {{ service.name }}Transport(abc.ABC):
{% for method in service.methods.values() %}

@property
def {{ method.name|snake_case }}(self) -> Callable[
def {{ method.transport_safe_name|snake_case }}(self) -> Callable[
[{{ method.input.ident }}],
Union[
{{ method.output.ident }},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
{% for method in service.methods.values() %}

@property
def {{ method.name|snake_case }}(self) -> Callable[
def {{ method.transport_safe_name|snake_case }}(self) -> Callable[
[{{ method.input.ident }}],
{{ method.output.ident }}]:
r"""Return a callable for the{{ ' ' }}
Expand All @@ -269,13 +269,13 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
# the request.
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if '{{ method.name|snake_case }}' not in self._stubs:
self._stubs['{{ method.name|snake_case }}'] = self.grpc_channel.{{ method.grpc_stub_type }}(
if '{{ method.transport_safe_name|snake_case }}' not in self._stubs:
self._stubs['{{ method.transport_safe_name|snake_case }}'] = self.grpc_channel.{{ method.grpc_stub_type }}(
'/{{ '.'.join(method.meta.address.package) }}.{{ service.name }}/{{ method.name }}',
request_serializer={{ method.input.ident }}.{% if method.input.ident.python_import.module.endswith('_pb2') %}SerializeToString{% else %}serialize{% endif %},
response_deserializer={{ method.output.ident }}.{% if method.output.ident.python_import.module.endswith('_pb2') %}FromString{% else %}deserialize{% endif %},
)
return self._stubs['{{ method.name|snake_case }}']
return self._stubs['{{ method.transport_safe_name|snake_case }}']
{% endfor %}

{% if opts.add_iam_methods %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,12 +396,12 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% for method in service.methods.values()|sort(attribute="name") %}

@property
def {{method.name | snake_case}}(self) -> Callable[
def {{method.transport_safe_name | snake_case}}(self) -> Callable[
[{{method.input.ident}}],
{{method.output.ident}}]:
stub = self._STUBS.get("{{method.name | snake_case}}")
stub = self._STUBS.get("{{method.transport_safe_name | snake_case}}")
if not stub:
stub = self._STUBS["{{method.name | snake_case}}"] = self._{{method.name}}(self._session, self._host, self._interceptor)
stub = self._STUBS["{{method.transport_safe_name | snake_case}}"] = self._{{method.name}}(self._session, self._host, self._interceptor)

# The return type is fine, but mypy isn't sophisticated enough to determine what's going on here.
# In C++ this would require a dynamic_cast
Expand Down
2 changes: 1 addition & 1 deletion gapic/ads-templates/setup.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ setuptools.setup(
install_requires=(
{# TODO(dovs): remove when 1.x deprecation is complete #}
{% if 'rest' in opts.transport %}
'google-api-core[grpc] >= 2.1.0, < 3.0.0dev',
'google-api-core[grpc] >= 2.4.0, < 3.0.0dev',
{% else %}
'google-api-core[grpc] >= 1.28.0, < 3.0.0dev',
{% endif %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def test_{{ method_name }}(request_type, transport: str = 'grpc'):

# Mock the actual call within the gRPC stub, and fake the request.
with mock.patch.object(
type(client.transport.{{ method.name|snake_case }}),
type(client.transport.{{ method.transport_safe_name|snake_case }}),
'__call__') as call:
# Designate an appropriate return value for the call.
{% if method.void %}
Expand Down Expand Up @@ -571,7 +571,7 @@ def test_{{ method_name }}_empty_call():

# Mock the actual call within the gRPC stub, and fake the request.
with mock.patch.object(
type(client.transport.{{ method.name|snake_case }}),
type(client.transport.{{ method.transport_safe_name|snake_case }}),
'__call__') as call:
client.{{ method_name }}()
call.assert_called()
Expand Down Expand Up @@ -600,7 +600,7 @@ def test_{{ method_name }}_field_headers():

# Mock the actual call within the gRPC stub, and fake the request.
with mock.patch.object(
type(client.transport.{{ method.name|snake_case }}),
type(client.transport.{{ method.transport_safe_name|snake_case }}),
'__call__') as call:
{% if method.void %}
call.return_value = None
Expand Down Expand Up @@ -638,7 +638,7 @@ def test_{{ method_name }}_from_dict_foreign():
)
# Mock the actual call within the gRPC stub, and fake the request.
with mock.patch.object(
type(client.transport.{{ method.name|snake_case }}),
type(client.transport.{{ method.transport_safe_name|snake_case }}),
'__call__') as call:
# Designate an appropriate return value for the call.
{% if method.void %}
Expand Down Expand Up @@ -668,7 +668,7 @@ def test_{{ method_name }}_flattened():

# Mock the actual call within the gRPC stub, and fake the request.
with mock.patch.object(
type(client.transport.{{ method.name|snake_case }}),
type(client.transport.{{ method.transport_safe_name|snake_case }}),
'__call__') as call:
# Designate an appropriate return value for the call.
{% if method.void %}
Expand Down Expand Up @@ -746,7 +746,7 @@ def test_{{ method_name }}_pager(transport_name: str = "grpc"):

# Mock the actual call within the gRPC stub, and fake the request.
with mock.patch.object(
type(client.transport.{{ method.name|snake_case }}),
type(client.transport.{{ method.transport_safe_name|snake_case }}),
'__call__') as call:
# Set the response to a series of pages.
call.side_effect = (
Expand Down Expand Up @@ -808,7 +808,7 @@ def test_{{ method_name }}_pages(transport_name: str = "grpc"):

# Mock the actual call within the gRPC stub, and fake the request.
with mock.patch.object(
type(client.transport.{{ method.name|snake_case }}),
type(client.transport.{{ method.transport_safe_name|snake_case }}),
'__call__') as call:
# Set the response to a series of pages.
{% if method.paged_result_field.map%}
Expand Down Expand Up @@ -1184,7 +1184,7 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
def test_{{ method_name }}_rest_unset_required_fields():
transport = transports.{{ service.rest_transport_name }}(credentials=ga_credentials.AnonymousCredentials)

unset_fields = transport.{{ method.name|snake_case }}._get_unset_required_fields({})
unset_fields = transport.{{ method.transport_safe_name|snake_case }}._get_unset_required_fields({})
assert set(unset_fields) == (set(({% for param in method.query_params|sort %}"{{ param|camel_case }}", {% endfor %})) & set(({% for param in method.input.required_fields %}"{{param.name|camel_case}}", {% endfor %})))


Expand Down Expand Up @@ -1645,7 +1645,7 @@ def test_{{ service.name|snake_case }}_base_transport():
# raise NotImplementedError.
methods = (
{% for method in service.methods.values() %}
'{{ method.name|snake_case }}',
'{{ method.transport_safe_name|snake_case }}',
{% endfor %}
{% if opts.add_iam_methods %}
'set_iam_policy',
Expand Down
15 changes: 15 additions & 0 deletions gapic/schema/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,21 @@ class Method:
def __getattr__(self, name):
return getattr(self.method_pb, name)

@property
def transport_safe_name(self) -> str:
# These names conflict with other methods in the transport.
# We don't want to disambiguate the names at the client level
# because the disambiguated name is less convenient and user friendly.
#
# Note: this should really be a class variable,
# but python 3.6 can't handle that.
TRANSPORT_UNSAFE_NAMES = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it is the second in a row disambiguating PR (the other one: #1178). Can we (should we) have some sort of centralized name-conflict resolution mechanism? I think it is fine for this PR to be as is, but in case more things like that happen, it would be good to have a unified name-conflict resolution unility class or something like that. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Different names that refer to different data structures need to be disambiguated against different sets of names. It doesn't matter too much if there's a field named create_channel. I'd argue that there are costs to being too aggressive disambiguating names, and that it makes more sense to deal with them in the Method or Field class when collisions occur.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that they have to be disamgiguated against different sets of names, but that set of names can be a standardized parameter to the generic disambiguating mechanism. Basically some mechanism though which all names go through, and get disambiguated depending on the context. But yeah, this is beyond the scope of this PR and we have many more important tasks to work on.

"CreateChannel",
"GrpcChannel",
"OperationsClient",
}
return f"{self.name}_" if self.name in TRANSPORT_UNSAFE_NAMES else self.name

@property
def is_operation_polling_method(self):
return self.output.is_extended_operation and self.options.Extensions[ex_ops_pb2.operation_polling_method]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,15 @@ class {{ service.async_client_name }}:

@classmethod
def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[ClientOptions] = None):
"""Return the API endpoint and client cert source for mutual TLS.
"""Return the API endpoint and client cert source for mutual TLS.

The client cert source is determined in the following order:
(1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the
client cert source is None.
(2) if `client_options.client_cert_source` is provided, use the provided one; if the
default client cert source exists, use the default one; otherwise the client cert
source is None.

The API endpoint is determined in the following order:
(1) if `client_options.api_endpoint` if provided, use the provided one.
(2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the
Expand All @@ -118,7 +118,7 @@ class {{ service.async_client_name }}:
Returns:
Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the
client cert source to use.

Raises:
google.auth.exceptions.MutualTLSChannelError: If any errors happen.
"""
Expand Down Expand Up @@ -302,7 +302,7 @@ class {{ service.async_client_name }}:
# Wrap the RPC method; this adds retry and timeout information,
# and friendly error handling.
rpc = gapic_v1.method_async.wrap_method(
self._client._transport.{{ method.name|snake_case }},
self._client._transport.{{ method.transport_safe_name|snake_case }},
{% if method.retry %}
default_retry=retries.Retry(
{% if method.retry.initial_backoff %}initial={{ method.retry.initial_backoff }},{% endif %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):

# Wrap the RPC method; this adds retry and timeout information,
# and friendly error handling.
rpc = self._transport._wrapped_methods[self._transport.{{ method.name|snake_case}}]
rpc = self._transport._wrapped_methods[self._transport.{{ method.transport_safe_name|snake_case}}]

{% if method.explicit_routing %}
header_params = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ class {{ service.name }}Transport(abc.ABC):
# Precompute the wrapped methods.
self._wrapped_methods = {
{% for method in service.methods.values() %}
self.{{ method.name|snake_case }}: gapic_v1.method.wrap_method(
self.{{ method.name|snake_case }},
self.{{ method.transport_safe_name|snake_case }}: gapic_v1.method.wrap_method(
self.{{ method.transport_safe_name|snake_case }},
{% if method.retry %}
default_retry=retries.Retry(
{% if method.retry.initial_backoff %}initial={{ method.retry.initial_backoff }},{% endif %}
Expand Down Expand Up @@ -160,7 +160,7 @@ class {{ service.name }}Transport(abc.ABC):
{% for method in service.methods.values() %}

@property
def {{ method.name|snake_case }}(self) -> Callable[
def {{ method.transport_safe_name|snake_case }}(self) -> Callable[
[{{ method.input.ident }}],
Union[
{{ method.output.ident }},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
{% for method in service.methods.values() %}

@property
def {{ method.name|snake_case }}(self) -> Callable[
def {{ method.transport_safe_name|snake_case }}(self) -> Callable[
[{{ method.input.ident }}],
{{ method.output.ident }}]:
r"""Return a callable for the{{ ' ' }}
Expand All @@ -269,13 +269,13 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
# the request.
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if '{{ method.name|snake_case }}' not in self._stubs:
self._stubs['{{ method.name|snake_case }}'] = self.grpc_channel.{{ method.grpc_stub_type }}(
if '{{ method.transport_safe_name|snake_case }}' not in self._stubs:
self._stubs['{{ method.transport_safe_name|snake_case }}'] = self.grpc_channel.{{ method.grpc_stub_type }}(
'/{{ '.'.join(method.meta.address.package) }}.{{ service.name }}/{{ method.name }}',
request_serializer={{ method.input.ident }}.{% if method.input.ident.python_import.module.endswith('_pb2') %}SerializeToString{% else %}serialize{% endif %},
response_deserializer={{ method.output.ident }}.{% if method.output.ident.python_import.module.endswith('_pb2') %}FromString{% else %}deserialize{% endif %},
)
return self._stubs['{{ method.name|snake_case }}']
return self._stubs['{{ method.transport_safe_name|snake_case }}']
{% endfor %}

{% if opts.add_iam_methods %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
{% for method in service.methods.values() %}

@property
def {{ method.name|snake_case }}(self) -> Callable[
def {{ method.transport_safe_name|snake_case }}(self) -> Callable[
[{{ method.input.ident }}],
Awaitable[{{ method.output.ident }}]]:
r"""Return a callable for the{{ ' ' }}
Expand All @@ -270,13 +270,13 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
# the request.
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if '{{ method.name|snake_case }}' not in self._stubs:
self._stubs['{{ method.name|snake_case }}'] = self.grpc_channel.{{ method.grpc_stub_type }}(
if '{{ method.transport_safe_name|snake_case }}' not in self._stubs:
self._stubs['{{ method.transport_safe_name|snake_case }}'] = self.grpc_channel.{{ method.grpc_stub_type }}(
'/{{ '.'.join(method.meta.address.package) }}.{{ service.name }}/{{ method.name }}',
request_serializer={{ method.input.ident }}.{% if method.input.ident.python_import.module.endswith('_pb2') %}SerializeToString{% else %}serialize{% endif %},
response_deserializer={{ method.output.ident }}.{% if method.output.ident.python_import.module.endswith('_pb2') %}FromString{% else %}deserialize{% endif %},
)
return self._stubs['{{ method.name|snake_case }}']
return self._stubs['{{ method.transport_safe_name|snake_case }}']
{% endfor %}

{% if opts.add_iam_methods %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,12 +396,12 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% for method in service.methods.values()|sort(attribute="name") %}

@property
def {{method.name | snake_case}}(self) -> Callable[
def {{method.transport_safe_name|snake_case}}(self) -> Callable[
[{{method.input.ident}}],
{{method.output.ident}}]:
stub = self._STUBS.get("{{method.name | snake_case}}")
stub = self._STUBS.get("{{method.transport_safe_name|snake_case}}")
if not stub:
stub = self._STUBS["{{method.name | snake_case}}"] = self._{{method.name}}(self._session, self._host, self._interceptor)
stub = self._STUBS["{{method.transport_safe_name|snake_case}}"] = self._{{method.name}}(self._session, self._host, self._interceptor)

# The return type is fine, but mypy isn't sophisticated enough to determine what's going on here.
# In C++ this would require a dynamic_cast
Expand Down
Loading