Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Use ParamSpec in a few places #12667

Merged
merged 13 commits into from
May 9, 2022
1 change: 1 addition & 0 deletions changelog.d/12667.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Use `ParamSpec` to refine type hints.
14 changes: 10 additions & 4 deletions synapse/app/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

from cryptography.utils import CryptographyDeprecationWarning
from matrix_common.versionstring import get_distribution_version_string
from typing_extensions import ParamSpec

import twisted
from twisted.internet import defer, error, reactor as _reactor
Expand Down Expand Up @@ -81,19 +82,22 @@

# list of tuples of function, args list, kwargs dict
_sighup_callbacks: List[
Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]]
Tuple[Callable[..., None], Tuple[object, ...], Dict[str, object]]
] = []
P = ParamSpec("P")


def register_sighup(func: Callable[..., None], *args: Any, **kwargs: Any) -> None:
def register_sighup(func: Callable[P, None], *args: P.args, **kwargs: P.kwargs) -> None:
"""
Register a function to be called when a SIGHUP occurs.

Args:
func: Function to be called when sent a SIGHUP signal.
*args, **kwargs: args and kwargs to be passed to the target function.
"""
_sighup_callbacks.append((func, args, kwargs))
# This type-ignore should be redundant once we use a mypy release with
# https://github.com/python/mypy/pull/12668.
_sighup_callbacks.append((func, args, kwargs)) # type: ignore[arg-type]


