diff --git a/CHANGES/8977.bugfix.rst b/CHANGES/8977.bugfix.rst new file mode 100644 index 0000000000..7d21fe0c3f --- /dev/null +++ b/CHANGES/8977.bugfix.rst @@ -0,0 +1 @@ +Made ``TestClient.app`` a ``Generic`` so type checkers will know the correct type (avoiding unneeded ``client.app is not None`` checks) -- by :user:`Dreamsorcerer`. diff --git a/CHANGES/9018.bugfix.rst b/CHANGES/9018.bugfix.rst new file mode 100644 index 0000000000..2de6d14290 --- /dev/null +++ b/CHANGES/9018.bugfix.rst @@ -0,0 +1 @@ +Updated Python parser to reject messages after a close message, matching C parser behaviour -- by :user:`Dreamsorcerer`. diff --git a/CHANGES/9033.misc.rst b/CHANGES/9033.misc.rst new file mode 100644 index 0000000000..07a017ffdd --- /dev/null +++ b/CHANGES/9033.misc.rst @@ -0,0 +1 @@ +Changed web entry point to not listen on TCP when only a Unix path is passed -- by :user:`Dreamsorcerer`. diff --git a/CHANGES/9063.bugfix.rst b/CHANGES/9063.bugfix.rst new file mode 100644 index 0000000000..e512677b9c --- /dev/null +++ b/CHANGES/9063.bugfix.rst @@ -0,0 +1 @@ +Fixed ``If-None-Match`` not using weak comparison -- by :user:`Dreamsorcerer`. diff --git a/CHANGES/9109.breaking.rst b/CHANGES/9109.breaking.rst new file mode 100644 index 0000000000..ecbce187c9 --- /dev/null +++ b/CHANGES/9109.breaking.rst @@ -0,0 +1 @@ +Changed default value to ``compress`` from ``None`` to ``False`` (``None`` is no longer an expected value) -- by :user:`Dreamsorcerer`. diff --git a/CHANGES/9200.breaking.rst b/CHANGES/9200.breaking.rst new file mode 100644 index 0000000000..0282e165c4 --- /dev/null +++ b/CHANGES/9200.breaking.rst @@ -0,0 +1,3 @@ +Improved middleware performance -- by :user:`bdraco`. + +The ``set_current_app`` method was removed from ``UrlMappingMatchInfo`` because it is no longer used, and it was unlikely external caller would ever use it. diff --git a/CHANGES/9204.misc.rst b/CHANGES/9204.misc.rst new file mode 100644 index 0000000000..da12a7df6f --- /dev/null +++ b/CHANGES/9204.misc.rst @@ -0,0 +1 @@ +Significantly speed up filtering cookies -- by :user:`bdraco`. diff --git a/CHANGES/9241.misc.rst b/CHANGES/9241.misc.rst new file mode 120000 index 0000000000..d6a2f2aaaa --- /dev/null +++ b/CHANGES/9241.misc.rst @@ -0,0 +1 @@ +9174.misc.rst \ No newline at end of file diff --git a/aiohttp/client.py b/aiohttp/client.py index 9c2fd8073a..26af65cb7e 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -173,7 +173,7 @@ class _RequestOptions(TypedDict, total=False): auth: Union[BasicAuth, None] allow_redirects: bool max_redirects: int - compress: Union[str, bool, None] + compress: Union[str, bool] chunked: Union[bool, None] expect100: bool raise_for_status: Union[None, bool, Callable[[ClientResponse], Awaitable[None]]] @@ -418,7 +418,7 @@ async def _request( auth: Optional[BasicAuth] = None, allow_redirects: bool = True, max_redirects: int = 10, - compress: Union[str, bool, None] = None, + compress: Union[str, bool] = False, chunked: Optional[bool] = None, expect100: bool = False, raise_for_status: Union[ @@ -1372,7 +1372,7 @@ def request( auth: Optional[BasicAuth] = None, allow_redirects: bool = True, max_redirects: int = 10, - compress: Optional[str] = None, + compress: Union[str, bool] = False, chunked: Optional[bool] = None, expect100: bool = False, raise_for_status: Optional[bool] = None, diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index e961776288..00bed0e8f8 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -203,7 +203,7 @@ def __init__( cookies: Optional[LooseCookies] = None, auth: Optional[BasicAuth] = None, version: http.HttpVersion = http.HttpVersion11, - compress: Union[str, bool, None] = None, + compress: Union[str, bool] = False, chunked: Optional[bool] = None, expect100: bool = False, loop: asyncio.AbstractEventLoop, @@ -235,7 +235,6 @@ def __init__( self.url = url.with_fragment(None) self.method = method.upper() self.chunked = chunked - self.compress = compress self.loop = loop self.length = None if response_class is None: @@ -255,7 +254,7 @@ def __init__( self.update_headers(headers) self.update_auto_headers(skip_auto_headers) self.update_cookies(cookies) - self.update_content_encoding(data) + self.update_content_encoding(data, compress) self.update_auth(auth, trust_env) self.update_proxy(proxy, proxy_auth, proxy_headers) @@ -422,22 +421,19 @@ def update_cookies(self, cookies: Optional[LooseCookies]) -> None: self.headers[hdrs.COOKIE] = c.output(header="", sep=";").strip() - def update_content_encoding(self, data: Any) -> None: + def update_content_encoding(self, data: Any, compress: Union[bool, str]) -> None: """Set request content encoding.""" + self.compress = None if not data: - # Don't compress an empty body. - self.compress = None return - enc = self.headers.get(hdrs.CONTENT_ENCODING, "").lower() - if enc: - if self.compress: + if self.headers.get(hdrs.CONTENT_ENCODING): + if compress: raise ValueError( "compress can not be set if Content-Encoding header is set" ) - elif self.compress: - if not isinstance(self.compress, str): - self.compress = "deflate" + elif compress: + self.compress = compress if isinstance(compress, str) else "deflate" self.headers[hdrs.CONTENT_ENCODING] = self.compress self.chunked = True # enable chunked, no need to deal with length @@ -640,7 +636,7 @@ async def send(self, conn: "Connection") -> "ClientResponse": ) if self.compress: - writer.enable_compression(self.compress) # type: ignore[arg-type] + writer.enable_compression(self.compress) if self.chunked is not None: writer.enable_chunking() diff --git a/aiohttp/cookiejar.py b/aiohttp/cookiejar.py index 9b7f38d668..85fd7716b5 100644 --- a/aiohttp/cookiejar.py +++ b/aiohttp/cookiejar.py @@ -94,6 +94,9 @@ def __init__( self._cookies: DefaultDict[Tuple[str, str], SimpleCookie] = defaultdict( SimpleCookie ) + self._morsel_cache: DefaultDict[Tuple[str, str], Dict[str, Morsel[str]]] = ( + defaultdict(dict) + ) self._host_only_cookies: Set[Tuple[str, str]] = set() self._unsafe = unsafe self._quote_cookie = quote_cookie @@ -129,6 +132,7 @@ def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None: if predicate is None: self._expire_heap.clear() self._cookies.clear() + self._morsel_cache.clear() self._host_only_cookies.clear() self._expirations.clear() return @@ -210,6 +214,7 @@ def _delete_cookies(self, to_del: List[Tuple[str, str, str]]) -> None: for domain, path, name in to_del: self._host_only_cookies.discard((domain, name)) self._cookies[(domain, path)].pop(name, None) + self._morsel_cache[(domain, path)].pop(name, None) self._expirations.pop((domain, path, name), None) def _expire_cookie(self, when: float, domain: str, path: str, name: str) -> None: @@ -285,7 +290,12 @@ def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> No else: cookie["expires"] = "" - self._cookies[(domain, path)][name] = cookie + key = (domain, path) + if self._cookies[key].get(name) != cookie: + # Don't blow away the cache if the same + # cookie gets set again + self._cookies[key][name] = cookie + self._morsel_cache[key].pop(name, None) self._do_expiration() @@ -337,30 +347,33 @@ def filter_cookies(self, request_url: URL) -> "BaseCookie[str]": # Create every combination of (domain, path) pairs. pairs = itertools.product(domains, paths) - # Point 2: https://www.rfc-editor.org/rfc/rfc6265.html#section-5.4 - cookies = itertools.chain.from_iterable( - self._cookies[p].values() for p in pairs - ) path_len = len(request_url.path) - for cookie in cookies: - name = cookie.key - domain = cookie["domain"] + # Point 2: https://www.rfc-editor.org/rfc/rfc6265.html#section-5.4 + for p in pairs: + for name, cookie in self._cookies[p].items(): + domain = cookie["domain"] - if (domain, name) in self._host_only_cookies and domain != hostname: - continue + if (domain, name) in self._host_only_cookies and domain != hostname: + continue - # Skip edge case when the cookie has a trailing slash but request doesn't. - if len(cookie["path"]) > path_len: - continue + # Skip edge case when the cookie has a trailing slash but request doesn't. + if len(cookie["path"]) > path_len: + continue - if is_not_secure and cookie["secure"]: - continue + if is_not_secure and cookie["secure"]: + continue + + # We already built the Morsel so reuse it here + if name in self._morsel_cache[p]: + filtered[name] = self._morsel_cache[p][name] + continue - # It's critical we use the Morsel so the coded_value - # (based on cookie version) is preserved - mrsl_val = cast("Morsel[str]", cookie.get(cookie.key, Morsel())) - mrsl_val.set(cookie.key, cookie.value, cookie.coded_value) - filtered[name] = mrsl_val + # It's critical we use the Morsel so the coded_value + # (based on cookie version) is preserved + mrsl_val = cast("Morsel[str]", cookie.get(cookie.key, Morsel())) + mrsl_val.set(cookie.key, cookie.value, cookie.coded_value) + self._morsel_cache[p][name] = mrsl_val + filtered[name] = mrsl_val return filtered diff --git a/aiohttp/http_parser.py b/aiohttp/http_parser.py index a79710414a..024e53f4d4 100644 --- a/aiohttp/http_parser.py +++ b/aiohttp/http_parser.py @@ -306,6 +306,7 @@ def feed_data( start_pos = 0 loop = self.loop + should_close = False while start_pos < data_len: # read HTTP message (request/response line + headers), \r\n\r\n # and split by lines @@ -317,6 +318,9 @@ def feed_data( continue if pos >= start_pos: + if should_close: + raise BadHttpMessage("Data after `Connection: close`") + # line found line = data[start_pos:pos] if SEP == b"\n": # For lax response parsing @@ -426,6 +430,7 @@ def get_content_length() -> Optional[int]: payload = EMPTY_PAYLOAD messages.append((msg, payload)) + should_close = msg.should_close else: self._tail = data[start_pos:] data = EMPTY diff --git a/aiohttp/http_writer.py b/aiohttp/http_writer.py index f54fa0f077..dc07a358c7 100644 --- a/aiohttp/http_writer.py +++ b/aiohttp/http_writer.py @@ -71,8 +71,8 @@ def _write(self, chunk: bytes) -> None: size = len(chunk) self.buffer_size += size self.output_size += size - transport = self.transport - if not self._protocol.connected or transport is None or transport.is_closing(): + transport = self._protocol.transport + if transport is None or transport.is_closing(): raise ClientConnectionResetError("Cannot write to closing transport") transport.write(chunk) diff --git a/aiohttp/pytest_plugin.py b/aiohttp/pytest_plugin.py index f5cdc8b186..503bb6a902 100644 --- a/aiohttp/pytest_plugin.py +++ b/aiohttp/pytest_plugin.py @@ -47,7 +47,7 @@ async def __call__( *, server_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any - ) -> TestClient[Request]: ... + ) -> TestClient[Request, Application]: ... @overload async def __call__( self, @@ -55,7 +55,7 @@ async def __call__( *, server_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any - ) -> TestClient[_Request]: ... + ) -> TestClient[_Request, None]: ... class AiohttpServer(Protocol): @@ -349,7 +349,7 @@ async def finalize() -> None: @pytest.fixture -def aiohttp_client_cls() -> Type[TestClient[Any]]: +def aiohttp_client_cls() -> Type[TestClient[Any, Any]]: """ Client class to use in ``aiohttp_client`` factory. @@ -377,7 +377,7 @@ def test_login(aiohttp_client): @pytest.fixture def aiohttp_client( - loop: asyncio.AbstractEventLoop, aiohttp_client_cls: Type[TestClient[Any]] + loop: asyncio.AbstractEventLoop, aiohttp_client_cls: Type[TestClient[Any, Any]] ) -> Iterator[AiohttpClient]: """Factory to create a TestClient instance. @@ -393,20 +393,20 @@ async def go( *, server_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any - ) -> TestClient[Request]: ... + ) -> TestClient[Request, Application]: ... @overload async def go( __param: BaseTestServer[_Request], *, server_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any - ) -> TestClient[_Request]: ... + ) -> TestClient[_Request, None]: ... async def go( __param: Union[Application, BaseTestServer[Any]], *, server_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any - ) -> TestClient[Any]: + ) -> TestClient[Any, Any]: if isinstance(__param, Application): server_kwargs = server_kwargs or {} server = TestServer(__param, **server_kwargs) diff --git a/aiohttp/test_utils.py b/aiohttp/test_utils.py index b02b666453..88dcf8ebf1 100644 --- a/aiohttp/test_utils.py +++ b/aiohttp/test_utils.py @@ -23,6 +23,7 @@ TypeVar, Union, cast, + overload, ) from unittest import IsolatedAsyncioTestCase, mock @@ -72,6 +73,7 @@ else: Self = Any +_ApplicationNone = TypeVar("_ApplicationNone", Application, None) _Request = TypeVar("_Request", bound=BaseRequest) REUSE_ADDRESS = os.name == "posix" and sys.platform != "cygwin" @@ -251,7 +253,7 @@ async def _make_runner(self, **kwargs: Any) -> ServerRunner: return ServerRunner(srv, **kwargs) -class TestClient(Generic[_Request]): +class TestClient(Generic[_Request, _ApplicationNone]): """ A test client implementation. @@ -261,7 +263,23 @@ class TestClient(Generic[_Request]): __test__ = False + @overload def __init__( + self: "TestClient[Request, Application]", + server: TestServer, + *, + cookie_jar: Optional[AbstractCookieJar] = None, + **kwargs: Any, + ) -> None: ... + @overload + def __init__( + self: "TestClient[_Request, None]", + server: BaseTestServer[_Request], + *, + cookie_jar: Optional[AbstractCookieJar] = None, + **kwargs: Any, + ) -> None: ... + def __init__( # type: ignore[misc] self, server: BaseTestServer[_Request], *, @@ -300,8 +318,8 @@ def server(self) -> BaseTestServer[_Request]: return self._server @property - def app(self) -> Optional[Application]: - return cast(Optional[Application], getattr(self._server, "app", None)) + def app(self) -> _ApplicationNone: + return getattr(self._server, "app", None) # type: ignore[return-value] @property def session(self) -> ClientSession: @@ -505,7 +523,7 @@ async def get_server(self, app: Application) -> TestServer: """Return a TestServer instance.""" return TestServer(app) - async def get_client(self, server: TestServer) -> TestClient[Request]: + async def get_client(self, server: TestServer) -> TestClient[Request, Application]: """Return a TestClient instance.""" return TestClient(server) diff --git a/aiohttp/web.py b/aiohttp/web.py index 650fe3417d..39b9b6bfde 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -528,21 +528,21 @@ def main(argv: List[str]) -> None: arg_parser.add_argument( "-H", "--hostname", - help="TCP/IP hostname to serve on (default: %(default)r)", - default="localhost", + help="TCP/IP hostname to serve on (default: localhost)", + default=None, ) arg_parser.add_argument( "-P", "--port", help="TCP/IP port to serve on (default: %(default)r)", type=int, - default="8080", + default=8080, ) arg_parser.add_argument( "-U", "--path", - help="Unix file system path to serve on. Specifying a path will cause " - "hostname and port arguments to be ignored.", + help="Unix file system path to serve on. Can be combined with hostname " + "to serve on both Unix and TCP.", ) args, extra_argv = arg_parser.parse_known_args(argv) @@ -569,8 +569,14 @@ def main(argv: List[str]) -> None: logging.basicConfig(level=logging.DEBUG) + if args.path and args.hostname is None: + host = port = None + else: + host = args.hostname or "localhost" + port = args.port + app = func(extra_argv) - run_app(app, host=args.hostname, port=args.port, path=args.path) + run_app(app, host=host, port=port, path=args.path) arg_parser.exit(message="Stopped\n") diff --git a/aiohttp/web_app.py b/aiohttp/web_app.py index febe51de2a..96b1b93cac 100644 --- a/aiohttp/web_app.py +++ b/aiohttp/web_app.py @@ -386,22 +386,20 @@ async def _handle(self, request: Request) -> StreamResponse: match_info.add_app(self) match_info.freeze() - resp = None request._match_info = match_info - expect = request.headers.get(hdrs.EXPECT) - if expect: + + if request.headers.get(hdrs.EXPECT): resp = await match_info.expect_handler(request) await request.writer.drain() + if resp is not None: + return resp - if resp is None: - handler = match_info.handler - - if self._run_middlewares: - handler = _build_middlewares(handler, match_info.apps) + handler = match_info.handler - resp = await handler(request) + if self._run_middlewares: + handler = _build_middlewares(handler, match_info.apps) - return resp + return await handler(request) def __call__(self) -> "Application": """gunicorn compatibility""" diff --git a/aiohttp/web_fileresponse.py b/aiohttp/web_fileresponse.py index 8c49af6b78..df6cb7c8b7 100644 --- a/aiohttp/web_fileresponse.py +++ b/aiohttp/web_fileresponse.py @@ -129,10 +129,12 @@ async def _sendfile( return writer @staticmethod - def _strong_etag_match(etag_value: str, etags: Tuple[ETag, ...]) -> bool: + def _etag_match(etag_value: str, etags: Tuple[ETag, ...], *, weak: bool) -> bool: if len(etags) == 1 and etags[0].value == ETAG_ANY: return True - return any(etag.value == etag_value for etag in etags if not etag.is_weak) + return any( + etag.value == etag_value for etag in etags if weak or not etag.is_weak + ) async def _not_modified( self, request: "BaseRequest", etag_value: str, last_modified: float @@ -201,9 +203,11 @@ async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter etag_value = f"{st.st_mtime_ns:x}-{st.st_size:x}" last_modified = st.st_mtime - # https://tools.ietf.org/html/rfc7232#section-6 + # https://www.rfc-editor.org/rfc/rfc9110#section-13.1.1-2 ifmatch = request.if_match - if ifmatch is not None and not self._strong_etag_match(etag_value, ifmatch): + if ifmatch is not None and not self._etag_match( + etag_value, ifmatch, weak=False + ): return await self._precondition_failed(request) unmodsince = request.if_unmodified_since @@ -214,8 +218,11 @@ async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter ): return await self._precondition_failed(request) + # https://www.rfc-editor.org/rfc/rfc9110#section-13.1.2-2 ifnonematch = request.if_none_match - if ifnonematch is not None and self._strong_etag_match(etag_value, ifnonematch): + if ifnonematch is not None and self._etag_match( + etag_value, ifnonematch, weak=True + ): return await self._not_modified(request, etag_value, last_modified) modsince = request.if_modified_since diff --git a/aiohttp/web_middlewares.py b/aiohttp/web_middlewares.py index 922dee3f7a..22e63f872c 100644 --- a/aiohttp/web_middlewares.py +++ b/aiohttp/web_middlewares.py @@ -115,7 +115,12 @@ async def impl(request: Request, handler: Handler) -> StreamResponse: def _fix_request_current_app(app: "Application") -> Middleware: async def impl(request: Request, handler: Handler) -> StreamResponse: - with request.match_info.set_current_app(app): + match_info = request.match_info + prev = match_info.current_app + match_info.current_app = app + try: return await handler(request) + finally: + match_info.current_app = prev return impl diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index 15b5a82974..876985fb0e 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -28,7 +28,7 @@ from .abc import AbstractAccessLogger, AbstractAsyncAccessLogger, AbstractStreamWriter from .base_protocol import BaseProtocol -from .helpers import ceil_timeout, set_exception +from .helpers import ceil_timeout from .http import ( HttpProcessingError, HttpRequestParser, @@ -90,6 +90,9 @@ class PayloadAccessError(Exception): """Payload was accessed after response was sent.""" +_PAYLOAD_ACCESS_ERROR = PayloadAccessError() + + class AccessLoggerWrapper(AbstractAsyncAccessLogger): """Wrap an AbstractAccessLogger so it behaves like an AbstractAsyncAccessLogger.""" @@ -617,7 +620,8 @@ async def start(self) -> None: self.log_debug("Uncompleted request.") self.close() - set_exception(payload, PayloadAccessError()) + payload.set_exception(_PAYLOAD_ACCESS_ERROR) + except asyncio.CancelledError: self.log_debug("Ignored premature client disconnection") raise diff --git a/aiohttp/web_urldispatcher.py b/aiohttp/web_urldispatcher.py index c0ef3ac796..ce6f17e736 100644 --- a/aiohttp/web_urldispatcher.py +++ b/aiohttp/web_urldispatcher.py @@ -8,7 +8,6 @@ import os import re import sys -from contextlib import contextmanager from pathlib import Path from types import MappingProxyType from typing import ( @@ -271,8 +270,8 @@ def current_app(self) -> "Application": assert app is not None return app - @contextmanager - def set_current_app(self, app: "Application") -> Generator[None, None, None]: + @current_app.setter + def current_app(self, app: "Application") -> None: if DEBUG: # pragma: no cover if app not in self._apps: raise RuntimeError( @@ -280,12 +279,7 @@ def set_current_app(self, app: "Application") -> Generator[None, None, None]: self._apps, app ) ) - prev = self._current_app self._current_app = app - try: - yield - finally: - self._current_app = prev def freeze(self) -> None: self._frozen = True diff --git a/docs/web_quickstart.rst b/docs/web_quickstart.rst index a47ec771da..5c565dfc5a 100644 --- a/docs/web_quickstart.rst +++ b/docs/web_quickstart.rst @@ -90,6 +90,10 @@ accepts a list of any non-parsed command-line arguments and returns an return app +.. note:: + For local development we typically recommend using + `aiohttp-devtools `_. + .. _aiohttp-web-handler: Handler diff --git a/docs/web_reference.rst b/docs/web_reference.rst index dea27c362d..72305130fd 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -1740,7 +1740,10 @@ Application and Router Use :meth:`add_static` for development only. In production, static content should be processed by web servers like *nginx* - or *apache*. + or *apache*. Such web servers will be able to provide significantly + better performance and security for static assets. Several past security + vulnerabilities in aiohttp only affected applications using + :meth:`add_static`. :param str prefix: URL path prefix for handled static files diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 437936e97b..cdd573e8f3 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -3795,7 +3795,7 @@ async def handler(request: web.Request) -> web.Response: await resp.write(b"1" * 1000) await asyncio.sleep(0.01) - async def request(client: TestClient[web.Request]) -> None: + async def request(client: TestClient[web.Request, web.Application]) -> None: timeout = aiohttp.ClientTimeout(total=0.5) async with await client.get("/", timeout=timeout) as resp: with pytest.raises(asyncio.TimeoutError): diff --git a/tests/test_client_request.py b/tests/test_client_request.py index 4ba08e96c8..433f174950 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -959,7 +959,7 @@ async def test_precompressed_data_stays_intact(loop: asyncio.AbstractEventLoop) URL("http://python.org/"), data=data, headers={"CONTENT-ENCODING": "deflate"}, - compress=None, + compress=False, loop=loop, ) assert not req.compress diff --git a/tests/test_http_parser.py b/tests/test_http_parser.py index 23ad9bf775..96a3f11d1c 100644 --- a/tests/test_http_parser.py +++ b/tests/test_http_parser.py @@ -851,6 +851,14 @@ def test_http_request_bad_status_line_whitespace(parser: HttpRequestParser) -> N parser.feed_data(text) +def test_http_request_message_after_close(parser: HttpRequestParser) -> None: + text = b"GET / HTTP/1.1\r\nConnection: close\r\n\r\nInvalid\r\n\r\n" + with pytest.raises( + http_exceptions.BadHttpMessage, match="Data after `Connection: close`" + ): + parser.feed_data(text) + + def test_http_request_upgrade(parser: HttpRequestParser) -> None: text = ( b"GET /test HTTP/1.1\r\n" diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py index 92cc1a61ed..718b95b651 100644 --- a/tests/test_test_utils.py +++ b/tests/test_test_utils.py @@ -1,6 +1,7 @@ import asyncio import gzip import socket +import sys from typing import Callable, Iterator, Mapping, NoReturn from unittest import mock @@ -20,7 +21,10 @@ make_mocked_request, ) -_TestClient = TestClient[web.Request] +if sys.version_info >= (3, 11): + from typing import assert_type + +_TestClient = TestClient[web.Request, web.Application] _hello_world_str = "Hello, world" _hello_world_bytes = _hello_world_str.encode("utf-8") @@ -71,7 +75,7 @@ def app() -> web.Application: def test_client( loop: asyncio.AbstractEventLoop, app: web.Application ) -> Iterator[_TestClient]: - async def make_client() -> TestClient[web.Request]: + async def make_client() -> TestClient[web.Request, web.Application]: return TestClient(TestServer(app)) client = loop.run_until_complete(make_client()) @@ -239,6 +243,8 @@ async def test_test_client_props() -> None: async with client: assert isinstance(client.port, int) assert client.server is not None + if sys.version_info >= (3, 11): + assert_type(client.app, web.Application) assert client.app is not None assert client.port == 0 @@ -255,6 +261,8 @@ async def hello(request: web.BaseRequest) -> NoReturn: async with client: assert isinstance(client.port, int) assert client.server is not None + if sys.version_info >= (3, 11): + assert_type(client.app, None) assert client.app is None assert client.port == 0 @@ -271,7 +279,7 @@ async def test_test_server_context_manager(loop: asyncio.AbstractEventLoop) -> N def test_client_unsupported_arg() -> None: with pytest.raises(TypeError) as e: - TestClient("string") # type: ignore[arg-type] + TestClient("string") # type: ignore[call-overload] assert ( str(e.value) == "server must be TestServer instance, found type: " diff --git a/tests/test_web_cli.py b/tests/test_web_cli.py index 0f87cb8c65..7728245d64 100644 --- a/tests/test_web_cli.py +++ b/tests/test_web_cli.py @@ -1,4 +1,6 @@ +import sys from typing import Any +from unittest import mock import pytest @@ -82,6 +84,32 @@ def test_entry_func_non_existent_attribute(mocker: Any) -> None: ) +@pytest.mark.skipif(sys.platform.startswith("win32"), reason="Windows not Unix") +def test_path_no_host(mocker: Any, monkeypatch: Any) -> None: + argv = "--path=test_path.sock alpha.beta:func".split() + mocker.patch("aiohttp.web.import_module") + + run_app = mocker.patch("aiohttp.web.run_app") + with pytest.raises(SystemExit): + web.main(argv) + + run_app.assert_called_with(mock.ANY, path="test_path.sock", host=None, port=None) + + +@pytest.mark.skipif(sys.platform.startswith("win32"), reason="Windows not Unix") +def test_path_and_host(mocker: Any, monkeypatch: Any) -> None: + argv = "--path=test_path.sock --host=localhost --port=8000 alpha.beta:func".split() + mocker.patch("aiohttp.web.import_module") + + run_app = mocker.patch("aiohttp.web.run_app") + with pytest.raises(SystemExit): + web.main(argv) + + run_app.assert_called_with( + mock.ANY, path="test_path.sock", host="localhost", port=8000 + ) + + def test_path_when_unsupported(mocker: Any, monkeypatch: Any) -> None: argv = "--path=test_path.sock alpha.beta:func".split() mocker.patch("aiohttp.web.import_module") diff --git a/tests/test_web_middleware.py b/tests/test_web_middleware.py index 3159a9d6f3..7a7b2c8e91 100644 --- a/tests/test_web_middleware.py +++ b/tests/test_web_middleware.py @@ -9,7 +9,9 @@ from aiohttp.test_utils import TestClient from aiohttp.typedefs import Handler, Middleware -CLI = Callable[[Iterable[Middleware]], Awaitable[TestClient[web.Request]]] +CLI = Callable[ + [Iterable[Middleware]], Awaitable[TestClient[web.Request, web.Application]] +] async def test_middleware_modifies_response( @@ -169,7 +171,7 @@ async def handler(request: web.Request) -> web.Response: def wrapper( extra_middlewares: Iterable[Middleware], - ) -> Awaitable[TestClient[web.Request]]: + ) -> Awaitable[TestClient[web.Request, web.Application]]: app = web.Application() app.router.add_route("GET", "/resource1", handler) app.router.add_route("GET", "/resource2/", handler) diff --git a/tests/test_web_sendfile_functional.py b/tests/test_web_sendfile_functional.py index 7f8ae587ba..f55d329c36 100644 --- a/tests/test_web_sendfile_functional.py +++ b/tests/test_web_sendfile_functional.py @@ -505,10 +505,9 @@ async def test_static_file_if_none_match( resp = await client.get("/") assert 200 == resp.status - original_etag = resp.headers.get("ETag") + original_etag = resp.headers["ETag"] assert resp.headers.get("Last-Modified") is not None - assert original_etag is not None resp.close() resp.release() @@ -547,6 +546,39 @@ async def test_static_file_if_none_match_star( await client.close() +@pytest.mark.parametrize("if_modified_since", ("", "Fri, 31 Dec 9999 23:59:59 GMT")) +async def test_static_file_if_none_match_weak( + aiohttp_client: Any, + app_with_static_route: web.Application, + if_modified_since: str, +) -> None: + client = await aiohttp_client(app_with_static_route) + + resp = await client.get("/") + assert 200 == resp.status + original_etag = resp.headers["ETag"] + + assert resp.headers.get("Last-Modified") is not None + resp.close() + resp.release() + + weak_etag = f"W/{original_etag}" + + resp = await client.get( + "/", + headers={"If-None-Match": weak_etag, "If-Modified-Since": if_modified_since}, + ) + body = await resp.read() + assert 304 == resp.status + assert resp.headers.get("Content-Length") is None + assert resp.headers.get("ETag") == original_etag + assert b"" == body + resp.close() + resp.release() + + await client.close() + + @pytest.mark.skipif(not ssl, reason="ssl not supported") async def test_static_file_ssl( aiohttp_server: Any,