diff --git a/aiohttp/web_reqrep.py b/aiohttp/web_reqrep.py index 64cadc77d3e..b81ddb2e5c4 100644 --- a/aiohttp/web_reqrep.py +++ b/aiohttp/web_reqrep.py @@ -1,4 +1,4 @@ -__all__ = ('Request', 'StreamResponse', 'Response') +__all__ = ('ContentCoding', 'Request', 'StreamResponse', 'Response') import asyncio import binascii @@ -12,6 +12,11 @@ import time import warnings +try: + import enum +except ImportError: + from flufl import enum + from email.utils import parsedate from types import MappingProxyType from urllib.parse import urlsplit, parse_qsl, unquote @@ -72,6 +77,16 @@ def content_length(self, _CONTENT_LENGTH=hdrs.CONTENT_LENGTH): FileField = collections.namedtuple('Field', 'name filename file content_type') +class ContentCoding(enum.Enum): + # The content codings that we have support for. + # + # Additional registered codings are listed at: + # https://www.iana.org/assignments/http-parameters/http-parameters.xhtml#content-coding + deflate = 'deflate' + gzip = 'gzip' + identity = 'identity' + + ############################################################ # HTTP Request ############################################################ @@ -436,8 +451,12 @@ def enable_chunked_encoding(self, chunk_size=None): self._chunked = True self._chunk_size = chunk_size - def enable_compression(self, force=False): - """Enables response compression with `deflate` encoding.""" + def enable_compression(self, force=None): + """Enables response compression encoding.""" + # Backwards compatibility for when force was a bool <0.17. + if type(force) == bool: + force = ContentCoding.deflate if force else ContentCoding.identity + self._compression = True self._compression_force = force @@ -577,6 +596,22 @@ def _start_pre_check(self, request): else: return None + def _start_compression(self, request): + def start(coding): + if coding != ContentCoding.identity: + self.headers[hdrs.CONTENT_ENCODING] = coding.value + self._resp_impl.add_compression_filter(coding.value) + + if self._compression_force: + start(self._compression_force) + else: + accept_encoding = request.headers.get( + hdrs.ACCEPT_ENCODING, '').lower() + for coding in ContentCoding: + if coding.value in accept_encoding: + start(coding) + return + def start(self, request): resp_impl = self._start_pre_check(request) if resp_impl is not None: @@ -598,10 +633,7 @@ def start(self, request): self._copy_cookies() if self._compression: - if (self._compression_force or - 'deflate' in request.headers.get( - hdrs.ACCEPT_ENCODING, '')): - resp_impl.add_compression_filter() + self._start_compression(request) if self._chunked: resp_impl.enable_chunked_encoding() diff --git a/docs/web_reference.rst b/docs/web_reference.rst index 94af32a70f6..98b1092adf3 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -382,14 +382,15 @@ StreamResponse .. seealso:: :meth:`enable_compression` - .. method:: enable_compression(force=False) + .. method:: enable_compression(force=None) Enable compression. - When *force* is ``False`` (default) compression is used only - when *deflate* is in *Accept-Encoding* request's header. + When *force* is unset compression encoding is selected based on + the request's *Accept-Encoding* header. - *Accept-Encoding* is not checked if *force* is ``True``. + *Accept-Encoding* is not checked if *force* is set to a + :class:`ContentCoding`. .. versionadded:: 0.14 @@ -1217,3 +1218,17 @@ Utilities *MIME type* of uploaded file, ``'text/plain'`` by default. .. seealso:: :ref:`aiohttp-web-file-upload` + + +Constants +--------- + +.. class:: ContentCoding + + An :class:`enum.Enum` class of available Content Codings. + + .. attribute:: deflate + + .. attribute:: gzip + + .. attribute:: identity diff --git a/setup.py b/setup.py index e03069934cd..2f9a0177fd3 100644 --- a/setup.py +++ b/setup.py @@ -56,7 +56,7 @@ def build_extension(self, ext): install_requires = ['chardet'] if sys.version_info < (3, 4): - install_requires += ['asyncio'] + install_requires += ['asyncio', 'flufl.enum'] tests_require = install_requires + ['nose', 'gunicorn'] diff --git a/tests/test_web_response.py b/tests/test_web_response.py index 44537a135fa..585448d40aa 100644 --- a/tests/test_web_response.py +++ b/tests/test_web_response.py @@ -4,7 +4,7 @@ from unittest import mock from aiohttp import hdrs from aiohttp.multidict import CIMultiDict -from aiohttp.web import Request, StreamResponse, Response +from aiohttp.web import ContentCoding, Request, StreamResponse, Response from aiohttp.protocol import HttpVersion, HttpVersion11, HttpVersion10 from aiohttp.protocol import RawRequestMessage @@ -197,7 +197,7 @@ def test_compression_no_accept(self, ResponseImpl): self.assertFalse(msg.add_compression_filter.called) @mock.patch('aiohttp.web_reqrep.ResponseImpl') - def test_force_compression_no_accept(self, ResponseImpl): + def test_force_compression_no_accept_backwards_compat(self, ResponseImpl): req = self.make_request('GET', '/') resp = StreamResponse() self.assertFalse(resp.chunked) @@ -211,7 +211,19 @@ def test_force_compression_no_accept(self, ResponseImpl): self.assertIsNotNone(msg.filter) @mock.patch('aiohttp.web_reqrep.ResponseImpl') - def test_compression(self, ResponseImpl): + def test_force_compression_false_backwards_compat(self, ResponseImpl): + req = self.make_request('GET', '/') + resp = StreamResponse() + + self.assertFalse(resp.compression) + resp.enable_compression(force=False) + self.assertTrue(resp.compression) + + msg = resp.start(req) + self.assertFalse(msg.add_compression_filter.called) + + @mock.patch('aiohttp.web_reqrep.ResponseImpl') + def test_compression_default_coding(self, ResponseImpl): req = self.make_request( 'GET', '/', headers=CIMultiDict({hdrs.ACCEPT_ENCODING: 'gzip, deflate'})) @@ -223,9 +235,62 @@ def test_compression(self, ResponseImpl): self.assertTrue(resp.compression) msg = resp.start(req) - self.assertTrue(msg.add_compression_filter.called) + msg.add_compression_filter.assert_called_with('deflate') + self.assertEqual('deflate', resp.headers.get(hdrs.CONTENT_ENCODING)) self.assertIsNotNone(msg.filter) + @mock.patch('aiohttp.web_reqrep.ResponseImpl') + def test_force_compression_deflate(self, ResponseImpl): + req = self.make_request( + 'GET', '/', + headers=CIMultiDict({hdrs.ACCEPT_ENCODING: 'gzip, deflate'})) + resp = StreamResponse() + + resp.enable_compression(ContentCoding.deflate) + self.assertTrue(resp.compression) + + msg = resp.start(req) + msg.add_compression_filter.assert_called_with('deflate') + self.assertEqual('deflate', resp.headers.get(hdrs.CONTENT_ENCODING)) + + @mock.patch('aiohttp.web_reqrep.ResponseImpl') + def test_force_compression_no_accept_deflate(self, ResponseImpl): + req = self.make_request('GET', '/') + resp = StreamResponse() + + resp.enable_compression(ContentCoding.deflate) + self.assertTrue(resp.compression) + + msg = resp.start(req) + msg.add_compression_filter.assert_called_with('deflate') + self.assertEqual('deflate', resp.headers.get(hdrs.CONTENT_ENCODING)) + + @mock.patch('aiohttp.web_reqrep.ResponseImpl') + def test_force_compression_gzip(self, ResponseImpl): + req = self.make_request( + 'GET', '/', + headers=CIMultiDict({hdrs.ACCEPT_ENCODING: 'gzip, deflate'})) + resp = StreamResponse() + + resp.enable_compression(ContentCoding.gzip) + self.assertTrue(resp.compression) + + msg = resp.start(req) + msg.add_compression_filter.assert_called_with('gzip') + self.assertEqual('gzip', resp.headers.get(hdrs.CONTENT_ENCODING)) + + @mock.patch('aiohttp.web_reqrep.ResponseImpl') + def test_force_compression_no_accept_gzip(self, ResponseImpl): + req = self.make_request('GET', '/') + resp = StreamResponse() + + resp.enable_compression(ContentCoding.gzip) + self.assertTrue(resp.compression) + + msg = resp.start(req) + msg.add_compression_filter.assert_called_with('gzip') + self.assertEqual('gzip', resp.headers.get(hdrs.CONTENT_ENCODING)) + def test_write_non_byteish(self): resp = StreamResponse() resp.start(self.make_request('GET', '/'))