Skip to content

Commit

Permalink
fix deserialization for lro which has discriminator
Browse files Browse the repository at this point in the history
  • Loading branch information
msyyc committed May 20, 2024
1 parent 70826da commit b4e7fff
Show file tree
Hide file tree
Showing 75 changed files with 2,757 additions and 1,280 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1320,6 +1320,7 @@ def initial_call(self, builder: LROOperationType) -> List[str]:
[f" {parameter.client_name}={parameter.client_name}," for parameter in builder.parameters.method]
)
retval.append(" cls=lambda x,y,z: x,")
retval.append(" stream=True,")
retval.append(" headers=_headers,")
retval.append(" params=_params,")
retval.append(" **kwargs")
Expand Down
4 changes: 4 additions & 0 deletions packages/autorest.python/autorest/preprocess/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,10 @@ def update_lro_operation(
self._update_lro_operation_helper(overload)
self.update_operation(code_model, overload["initialOperation"], is_overload=True)

# for lro initial reponse, there is no need to make deserialization so we mark it
# as stream operation by default which will not make deserialization by default
yaml_data["initialOperation"]["exposeStreamKeyword"] = True

def update_paging_operation(
self,
code_model: Dict[str, Any],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _basic_polling_initial(self, product: Optional[Union[JSON, IO[bytes]]] = Non
)
_request.url = self._client.format_url(_request.url)

_stream = False
_stream = kwargs.pop("stream", False)
pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access
_request, stream=_stream, **kwargs
)
Expand All @@ -124,10 +124,13 @@ def _basic_polling_initial(self, product: Optional[Union[JSON, IO[bytes]]] = Non

deserialized = None
if response.status_code == 200:
if response.content:
deserialized = response.json()
if _stream:
deserialized = response.iter_bytes()
else:
deserialized = None
if response.content:
deserialized = response.json()
else:
deserialized = None

if cls:
return cls(pipeline_response, deserialized, {}) # type: ignore
Expand Down Expand Up @@ -241,6 +244,7 @@ def begin_basic_polling(
product=product,
content_type=content_type,
cls=lambda x, y, z: x,
stream=True,
headers=_headers,
params=_params,
**kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ async def _basic_polling_initial(
)
_request.url = self._client.format_url(_request.url)

_stream = False
_stream = kwargs.pop("stream", False)
pipeline_response: PipelineResponse = await self._client._pipeline.run( # type: ignore # pylint: disable=protected-access
_request, stream=_stream, **kwargs
)
Expand All @@ -97,10 +97,13 @@ async def _basic_polling_initial(

deserialized = None
if response.status_code == 200:
if response.content:
deserialized = response.json()
if _stream:
deserialized = response.iter_bytes()
else:
deserialized = None
if response.content:
deserialized = response.json()
else:
deserialized = None

if cls:
return cls(pipeline_response, deserialized, {}) # type: ignore
Expand Down Expand Up @@ -214,6 +217,7 @@ async def begin_basic_polling(
product=product,
content_type=content_type,
cls=lambda x, y, z: x,
stream=True,
headers=_headers,
params=_params,
**kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@
# --------------------------------------------------------------------------

from ._multiapi_service_client import MultiapiServiceClient

__all__ = ["MultiapiServiceClient"]
__all__ = ['MultiapiServiceClient']

try:
from ._patch import patch_sdk # type: ignore

patch_sdk()
except ImportError:
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
# pylint: disable=unused-import,ungrouped-imports
from azure.core.credentials import TokenCredential


class MultiapiServiceClientConfiguration:
"""Configuration for MultiapiServiceClient.
Expand All @@ -30,27 +29,32 @@ class MultiapiServiceClientConfiguration:
:type credential: ~azure.core.credentials.TokenCredential
"""

def __init__(self, credential: "TokenCredential", **kwargs: Any):
def __init__(
self,
credential: "TokenCredential",
**kwargs: Any
):
if credential is None:
raise ValueError("Parameter 'credential' must not be None.")

self.credential = credential
self.credential_scopes = kwargs.pop("credential_scopes", ["https://management.azure.com/.default"])
kwargs.setdefault("sdk_moniker", "azure-multiapi-sample/{}".format(VERSION))
self.credential_scopes = kwargs.pop('credential_scopes', ['https://management.azure.com/.default'])
kwargs.setdefault('sdk_moniker', 'azure-multiapi-sample/{}'.format(VERSION))
self.polling_interval = kwargs.get("polling_interval", 30)
self._configure(**kwargs)

def _configure(self, **kwargs: Any):
self.user_agent_policy = kwargs.get("user_agent_policy") or policies.UserAgentPolicy(**kwargs)
self.headers_policy = kwargs.get("headers_policy") or policies.HeadersPolicy(**kwargs)
self.proxy_policy = kwargs.get("proxy_policy") or policies.ProxyPolicy(**kwargs)
self.logging_policy = kwargs.get("logging_policy") or policies.NetworkTraceLoggingPolicy(**kwargs)
self.http_logging_policy = kwargs.get("http_logging_policy") or ARMHttpLoggingPolicy(**kwargs)
self.retry_policy = kwargs.get("retry_policy") or policies.RetryPolicy(**kwargs)
self.custom_hook_policy = kwargs.get("custom_hook_policy") or policies.CustomHookPolicy(**kwargs)
self.redirect_policy = kwargs.get("redirect_policy") or policies.RedirectPolicy(**kwargs)
self.authentication_policy = kwargs.get("authentication_policy")
def _configure(
self,
**kwargs: Any
):
self.user_agent_policy = kwargs.get('user_agent_policy') or policies.UserAgentPolicy(**kwargs)
self.headers_policy = kwargs.get('headers_policy') or policies.HeadersPolicy(**kwargs)
self.proxy_policy = kwargs.get('proxy_policy') or policies.ProxyPolicy(**kwargs)
self.logging_policy = kwargs.get('logging_policy') or policies.NetworkTraceLoggingPolicy(**kwargs)
self.http_logging_policy = kwargs.get('http_logging_policy') or ARMHttpLoggingPolicy(**kwargs)
self.retry_policy = kwargs.get('retry_policy') or policies.RetryPolicy(**kwargs)
self.custom_hook_policy = kwargs.get('custom_hook_policy') or policies.CustomHookPolicy(**kwargs)
self.redirect_policy = kwargs.get('redirect_policy') or policies.RedirectPolicy(**kwargs)
self.authentication_policy = kwargs.get('authentication_policy')
if self.credential and not self.authentication_policy:
self.authentication_policy = ARMChallengeAuthenticationPolicy(
self.credential, *self.credential_scopes, **kwargs
)
self.authentication_policy = ARMChallengeAuthenticationPolicy(self.credential, *self.credential_scopes, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,13 @@
# pylint: disable=unused-import,ungrouped-imports
from azure.core.credentials import TokenCredential


class _SDKClient(object):
def __init__(self, *args, **kwargs):
"""This is a fake class to support current implemetation of MultiApiClientMixin."
Will be removed in final version of multiapi azure-core based client
"""
pass


class MultiapiServiceClient(MultiapiServiceClientOperationsMixin, MultiApiClientMixin, _SDKClient):
"""Service client for multiapi client testing.
Expand All @@ -56,30 +54,28 @@ class MultiapiServiceClient(MultiapiServiceClientOperationsMixin, MultiApiClient
:keyword int polling_interval: Default waiting time between two polls for LRO operations if no Retry-After header is present.
"""

DEFAULT_API_VERSION = "3.0.0"
DEFAULT_API_VERSION = '3.0.0'
_PROFILE_TAG = "azure.multiapi.sample.MultiapiServiceClient"
LATEST_PROFILE = ProfileDefinition(
{
_PROFILE_TAG: {
None: DEFAULT_API_VERSION,
"begin_test_lro": "1.0.0",
"begin_test_lro_and_paging": "1.0.0",
"test_one": "2.0.0",
}
},
_PROFILE_TAG + " latest",
LATEST_PROFILE = ProfileDefinition({
_PROFILE_TAG: {
None: DEFAULT_API_VERSION,
'begin_test_lro': '1.0.0',
'begin_test_lro_and_paging': '1.0.0',
'test_one': '2.0.0',
}},
_PROFILE_TAG + " latest"
)

def __init__(
self,
credential: "TokenCredential",
api_version: Optional[str] = None,
api_version: Optional[str]=None,
base_url: str = "http://localhost:3000",
profile: KnownProfiles = KnownProfiles.default,
profile: KnownProfiles=KnownProfiles.default,
**kwargs: Any
):
if api_version:
kwargs.setdefault("api_version", api_version)
kwargs.setdefault('api_version', api_version)
self._config = MultiapiServiceClientConfiguration(credential, **kwargs)
_policies = kwargs.pop("policies", None)
if _policies is None:
Expand All @@ -100,7 +96,10 @@ def __init__(
self._config.http_logging_policy,
]
self._client = ARMPipelineClient(base_url=base_url, policies=_policies, **kwargs)
super(MultiapiServiceClient, self).__init__(api_version=api_version, profile=profile)
super(MultiapiServiceClient, self).__init__(
api_version=api_version,
profile=profile
)

@classmethod
def _models_dict(cls, api_version):
Expand All @@ -110,79 +109,62 @@ def _models_dict(cls, api_version):
def models(cls, api_version=DEFAULT_API_VERSION):
"""Module depends on the API version:
* 1.0.0: :mod:`v1.models<azure.multiapi.sample.v1.models>`
* 2.0.0: :mod:`v2.models<azure.multiapi.sample.v2.models>`
* 3.0.0: :mod:`v3.models<azure.multiapi.sample.v3.models>`
* 1.0.0: :mod:`v1.models<azure.multiapi.sample.v1.models>`
* 2.0.0: :mod:`v2.models<azure.multiapi.sample.v2.models>`
* 3.0.0: :mod:`v3.models<azure.multiapi.sample.v3.models>`
"""
if api_version == "1.0.0":
if api_version == '1.0.0':
from .v1 import models

return models
elif api_version == "2.0.0":
elif api_version == '2.0.0':
from .v2 import models

return models
elif api_version == "3.0.0":
elif api_version == '3.0.0':
from .v3 import models

return models
raise ValueError("API version {} is not available".format(api_version))

@property
def operation_group_one(self):
"""Instance depends on the API version:
* 1.0.0: :class:`OperationGroupOneOperations<azure.multiapi.sample.v1.operations.OperationGroupOneOperations>`
* 2.0.0: :class:`OperationGroupOneOperations<azure.multiapi.sample.v2.operations.OperationGroupOneOperations>`
* 3.0.0: :class:`OperationGroupOneOperations<azure.multiapi.sample.v3.operations.OperationGroupOneOperations>`
* 1.0.0: :class:`OperationGroupOneOperations<azure.multiapi.sample.v1.operations.OperationGroupOneOperations>`
* 2.0.0: :class:`OperationGroupOneOperations<azure.multiapi.sample.v2.operations.OperationGroupOneOperations>`
* 3.0.0: :class:`OperationGroupOneOperations<azure.multiapi.sample.v3.operations.OperationGroupOneOperations>`
"""
api_version = self._get_api_version("operation_group_one")
if api_version == "1.0.0":
api_version = self._get_api_version('operation_group_one')
if api_version == '1.0.0':
from .v1.operations import OperationGroupOneOperations as OperationClass
elif api_version == "2.0.0":
elif api_version == '2.0.0':
from .v2.operations import OperationGroupOneOperations as OperationClass
elif api_version == "3.0.0":
elif api_version == '3.0.0':
from .v3.operations import OperationGroupOneOperations as OperationClass
else:
raise ValueError("API version {} does not have operation group 'operation_group_one'".format(api_version))
self._config.api_version = api_version
return OperationClass(
self._client,
self._config,
Serializer(self._models_dict(api_version)),
Deserializer(self._models_dict(api_version)),
api_version,
)
return OperationClass(self._client, self._config, Serializer(self._models_dict(api_version)), Deserializer(self._models_dict(api_version)), api_version)

@property
def operation_group_two(self):
"""Instance depends on the API version:
* 2.0.0: :class:`OperationGroupTwoOperations<azure.multiapi.sample.v2.operations.OperationGroupTwoOperations>`
* 3.0.0: :class:`OperationGroupTwoOperations<azure.multiapi.sample.v3.operations.OperationGroupTwoOperations>`
* 2.0.0: :class:`OperationGroupTwoOperations<azure.multiapi.sample.v2.operations.OperationGroupTwoOperations>`
* 3.0.0: :class:`OperationGroupTwoOperations<azure.multiapi.sample.v3.operations.OperationGroupTwoOperations>`
"""
api_version = self._get_api_version("operation_group_two")
if api_version == "2.0.0":
api_version = self._get_api_version('operation_group_two')
if api_version == '2.0.0':
from .v2.operations import OperationGroupTwoOperations as OperationClass
elif api_version == "3.0.0":
elif api_version == '3.0.0':
from .v3.operations import OperationGroupTwoOperations as OperationClass
else:
raise ValueError("API version {} does not have operation group 'operation_group_two'".format(api_version))
self._config.api_version = api_version
return OperationClass(
self._client,
self._config,
Serializer(self._models_dict(api_version)),
Deserializer(self._models_dict(api_version)),
api_version,
)
return OperationClass(self._client, self._config, Serializer(self._models_dict(api_version)), Deserializer(self._models_dict(api_version)), api_version)

def close(self):
self._client.close()

def __enter__(self):
self._client.__enter__()
return self

def __exit__(self, *exc_details):
self._client.__exit__(*exc_details)
Loading

0 comments on commit b4e7fff

Please sign in to comment.