Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Use ParamSpec in a few places #12667

Merged
merged 13 commits into from
May 9, 2022
1 change: 1 addition & 0 deletions changelog.d/12667.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Use `ParamSpec` to refine type hints.
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@ netaddr = ">=0.7.18"
Jinja2 = ">=3.0"
bleach = ">=1.4.3"
# We use `ParamSpec` and `Concatenate`, which were added in `typing-extensions` 3.10.0.0.
typing-extensions = ">=3.10.0"
# Additionally we need https://github.com/python/typing/pull/817 to allow types to be
# generic over ParamSpecs.
typing-extensions = ">=3.10.0.1"
# We enforce that we have a `cryptography` version that bundles an `openssl`
# with the latest security patches.
cryptography = ">=3.4.7"
Expand Down
14 changes: 10 additions & 4 deletions synapse/app/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

from cryptography.utils import CryptographyDeprecationWarning
from matrix_common.versionstring import get_distribution_version_string
from typing_extensions import ParamSpec

import twisted
from twisted.internet import defer, error, reactor as _reactor
Expand Down Expand Up @@ -81,19 +82,22 @@

# list of tuples of function, args list, kwargs dict
_sighup_callbacks: List[
Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]]
Tuple[Callable[..., None], Tuple[object, ...], Dict[str, object]]
] = []
P = ParamSpec("P")


def register_sighup(func: Callable[..., None], *args: Any, **kwargs: Any) -> None:
def register_sighup(func: Callable[P, None], *args: P.args, **kwargs: P.kwargs) -> None:
"""
Register a function to be called when a SIGHUP occurs.

Args:
func: Function to be called when sent a SIGHUP signal.
*args, **kwargs: args and kwargs to be passed to the target function.
"""
_sighup_callbacks.append((func, args, kwargs))
# This type-ignore should be redundant once we use a mypy release with
# https://github.com/python/mypy/pull/12668.
_sighup_callbacks.append((func, args, kwargs)) # type: ignore[arg-type]


