diff --git a/changelog.d/12734.misc b/changelog.d/12734.misc new file mode 100644 index 000000000000..ffbfb0d63233 --- /dev/null +++ b/changelog.d/12734.misc @@ -0,0 +1 @@ +Tidy up and type-hint the database engine modules. diff --git a/mypy.ini b/mypy.ini index 9ae7ad211c54..b5b907973ffc 100644 --- a/mypy.ini +++ b/mypy.ini @@ -232,6 +232,9 @@ disallow_untyped_defs = True [mypy-synapse.storage.databases.main.user_erasure_store] disallow_untyped_defs = True +[mypy-synapse.storage.engines.*] +disallow_untyped_defs = True + [mypy-synapse.storage.prepare_database] disallow_untyped_defs = True diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py index afb7d5054db8..f51b3d228ee7 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py @@ -11,25 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Mapping from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup from .postgres import PostgresEngine from .sqlite import Sqlite3Engine -def create_engine(database_config) -> BaseDatabaseEngine: +def create_engine(database_config: Mapping[str, Any]) -> BaseDatabaseEngine: name = database_config["name"] if name == "sqlite3": - import sqlite3 - - return Sqlite3Engine(sqlite3, database_config) + return Sqlite3Engine(database_config) if name == "psycopg2": - # Note that psycopg2cffi-compat provides the psycopg2 module on pypy. - import psycopg2 - - return PostgresEngine(psycopg2, database_config) + return PostgresEngine(database_config) raise RuntimeError("Unsupported database engine '%s'" % (name,)) diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py index 143cd98ca292..971ff8269323 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py @@ -13,9 +13,12 @@ # limitations under the License. import abc from enum import IntEnum -from typing import Generic, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Generic, Mapping, Optional, TypeVar -from synapse.storage.types import Connection +from synapse.storage.types import Connection, Cursor, DBAPI2Module + +if TYPE_CHECKING: + from synapse.storage.database import LoggingDatabaseConnection class IsolationLevel(IntEnum): @@ -32,7 +35,7 @@ class IncorrectDatabaseSetup(RuntimeError): class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta): - def __init__(self, module, database_config: dict): + def __init__(self, module: DBAPI2Module, config: Mapping[str, Any]): self.module = module @property @@ -69,7 +72,7 @@ def check_database( ... @abc.abstractmethod - def check_new_database(self, txn) -> None: + def check_new_database(self, txn: Cursor) -> None: """Gets called when setting up a brand new database. This allows us to apply stricter checks on new databases versus existing database. """ @@ -79,8 +82,11 @@ def check_new_database(self, txn) -> None: def convert_param_style(self, sql: str) -> str: ... + # This method would ideally take a plain ConnectionType, but it seems that + # the Sqlite engine expects to use LoggingDatabaseConnection.cursor + # instead of sqlite3.Connection.cursor: only the former takes a txn_name. @abc.abstractmethod - def on_new_connection(self, db_conn: ConnectionType) -> None: + def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None: ... @abc.abstractmethod @@ -92,7 +98,7 @@ def is_connection_closed(self, conn: ConnectionType) -> bool: ... @abc.abstractmethod - def lock_table(self, txn, table: str) -> None: + def lock_table(self, txn: Cursor, table: str) -> None: ... @property @@ -102,12 +108,12 @@ def server_version(self) -> str: ... @abc.abstractmethod - def in_transaction(self, conn: Connection) -> bool: + def in_transaction(self, conn: ConnectionType) -> bool: """Whether the connection is currently in a transaction.""" ... @abc.abstractmethod - def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool): + def attempt_to_set_autocommit(self, conn: ConnectionType, autocommit: bool) -> None: """Attempt to set the connections autocommit mode. When True queries are run outside of transactions. @@ -119,8 +125,8 @@ def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool): @abc.abstractmethod def attempt_to_set_isolation_level( - self, conn: Connection, isolation_level: Optional[int] - ): + self, conn: ConnectionType, isolation_level: Optional[int] + ) -> None: """Attempt to set the connections isolation level. Note: This has no effect on SQLite3, as transactions are SERIALIZABLE by default. diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index e8d29e287004..391f8ed24a3d 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -13,39 +13,47 @@ # limitations under the License. import logging -from typing import Mapping, Optional +from typing import TYPE_CHECKING, Any, Mapping, NoReturn, Optional, Tuple, cast from synapse.storage.engines._base import ( BaseDatabaseEngine, IncorrectDatabaseSetup, IsolationLevel, ) -from synapse.storage.types import Connection +from synapse.storage.types import Cursor + +if TYPE_CHECKING: + import psycopg2 # noqa: F401 + + from synapse.storage.database import LoggingDatabaseConnection + logger = logging.getLogger(__name__) -class PostgresEngine(BaseDatabaseEngine): - def __init__(self, database_module, database_config): - super().__init__(database_module, database_config) - self.module.extensions.register_type(self.module.extensions.UNICODE) +class PostgresEngine(BaseDatabaseEngine["psycopg2.connection"]): + def __init__(self, database_config: Mapping[str, Any]): + import psycopg2.extensions + + super().__init__(psycopg2, database_config) + psycopg2.extensions.register_type(psycopg2.extensions.UNICODE) # Disables passing `bytes` to txn.execute, c.f. #6186. If you do # actually want to use bytes than wrap it in `bytearray`. - def _disable_bytes_adapter(_): + def _disable_bytes_adapter(_: bytes) -> NoReturn: raise Exception("Passing bytes to DB is disabled.") - self.module.extensions.register_adapter(bytes, _disable_bytes_adapter) - self.synchronous_commit = database_config.get("synchronous_commit", True) - self._version = None # unknown as yet + psycopg2.extensions.register_adapter(bytes, _disable_bytes_adapter) + self.synchronous_commit: bool = database_config.get("synchronous_commit", True) + self._version: Optional[int] = None # unknown as yet self.isolation_level_map: Mapping[int, int] = { - IsolationLevel.READ_COMMITTED: self.module.extensions.ISOLATION_LEVEL_READ_COMMITTED, - IsolationLevel.REPEATABLE_READ: self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ, - IsolationLevel.SERIALIZABLE: self.module.extensions.ISOLATION_LEVEL_SERIALIZABLE, + IsolationLevel.READ_COMMITTED: psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED, + IsolationLevel.REPEATABLE_READ: psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ, + IsolationLevel.SERIALIZABLE: psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE, } self.default_isolation_level = ( - self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ + psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ ) self.config = database_config @@ -53,19 +61,21 @@ def _disable_bytes_adapter(_): def single_threaded(self) -> bool: return False - def get_db_locale(self, txn): + def get_db_locale(self, txn: Cursor) -> Tuple[str, str]: txn.execute( "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()" ) - collation, ctype = txn.fetchone() + collation, ctype = cast(Tuple[str, str], txn.fetchone()) return collation, ctype - def check_database(self, db_conn, allow_outdated_version: bool = False): + def check_database( + self, db_conn: "psycopg2.connection", allow_outdated_version: bool = False + ) -> None: # Get the version of PostgreSQL that we're using. As per the psycopg2 # docs: The number is formed by converting the major, minor, and # revision numbers into two-decimal-digit numbers and appending them # together. For example, version 8.1.5 will be returned as 80105 - self._version = db_conn.server_version + self._version = cast(int, db_conn.server_version) allow_unsafe_locale = self.config.get("allow_unsafe_locale", False) # Are we on a supported PostgreSQL version? @@ -108,7 +118,7 @@ def check_database(self, db_conn, allow_outdated_version: bool = False): ctype, ) - def check_new_database(self, txn): + def check_new_database(self, txn: Cursor) -> None: """Gets called when setting up a brand new database. This allows us to apply stricter checks on new databases versus existing database. """ @@ -129,10 +139,10 @@ def check_new_database(self, txn): "See docs/postgres.md for more information." % ("\n".join(errors)) ) - def convert_param_style(self, sql): + def convert_param_style(self, sql: str) -> str: return sql.replace("?", "%s") - def on_new_connection(self, db_conn): + def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None: db_conn.set_isolation_level(self.default_isolation_level) # Set the bytea output to escape, vs the default of hex @@ -149,14 +159,14 @@ def on_new_connection(self, db_conn): db_conn.commit() @property - def can_native_upsert(self): + def can_native_upsert(self) -> bool: """ Can we use native UPSERTs? """ return True @property - def supports_using_any_list(self): + def supports_using_any_list(self) -> bool: """Do we support using `a = ANY(?)` and passing a list""" return True @@ -165,27 +175,25 @@ def supports_returning(self) -> bool: """Do we support the `RETURNING` clause in insert/update/delete?""" return True - def is_deadlock(self, error): - if isinstance(error, self.module.DatabaseError): + def is_deadlock(self, error: Exception) -> bool: + import psycopg2.extensions + + if isinstance(error, psycopg2.DatabaseError): # https://www.postgresql.org/docs/current/static/errcodes-appendix.html # "40001" serialization_failure # "40P01" deadlock_detected return error.pgcode in ["40001", "40P01"] return False - def is_connection_closed(self, conn): + def is_connection_closed(self, conn: "psycopg2.connection") -> bool: return bool(conn.closed) - def lock_table(self, txn, table): + def lock_table(self, txn: Cursor, table: str) -> None: txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,)) @property - def server_version(self): - """Returns a string giving the server version. For example: '8.1.5' - - Returns: - string - """ + def server_version(self) -> str: + """Returns a string giving the server version. For example: '8.1.5'.""" # note that this is a bit of a hack because it relies on check_database # having been called. Still, that should be a safe bet here. numver = self._version @@ -197,17 +205,21 @@ def server_version(self): else: return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100) - def in_transaction(self, conn: Connection) -> bool: - return conn.status != self.module.extensions.STATUS_READY # type: ignore + def in_transaction(self, conn: "psycopg2.connection") -> bool: + import psycopg2.extensions + + return conn.status != psycopg2.extensions.STATUS_READY - def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool): - return conn.set_session(autocommit=autocommit) # type: ignore + def attempt_to_set_autocommit( + self, conn: "psycopg2.connection", autocommit: bool + ) -> None: + return conn.set_session(autocommit=autocommit) def attempt_to_set_isolation_level( - self, conn: Connection, isolation_level: Optional[int] - ): + self, conn: "psycopg2.connection", isolation_level: Optional[int] + ) -> None: if isolation_level is None: isolation_level = self.default_isolation_level else: isolation_level = self.isolation_level_map[isolation_level] - return conn.set_isolation_level(isolation_level) # type: ignore + return conn.set_isolation_level(isolation_level) diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 6c19e55999bd..621f2c5efe28 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -12,21 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. import platform +import sqlite3 import struct import threading -import typing -from typing import Optional +from typing import TYPE_CHECKING, Any, List, Mapping, Optional from synapse.storage.engines import BaseDatabaseEngine -from synapse.storage.types import Connection +from synapse.storage.types import Cursor -if typing.TYPE_CHECKING: - import sqlite3 # noqa: F401 +if TYPE_CHECKING: + from synapse.storage.database import LoggingDatabaseConnection -class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): - def __init__(self, database_module, database_config): - super().__init__(database_module, database_config) +class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]): + def __init__(self, database_config: Mapping[str, Any]): + super().__init__(sqlite3, database_config) database = database_config.get("args", {}).get("database") self._is_in_memory = database in ( @@ -37,7 +37,7 @@ def __init__(self, database_module, database_config): if platform.python_implementation() == "PyPy": # pypy's sqlite3 module doesn't handle bytearrays, convert them # back to bytes. - database_module.register_adapter(bytearray, lambda array: bytes(array)) + sqlite3.register_adapter(bytearray, lambda array: bytes(array)) # The current max state_group, or None if we haven't looked # in the DB yet. @@ -49,41 +49,43 @@ def single_threaded(self) -> bool: return True @property - def can_native_upsert(self): + def can_native_upsert(self) -> bool: """ Do we support native UPSERTs? This requires SQLite3 3.24+, plus some more work we haven't done yet to tell what was inserted vs updated. """ - return self.module.sqlite_version_info >= (3, 24, 0) + return sqlite3.sqlite_version_info >= (3, 24, 0) @property - def supports_using_any_list(self): + def supports_using_any_list(self) -> bool: """Do we support using `a = ANY(?)` and passing a list""" return False @property def supports_returning(self) -> bool: """Do we support the `RETURNING` clause in insert/update/delete?""" - return self.module.sqlite_version_info >= (3, 35, 0) + return sqlite3.sqlite_version_info >= (3, 35, 0) - def check_database(self, db_conn, allow_outdated_version: bool = False): + def check_database( + self, db_conn: sqlite3.Connection, allow_outdated_version: bool = False + ) -> None: if not allow_outdated_version: - version = self.module.sqlite_version_info + version = sqlite3.sqlite_version_info # Synapse is untested against older SQLite versions, and we don't want # to let users upgrade to a version of Synapse with broken support for their # sqlite version, because it risks leaving them with a half-upgraded db. if version < (3, 22, 0): raise RuntimeError("Synapse requires sqlite 3.22 or above.") - def check_new_database(self, txn): + def check_new_database(self, txn: Cursor) -> None: """Gets called when setting up a brand new database. This allows us to apply stricter checks on new databases versus existing database. """ - def convert_param_style(self, sql): + def convert_param_style(self, sql: str) -> str: return sql - def on_new_connection(self, db_conn): + def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None: # We need to import here to avoid an import loop. from synapse.storage.prepare_database import prepare_database @@ -97,48 +99,46 @@ def on_new_connection(self, db_conn): db_conn.execute("PRAGMA foreign_keys = ON;") db_conn.commit() - def is_deadlock(self, error): + def is_deadlock(self, error: Exception) -> bool: return False - def is_connection_closed(self, conn): + def is_connection_closed(self, conn: sqlite3.Connection) -> bool: return False - def lock_table(self, txn, table): + def lock_table(self, txn: Cursor, table: str) -> None: return @property - def server_version(self): - """Gets a string giving the server version. For example: '3.22.0' + def server_version(self) -> str: + """Gets a string giving the server version. For example: '3.22.0'.""" + return "%i.%i.%i" % sqlite3.sqlite_version_info - Returns: - string - """ - return "%i.%i.%i" % self.module.sqlite_version_info - - def in_transaction(self, conn: Connection) -> bool: - return conn.in_transaction # type: ignore + def in_transaction(self, conn: sqlite3.Connection) -> bool: + return conn.in_transaction - def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool): + def attempt_to_set_autocommit( + self, conn: sqlite3.Connection, autocommit: bool + ) -> None: # Twisted doesn't let us set attributes on the connections, so we can't # set the connection to autocommit mode. pass def attempt_to_set_isolation_level( - self, conn: Connection, isolation_level: Optional[int] - ): - # All transactions are SERIALIZABLE by default in sqllite + self, conn: sqlite3.Connection, isolation_level: Optional[int] + ) -> None: + # All transactions are SERIALIZABLE by default in sqlite pass # Following functions taken from: https://github.com/coleifer/peewee -def _parse_match_info(buf): +def _parse_match_info(buf: bytes) -> List[int]: bufsize = len(buf) return [struct.unpack("@I", buf[i : i + 4])[0] for i in range(0, bufsize, 4)] -def _rank(raw_match_info): +def _rank(raw_match_info: bytes) -> float: """Handle match_info called w/default args 'pcx' - based on the example rank function http://sqlite.org/fts3.html#appendix_a """ diff --git a/synapse/storage/types.py b/synapse/storage/types.py index 40536c183005..0031df1e0649 100644 --- a/synapse/storage/types.py +++ b/synapse/storage/types.py @@ -94,3 +94,73 @@ def __exit__( traceback: Optional[TracebackType], ) -> Optional[bool]: ... + + +class DBAPI2Module(Protocol): + """The module-level attributes that we use from PEP 249. + + This is NOT a comprehensive stub for the entire DBAPI2.""" + + __name__: str + + # Exceptions. See https://peps.python.org/pep-0249/#exceptions + + # For our specific drivers: + # - Python's sqlite3 module doesn't contains the same descriptions as the + # DBAPI2 spec, see https://docs.python.org/3/library/sqlite3.html#exceptions + # - Psycopg2 maps every Postgres error code onto a unique exception class which + # extends from this hierarchy. See + # https://docs.python.org/3/library/sqlite3.html?highlight=sqlite3#exceptions + # https://www.postgresql.org/docs/current/errcodes-appendix.html#ERRCODES-TABLE + Warning: Type[Exception] + Error: Type[Exception] + + # Errors are divided into `InterfaceError`s (something went wrong in the database + # driver) and `DatabaseError`s (something went wrong in the database). These are + # both subclasses of `Error`, but we can't currently express this in type + # annotations due to https://github.com/python/mypy/issues/8397 + InterfaceError: Type[Exception] + DatabaseError: Type[Exception] + + # Everything below is a subclass of `DatabaseError`. + + # Roughly: the database rejected a nonsensical value. Examples: + # - An integer was too big for its data type. + # - An invalid date time was provided. + # - A string contained a null code point. + DataError: Type[Exception] + + # Roughly: something went wrong in the database, but it's not within the application + # programmer's control. Examples: + # - We failed to establish a connection to the database. + # - The connection to the database was lost. + # - A deadlock was detected. + # - A serialisation failure occurred. + # - The database ran out of resources, such as storage, memory, connections, etc. + # - The database encountered an error from the operating system. + OperationalError: Type[Exception] + + # Roughly: we've given the database data which breaks a rule we asked it to enforce. + # Examples: + # - Stop, criminal scum! You violated the foreign key constraint + # - Also check constraints, non-null constraints, etc. + IntegrityError: Type[Exception] + + # Roughly: something went wrong within the database server itself. + InternalError: Type[Exception] + + # Roughly: the application did something silly that needs to be fixed. Examples: + # - We don't have permissions to do something. + # - We tried to create a table with duplicate column names. + # - We tried to use a reserved name. + # - We referred to a column that doesn't exist. + ProgrammingError: Type[Exception] + + # Roughly: we've tried to do something that this database doesn't support. + NotSupportedError: Type[Exception] + + def connect(self, **parameters: object) -> Connection: + ... + + +__all__ = ["Cursor", "Connection", "DBAPI2Module"]