Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix deserialization error for LRO which has discriminator #2589

Closed
wants to merge 52 commits into from
Closed
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
4c12c62
code
msyyc May 22, 2024
49ab757
fix for legacy test
msyyc May 22, 2024
3c22f8e
inv and black
msyyc May 22, 2024
7adc1c8
fix mypy
msyyc May 22, 2024
ca98be8
fix pyright error
msyyc May 22, 2024
e142d48
fix pylint
msyyc May 22, 2024
1857677
inv
msyyc May 22, 2024
dccdd16
update
msyyc May 23, 2024
c97c28c
review
msyyc May 23, 2024
2eea3a4
fix
msyyc May 23, 2024
a718d6a
fix
msyyc May 23, 2024
ebf0e16
Fix test
msyyc May 23, 2024
738caab
fix multiapi test
msyyc May 23, 2024
76c8d03
disable deserialize for all initial operation
msyyc May 24, 2024
d163f3d
review
msyyc May 24, 2024
133bf7b
inv
msyyc May 24, 2024
f830afe
Merge branch 'main' of https://github.com/Azure/autorest.python into …
msyyc May 24, 2024
f464d47
update changelog
msyyc May 24, 2024
0647135
inv
msyyc May 24, 2024
c0f0610
Merge branch 'deserialization-fix' of https://github.com/Azure/autore…
msyyc May 24, 2024
148e934
inv
msyyc May 24, 2024
52c13db
Merge branch 'main' into deserialization-fix
msyyc May 28, 2024
8b7f073
Merge branch 'main' into deserialization-fix
msyyc May 29, 2024
24d19a1
force initial operation to return stream
May 29, 2024
fa2c5ce
revert extra changes in builder_serializer
May 29, 2024
41f1d09
regen
May 29, 2024
3a04d08
regen lropaging
May 29, 2024
c92b238
regen with load_body for aiohttp
May 29, 2024
f71903c
fix
msyyc May 30, 2024
49d85c3
inv
msyyc May 30, 2024
1602dc1
use pipeline_response.http_response for legacy
msyyc May 30, 2024
015d844
fix test
msyyc May 30, 2024
bcbc2d3
inv
msyyc May 30, 2024
44017d5
Merge branch 'main' of https://github.com/Azure/autorest.python into …
Jun 4, 2024
f7e5be6
Merge branch 'deserialization-fix' of https://github.com/Azure/autore…
Jun 4, 2024
f0eae62
read in response
Jun 4, 2024
d8ef34f
inv
msyyc Jun 5, 2024
b086c62
fix multiapi test
msyyc Jun 5, 2024
066eb98
inv
msyyc Jun 5, 2024
cfb50f5
fix pyright
msyyc Jun 5, 2024
df5cf4c
simplify code
Jun 5, 2024
cfe3dbd
generate
Jun 5, 2024
e805530
regen
Jun 5, 2024
eb9250a
regenerate
Jun 5, 2024
f6577fb
black
Jun 5, 2024
d70e0e9
Merge branch 'main' of https://github.com/Azure/autorest.python into …
Jun 6, 2024
d71c77b
regen to revert changes
Jun 6, 2024
a3f8f22
revert changes
Jun 6, 2024
91d8643
regen
Jun 6, 2024
6681433
Merge branch 'main' of https://github.com/Azure/autorest.python into …
Jun 7, 2024
01b89e8
regen
Jun 7, 2024
c0e2522
revert tasks change
Jun 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
8 changes: 8 additions & 0 deletions .chronus/changes/deserialization-fix-2024-4-24-16-48-41.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
changeKind: fix
packages:
- "@autorest/python"
- "@azure-tools/typespec-python"
---