def start_worker_reactor(
Expand Down Expand Up @@ -214,7 +218,9 @@ def redirect_stdio_to_logs() -> None:
print("Redirected stdout/stderr to logs")


def register_start(cb: Callable[..., Awaitable], *args: Any, **kwargs: Any) -> None:
def register_start(
cb: Callable[P, Awaitable], *args: P.args, **kwargs: P.kwargs
) -> None:
"""Register a callback with the reactor, to be called once it is running

This can be used to initialise parts of the system which require an asynchronous
Expand Down
15 changes: 12 additions & 3 deletions synapse/events/presence_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@
List,
Optional,
Set,
TypeVar,
Union,
)

from typing_extensions import ParamSpec

from synapse.api.presence import UserPresenceState
from synapse.util.async_helpers import maybe_awaitable

Expand All @@ -40,6 +43,10 @@
logger = logging.getLogger(__name__)


P = ParamSpec("P")
R = TypeVar("R")


def load_legacy_presence_router(hs: "HomeServer") -> None:
"""Wrapper that loads a presence router module configured using the old
configuration, and registers the hooks they implement.
Expand All @@ -63,13 +70,15 @@ def load_legacy_presence_router(hs: "HomeServer") -> None:

# All methods that the module provides should be async, but this wasn't enforced
# in the old module system, so we wrap them if needed
def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
def async_wrapper(
f: Optional[Callable[P, R]]
) -> Optional[Callable[P, Awaitable[R]]]:
# f might be None if the callback isn't implemented by the module. In this
# case we don't want to register a callback at all so we return None.
if f is None:
return None

def run(*args: Any, **kwargs: Any) -> Awaitable:
def run(*args: P.args, **kwargs: P.kwargs) -> Awaitable[R]:
# Assertion required because mypy can't prove we won't change `f`
# back to `None`. See
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
Expand All @@ -80,7 +89,7 @@ def run(*args: Any, **kwargs: Any) -> Awaitable:
return run

# Register the hooks through the module API.
hooks = {
hooks: Dict[str, Optional[Callable[..., Any]]] = {
hook: async_wrapper(getattr(presence_router, hook, None))
for hook in presence_router_methods
}
Expand Down
17 changes: 10 additions & 7 deletions synapse/module_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import attr
import jinja2
from typing_extensions import ParamSpec

from twisted.internet import defer
from twisted.web.resource import Resource
Expand Down Expand Up @@ -129,6 +130,7 @@


T = TypeVar("T")
P = ParamSpec("P")

"""
This package defines the 'stable' API which can be used by extension modules which
Expand Down Expand Up @@ -799,9 +801,9 @@ def invalidate_access_token(
def run_db_interaction(
self,
desc: str,
func: Callable[..., T],
*args: Any,
**kwargs: Any,
func: Callable[P, T],
*args: P.args,
**kwargs: P.kwargs,
) -> "defer.Deferred[T]":
"""Run a function with a database connection

Expand All @@ -817,8 +819,9 @@ def run_db_interaction(
Returns:
Deferred[object]: result of func
"""
# type-ignore: See https://github.com/python/mypy/issues/8862
return defer.ensureDeferred(
self._store.db_pool.runInteraction(desc, func, *args, **kwargs)
self._store.db_pool.runInteraction(desc, func, *args, **kwargs) # type: ignore[arg-type]
)

def complete_sso_login(
Expand Down Expand Up @@ -1296,9 +1299,9 @@ async def get_room_state(

async def defer_to_thread(
self,
f: Callable[..., T],
*args: Any,
**kwargs: Any,
f: Callable[P, T],
*args: P.args,
**kwargs: P.kwargs,
) -> T:
"""Runs the given function in a separate thread from Synapse's thread pool.

Expand Down
4 changes: 1 addition & 3 deletions synapse/rest/client/knock.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
import logging
from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple

from twisted.web.server import Request

from synapse.api.constants import Membership
from synapse.api.errors import SynapseError
from synapse.http.server import HttpServer
Expand Down Expand Up @@ -97,7 +95,7 @@ async def on_POST(
return 200, {"room_id": room_id}

def on_PUT(
self, request: Request, room_identifier: str, txn_id: str
self, request: SynapseRequest, room_identifier: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)

Expand Down
19 changes: 12 additions & 7 deletions synapse/rest/client/transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
"""This module contains logic for storing HTTP PUT transactions. This is used
to ensure idempotency when performing PUTs using the REST API."""
import logging
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Tuple
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Tuple

from typing_extensions import ParamSpec

from twisted.python.failure import Failure
from twisted.web.server import Request
Expand All @@ -32,6 +34,9 @@
CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins


P = ParamSpec("P")


class HttpTransactionCache:
def __init__(self, hs: "HomeServer"):
self.hs = hs
Expand Down Expand Up @@ -65,9 +70,9 @@ def _get_transaction_key(self, request: Request) -> str:
def fetch_or_execute_request(
self,
request: Request,
fn: Callable[..., Awaitable[Tuple[int, JsonDict]]],
*args: Any,
**kwargs: Any,
fn: Callable[P, Awaitable[Tuple[int, JsonDict]]],
*args: P.args,
**kwargs: P.kwargs,
) -> Awaitable[Tuple[int, JsonDict]]:
"""A helper function for fetch_or_execute which extracts
a transaction key from the given request.
Expand All @@ -82,9 +87,9 @@ def fetch_or_execute_request(
def fetch_or_execute(
self,
txn_key: str,
fn: Callable[..., Awaitable[Tuple[int, JsonDict]]],
*args: Any,
**kwargs: Any,
fn: Callable[P, Awaitable[Tuple[int, JsonDict]]],
*args: P.args,
**kwargs: P.kwargs,
) -> Awaitable[Tuple[int, JsonDict]]:
"""Fetches the response for this transaction, or executes the given function
to produce a response for this transaction.
Expand Down
31 changes: 20 additions & 11 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def __getattr__(self, name):


# The type of entry which goes on our after_callbacks and exception_callbacks lists.
_CallbackListEntry = Tuple[Callable[..., object], Iterable[Any], Dict[str, Any]]
_CallbackListEntry = Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]]

P = ParamSpec("P")
R = TypeVar("R")
Expand Down Expand Up @@ -239,7 +239,9 @@ def __init__(
self.after_callbacks = after_callbacks
self.exception_callbacks = exception_callbacks

def call_after(self, callback: Callable[..., object], *args: Any, **kwargs: Any):
def call_after(
self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs
) -> None:
"""Call the given callback on the main twisted thread after the transaction has
finished.

Expand All @@ -256,11 +258,12 @@ def call_after(self, callback: Callable[..., object], *args: Any, **kwargs: Any)
# LoggingTransaction isn't expecting there to be any callbacks; assert that
# is not the case.
assert self.after_callbacks is not None
self.after_callbacks.append((callback, args, kwargs))
# type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
self.after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type]

def call_on_exception(
self, callback: Callable[..., object], *args: Any, **kwargs: Any
):
self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs
) -> None:
"""Call the given callback on the main twisted thread after the transaction has
failed.

Expand All @@ -274,7 +277,8 @@ def call_on_exception(
# LoggingTransaction isn't expecting there to be any callbacks; assert that
# is not the case.
assert self.exception_callbacks is not None
self.exception_callbacks.append((callback, args, kwargs))
# type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
self.exception_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type]

def fetchone(self) -> Optional[Tuple]:
return self.txn.fetchone()
Expand Down Expand Up @@ -549,9 +553,9 @@ def new_transaction(
desc: str,
after_callbacks: List[_CallbackListEntry],
exception_callbacks: List[_CallbackListEntry],
func: Callable[..., R],
*args: Any,
**kwargs: Any,
func: Callable[Concatenate[LoggingTransaction, P], R],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
"""Start a new database transaction with the given connection.

Expand Down Expand Up @@ -581,15 +585,20 @@ def new_transaction(
# will fail if we have to repeat the transaction.
# For now, we just log an error, and hope that it works on the first attempt.
# TODO: raise an exception.
for i, arg in enumerate(args):

# Type-ignore Mypy doesn't yet consider ParamSpec.args to be iterable; see
# https://github.com/python/mypy/pull/12668
for i, arg in enumerate(args): # type: ignore[arg-type, var-annotated]
if inspect.isgenerator(arg):
logger.error(
"Programming error: generator passed to new_transaction as "
"argument %i to function %s",
i,
func,
)
for name, val in kwargs.items():
# Type-ignore Mypy doesn't yet consider ParamSpec.args to be a mapping; see
# https://github.com/python/mypy/pull/12668
for name, val in kwargs.items(): # type: ignore[attr-defined]
if inspect.isgenerator(val):
logger.error(
"Programming error: generator passed to new_transaction as "
Expand Down
8 changes: 6 additions & 2 deletions synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1648,8 +1648,12 @@ def prefill():
txn.call_after(prefill)

def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None:
# Invalidate the caches for the redacted event, note that these caches
# are also cleared as part of event replication in _invalidate_caches_for_event.
"""Invalidate the caches for the redacted event.

Note that these caches are also cleared as part of event replication in
_invalidate_caches_for_event.
"""
assert event.redacts is not None
txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,))
txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,))
Expand Down
26 changes: 19 additions & 7 deletions synapse/util/async_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
)

import attr
from typing_extensions import AsyncContextManager, Literal
from typing_extensions import AsyncContextManager, Concatenate, Literal, ParamSpec

from twisted.internet import defer
from twisted.internet.defer import CancelledError
Expand Down Expand Up @@ -237,9 +237,16 @@ async def _concurrently_execute_inner(value: T) -> None:
)


P = ParamSpec("P")
R = TypeVar("R")


async def yieldable_gather_results(
func: Callable[..., Awaitable[T]], iter: Iterable, *args: Any, **kwargs: Any
) -> List[T]:
func: Callable[Concatenate[T, P], Awaitable[R]],
iter: Iterable[T],
*args: P.args,
**kwargs: P.kwargs,
) -> List[R]:
"""Executes the function with each argument concurrently.

Args:
Expand All @@ -255,7 +262,15 @@ async def yieldable_gather_results(
try:
return await make_deferred_yieldable(
defer.gatherResults(
[run_in_background(func, item, *args, **kwargs) for item in iter],
# type-ignore: mypy reports two errors:
# error: Argument 1 to "run_in_background" has incompatible type
# "Callable[[T, **P], Awaitable[R]]"; expected
# "Callable[[T, **P], Awaitable[R]]" [arg-type]
# error: Argument 2 to "run_in_background" has incompatible type
# "T"; expected "[T, **P.args]" [arg-type]
# The former looks like a mypy bug, and the latter looks like a
# false positive.
[run_in_background(func, item, *args, **kwargs) for item in iter], # type: ignore[arg-type]
consumeErrors=True,
)
)
Expand Down Expand Up @@ -577,9 +592,6 @@ async def _ctx_manager() -> AsyncIterator[None]:
return _ctx_manager()


R = TypeVar("R")


def timeout_deferred(
deferred: "defer.Deferred[_T]", timeout: float, reactor: IReactorTime
) -> "defer.Deferred[_T]":
Expand Down
Loading