From 84110a265c3ef0a20fb3fd5a4d17c10e4f6fa843 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 25 Feb 2020 12:12:16 +0000 Subject: [PATCH 1/7] Add BaseDatabaseEngine type --- synapse/storage/engines/__init__.py | 7 ++- synapse/storage/engines/_base.py | 79 +++++++++++++++++++++++++++++ synapse/storage/engines/postgres.py | 6 +-- synapse/storage/engines/sqlite.py | 6 ++- 4 files changed, 89 insertions(+), 9 deletions(-) diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py index 9d2d51992217..918e6c568fbb 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py @@ -12,18 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import importlib import platform -from ._base import IncorrectDatabaseSetup +from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup from .postgres import PostgresEngine from .sqlite import Sqlite3Engine SUPPORTED_MODULE = {"sqlite3": Sqlite3Engine, "psycopg2": PostgresEngine} -def create_engine(database_config): +def create_engine(database_config) -> BaseDatabaseEngine: name = database_config["name"] engine_class = SUPPORTED_MODULE.get(name, None) @@ -37,4 +36,4 @@ def create_engine(database_config): raise RuntimeError("Unsupported database engine '%s'" % (name,)) -__all__ = ["create_engine", "IncorrectDatabaseSetup"] +__all__ = ["create_engine", "BaseDatabaseEngine", "IncorrectDatabaseSetup"] diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py index ec5a4d198be2..5e6f356b513f 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py @@ -12,7 +12,86 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc class IncorrectDatabaseSetup(RuntimeError): pass + + +class BaseDatabaseEngine(metaclass=abc.ABCMeta): + def __init__(self, module, database_config: dict): + self.module = module + + @abc.abstractmethod + @property + def single_threaded(self) -> bool: + ... + + @abc.abstractmethod + @property + def can_native_upsert(self) -> bool: + """ + Do we support native UPSERTs? + """ + ... + + @abc.abstractmethod + @property + def supports_tuple_comparison(self) -> bool: + """ + Do we support comparing tuples, i.e. `(a, b) > (c, d)`? + """ + ... + + @abc.abstractmethod + @property + def supports_using_any_list(self) -> bool: + """ + Do we support using `a = ANY(?)` and passing a list + """ + ... + + @abc.abstractmethod + def check_database(self, db_conn, allow_outdated_version: bool = False) -> None: + ... + + @abc.abstractmethod + def check_new_database(self, txn) -> None: + """Gets called when setting up a brand new database. This allows us to + apply stricter checks on new databases versus existing database. + """ + ... + + @abc.abstractmethod + def convert_param_style(self, sql: str) -> str: + ... + + @abc.abstractmethod + def on_new_connection(self, db_conn) -> None: + ... + + @abc.abstractmethod + def is_deadlock(self, error: Exception) -> bool: + ... + + @abc.abstractmethod + def is_connection_closed(self, conn) -> bool: + ... + + @abc.abstractmethod + def lock_table(self, txn, table: str) -> None: + ... + + @abc.abstractmethod + def get_next_state_group_id(self, txn) -> int: + """Returns an int that can be used as a new state_group ID + """ + ... + + @abc.abstractmethod + @property + def server_version(self) -> str: + """Gets a string giving the server version. For example: '3.22.0' + """ + ... diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index a077345960a9..2e9d8bf843da 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -15,16 +15,16 @@ import logging -from ._base import IncorrectDatabaseSetup +from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup logger = logging.getLogger(__name__) -class PostgresEngine(object): +class PostgresEngine(BaseDatabaseEngine): single_threaded = False def __init__(self, database_module, database_config): - self.module = database_module + super().__init__(database_module, database_config) self.module.extensions.register_type(self.module.extensions.UNICODE) # Disables passing `bytes` to txn.execute, c.f. #6186. If you do diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 641e49069758..b3732c75906d 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -16,12 +16,14 @@ import struct import threading +from synapse.storage.engines import BaseDatabaseEngine -class Sqlite3Engine(object): + +class Sqlite3Engine(BaseDatabaseEngine): single_threaded = True def __init__(self, database_module, database_config): - self.module = database_module + super().__init__(database_module, database_config) database = database_config.get("args", {}).get("database") self._is_in_memory = database in (None, ":memory:",) From 8a32aece9f31f529a40dd07d889dc1c3f98eb3c4 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 25 Feb 2020 14:27:55 +0000 Subject: [PATCH 2/7] Database type annotations --- synapse/storage/database.py | 132 +++++++++++++++------------- synapse/storage/engines/__init__.py | 21 +++-- synapse/storage/engines/_base.py | 26 ++++-- synapse/storage/engines/postgres.py | 6 +- synapse/storage/engines/sqlite.py | 11 +-- synapse/storage/types.py | 70 +++++++++++++++ tox.ini | 5 +- 7 files changed, 185 insertions(+), 86 deletions(-) create mode 100644 synapse/storage/types.py diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 1953614401c9..7d5abc6408ee 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -15,9 +15,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import sys import time -from typing import Iterable, Tuple +# on python 3, use time.monotonic, since time.clock can go backwards +from time import monotonic as monotonic_time +from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple from six import iteritems, iterkeys, itervalues from six.moves import intern, range @@ -32,24 +33,14 @@ from synapse.logging.context import LoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater -from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine +from synapse.storage.types import Connection, Cursor, Row from synapse.util.stringutils import exception_to_unicode -# import a function which will return a monotonic time, in seconds -try: - # on python 3, use time.monotonic, since time.clock can go backwards - from time import monotonic as monotonic_time -except ImportError: - # ... but python 2 doesn't have it - from time import clock as monotonic_time - logger = logging.getLogger(__name__) -try: - MAX_TXN_ID = sys.maxint - 1 -except AttributeError: - # python 3 does not have a maximum int value - MAX_TXN_ID = 2 ** 63 - 1 +# python 3 does not have a maximum int value +MAX_TXN_ID = 2 ** 63 - 1 sql_logger = logging.getLogger("synapse.storage.SQL") transaction_logger = logging.getLogger("synapse.storage.txn") @@ -77,7 +68,7 @@ def make_pool( - reactor, db_config: DatabaseConnectionConfig, engine + reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine ) -> adbapi.ConnectionPool: """Get the connection pool for the database. """ @@ -90,7 +81,9 @@ def make_pool( ) -def make_conn(db_config: DatabaseConnectionConfig, engine): +def make_conn( + db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine +) -> Connection: """Make a new connection to the database and return it. Returns: @@ -107,20 +100,24 @@ def make_conn(db_config: DatabaseConnectionConfig, engine): return db_conn -class LoggingTransaction(object): +# the type of entry which goes on our after_callbacks ane exception_callbacks lists +_CallbackType = Tuple[Callable[..., None], Iterable[Any], Dict[str, Any]] + + +class LoggingTransaction(Cursor): """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging and metrics to the .execute() method. Args: txn: The database transcation object to wrap. - name (str): The name of this transactions for logging. - database_engine (Sqlite3Engine|PostgresEngine) - after_callbacks(list|None): A list that callbacks will be appended to + name: The name of this transactions for logging. + database_engine + after_callbacks: A list that callbacks will be appended to that have been added by `call_after` which should be run on successful completion of the transaction. None indicates that no callbacks should be allowed to be scheduled to run. - exception_callbacks(list|None): A list that callbacks will be appended + exception_callbacks: A list that callbacks will be appended to that have been added by `call_on_exception` which should be run if transaction ends with an error. None indicates that no callbacks should be allowed to be scheduled to run. @@ -135,46 +132,61 @@ class LoggingTransaction(object): ] def __init__( - self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None + self, + txn: Cursor, + name: str, + database_engine: BaseDatabaseEngine, + after_callbacks: Optional[List[_CallbackType]] = None, + exception_callbacks: Optional[List[_CallbackType]] = None, ): - object.__setattr__(self, "txn", txn) - object.__setattr__(self, "name", name) - object.__setattr__(self, "database_engine", database_engine) - object.__setattr__(self, "after_callbacks", after_callbacks) - object.__setattr__(self, "exception_callbacks", exception_callbacks) + self.txn = txn + self.name = name + self.database_engine = database_engine + self.after_callbacks = after_callbacks + self.exception_callbacks = exception_callbacks - def call_after(self, callback, *args, **kwargs): + def call_after(self, callback: Callable[..., None], *args, **kwargs): """Call the given callback on the main twisted thread after the transaction has finished. Used to invalidate the caches on the correct thread. """ + assert self.after_callbacks is not None self.after_callbacks.append((callback, args, kwargs)) - def call_on_exception(self, callback, *args, **kwargs): + def call_on_exception(self, callback: Callable[..., None], *args, **kwargs): + assert self.exception_callbacks is not None self.exception_callbacks.append((callback, args, kwargs)) - def __getattr__(self, name): - return getattr(self.txn, name) + def fetchall(self) -> List[Row]: + return self.txn.fetchall() - def __setattr__(self, name, value): - setattr(self.txn, name, value) + def fetchone(self) -> Row: + return self.txn.fetchone() - def __iter__(self): + def __iter__(self) -> Iterator[Row]: return self.txn.__iter__() + @property + def rowcount(self) -> int: + return self.txn.rowcount + + @property + def description(self) -> Any: + return self.txn.description + def execute_batch(self, sql, args): if isinstance(self.database_engine, PostgresEngine): - from psycopg2.extras import execute_batch + from psycopg2.extras import execute_batch # type: ignore self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args) else: for val in args: self.execute(sql, val) - def execute(self, sql, *args): + def execute(self, sql: str, *args: Any): self._do_execute(self.txn.execute, sql, *args) - def executemany(self, sql, *args): + def executemany(self, sql: str, *args: Any): self._do_execute(self.txn.executemany, sql, *args) def _make_sql_one_line(self, sql): @@ -251,7 +263,9 @@ class Database(object): _TXN_ID = 0 - def __init__(self, hs, database_config: DatabaseConnectionConfig, engine): + def __init__( + self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine + ): self.hs = hs self._clock = hs.get_clock() self._database_config = database_config @@ -259,9 +273,9 @@ def __init__(self, hs, database_config: DatabaseConnectionConfig, engine): self.updates = BackgroundUpdater(hs, self) - self._previous_txn_total_time = 0 - self._current_txn_total_time = 0 - self._previous_loop_ts = 0 + self._previous_txn_total_time = 0.0 + self._current_txn_total_time = 0.0 + self._previous_loop_ts = 0.0 # TODO(paul): These can eventually be removed once the metrics code # is running in mainline, and we have some nice monitoring frontends @@ -463,23 +477,23 @@ def new_transaction( sql_txn_timer.labels(desc).observe(duration) @defer.inlineCallbacks - def runInteraction(self, desc, func, *args, **kwargs): + def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any): """Starts a transaction on the database and runs a given function Arguments: - desc (str): description of the transaction, for logging and metrics - func (func): callback function, which will be called with a + desc: description of the transaction, for logging and metrics + func: callback function, which will be called with a database transaction (twisted.enterprise.adbapi.Transaction) as its first argument, followed by `args` and `kwargs`. - args (list): positional args to pass to `func` - kwargs (dict): named args to pass to `func` + args: positional args to pass to `func` + kwargs: named args to pass to `func` Returns: Deferred: The result of func """ - after_callbacks = [] - exception_callbacks = [] + after_callbacks = [] # type: List[_CallbackType] + exception_callbacks = [] # type: List[_CallbackType] if LoggingContext.current_context() == LoggingContext.sentinel: logger.warning("Starting db txn '%s' from sentinel context", desc) @@ -505,15 +519,15 @@ def runInteraction(self, desc, func, *args, **kwargs): return result @defer.inlineCallbacks - def runWithConnection(self, func, *args, **kwargs): + def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any): """Wraps the .runWithConnection() method on the underlying db_pool. Arguments: - func (func): callback function, which will be called with a + func: callback function, which will be called with a database connection (twisted.enterprise.adbapi.Connection) as its first argument, followed by `args` and `kwargs`. - args (list): positional args to pass to `func` - kwargs (dict): named args to pass to `func` + args: positional args to pass to `func` + kwargs: named args to pass to `func` Returns: Deferred: The result of func @@ -800,7 +814,7 @@ def _getwhere(key): return False # We didn't find any existing rows, so insert a new one - allvalues = {} + allvalues = {} # type: Dict[str, Any] allvalues.update(keyvalues) allvalues.update(values) allvalues.update(insertion_values) @@ -829,7 +843,7 @@ def simple_upsert_txn_native_upsert( Returns: None """ - allvalues = {} + allvalues = {} # type: Dict[str, Any] allvalues.update(keyvalues) allvalues.update(insertion_values) @@ -916,7 +930,7 @@ def simple_upsert_many_txn_native_upsert( Returns: None """ - allnames = [] + allnames = [] # type: List[str] allnames.extend(key_names) allnames.extend(value_names) @@ -1100,7 +1114,7 @@ def simple_select_many_batch( keyvalues : dict of column names and values to select the rows with retcols : list of strings giving the names of the columns to return """ - results = [] + results = [] # type: List[Dict[str, Any]] if not iterable: return results @@ -1439,7 +1453,7 @@ def simple_select_list_paginate_txn( raise ValueError("order_direction must be one of 'ASC' or 'DESC'.") where_clause = "WHERE " if filters or keyvalues else "" - arg_list = [] + arg_list = [] # type: List[Any] if filters: where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters) arg_list += list(filters.values()) diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py index 918e6c568fbb..035f9ea6e98b 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py @@ -12,26 +12,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import importlib import platform from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup from .postgres import PostgresEngine from .sqlite import Sqlite3Engine -SUPPORTED_MODULE = {"sqlite3": Sqlite3Engine, "psycopg2": PostgresEngine} - def create_engine(database_config) -> BaseDatabaseEngine: name = database_config["name"] - engine_class = SUPPORTED_MODULE.get(name, None) - if engine_class: + if name == "sqlite3": + import sqlite3 + + return Sqlite3Engine(sqlite3, database_config) + + if name == "psycopg2": # pypy requires psycopg2cffi rather than psycopg2 - if name == "psycopg2" and platform.python_implementation() == "PyPy": - name = "psycopg2cffi" - module = importlib.import_module(name) - return engine_class(module, database_config) + if platform.python_implementation() == "PyPy": + import psycopg2cffi as psycopg2 # type: ignore + else: + import psycopg2 # type: ignore + + return PostgresEngine(psycopg2, database_config) raise RuntimeError("Unsupported database engine '%s'" % (name,)) diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py index 5e6f356b513f..ab0bbe4bd364 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py @@ -13,39 +13,45 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc +from typing import Generic, TypeVar + +from synapse.storage.types import Connection class IncorrectDatabaseSetup(RuntimeError): pass -class BaseDatabaseEngine(metaclass=abc.ABCMeta): +ConnectionType = TypeVar("ConnectionType", bound=Connection) + + +class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta): def __init__(self, module, database_config: dict): self.module = module - @abc.abstractmethod @property + @abc.abstractmethod def single_threaded(self) -> bool: ... - @abc.abstractmethod @property + @abc.abstractmethod def can_native_upsert(self) -> bool: """ Do we support native UPSERTs? """ ... - @abc.abstractmethod @property + @abc.abstractmethod def supports_tuple_comparison(self) -> bool: """ Do we support comparing tuples, i.e. `(a, b) > (c, d)`? """ ... - @abc.abstractmethod @property + @abc.abstractmethod def supports_using_any_list(self) -> bool: """ Do we support using `a = ANY(?)` and passing a list @@ -53,7 +59,9 @@ def supports_using_any_list(self) -> bool: ... @abc.abstractmethod - def check_database(self, db_conn, allow_outdated_version: bool = False) -> None: + def check_database( + self, db_conn: ConnectionType, allow_outdated_version: bool = False + ) -> None: ... @abc.abstractmethod @@ -68,7 +76,7 @@ def convert_param_style(self, sql: str) -> str: ... @abc.abstractmethod - def on_new_connection(self, db_conn) -> None: + def on_new_connection(self, db_conn: ConnectionType) -> None: ... @abc.abstractmethod @@ -76,7 +84,7 @@ def is_deadlock(self, error: Exception) -> bool: ... @abc.abstractmethod - def is_connection_closed(self, conn) -> bool: + def is_connection_closed(self, conn: ConnectionType) -> bool: ... @abc.abstractmethod @@ -89,8 +97,8 @@ def get_next_state_group_id(self, txn) -> int: """ ... - @abc.abstractmethod @property + @abc.abstractmethod def server_version(self) -> str: """Gets a string giving the server version. For example: '3.22.0' """ diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 2e9d8bf843da..48f103bb5a4f 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -21,8 +21,6 @@ class PostgresEngine(BaseDatabaseEngine): - single_threaded = False - def __init__(self, database_module, database_config): super().__init__(database_module, database_config) self.module.extensions.register_type(self.module.extensions.UNICODE) @@ -36,6 +34,10 @@ def _disable_bytes_adapter(_): self.synchronous_commit = database_config.get("synchronous_commit", True) self._version = None # unknown as yet + @property + def single_threaded(self) -> bool: + return False + def check_database(self, db_conn, allow_outdated_version: bool = False): # Get the version of PostgreSQL that we're using. As per the psycopg2 # docs: The number is formed by converting the major, minor, and diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index b3732c75906d..2bfeefd54ed2 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -12,16 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import sqlite3 import struct import threading from synapse.storage.engines import BaseDatabaseEngine -class Sqlite3Engine(BaseDatabaseEngine): - single_threaded = True - +class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]): def __init__(self, database_module, database_config): super().__init__(database_module, database_config) @@ -33,6 +31,10 @@ def __init__(self, database_module, database_config): self._current_state_group_id = None self._current_state_group_id_lock = threading.Lock() + @property + def single_threaded(self) -> bool: + return True + @property def can_native_upsert(self): """ @@ -70,7 +72,6 @@ def convert_param_style(self, sql): return sql def on_new_connection(self, db_conn): - # We need to import here to avoid an import loop. from synapse.storage.prepare_database import prepare_database diff --git a/synapse/storage/types.py b/synapse/storage/types.py new file mode 100644 index 000000000000..9e364fd29b83 --- /dev/null +++ b/synapse/storage/types.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Iterable, Iterator, List + +from typing_extensions import Protocol + + +""" +Some very basic protocol definitions for the DB-API2 classes specified in PEP-249 +""" + + +class Row(Protocol): + # todo: make this stronger + pass + + +class Cursor(Protocol): + def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any: + ... + + def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any: + ... + + def fetchall(self) -> List[Row]: + ... + + def fetchone(self) -> Row: + ... + + @property + def description(self) -> Any: + return None + + @property + def rowcount(self) -> int: + return 0 + + def __iter__(self) -> Iterator[Row]: + ... + + def close(self) -> None: + ... + + +class Connection(Protocol): + def cursor(self) -> Cursor: + ... + + def close(self) -> None: + ... + + def commit(self) -> None: + ... + + def rollback(self, *args, **kwargs) -> None: + ... diff --git a/tox.ini b/tox.ini index b715ea0bff40..c89bf20d4092 100644 --- a/tox.ini +++ b/tox.ini @@ -168,7 +168,6 @@ commands= coverage html [testenv:mypy] -basepython = python3.7 skip_install = True deps = {[base]deps} @@ -179,7 +178,8 @@ env = extras = all commands = mypy \ synapse/api \ - synapse/config/ \ + synapse/appservice \ + synapse/config \ synapse/events/spamcheck.py \ synapse/federation/sender \ synapse/federation/transport \ @@ -191,6 +191,7 @@ commands = mypy \ synapse/rest \ synapse/spam_checker_api \ synapse/storage/engines \ + synapse/storage/database.py \ synapse/streams # To find all folders that pass mypy you run: From 32996341db78738e0e4048097830c1f8348bee9a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 25 Feb 2020 14:28:33 +0000 Subject: [PATCH 3/7] changelog --- changelog.d/6987.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/6987.misc diff --git a/changelog.d/6987.misc b/changelog.d/6987.misc new file mode 100644 index 000000000000..7ff74cda5533 --- /dev/null +++ b/changelog.d/6987.misc @@ -0,0 +1 @@ +Add some type annotations to the database storage classes. From ec58dce48ef31a5f324456ca3734694b9ffcfed3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 25 Feb 2020 14:42:07 +0000 Subject: [PATCH 4/7] fix lint --- synapse/storage/database.py | 1 - 1 file changed, 1 deletion(-) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 7d5abc6408ee..c69cd43a8f8c 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -16,7 +16,6 @@ # limitations under the License. import logging import time -# on python 3, use time.monotonic, since time.clock can go backwards from time import monotonic as monotonic_time from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple From 7a63448f44af26235fdb51ba9a8ef4df02eb7d1a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 25 Feb 2020 14:45:38 +0000 Subject: [PATCH 5/7] more lint --- synapse/storage/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/types.py b/synapse/storage/types.py index 9e364fd29b83..03e2f28733ff 100644 --- a/synapse/storage/types.py +++ b/synapse/storage/types.py @@ -19,7 +19,7 @@ """ -Some very basic protocol definitions for the DB-API2 classes specified in PEP-249 +Some very basic protocol definitions for the DB-API2 classes specified in PEP-249 """ From 68da6027bebdcad11f84b4f5e11da082a394870a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 26 Feb 2020 07:23:09 +0000 Subject: [PATCH 6/7] fix various test failures --- synapse/storage/database.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index c69cd43a8f8c..ef9874ce5098 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -99,11 +99,12 @@ def make_conn( return db_conn -# the type of entry which goes on our after_callbacks ane exception_callbacks lists -_CallbackType = Tuple[Callable[..., None], Iterable[Any], Dict[str, Any]] +# the type of entry which goes on our after_callbacks and exception_callbacks lists +# use of a string here because python 3.5.2 doesn't support Callable([...]). +_CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]] -class LoggingTransaction(Cursor): +class LoggingTransaction: """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging and metrics to the .execute() method. @@ -135,8 +136,8 @@ def __init__( txn: Cursor, name: str, database_engine: BaseDatabaseEngine, - after_callbacks: Optional[List[_CallbackType]] = None, - exception_callbacks: Optional[List[_CallbackType]] = None, + after_callbacks: Optional[List[_CallbackListEntry]] = None, + exception_callbacks: Optional[List[_CallbackListEntry]] = None, ): self.txn = txn self.name = name @@ -144,7 +145,7 @@ def __init__( self.after_callbacks = after_callbacks self.exception_callbacks = exception_callbacks - def call_after(self, callback: Callable[..., None], *args, **kwargs): + def call_after(self, callback: "Callable[..., None]", *args, **kwargs): """Call the given callback on the main twisted thread after the transaction has finished. Used to invalidate the caches on the correct thread. @@ -152,7 +153,7 @@ def call_after(self, callback: Callable[..., None], *args, **kwargs): assert self.after_callbacks is not None self.after_callbacks.append((callback, args, kwargs)) - def call_on_exception(self, callback: Callable[..., None], *args, **kwargs): + def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs): assert self.exception_callbacks is not None self.exception_callbacks.append((callback, args, kwargs)) @@ -218,6 +219,9 @@ def _do_execute(self, func, sql, *args): sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs) sql_query_timer.labels(sql.split()[0]).observe(secs) + def close(self): + self.txn.close() + class PerformanceCounters(object): def __init__(self): @@ -491,8 +495,8 @@ def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any): Returns: Deferred: The result of func """ - after_callbacks = [] # type: List[_CallbackType] - exception_callbacks = [] # type: List[_CallbackType] + after_callbacks = [] # type: List[_CallbackListEntry] + exception_callbacks = [] # type: List[_CallbackListEntry] if LoggingContext.current_context() == LoggingContext.sentinel: logger.warning("Starting db txn '%s' from sentinel context", desc) From 4170b2e468ceccccc889166920655bfde74c2fac Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 26 Feb 2020 16:54:38 +0000 Subject: [PATCH 7/7] address review comments --- synapse/storage/database.py | 20 ++++++++++++++------ synapse/storage/types.py | 13 ++++--------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index ef9874ce5098..609db406167e 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -33,7 +33,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine -from synapse.storage.types import Connection, Cursor, Row +from synapse.storage.types import Connection, Cursor from synapse.util.stringutils import exception_to_unicode logger = logging.getLogger(__name__) @@ -99,8 +99,10 @@ def make_conn( return db_conn -# the type of entry which goes on our after_callbacks and exception_callbacks lists -# use of a string here because python 3.5.2 doesn't support Callable([...]). +# The type of entry which goes on our after_callbacks and exception_callbacks lists. +# +# Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so +# that mypy sees the type but the runtime python doesn't. _CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]] @@ -150,20 +152,26 @@ def call_after(self, callback: "Callable[..., None]", *args, **kwargs): transaction has finished. Used to invalidate the caches on the correct thread. """ + # if self.after_callbacks is None, that means that whatever constructed the + # LoggingTransaction isn't expecting there to be any callbacks; assert that + # is not the case. assert self.after_callbacks is not None self.after_callbacks.append((callback, args, kwargs)) def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs): + # if self.exception_callbacks is None, that means that whatever constructed the + # LoggingTransaction isn't expecting there to be any callbacks; assert that + # is not the case. assert self.exception_callbacks is not None self.exception_callbacks.append((callback, args, kwargs)) - def fetchall(self) -> List[Row]: + def fetchall(self) -> List[Tuple]: return self.txn.fetchall() - def fetchone(self) -> Row: + def fetchone(self) -> Tuple: return self.txn.fetchone() - def __iter__(self) -> Iterator[Row]: + def __iter__(self) -> Iterator[Tuple]: return self.txn.__iter__() @property diff --git a/synapse/storage/types.py b/synapse/storage/types.py index 03e2f28733ff..daff81c5ee23 100644 --- a/synapse/storage/types.py +++ b/synapse/storage/types.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Iterable, Iterator, List +from typing import Any, Iterable, Iterator, List, Tuple from typing_extensions import Protocol @@ -23,11 +23,6 @@ """ -class Row(Protocol): - # todo: make this stronger - pass - - class Cursor(Protocol): def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any: ... @@ -35,10 +30,10 @@ def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any: def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any: ... - def fetchall(self) -> List[Row]: + def fetchall(self) -> List[Tuple]: ... - def fetchone(self) -> Row: + def fetchone(self) -> Tuple: ... @property @@ -49,7 +44,7 @@ def description(self) -> Any: def rowcount(self) -> int: return 0 - def __iter__(self) -> Iterator[Row]: + def __iter__(self) -> Iterator[Tuple]: ... def close(self) -> None: