diff --git a/.coveragerc b/.coveragerc index 0b5d5bf0ad4..7792266b114 100644 --- a/.coveragerc +++ b/.coveragerc @@ -6,3 +6,6 @@ omit = site-packages [report] exclude_also = if TYPE_CHECKING + assert False + : \.\.\.(\s*#.*)?$ + ^ +\.\.\.$ diff --git a/.github/workflows/ci-cd.yml b/.github/workflows/ci-cd.yml index f072a12aa34..93d4575da2d 100644 --- a/.github/workflows/ci-cd.yml +++ b/.github/workflows/ci-cd.yml @@ -45,7 +45,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v5 with: - python-version: 3.9 + python-version: 3.11 - name: Cache PyPI uses: actions/cache@v4.0.2 with: diff --git a/CHANGES/8634.misc.rst b/CHANGES/8634.misc.rst new file mode 100644 index 00000000000..cf4c68d5119 --- /dev/null +++ b/CHANGES/8634.misc.rst @@ -0,0 +1 @@ +Minor improvements to various type annotations -- by :user:`Dreamsorcerer`. diff --git a/aiohttp/client.py b/aiohttp/client.py index c70ad65c59e..1d4ccc0814a 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -9,7 +9,7 @@ import traceback import warnings from contextlib import suppress -from types import SimpleNamespace, TracebackType +from types import TracebackType from typing import ( TYPE_CHECKING, Any, @@ -155,7 +155,7 @@ class _RequestOptions(TypedDict, total=False): - params: Union[Mapping[str, str], None] + params: Union[Mapping[str, Union[str, int]], str, None] data: Any json: Any cookies: Union[LooseCookies, None] @@ -175,7 +175,7 @@ class _RequestOptions(TypedDict, total=False): ssl: Union[SSLContext, bool, Fingerprint] server_hostname: Union[str, None] proxy_headers: Union[LooseHeaders, None] - trace_request_ctx: Union[SimpleNamespace, None] + trace_request_ctx: Union[Mapping[str, str], None] read_bufsize: Union[int, None] auto_decompress: Union[bool, None] max_line_size: Union[int, None] @@ -422,11 +422,22 @@ def __del__(self, _warnings: Any = warnings) -> None: context["source_traceback"] = self._source_traceback self._loop.call_exception_handler(context) - def request( - self, method: str, url: StrOrURL, **kwargs: Any - ) -> "_RequestContextManager": - """Perform HTTP request.""" - return _RequestContextManager(self._request(method, url, **kwargs)) + if sys.version_info >= (3, 11) and TYPE_CHECKING: + + def request( + self, + method: str, + url: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> "_RequestContextManager": ... + + else: + + def request( + self, method: str, url: StrOrURL, **kwargs: Any + ) -> "_RequestContextManager": + """Perform HTTP request.""" + return _RequestContextManager(self._request(method, url, **kwargs)) def _build_url(self, str_or_url: StrOrURL) -> URL: url = URL(str_or_url) @@ -466,7 +477,7 @@ async def _request( ssl: Union[SSLContext, bool, Fingerprint] = True, server_hostname: Optional[str] = None, proxy_headers: Optional[LooseHeaders] = None, - trace_request_ctx: Optional[SimpleNamespace] = None, + trace_request_ctx: Optional[Mapping[str, str]] = None, read_bufsize: Optional[int] = None, auto_decompress: Optional[bool] = None, max_line_size: Optional[int] = None, diff --git a/aiohttp/client_exceptions.py b/aiohttp/client_exceptions.py index f15a9ee3d3e..ff29b3d3ca9 100644 --- a/aiohttp/client_exceptions.py +++ b/aiohttp/client_exceptions.py @@ -100,7 +100,7 @@ def __str__(self) -> str: return "{}, message={!r}, url={!r}".format( self.status, self.message, - self.request_info.real_url, + str(self.request_info.real_url), ) def __repr__(self) -> str: diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 37d14e107fd..2c10da4ff81 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -245,7 +245,8 @@ class ClientRequest: hdrs.ACCEPT_ENCODING: _gen_default_accept_encoding(), } - body = b"" + # Type of body depends on PAYLOAD_REGISTRY, which is dynamic. + body: Any = b"" auth = None response = None @@ -441,7 +442,7 @@ def update_headers(self, headers: Optional[LooseHeaders]) -> None: if headers: if isinstance(headers, (dict, MultiDictProxy, MultiDict)): - headers = headers.items() # type: ignore[assignment] + headers = headers.items() for key, value in headers: # type: ignore[misc] # A special case for Host header @@ -597,6 +598,10 @@ def update_proxy( raise ValueError("proxy_auth must be None or BasicAuth() tuple") self.proxy = proxy self.proxy_auth = proxy_auth + if proxy_headers is not None and not isinstance( + proxy_headers, (MultiDict, MultiDictProxy) + ): + proxy_headers = CIMultiDict(proxy_headers) self.proxy_headers = proxy_headers def keep_alive(self) -> bool: @@ -632,10 +637,10 @@ async def write_bytes( await self.body.write(writer) else: if isinstance(self.body, (bytes, bytearray)): - self.body = (self.body,) # type: ignore[assignment] + self.body = (self.body,) for chunk in self.body: - await writer.write(chunk) # type: ignore[arg-type] + await writer.write(chunk) except OSError as underlying_exc: reraised_exc = underlying_exc diff --git a/aiohttp/connector.py b/aiohttp/connector.py index cd89ea641d3..2e07395aece 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -23,6 +23,7 @@ List, Literal, Optional, + Sequence, Set, Tuple, Type, @@ -833,7 +834,7 @@ def clear_dns_cache( self._cached_hosts.clear() async def _resolve_host( - self, host: str, port: int, traces: Optional[List["Trace"]] = None + self, host: str, port: int, traces: Optional[Sequence["Trace"]] = None ) -> List[ResolveResult]: """Resolve host and return list of addresses.""" if is_ip_address(host): @@ -902,7 +903,7 @@ async def _resolve_host_with_throttle( key: Tuple[str, int], host: str, port: int, - traces: Optional[List["Trace"]], + traces: Optional[Sequence["Trace"]], ) -> List[ResolveResult]: """Resolve host with a dns events throttle.""" if key in self._throttle_dns_events: diff --git a/aiohttp/pytest_plugin.py b/aiohttp/pytest_plugin.py index 6225fdf2be0..c862b409566 100644 --- a/aiohttp/pytest_plugin.py +++ b/aiohttp/pytest_plugin.py @@ -2,7 +2,17 @@ import contextlib import inspect import warnings -from typing import Any, Awaitable, Callable, Dict, Iterator, Optional, Type, Union +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Iterator, + Optional, + Protocol, + Type, + Union, +) import pytest @@ -24,9 +34,23 @@ except ImportError: # pragma: no cover uvloop = None # type: ignore[assignment] -AiohttpClient = Callable[[Union[Application, BaseTestServer]], Awaitable[TestClient]] AiohttpRawServer = Callable[[Application], Awaitable[RawTestServer]] -AiohttpServer = Callable[[Application], Awaitable[TestServer]] + + +class AiohttpClient(Protocol): + def __call__( + self, + __param: Union[Application, BaseTestServer], + *, + server_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any + ) -> Awaitable[TestClient]: ... + + +class AiohttpServer(Protocol): + def __call__( + self, app: Application, *, port: Optional[int] = None, **kwargs: Any + ) -> Awaitable[TestServer]: ... def pytest_addoption(parser): # type: ignore[no-untyped-def] @@ -262,7 +286,9 @@ def aiohttp_server(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpServer]: """ servers = [] - async def go(app, *, port=None, **kwargs): # type: ignore[no-untyped-def] + async def go( + app: Application, *, port: Optional[int] = None, **kwargs: Any + ) -> TestServer: server = TestServer(app, port=port) await server.start_server(loop=loop, **kwargs) servers.append(server) diff --git a/aiohttp/test_utils.py b/aiohttp/test_utils.py index a36e8599689..97c1469dd2a 100644 --- a/aiohttp/test_utils.py +++ b/aiohttp/test_utils.py @@ -11,17 +11,7 @@ import warnings from abc import ABC, abstractmethod from types import TracebackType -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Iterator, - List, - Optional, - Type, - Union, - cast, -) +from typing import TYPE_CHECKING, Any, Callable, Iterator, List, Optional, Type, cast from unittest import IsolatedAsyncioTestCase, mock from aiosignal import Signal @@ -29,7 +19,11 @@ from yarl import URL import aiohttp -from aiohttp.client import _RequestContextManager, _WSRequestContextManager +from aiohttp.client import ( + _RequestContextManager, + _RequestOptions, + _WSRequestContextManager, +) from . import ClientSession, hdrs from .abc import AbstractCookieJar @@ -55,6 +49,9 @@ else: SSLContext = None +if sys.version_info >= (3, 11) and TYPE_CHECKING: + from typing import Unpack + REUSE_ADDRESS = os.name == "posix" and sys.platform != "cygwin" @@ -90,7 +87,7 @@ class BaseTestServer(ABC): def __init__( self, *, - scheme: Union[str, object] = sentinel, + scheme: str = "", loop: Optional[asyncio.AbstractEventLoop] = None, host: str = "127.0.0.1", port: Optional[int] = None, @@ -135,12 +132,8 @@ async def start_server( sockets = server.sockets # type: ignore[attr-defined] assert sockets is not None self.port = sockets[0].getsockname()[1] - if self.scheme is sentinel: - if self._ssl: - scheme = "https" - else: - scheme = "http" - self.scheme = scheme + if not self.scheme: + self.scheme = "https" if self._ssl else "http" self._root = URL(f"{self.scheme}://{self.host}:{self.port}") @abstractmethod # pragma: no cover @@ -222,7 +215,7 @@ def __init__( self, app: Application, *, - scheme: Union[str, object] = sentinel, + scheme: str = "", host: str = "127.0.0.1", port: Optional[int] = None, **kwargs: Any, @@ -239,7 +232,7 @@ def __init__( self, handler: _RequestHandler, *, - scheme: Union[str, object] = sentinel, + scheme: str = "", host: str = "127.0.0.1", port: Optional[int] = None, **kwargs: Any, @@ -324,45 +317,101 @@ async def _request( self._responses.append(resp) return resp - def request( - self, method: str, path: StrOrURL, **kwargs: Any - ) -> _RequestContextManager: - """Routes a request to tested http server. + if sys.version_info >= (3, 11) and TYPE_CHECKING: + + def request( + self, method: str, path: StrOrURL, **kwargs: Unpack[_RequestOptions] + ) -> _RequestContextManager: ... + + def get( + self, + path: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> _RequestContextManager: ... + + def options( + self, + path: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> _RequestContextManager: ... + + def head( + self, + path: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> _RequestContextManager: ... + + def post( + self, + path: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> _RequestContextManager: ... + + def put( + self, + path: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> _RequestContextManager: ... + + def patch( + self, + path: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> _RequestContextManager: ... + + def delete( + self, + path: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> _RequestContextManager: ... - The interface is identical to aiohttp.ClientSession.request, - except the loop kwarg is overridden by the instance used by the - test server. + else: - """ - return _RequestContextManager(self._request(method, path, **kwargs)) + def request( + self, method: str, path: StrOrURL, **kwargs: Any + ) -> _RequestContextManager: + """Routes a request to tested http server. - def get(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: - """Perform an HTTP GET request.""" - return _RequestContextManager(self._request(hdrs.METH_GET, path, **kwargs)) + The interface is identical to aiohttp.ClientSession.request, + except the loop kwarg is overridden by the instance used by the + test server. - def post(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: - """Perform an HTTP POST request.""" - return _RequestContextManager(self._request(hdrs.METH_POST, path, **kwargs)) + """ + return _RequestContextManager(self._request(method, path, **kwargs)) - def options(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: - """Perform an HTTP OPTIONS request.""" - return _RequestContextManager(self._request(hdrs.METH_OPTIONS, path, **kwargs)) + def get(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: + """Perform an HTTP GET request.""" + return _RequestContextManager(self._request(hdrs.METH_GET, path, **kwargs)) - def head(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: - """Perform an HTTP HEAD request.""" - return _RequestContextManager(self._request(hdrs.METH_HEAD, path, **kwargs)) + def post(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: + """Perform an HTTP POST request.""" + return _RequestContextManager(self._request(hdrs.METH_POST, path, **kwargs)) - def put(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: - """Perform an HTTP PUT request.""" - return _RequestContextManager(self._request(hdrs.METH_PUT, path, **kwargs)) + def options(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: + """Perform an HTTP OPTIONS request.""" + return _RequestContextManager( + self._request(hdrs.METH_OPTIONS, path, **kwargs) + ) + + def head(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: + """Perform an HTTP HEAD request.""" + return _RequestContextManager(self._request(hdrs.METH_HEAD, path, **kwargs)) - def patch(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: - """Perform an HTTP PATCH request.""" - return _RequestContextManager(self._request(hdrs.METH_PATCH, path, **kwargs)) + def put(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: + """Perform an HTTP PUT request.""" + return _RequestContextManager(self._request(hdrs.METH_PUT, path, **kwargs)) - def delete(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: - """Perform an HTTP PATCH request.""" - return _RequestContextManager(self._request(hdrs.METH_DELETE, path, **kwargs)) + def patch(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: + """Perform an HTTP PATCH request.""" + return _RequestContextManager( + self._request(hdrs.METH_PATCH, path, **kwargs) + ) + + def delete(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: + """Perform an HTTP PATCH request.""" + return _RequestContextManager( + self._request(hdrs.METH_DELETE, path, **kwargs) + ) def ws_connect(self, path: StrOrURL, **kwargs: Any) -> _WSRequestContextManager: """Initiate websocket connection. diff --git a/aiohttp/tracing.py b/aiohttp/tracing.py index 66007cbeb2c..012ed7bdaf6 100644 --- a/aiohttp/tracing.py +++ b/aiohttp/tracing.py @@ -1,5 +1,5 @@ from types import SimpleNamespace -from typing import TYPE_CHECKING, Awaitable, Optional, Protocol, Type, TypeVar +from typing import TYPE_CHECKING, Awaitable, Mapping, Optional, Protocol, Type, TypeVar import attr from aiosignal import Signal @@ -101,7 +101,7 @@ def __init__( self._trace_config_ctx_factory = trace_config_ctx_factory def trace_config_ctx( - self, trace_request_ctx: Optional[SimpleNamespace] = None + self, trace_request_ctx: Optional[Mapping[str, str]] = None ) -> SimpleNamespace: """Return a new trace_config_ctx instance""" return self._trace_config_ctx_factory(trace_request_ctx=trace_request_ctx) diff --git a/aiohttp/typedefs.py b/aiohttp/typedefs.py index 80dd26e80bd..9fb21c15f83 100644 --- a/aiohttp/typedefs.py +++ b/aiohttp/typedefs.py @@ -35,7 +35,13 @@ Byteish = Union[bytes, bytearray, memoryview] JSONEncoder = Callable[[Any], str] JSONDecoder = Callable[[str], Any] -LooseHeaders = Union[Mapping[Union[str, istr], str], _CIMultiDict, _CIMultiDictProxy] +LooseHeaders = Union[ + Mapping[str, str], + Mapping[istr, str], + _CIMultiDict, + _CIMultiDictProxy, + Iterable[Tuple[Union[str, istr], str]], +] RawHeaders = Tuple[Tuple[bytes, bytes], ...] StrOrURL = Union[str, URL] diff --git a/aiohttp/web_request.py b/aiohttp/web_request.py index d059a166884..28d9ef3d10b 100644 --- a/aiohttp/web_request.py +++ b/aiohttp/web_request.py @@ -239,7 +239,8 @@ def clone( # a copy semantic dct["headers"] = CIMultiDictProxy(CIMultiDict(headers)) dct["raw_headers"] = tuple( - (k.encode("utf-8"), v.encode("utf-8")) for k, v in headers.items() + (k.encode("utf-8"), v.encode("utf-8")) + for k, v in dct["headers"].items() ) message = self._message._replace(**dct) diff --git a/requirements/lint.in b/requirements/lint.in index 98910e21f0e..0d46809a083 100644 --- a/requirements/lint.in +++ b/requirements/lint.in @@ -1,8 +1,11 @@ aiodns aioredis +freezegun mypy; implementation_name == "cpython" pre-commit pytest +pytest-mock python-on-whales slotscheck +trustme uvloop; platform_system != "Windows" diff --git a/requirements/lint.txt b/requirements/lint.txt index 85b96964c05..97809fe3dde 100644 --- a/requirements/lint.txt +++ b/requirements/lint.txt @@ -30,6 +30,8 @@ exceptiongroup==1.1.2 # via pytest filelock==3.12.2 # via virtualenv +freezegun==1.5.1 + # via -r requirements/lint.in identify==2.5.26 # via pre-commit idna==3.7 @@ -66,6 +68,8 @@ pygments==2.17.2 # via rich pytest==8.3.2 # via -r requirements/lint.in +pytest-mock==3.14.0 + # via -r requirements/lint.in python-on-whales==0.72.0 # via -r requirements/lint.in pyyaml==6.0.1 @@ -85,6 +89,8 @@ tomli==2.0.1 # slotscheck tqdm==4.66.2 # via python-on-whales +trustme==1.1.0 + # via -r requirements/lint.in typer==0.12.3 # via python-on-whales typing-extensions==4.11.0