Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Core raw streaming #17920

Merged
merged 37 commits into from
May 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdk/core/azure-core/CLIENT_LIBRARY_DEVELOPER.md
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ class HttpResponse(object):
def text(self, encoding=None):
"""Return the whole body as a string."""

def stream_download(self, chunk_size=None, callback=None):
def stream_download(self, pipeline, **kwargs):
johanste marked this conversation as resolved.
Show resolved Hide resolved
"""Generator for streaming request body data.
Should be implemented by sub-classes if streaming download
is supported.
Expand Down
36 changes: 26 additions & 10 deletions sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
CONTENT_CHUNK_SIZE = 10 * 1024
_LOGGER = logging.getLogger(__name__)


class AioHttpTransport(AsyncHttpTransport):
"""AioHttp HTTP sender implementation.

Expand Down Expand Up @@ -89,7 +88,8 @@ async def open(self):
self.session = aiohttp.ClientSession(
loop=self._loop,
trust_env=self._use_env_settings,
cookie_jar=jar
cookie_jar=jar,
auto_decompress=False,
)
if self.session is not None:
await self.session.__aenter__()
Expand Down Expand Up @@ -191,22 +191,24 @@ async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpR
raise ServiceResponseError(err, error=err) from err
return response


class AioHttpStreamDownloadGenerator(AsyncIterator):
"""Streams the response body data.

:param pipeline: The pipeline object
:param response: The client response object.
:param block_size: block size of data sent over connection.
:type block_size: int
:keyword bool decompress: If True which is default, will attempt to decode the body based
on the ‘content-encoding’ header.
"""
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse) -> None:
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> None:
self.pipeline = pipeline
self.request = response.request
self.response = response
self.block_size = response.block_size
self._decompress = kwargs.pop("decompress", True)
if len(kwargs) > 0:
raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0]))
self.content_length = int(response.internal_response.headers.get('Content-Length', 0))
self.downloaded = 0
self._decompressor = None

def __len__(self):
return self.content_length
Expand All @@ -216,6 +218,18 @@ async def __anext__(self):
chunk = await self.response.internal_response.content.read(self.block_size)
if not chunk:
raise _ResponseStopIteration()
if not self._decompress:
return chunk
enc = self.response.internal_response.headers.get('Content-Encoding')
if not enc:
return chunk
enc = enc.lower()
if enc in ("gzip", "deflate"):
if not self._decompressor:
import zlib
zlib_mode = 16 + zlib.MAX_WBITS if enc == "gzip" else zlib.MAX_WBITS
self._decompressor = zlib.decompressobj(wbits=zlib_mode)
chunk = self._decompressor.decompress(chunk)
return chunk
except _ResponseStopIteration:
self.response.internal_response.close()
Expand Down Expand Up @@ -269,13 +283,15 @@ async def load_body(self) -> None:
"""Load in memory the body, so it could be accessible from sync methods."""
self._body = await self.internal_response.read()

def stream_download(self, pipeline) -> AsyncIteratorType[bytes]:
def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]:
"""Generator for streaming response body data.

:param pipeline: The pipeline object
:type pipeline: azure.core.pipeline
:type pipeline: azure.core.pipeline.Pipeline
:keyword bool decompress: If True which is default, will attempt to decode the body based
on the ‘content-encoding’ header.
"""
return AioHttpStreamDownloadGenerator(pipeline, self)
return AioHttpStreamDownloadGenerator(pipeline, self, **kwargs)

def __getstate__(self):
# Be sure body is loaded in memory, otherwise not pickable and let it throw
Expand Down
4 changes: 2 additions & 2 deletions sdk/core/azure-core/azure/core/pipeline/transport/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,8 +580,8 @@ def __repr__(self):


class HttpResponse(_HttpResponseBase): # pylint: disable=abstract-method
def stream_download(self, pipeline):
# type: (PipelineType) -> Iterator[bytes]
def stream_download(self, pipeline, **kwargs):
# type: (PipelineType, **Any) -> Iterator[bytes]
"""Generator for streaming request body data.

Should be implemented by sub-classes if streaming download
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,16 @@ class AsyncHttpResponse(_HttpResponseBase): # pylint: disable=abstract-method
Allows for the asynchronous streaming of data from the response.
"""

