Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Use Concatenate to annotate do_execute #12666

Merged
merged 2 commits into from
May 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12666.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Use `Concatenate` to better annotate `_do_execute`.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ netaddr = ">=0.7.18"
# add a lower bound to the Jinja2 dependency.
Jinja2 = ">=3.0"
bleach = ">=1.4.3"
# We use `ParamSpec`, which was added in `typing-extensions` 3.10.0.0.
# We use `ParamSpec` and `Concatenate`, which were added in `typing-extensions` 3.10.0.0.
typing-extensions = ">=3.10.0"
# We enforce that we have a `cryptography` version that bundles an `openssl`
# with the latest security patches.
Expand Down
19 changes: 14 additions & 5 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

import attr
from prometheus_client import Histogram
from typing_extensions import Literal
from typing_extensions import Concatenate, Literal, ParamSpec

from twisted.enterprise import adbapi

Expand Down Expand Up @@ -194,7 +194,7 @@ def __getattr__(self, name):
# The type of entry which goes on our after_callbacks and exception_callbacks lists.
_CallbackListEntry = Tuple[Callable[..., object], Iterable[Any], Dict[str, Any]]


P = ParamSpec("P")
R = TypeVar("R")


Expand Down Expand Up @@ -339,7 +339,13 @@ def _make_sql_one_line(self, sql: str) -> str:
"Strip newlines out of SQL so that the loggers in the DB are on one line"
return " ".join(line.strip() for line in sql.splitlines() if line.strip())

def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R:
def _do_execute(
self,
func: Callable[Concatenate[str, P], R],
sql: str,
*args: P.args,
**kwargs: P.kwargs,
) -> R:
sql = self._make_sql_one_line(sql)

# TODO(paul): Maybe use 'info' and 'debug' for values?
Expand All @@ -348,7 +354,10 @@ def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R:
sql = self.database_engine.convert_param_style(sql)
if args:
try:
sql_logger.debug("[SQL values] {%s} %r", self.name, args[0])
# The type-ignore should be redundant once mypy releases a version with
# https://github.com/python/mypy/pull/12668. (`args` might be empty,
# (but we'll catch the index error if so.)
sql_logger.debug("[SQL values] {%s} %r", self.name, args[0]) # type: ignore[index]
except Exception:
# Don't let logging failures stop SQL from working
pass
Expand All @@ -363,7 +372,7 @@ def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R:
opentracing.tags.DATABASE_STATEMENT: sql,
},
):
return func(sql, *args)
return func(sql, *args, **kwargs)
except Exception as e:
sql_logger.debug("[SQL FAIL] {%s} %s", self.name, e)
raise
Expand Down