Skip to content

Commit

Permalink
Revert "switch over completely to rest requests and responses (#2605)" (
Browse files Browse the repository at this point in the history
  • Loading branch information
iscai-msft authored Jun 5, 2024
1 parent 4d9448e commit fddb389
Show file tree
Hide file tree
Showing 522 changed files with 5,193 additions and 5,223 deletions.
7 changes: 0 additions & 7 deletions .chronus/changes/switch_to_rest-2024-4-30-12-59-16.md

This file was deleted.

12 changes: 12 additions & 0 deletions packages/autorest.python/autorest/codegen/models/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,18 @@ def link_lro_initial_operations(self) -> None:
if isinstance(operation, (LROOperation, LROPagingOperation)):
operation.initial_operation = self.lookup_operation(id(operation.yaml_data["initialOperation"]))

@property
def need_request_converter(self) -> bool:
"""
Whether we need to convert our created azure.core.rest.HttpRequest to
azure.core.pipeline.transport.HttpRequest
"""
return (
self.code_model.options["show_operations"]
and bool(self.request_builders)
and not self.code_model.options["version_tolerant"]
)

@property
def has_abstract_operations(self) -> bool:
"""Whether there is abstract operation in any operation group."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,11 @@ def need_vendored_code(self, async_mode: bool) -> bool:
return True
if async_mode:
return self.need_mixin_abc
return self.need_mixin_abc or self.has_etag or self.has_form_data
return self.need_request_converter or self.need_mixin_abc or self.has_etag or self.has_form_data

@property
def need_request_converter(self) -> bool:
return any(c for c in self.clients if c.need_request_converter)

@property
def need_mixin_abc(self) -> bool:
Expand Down
38 changes: 27 additions & 11 deletions packages/autorest.python/autorest/codegen/models/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,8 @@ def imports( # pylint: disable=too-many-branches, disable=too-many-statements
file_import.add_import("warnings", ImportType.STDLIB)

relative_path = "..." if async_mode else ".."
if self.code_model.need_request_converter:
file_import.add_submodule_import(f"{relative_path}_vendor", "_convert_request", ImportType.LOCAL)
if self.has_etag:
file_import.add_submodule_import(
"exceptions",
Expand All @@ -375,18 +377,32 @@ def imports( # pylint: disable=too-many-branches, disable=too-many-statements
if not async_mode:
file_import.add_submodule_import(f"{relative_path}_vendor", "prep_if_match", ImportType.LOCAL)
file_import.add_submodule_import(f"{relative_path}_vendor", "prep_if_none_match", ImportType.LOCAL)
if async_mode:
file_import.add_submodule_import(
"rest",
"AsyncHttpResponse",
ImportType.SDKCORE,
)
if self.code_model.need_request_converter:
if async_mode:
file_import.add_submodule_import(
"azure.core.pipeline.transport",
"AsyncHttpResponse",
ImportType.SDKCORE,
)
else:
file_import.add_submodule_import(
"azure.core.pipeline.transport",
"HttpResponse",
ImportType.SDKCORE,
)
else:
file_import.add_submodule_import(
"rest",
"HttpResponse",
ImportType.SDKCORE,
)
if async_mode:
file_import.add_submodule_import(
"rest",
"AsyncHttpResponse",
ImportType.SDKCORE,
)
else:
file_import.add_submodule_import(
"rest",
"HttpResponse",
ImportType.SDKCORE,
)
if self.code_model.options["builders_visibility"] == "embedded" and not async_mode:
file_import.merge(self.request_builder.imports())
file_import.add_submodule_import(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,11 @@ def _create_request_builder_call(

def _postprocess_http_request(self, builder: OperationType, template_url: Optional[str] = None) -> List[str]:
retval: List[str] = []
if not self.code_model.options["version_tolerant"]:
pass_files = ""
if builder.parameters.has_body and builder.parameters.body_parameter.client_name == "files":
pass_files = ", _files"
retval.append(f"_request = _convert_request(_request{pass_files})")
if builder.parameters.path:
retval.extend(self.serialize_path(builder))
url_to_format = "_request.url"
Expand Down Expand Up @@ -964,12 +969,13 @@ def response_headers_and_deserialization(
def handle_error_response(self, builder: OperationType) -> List[str]:
async_await = "await " if self.async_mode else ""
retval = [f"if response.status_code not in {str(builder.success_status_codes)}:"]
retval.extend(
[
" if _stream:",
f" {async_await} response.read() # Load the body in memory and close the socket",
]
)
if not self.code_model.need_request_converter:
retval.extend(
[
" if _stream:",
f" {async_await} response.read() # Load the body in memory and close the socket",
]
)
type_ignore = " # type: ignore" if _need_type_ignore(builder) else ""
retval.append(
f" map_error(status_code=response.status_code, response=response, error_map=error_map){type_ignore}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ def serialize_vendor_file(self, clients: List[Client]) -> str:

# configure imports
file_import = FileImport(self.code_model)
if self.code_model.need_request_converter:
file_import.add_submodule_import(
"azure.core.pipeline.transport",
"HttpRequest",
ImportType.SDKCORE,
)

if self.code_model.need_mixin_abc:
file_import.add_submodule_import(
"abc",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@

{{ imports }}

{% if code_model.need_request_converter and not async_mode %}
def _convert_request(request, files=None):
data = request.content if not files else None
request = HttpRequest(method=request.method, url=request.url, headers=request.headers, data=data)
if files:
request.set_formdata_body(files)
return request
{% endif %}
{% if code_model.need_mixin_abc %}
{% for client in clients | selectattr("has_mixin") %}
{% set pylint_disable = "# pylint: disable=name-too-long" if (client.name | length) + ("MixinABC" | length) > 40 else "" %}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# Code generated by Microsoft (R) AutoRest Code Generator.
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

from azure.core.pipeline.transport import HttpRequest


def _convert_request(request, files=None):
data = request.content if not files else None
request = HttpRequest(method=request.method, url=request.url, headers=request.headers, data=data)
if files:
request.set_formdata_body(files)
return request
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
map_error,
)
from azure.core.pipeline import PipelineResponse
from azure.core.rest import AsyncHttpResponse, HttpRequest
from azure.core.pipeline.transport import AsyncHttpResponse
from azure.core.rest import HttpRequest
from azure.core.tracing.decorator_async import distributed_trace_async
from azure.mgmt.core.exceptions import ARMErrorFormat

from ..._vendor import _convert_request
from ...operations._http_success_operations import build_head200_request, build_head204_request, build_head404_request

if sys.version_info >= (3, 9):
Expand Down Expand Up @@ -74,6 +76,7 @@ async def head200(self, **kwargs: Any) -> bool:
headers=_headers,
params=_params,
)
_request = _convert_request(_request)
_request.url = self._client.format_url(_request.url)

_stream = False
Expand All @@ -84,8 +87,6 @@ async def head200(self, **kwargs: Any) -> bool:
response = pipeline_response.http_response

if response.status_code not in [200, 404]:
if _stream:
await response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response, error_format=ARMErrorFormat)

Expand Down Expand Up @@ -118,6 +119,7 @@ async def head204(self, **kwargs: Any) -> bool:
headers=_headers,
params=_params,
)
_request = _convert_request(_request)
_request.url = self._client.format_url(_request.url)

_stream = False
Expand All @@ -128,8 +130,6 @@ async def head204(self, **kwargs: Any) -> bool:
response = pipeline_response.http_response

if response.status_code not in [204, 404]:
if _stream:
await response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response, error_format=ARMErrorFormat)

Expand Down Expand Up @@ -162,6 +162,7 @@ async def head404(self, **kwargs: Any) -> bool:
headers=_headers,
params=_params,
)
_request = _convert_request(_request)
_request.url = self._client.format_url(_request.url)

_stream = False
Expand All @@ -172,8 +173,6 @@ async def head404(self, **kwargs: Any) -> bool:
response = pipeline_response.http_response

if response.status_code not in [204, 404]:
if _stream:
await response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response, error_format=ARMErrorFormat)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
map_error,
)
from azure.core.pipeline import PipelineResponse
from azure.core.rest import HttpRequest, HttpResponse
from azure.core.pipeline.transport import HttpResponse
from azure.core.rest import HttpRequest
from azure.core.tracing.decorator import distributed_trace
from azure.mgmt.core.exceptions import ARMErrorFormat

from .._serialization import Serializer
from .._vendor import _convert_request

if sys.version_info >= (3, 9):
from collections.abc import MutableMapping
Expand Down Expand Up @@ -98,6 +100,7 @@ def head200(self, **kwargs: Any) -> bool:
headers=_headers,
params=_params,
)
_request = _convert_request(_request)
_request.url = self._client.format_url(_request.url)

_stream = False
Expand All @@ -108,8 +111,6 @@ def head200(self, **kwargs: Any) -> bool:
response = pipeline_response.http_response

if response.status_code not in [200, 404]:
if _stream:
response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response, error_format=ARMErrorFormat)

Expand Down Expand Up @@ -142,6 +143,7 @@ def head204(self, **kwargs: Any) -> bool:
headers=_headers,
params=_params,
)
_request = _convert_request(_request)
_request.url = self._client.format_url(_request.url)

_stream = False
Expand All @@ -152,8 +154,6 @@ def head204(self, **kwargs: Any) -> bool:
response = pipeline_response.http_response

if response.status_code not in [204, 404]:
if _stream:
response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response, error_format=ARMErrorFormat)

Expand Down Expand Up @@ -186,6 +186,7 @@ def head404(self, **kwargs: Any) -> bool:
headers=_headers,
params=_params,
)
_request = _convert_request(_request)
_request.url = self._client.format_url(_request.url)

_stream = False
Expand All @@ -196,8 +197,6 @@ def head404(self, **kwargs: Any) -> bool:
response = pipeline_response.http_response

if response.status_code not in [204, 404]:
if _stream:
response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response, error_format=ARMErrorFormat)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from abc import ABC
from typing import TYPE_CHECKING

from azure.core.pipeline.transport import HttpRequest

from ._configuration import MultiapiServiceClientConfiguration

if TYPE_CHECKING:
Expand All @@ -17,6 +19,14 @@
from .._serialization import Deserializer, Serializer


def _convert_request(request, files=None):
data = request.content if not files else None
request = HttpRequest(method=request.method, url=request.url, headers=request.headers, data=data)
if files:
request.set_formdata_body(files)
return request


class MultiapiServiceClientMixinABC(ABC):
"""DO NOT use this class. It is for internal typing use only."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from abc import ABC
from typing import TYPE_CHECKING

from azure.core.pipeline.transport import HttpRequest

from ._configuration import MultiapiServiceClientConfiguration

if TYPE_CHECKING:
Expand Down
Loading

0 comments on commit fddb389

Please sign in to comment.