def stream_download(self, pipeline) -> AsyncIteratorType[bytes]:
def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]:
"""Generator for streaming response body data.

Should be implemented by sub-classes if streaming download
is supported. Will return an asynchronous generator.

:param pipeline: The pipeline object
:type pipeline: azure.core.pipeline
:type pipeline: azure.core.pipeline.Pipeline
:keyword bool decompress: If True which is default, will attempt to decode the body based
on the ‘content-encoding’ header.
"""

def parts(self) -> AsyncIterator:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
AsyncHttpResponse,
_ResponseStopIteration,
_iterate_response_content)
from ._requests_basic import RequestsTransportResponse
from ._requests_basic import RequestsTransportResponse, _read_raw_stream
from ._base_requests_async import RequestsAsyncTransportBase


Expand Down Expand Up @@ -138,17 +138,22 @@ class AsyncioStreamDownloadGenerator(AsyncIterator):

:param pipeline: The pipeline object
:param response: The response object.
:param generator iter_content_func: Iterator for response data.
:param int content_length: size of body in bytes.
:keyword bool decompress: If True which is default, will attempt to decode the body based
on the ‘content-encoding’ header.
"""
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse) -> None:
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> None:
self.pipeline = pipeline
self.request = response.request
self.response = response
self.block_size = response.block_size
self.iter_content_func = self.response.internal_response.iter_content(self.block_size)
decompress = kwargs.pop("decompress", True)
if len(kwargs) > 0:
raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0]))
annatisch marked this conversation as resolved.
Show resolved Hide resolved
if decompress:
self.iter_content_func = self.response.internal_response.iter_content(self.block_size)
else:
self.iter_content_func = _read_raw_stream(self.response.internal_response, self.block_size)
self.content_length = int(response.headers.get('Content-Length', 0))
self.downloaded = 0

def __len__(self):
return self.content_length
Expand Down Expand Up @@ -178,6 +183,6 @@ async def __anext__(self):
class AsyncioRequestsTransportResponse(AsyncHttpResponse, RequestsTransportResponse): # type: ignore
"""Asynchronous streaming of data from the response.
"""
def stream_download(self, pipeline) -> AsyncIteratorType[bytes]: # type: ignore
def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]: # type: ignore
"""Generator for streaming request body data."""
return AsyncioStreamDownloadGenerator(pipeline, self) # type: ignore
return AsyncioStreamDownloadGenerator(pipeline, self, **kwargs) # type: ignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
from typing import Iterator, Optional, Any, Union, TypeVar
import urllib3 # type: ignore
from urllib3.util.retry import Retry # type: ignore
from urllib3.exceptions import (
DecodeError, ReadTimeoutError, ProtocolError
)
import requests

from azure.core.configuration import ConnectionConfiguration
Expand All @@ -48,6 +51,25 @@

_LOGGER = logging.getLogger(__name__)

def _read_raw_stream(response, chunk_size=1):
# Special case for urllib3.
if hasattr(response.raw, 'stream'):
try:
for chunk in response.raw.stream(chunk_size, decode_content=False):
yield chunk
except ProtocolError as e:
raise requests.exceptions.ChunkedEncodingError(e)
except DecodeError as e:
raise requests.exceptions.ContentDecodingError(e)
except ReadTimeoutError as e:
raise requests.exceptions.ConnectionError(e)
else:
# Standard file-like object.
while True:
chunk = response.raw.read(chunk_size)
if not chunk:
break
yield chunk

class _RequestsTransportResponseBase(_HttpResponseBase):
"""Base class for accessing response data.
Expand Down Expand Up @@ -98,13 +120,21 @@ class StreamDownloadGenerator(object):