Fix deserialization error for lro when return type has discriminator and succeed in initial response
8 changes: 8 additions & 0 deletions packages/autorest.python/autorest/codegen/models/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ def __init__(
self.has_etag: bool = self.yaml_data.get("hasEtag", False)
self.cross_language_definition_id: Optional[str] = self.yaml_data.get("crossLanguageDefinitionId")

@property
def stream_value(self) -> Union[str, bool]:
return (
f'kwargs.pop("stream", {self.has_stream_response})'
if self.expose_stream_keyword and self.has_response_body
else self.has_stream_response
)

@property
def has_form_data_body(self):
return self.parameters.has_form_data_body
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -560,13 +560,8 @@ def example_template(self, builder: OperationType) -> List[str]:

def make_pipeline_call(self, builder: OperationType) -> List[str]:
type_ignore = self.async_mode and builder.group_name == "" # is in a mixin
stream_value = (
f'kwargs.pop("stream", {builder.has_stream_response})'
if builder.expose_stream_keyword and builder.has_response_body
else builder.has_stream_response
)
return [
f"_stream = {stream_value}",
f"_stream = {builder.stream_value}",
f"pipeline_response: PipelineResponse = {self._call_method}self._client.{self.pipeline_name}.run( "
+ f"{'# type: ignore' if type_ignore else ''} # pylint: disable=protected-access",
" _request,",
Expand Down Expand Up @@ -925,7 +920,7 @@ def response_headers_and_deserialization(
if self.code_model.options["models_mode"] == "msrest":
deserialize_code.append("deserialized = self._deserialize(")
deserialize_code.append(f" '{response.serialization_type}',{pylint_disable}")
deserialize_code.append(" pipeline_response")
deserialize_code.append(" pipeline_response.http_response")
deserialize_code.append(")")
elif self.code_model.options["models_mode"] == "dpg":
if builder.has_stream_response:
Expand Down Expand Up @@ -964,12 +959,11 @@ def response_headers_and_deserialization(
def handle_error_response(self, builder: OperationType) -> List[str]:
async_await = "await " if self.async_mode else ""
retval = [f"if response.status_code not in {str(builder.success_status_codes)}:"]
retval.extend(
[
" if _stream:",
f" {async_await} response.read() # Load the body in memory and close the socket",
]
)
response_read = f" {async_await}response.read() # Load the body in memory and close the socket"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can always call read, regardless of stream or not

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is very strange to judge _stream when its value is True, so I think it is better to keep the logic.

if builder.stream_value is True: # _stream is True so no need to judge it
retval.append(response_read)
elif isinstance(builder.stream_value, str): # _stream is not sure, so we need to judge it
retval.extend([" if _stream:", f" {response_read}"])
type_ignore = " # type: ignore" if _need_type_ignore(builder) else ""
retval.append(
f" map_error(status_code=response.status_code, response=response, error_map=error_map){type_ignore}"
Expand Down Expand Up @@ -1218,12 +1212,15 @@ def _extract_data_callback(self, builder: PagingOperationType) -> List[str]:
response = builder.responses[0]
deserialized = "pipeline_response.http_response.json()"
if self.code_model.options["models_mode"] == "msrest":
suffix = ".http_response" if hasattr(builder, "initial_operation") else ""
deserialize_type = response.serialization_type
pylint_disable = " # pylint: disable=protected-access"
if isinstance(response.type, ModelType) and not response.type.internal:
deserialize_type = f'"{response.serialization_type}"'
pylint_disable = ""
deserialized = f"self._deserialize(\n {deserialize_type},{pylint_disable}\n pipeline_response\n)"
deserialized = (
f"self._deserialize(\n {deserialize_type},{pylint_disable}\n pipeline_response{suffix}\n)"
)
retval.append(f" deserialized = {deserialized}")
elif self.code_model.options["models_mode"] == "dpg":
# we don't want to generate paging models for DPG
Expand Down Expand Up @@ -1318,6 +1315,8 @@ def initial_call(self, builder: LROOperationType) -> List[str]:
retval.append(" params=_params,")
retval.append(" **kwargs")
retval.append(" )")
retval.append(f" {'await ' if self.async_mode else ''}raw_result.http_response.read() # type: ignore")

retval.append("kwargs.pop('error_map', None)")
return retval

Expand Down
7 changes: 7 additions & 0 deletions packages/autorest.python/autorest/preprocess/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,12 +412,18 @@ def update_lro_operation(
yaml_data: Dict[str, Any],
is_overload: bool = False,
) -> None:
def convert_initial_operation_response_type(data: Dict[str, Any]) -> None:
for response in data.get("responses", []):
response["type"] = KNOWN_TYPES["binary"]

self.update_operation(code_model, yaml_data, is_overload=is_overload)
self.update_operation(code_model, yaml_data["initialOperation"], is_overload=is_overload)
convert_initial_operation_response_type(yaml_data["initialOperation"])
self._update_lro_operation_helper(yaml_data)
for overload in yaml_data.get("overloads", []):
self._update_lro_operation_helper(overload)
self.update_operation(code_model, overload["initialOperation"], is_overload=True)
convert_initial_operation_response_type(overload["initialOperation"])

def update_paging_operation(
self,
Expand Down Expand Up @@ -466,6 +472,7 @@ def update_operation_groups(self, code_model: Dict[str, Any], client: Dict[str,
def update_yaml(self, yaml_data: Dict[str, Any]) -> None:
"""Convert in place the YAML str."""
self.update_types(yaml_data["types"])
yaml_data["types"] += KNOWN_TYPES.values()
for client in yaml_data["clients"]:
self.update_client(client)
self.update_operation_groups(yaml_data, client)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,6 @@ async def head200(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-
response = pipeline_response.http_response

if response.status_code not in [200, 404]:
if _stream:
await response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

Expand Down Expand Up @@ -130,8 +128,6 @@ async def head204(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-
response = pipeline_response.http_response

if response.status_code not in [204, 404]:
if _stream:
await response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

Expand Down Expand Up @@ -173,8 +169,6 @@ async def head404(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-
response = pipeline_response.http_response

if response.status_code not in [204, 404]:
if _stream:
await response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,6 @@ def head200(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-return
response = pipeline_response.http_response

if response.status_code not in [200, 404]:
if _stream:
response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

Expand Down Expand Up @@ -150,8 +148,6 @@ def head204(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-return
response = pipeline_response.http_response

if response.status_code not in [204, 404]:
if _stream:
response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

Expand Down Expand Up @@ -193,8 +189,6 @@ def head404(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-return
response = pipeline_response.http_response

if response.status_code not in [204, 404]:
if _stream:
response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,6 @@ async def head200(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-
response = pipeline_response.http_response

if response.status_code not in [200, 404]:
if _stream:
await response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

Expand Down Expand Up @@ -130,8 +128,6 @@ async def head204(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-
response = pipeline_response.http_response

if response.status_code not in [204, 404]:
if _stream:
await response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

Expand Down Expand Up @@ -173,8 +169,6 @@ async def head404(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-
response = pipeline_response.http_response

if response.status_code not in [204, 404]:
if _stream:
await response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,6 @@ def head200(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-return
response = pipeline_response.http_response

if response.status_code not in [200, 404]:
if _stream:
response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

Expand Down Expand Up @@ -150,8 +148,6 @@ def head204(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-return
response = pipeline_response.http_response

if response.status_code not in [204, 404]:
if _stream:
response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

Expand Down Expand Up @@ -193,8 +189,6 @@ def head404(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-return
response = pipeline_response.http_response

if response.status_code not in [204, 404]:
if _stream:
response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# --------------------------------------------------------------------------
from io import IOBase
import sys
from typing import Any, Callable, Dict, IO, Iterable, Optional, Type, TypeVar, Union, cast, overload
from typing import Any, Callable, Dict, IO, Iterable, Iterator, Optional, Type, TypeVar, Union, cast, overload

from my.library import CustomDefaultPollingMethod, CustomPager, CustomPoller

Expand Down Expand Up @@ -74,7 +74,9 @@ def build_polling_paging_example_basic_paging_request(**kwargs: Any) -> HttpRequ

class PollingPagingExampleOperationsMixin(PollingPagingExampleMixinABC):

def _basic_polling_initial(self, product: Optional[Union[JSON, IO[bytes]]] = None, **kwargs: Any) -> Optional[JSON]:
def _basic_polling_initial(
self, product: Optional[Union[JSON, IO[bytes]]] = None, **kwargs: Any
) -> Iterator[bytes]:
error_map: MutableMapping[int, Type[HttpResponseError]] = {
401: ClientAuthenticationError,
404: ResourceNotFoundError,
Expand All @@ -87,7 +89,7 @@ def _basic_polling_initial(self, product: Optional[Union[JSON, IO[bytes]]] = Non
_params = kwargs.pop("params", {}) or {}

content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None))
cls: ClsType[Optional[JSON]] = kwargs.pop("cls", None)
cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None)

content_type = content_type or "application/json"
_json = None
Expand All @@ -109,30 +111,28 @@ def _basic_polling_initial(self, product: Optional[Union[JSON, IO[bytes]]] = Non
)
_request.url = self._client.format_url(_request.url)

_stream = False
_stream = True
pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access
_request, stream=_stream, **kwargs
)

response = pipeline_response.http_response

if response.status_code not in [200, 204]:
if _stream:
response.read() # Load the body in memory and close the socket
response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

deserialized = None
if response.status_code == 200:
if response.content:
deserialized = response.json()
else:
deserialized = None
deserialized = response.iter_bytes()

if response.status_code == 204:
deserialized = response.iter_bytes()

if cls:
return cls(pipeline_response, deserialized, {}) # type: ignore
return cls(pipeline_response, cast(Iterator[bytes], deserialized), {}) # type: ignore

return deserialized # type: ignore
return cast(Iterator[bytes], deserialized) # type: ignore

@overload
def begin_basic_polling(
Expand Down Expand Up @@ -245,6 +245,7 @@ def begin_basic_polling(
params=_params,
**kwargs
)
raw_result.http_response.read() # type: ignore
kwargs.pop("error_map", None)

def get_long_running_output(pipeline_response):
Expand Down Expand Up @@ -336,8 +337,6 @@ def get_next(next_link=None):
response = pipeline_response.http_response

if response.status_code not in [200]:
if _stream:
response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# --------------------------------------------------------------------------
from io import IOBase
import sys
from typing import Any, AsyncIterable, Callable, Dict, IO, Optional, Type, TypeVar, Union, cast, overload
from typing import Any, AsyncIterable, AsyncIterator, Callable, Dict, IO, Optional, Type, TypeVar, Union, cast, overload

from my.library.aio import AsyncCustomDefaultPollingMethod, AsyncCustomPager, AsyncCustomPoller

Expand Down Expand Up @@ -47,7 +47,7 @@ class PollingPagingExampleOperationsMixin(PollingPagingExampleMixinABC):

async def _basic_polling_initial(
self, product: Optional[Union[JSON, IO[bytes]]] = None, **kwargs: Any
) -> Optional[JSON]:
) -> AsyncIterator[bytes]:
error_map: MutableMapping[int, Type[HttpResponseError]] = {
401: ClientAuthenticationError,
404: ResourceNotFoundError,
Expand All @@ -60,7 +60,7 @@ async def _basic_polling_initial(
_params = kwargs.pop("params", {}) or {}

content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None))
cls: ClsType[Optional[JSON]] = kwargs.pop("cls", None)
cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None)

content_type = content_type or "application/json"
_json = None
Expand All @@ -82,30 +82,28 @@ async def _basic_polling_initial(
)
_request.url = self._client.format_url(_request.url)

_stream = False
_stream = True
pipeline_response: PipelineResponse = await self._client._pipeline.run( # type: ignore # pylint: disable=protected-access
_request, stream=_stream, **kwargs
)

response = pipeline_response.http_response

if response.status_code not in [200, 204]:
if _stream:
await response.read() # Load the body in memory and close the socket
await response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

deserialized = None
if response.status_code == 200:
if response.content:
deserialized = response.json()
else:
deserialized = None
deserialized = response.iter_bytes()

if response.status_code == 204:
deserialized = response.iter_bytes()

if cls:
return cls(pipeline_response, deserialized, {}) # type: ignore
return cls(pipeline_response, cast(AsyncIterator[bytes], deserialized), {}) # type: ignore

return deserialized # type: ignore
return cast(AsyncIterator[bytes], deserialized) # type: ignore

@overload
async def begin_basic_polling(
Expand Down Expand Up @@ -218,6 +216,7 @@ async def begin_basic_polling(
params=_params,
**kwargs
)
await raw_result.http_response.read() # type: ignore
kwargs.pop("error_map", None)

def get_long_running_output(pipeline_response):
Expand Down Expand Up @@ -313,8 +312,6 @@ async def get_next(next_link=None):
response = pipeline_response.http_response

if response.status_code not in [200]:
if _stream:
await response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

Expand Down
Loading
Loading