def start_worker_reactor(
Expand Down Expand Up @@ -214,7 +218,9 @@ def redirect_stdio_to_logs() -> None:
print("Redirected stdout/stderr to logs")


def register_start(cb: Callable[..., Awaitable], *args: Any, **kwargs: Any) -> None:
def register_start(
cb: Callable[P, Awaitable], *args: P.args, **kwargs: P.kwargs
) -> None:
"""Register a callback with the reactor, to be called once it is running

This can be used to initialise parts of the system which require an asynchronous
Expand Down
15 changes: 12 additions & 3 deletions synapse/events/presence_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@
List,
Optional,
Set,
TypeVar,
Union,
)

from typing_extensions import ParamSpec

from synapse.api.presence import UserPresenceState
from synapse.util.async_helpers import maybe_awaitable

Expand All @@ -40,6 +43,10 @@
logger = logging.getLogger(__name__)


P = ParamSpec("P")
R = TypeVar("R")


def load_legacy_presence_router(hs: "HomeServer") -> None:
"""Wrapper that loads a presence router module configured using the old
configuration, and registers the hooks they implement.
Expand All @@ -63,13 +70,15 @@ def load_legacy_presence_router(hs: "HomeServer") -> None:

# All methods that the module provides should be async, but this wasn't enforced
# in the old module system, so we wrap them if needed
def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
def async_wrapper(
f: Optional[Callable[P, R]]
) -> Optional[Callable[P, Awaitable[R]]]:
# f might be None if the callback isn't implemented by the module. In this
# case we don't want to register a callback at all so we return None.
if f is None:
return None

def run(*args: Any, **kwargs: Any) -> Awaitable:
def run(*args: P.args, **kwargs: P.kwargs) -> Awaitable[R]:
# Assertion required because mypy can't prove we won't change `f`
# back to `None`. See
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
Expand All @@ -80,7 +89,7 @@ def run(*args: Any, **kwargs: Any) -> Awaitable:
return run

# Register the hooks through the module API.
hooks = {
hooks: Dict[str, Optional[Callable[..., Any]]] = {
hook: async_wrapper(getattr(presence_router, hook, None))
for hook in presence_router_methods
}
Expand Down
17 changes: 10 additions & 7 deletions synapse/module_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import attr
import jinja2
from typing_extensions import ParamSpec

from twisted.internet import defer
from twisted.web.resource import Resource
Expand Down Expand Up @@ -129,6 +130,7 @@


T = TypeVar("T")
P = ParamSpec("P")

"""
This package defines the 'stable' API which can be used by extension modules which
Expand Down Expand Up @@ -799,9 +801,9 @@ def invalidate_access_token(
def run_db_interaction(
self,
desc: str,
func: Callable[..., T],
*args: Any,
**kwargs: Any,
func: Callable[P, T],
*args: P.args,
**kwargs: P.kwargs,
) -> "defer.Deferred[T]":
"""Run a function with a database connection

Expand All @@ -817,8 +819,9 @@ def run_db_interaction(
Returns:
Deferred[object]: result of func
"""
# type-ignore: See https://github.com/python/mypy/issues/8862
return defer.ensureDeferred(
self._store.db_pool.runInteraction(desc, func, *args, **kwargs)
self._store.db_pool.runInteraction(desc, func, *args, **kwargs) # type: ignore[arg-type]
)

def complete_sso_login(
Expand Down Expand Up @@ -1296,9 +1299,9 @@ async def get_room_state(

async def defer_to_thread(
self,
f: Callable[..., T],
*args: Any,
**kwargs: Any,
f: Callable[P, T],
*args: P.args,
**kwargs: P.kwargs,
) -> T:
"""Runs the given function in a separate thread from Synapse's thread pool.

Expand Down
4 changes: 1 addition & 3 deletions synapse/rest/client/knock.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
import logging
from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple

from twisted.web.server import Request

from synapse.api.constants import Membership
from synapse.api.errors import SynapseError
from synapse.http.server import HttpServer
Expand Down Expand Up @@ -97,7 +95,7 @@ async def on_POST(
return 200, {"room_id": room_id}

def on_PUT(
self, request: Request, room_identifier: str, txn_id: str
self, request: SynapseRequest, room_identifier: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)

Expand Down
19 changes: 12 additions & 7 deletions synapse/rest/client/transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
"""This module contains logic for storing HTTP PUT transactions. This is used
to ensure idempotency when performing PUTs using the REST API."""
import logging
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Tuple
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Tuple

from typing_extensions import ParamSpec

from twisted.python.failure import Failure
from twisted.web.server import Request
Expand All @@ -32,6 +34,9 @@
CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins


P = ParamSpec("P")


class HttpTransactionCache:
def __init__(self, hs: "HomeServer"):
self.hs = hs
Expand Down Expand Up @@ -65,9 +70,9 @@ def _get_transaction_key(self, request: Request) -> str:
def fetch_or_execute_request(
self,
request: Request,
fn: Callable[..., Awaitable[Tuple[int, JsonDict]]],
*args: Any,
**kwargs: Any,
fn: Callable[P, Awaitable[Tuple[int, JsonDict]]],
*args: P.args,
**kwargs: P.kwargs,
) -> Awaitable[Tuple[int, JsonDict]]:
"""A helper function for fetch_or_execute which extracts
a transaction key from the given request.
Expand All @@ -82,9 +87,9 @@ def fetch_or_execute_request(
def fetch_or_execute(
self,
txn_key: str,
fn: Callable[..., Awaitable[Tuple[int, JsonDict]]],
*args: Any,
**kwargs: Any,
fn: Callable[P, Awaitable[Tuple[int, JsonDict]]],
*args: P.args,
**kwargs: P.kwargs,
) -> Awaitable[Tuple[int, JsonDict]]:
"""Fetches the response for this transaction, or executes the given function
to produce a response for this transaction.
Expand Down
50 changes: 34 additions & 16 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

import attr
from prometheus_client import Histogram
from typing_extensions import Literal
from typing_extensions import Concatenate, Literal, ParamSpec

from twisted.enterprise import adbapi

Expand Down Expand Up @@ -192,9 +192,9 @@ def __getattr__(self, name):


# The type of entry which goes on our after_callbacks and exception_callbacks lists.
_CallbackListEntry = Tuple[Callable[..., object], Iterable[Any], Dict[str, Any]]

_CallbackListEntry = Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]]

P = ParamSpec("P")
R = TypeVar("R")


Expand Down Expand Up @@ -239,7 +239,9 @@ def __init__(
self.after_callbacks = after_callbacks
self.exception_callbacks = exception_callbacks

def call_after(self, callback: Callable[..., object], *args: Any, **kwargs: Any):
def call_after(
self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs
) -> None:
"""Call the given callback on the main twisted thread after the transaction has
finished.

Expand All @@ -256,11 +258,12 @@ def call_after(self, callback: Callable[..., object], *args: Any, **kwargs: Any)
# LoggingTransaction isn't expecting there to be any callbacks; assert that
# is not the case.
assert self.after_callbacks is not None
self.after_callbacks.append((callback, args, kwargs))
# type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
self.after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type]

def call_on_exception(
self, callback: Callable[..., object], *args: Any, **kwargs: Any
):
self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs
) -> None:
"""Call the given callback on the main twisted thread after the transaction has
failed.

Expand All @@ -274,7 +277,8 @@ def call_on_exception(
# LoggingTransaction isn't expecting there to be any callbacks; assert that
# is not the case.
assert self.exception_callbacks is not None
self.exception_callbacks.append((callback, args, kwargs))
# type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
self.exception_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type]

def fetchone(self) -> Optional[Tuple]:
return self.txn.fetchone()
Expand Down Expand Up @@ -339,7 +343,13 @@ def _make_sql_one_line(self, sql: str) -> str:
"Strip newlines out of SQL so that the loggers in the DB are on one line"
return " ".join(line.strip() for line in sql.splitlines() if line.strip())

def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R:
def _do_execute(
self,
func: Callable[Concatenate[str, P], R],
sql: str,
*args: P.args,
**kwargs: P.kwargs,
) -> R:
sql = self._make_sql_one_line(sql)

# TODO(paul): Maybe use 'info' and 'debug' for values?
Expand All @@ -348,7 +358,10 @@ def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R:
sql = self.database_engine.convert_param_style(sql)
if args:
try:
sql_logger.debug("[SQL values] {%s} %r", self.name, args[0])
# The type-ignore should be redundant once mypy releases a version with
# https://github.com/python/mypy/pull/12668. (`args` might be empty,
# (but we'll catch the index error if so.)
sql_logger.debug("[SQL values] {%s} %r", self.name, args[0]) # type: ignore[index]
except Exception:
# Don't let logging failures stop SQL from working
pass
Expand All @@ -363,7 +376,7 @@ def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R:
opentracing.tags.DATABASE_STATEMENT: sql,
},
):
return func(sql, *args)
return func(sql, *args, **kwargs)
except Exception as e:
sql_logger.debug("[SQL FAIL] {%s} %s", self.name, e)
raise
Expand Down Expand Up @@ -540,9 +553,9 @@ def new_transaction(
desc: str,
after_callbacks: List[_CallbackListEntry],
exception_callbacks: List[_CallbackListEntry],
func: Callable[..., R],
*args: Any,
**kwargs: Any,
func: Callable[Concatenate[LoggingTransaction, P], R],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
"""Start a new database transaction with the given connection.

Expand Down Expand Up @@ -572,15 +585,20 @@ def new_transaction(
# will fail if we have to repeat the transaction.
# For now, we just log an error, and hope that it works on the first attempt.
# TODO: raise an exception.
for i, arg in enumerate(args):

# Type-ignore Mypy doesn't yet consider ParamSpec.args to be iterable; see
# https://github.com/python/mypy/pull/12668
for i, arg in enumerate(args): # type: ignore[arg-type, var-annotated]
if inspect.isgenerator(arg):
logger.error(
"Programming error: generator passed to new_transaction as "
"argument %i to function %s",
i,
func,
)
for name, val in kwargs.items():
# Type-ignore Mypy doesn't yet consider ParamSpec.args to be a mapping; see
# https://github.com/python/mypy/pull/12668
for name, val in kwargs.items(): # type: ignore[attr-defined]
if inspect.isgenerator(val):
logger.error(
"Programming error: generator passed to new_transaction as "
Expand Down
12 changes: 9 additions & 3 deletions synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1648,9 +1648,15 @@ def prefill():
txn.call_after(prefill)

def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None:
# Invalidate the caches for the redacted event, note that these caches
# are also cleared as part of event replication in _invalidate_caches_for_event.
txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
"""Invalidate the caches for the redacted event.

Note that these caches are also cleared as part of event replication in
_invalidate_caches_for_event.
"""

# type-ignore: mypy detects that event.redacts may be None. Presumably the
# application has ensured that this is not the case if we call this function.
txn.call_after(self.store._invalidate_get_event_cache, event.redacts) # type: ignore[arg-type]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'd prefer an assert events.redacts is not None? We do check that before we call this function, so it shouldn't be a risky thing to do

txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,))
txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,))

Expand Down
Loading