:param pipeline: The pipeline object
:param response: The response object.
:keyword bool decompress: If True which is default, will attempt to decode the body based
on the ‘content-encoding’ header.
"""
def __init__(self, pipeline, response):
def __init__(self, pipeline, response, **kwargs):
self.pipeline = pipeline
self.request = response.request
self.response = response
self.block_size = response.block_size
self.iter_content_func = self.response.internal_response.iter_content(self.block_size)
decompress = kwargs.pop("decompress", True)
if len(kwargs) > 0:
raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0]))
if decompress:
self.iter_content_func = self.response.internal_response.iter_content(self.block_size)
else:
self.iter_content_func = _read_raw_stream(self.response.internal_response, self.block_size)
self.content_length = int(response.headers.get('Content-Length', 0))

def __len__(self):
Expand Down Expand Up @@ -134,10 +164,10 @@ def __next__(self):
class RequestsTransportResponse(HttpResponse, _RequestsTransportResponseBase):
"""Streaming of data from the response.
"""
def stream_download(self, pipeline):
# type: (PipelineType) -> Iterator[bytes]
def stream_download(self, pipeline, **kwargs):
# type: (PipelineType, **Any) -> Iterator[bytes]
"""Generator for streaming request body data."""
return StreamDownloadGenerator(pipeline, self)
return StreamDownloadGenerator(pipeline, self, **kwargs)


class RequestsTransport(HttpTransport):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
AsyncHttpResponse,
_ResponseStopIteration,
_iterate_response_content)
from ._requests_basic import RequestsTransportResponse
from ._requests_basic import RequestsTransportResponse, _read_raw_stream
from ._base_requests_async import RequestsAsyncTransportBase


Expand All @@ -54,15 +54,22 @@ class TrioStreamDownloadGenerator(AsyncIterator):

:param pipeline: The pipeline object
:param response: The response object.
:keyword bool decompress: If True which is default, will attempt to decode the body based
on the ‘content-encoding’ header.
"""
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse) -> None:
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> None:
self.pipeline = pipeline
self.request = response.request
self.response = response
self.block_size = response.block_size
self.iter_content_func = self.response.internal_response.iter_content(self.block_size)
decompress = kwargs.pop("decompress", True)
if len(kwargs) > 0:
raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0]))
if decompress:
self.iter_content_func = self.response.internal_response.iter_content(self.block_size)
else:
self.iter_content_func = _read_raw_stream(self.response.internal_response, self.block_size)
self.content_length = int(response.headers.get('Content-Length', 0))
self.downloaded = 0

def __len__(self):
return self.content_length
Expand Down Expand Up @@ -95,10 +102,10 @@ async def __anext__(self):
class TrioRequestsTransportResponse(AsyncHttpResponse, RequestsTransportResponse): # type: ignore
"""Asynchronous streaming of data from the response.
"""
def stream_download(self, pipeline) -> AsyncIteratorType[bytes]: # type: ignore
def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]: # type: ignore
"""Generator for streaming response data.
"""
return TrioStreamDownloadGenerator(pipeline, self)
return TrioStreamDownloadGenerator(pipeline, self, **kwargs)


class TrioRequestsTransport(RequestsAsyncTransportBase): # type: ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,18 @@

@pytest.mark.asyncio
async def test_connection_error_response():
class MockSession(object):
def __init__(self):
self.auto_decompress = True

@property
def auto_decompress(self):
return self.auto_decompress

class MockTransport(AsyncHttpTransport):
def __init__(self):
self._count = 0
self.session = MockSession

async def __aexit__(self, exc_type, exc_val, exc_tb):
pass
Expand Down Expand Up @@ -60,7 +69,7 @@ async def __call__(self, *args, **kwargs):
pipeline = AsyncPipeline(MockTransport())
http_response = AsyncHttpResponse(http_request, None)
http_response.internal_response = MockInternalResponse()
stream = AioHttpStreamDownloadGenerator(pipeline, http_response)
stream = AioHttpStreamDownloadGenerator(pipeline, http_response, decompress=False)
with mock.patch('asyncio.sleep', new_callable=AsyncMock):
with pytest.raises(ConnectionError):
await stream.__anext__()
Expand All @@ -75,6 +84,8 @@ async def test_response_streaming_error_behavior():

class FakeStreamWithConnectionError:
# fake object for urllib3.response.HTTPResponse
def __init__(self):
self.total_response_size = 500

def stream(self, chunk_size, decode_content=False):
assert chunk_size == block_size
Expand All @@ -86,6 +97,15 @@ def stream(self, chunk_size, decode_content=False):
left -= len(data)
yield data

def read(self, chunk_size, decode_content=False):
assert chunk_size == block_size
if self.total_response_size > 0:
if self.total_response_size <= block_size:
raise requests.exceptions.ConnectionError()
data = b"X" * min(chunk_size, self.total_response_size)
self.total_response_size -= len(data)
return data

def close(self):
pass

Expand Down
Loading