Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add some type annotations in synapse.storage #6987

Merged
merged 7 commits into from
Feb 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/6987.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add some type annotations to the database storage classes.
143 changes: 84 additions & 59 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import sys
import time
from typing import Iterable, Tuple
from time import monotonic as monotonic_time
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple

from six import iteritems, iterkeys, itervalues
from six.moves import intern, range
Expand All @@ -32,24 +32,14 @@
from synapse.logging.context import LoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
from synapse.util.stringutils import exception_to_unicode

# import a function which will return a monotonic time, in seconds
try:
# on python 3, use time.monotonic, since time.clock can go backwards
from time import monotonic as monotonic_time
except ImportError:
# ... but python 2 doesn't have it
from time import clock as monotonic_time

logger = logging.getLogger(__name__)

try:
MAX_TXN_ID = sys.maxint - 1
except AttributeError:
# python 3 does not have a maximum int value
MAX_TXN_ID = 2 ** 63 - 1
# python 3 does not have a maximum int value
MAX_TXN_ID = 2 ** 63 - 1

sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn")
Expand Down Expand Up @@ -77,7 +67,7 @@


def make_pool(
reactor, db_config: DatabaseConnectionConfig, engine
reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
) -> adbapi.ConnectionPool:
"""Get the connection pool for the database.
"""
Expand All @@ -90,7 +80,9 @@ def make_pool(
)


def make_conn(db_config: DatabaseConnectionConfig, engine):
def make_conn(
db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
) -> Connection:
"""Make a new connection to the database and return it.

Returns:
Expand All @@ -107,20 +99,27 @@ def make_conn(db_config: DatabaseConnectionConfig, engine):
return db_conn


class LoggingTransaction(object):
# The type of entry which goes on our after_callbacks and exception_callbacks lists.
#
# Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
# that mypy sees the type but the runtime python doesn't.
_CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]]


class LoggingTransaction:
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
method.

Args:
txn: The database transcation object to wrap.
name (str): The name of this transactions for logging.
database_engine (Sqlite3Engine|PostgresEngine)
after_callbacks(list|None): A list that callbacks will be appended to
name: The name of this transactions for logging.
database_engine
after_callbacks: A list that callbacks will be appended to
that have been added by `call_after` which should be run on
successful completion of the transaction. None indicates that no
callbacks should be allowed to be scheduled to run.
exception_callbacks(list|None): A list that callbacks will be appended
exception_callbacks: A list that callbacks will be appended
to that have been added by `call_on_exception` which should be run
if transaction ends with an error. None indicates that no callbacks
should be allowed to be scheduled to run.
Expand All @@ -135,46 +134,67 @@ class LoggingTransaction(object):
]

def __init__(
self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None
self,
txn: Cursor,
name: str,
database_engine: BaseDatabaseEngine,
after_callbacks: Optional[List[_CallbackListEntry]] = None,
exception_callbacks: Optional[List[_CallbackListEntry]] = None,
):
object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name)
object.__setattr__(self, "database_engine", database_engine)
object.__setattr__(self, "after_callbacks", after_callbacks)
object.__setattr__(self, "exception_callbacks", exception_callbacks)
self.txn = txn
self.name = name
self.database_engine = database_engine
self.after_callbacks = after_callbacks
self.exception_callbacks = exception_callbacks

def call_after(self, callback, *args, **kwargs):
def call_after(self, callback: "Callable[..., None]", *args, **kwargs):
"""Call the given callback on the main twisted thread after the
transaction has finished. Used to invalidate the caches on the
correct thread.
"""
# if self.after_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that
# is not the case.
assert self.after_callbacks is not None
richvdh marked this conversation as resolved.
Show resolved Hide resolved
self.after_callbacks.append((callback, args, kwargs))

def call_on_exception(self, callback, *args, **kwargs):
def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs):
# if self.exception_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that
# is not the case.
assert self.exception_callbacks is not None
self.exception_callbacks.append((callback, args, kwargs))

def __getattr__(self, name):
return getattr(self.txn, name)
def fetchall(self) -> List[Tuple]:
return self.txn.fetchall()

def __setattr__(self, name, value):
setattr(self.txn, name, value)
def fetchone(self) -> Tuple:
return self.txn.fetchone()

def __iter__(self):
def __iter__(self) -> Iterator[Tuple]:
return self.txn.__iter__()

@property
def rowcount(self) -> int:
return self.txn.rowcount

@property
def description(self) -> Any:
return self.txn.description

def execute_batch(self, sql, args):
if isinstance(self.database_engine, PostgresEngine):
from psycopg2.extras import execute_batch
from psycopg2.extras import execute_batch # type: ignore

self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
else:
for val in args:
self.execute(sql, val)

def execute(self, sql, *args):
def execute(self, sql: str, *args: Any):
self._do_execute(self.txn.execute, sql, *args)

def executemany(self, sql, *args):
def executemany(self, sql: str, *args: Any):
self._do_execute(self.txn.executemany, sql, *args)

def _make_sql_one_line(self, sql):
Expand Down Expand Up @@ -207,6 +227,9 @@ def _do_execute(self, func, sql, *args):
sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
sql_query_timer.labels(sql.split()[0]).observe(secs)

def close(self):
self.txn.close()


class PerformanceCounters(object):
def __init__(self):
Expand Down Expand Up @@ -251,17 +274,19 @@ class Database(object):

_TXN_ID = 0

def __init__(self, hs, database_config: DatabaseConnectionConfig, engine):
def __init__(
self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
):
self.hs = hs
self._clock = hs.get_clock()
self._database_config = database_config
self._db_pool = make_pool(hs.get_reactor(), database_config, engine)

self.updates = BackgroundUpdater(hs, self)

self._previous_txn_total_time = 0
self._current_txn_total_time = 0
self._previous_loop_ts = 0
self._previous_txn_total_time = 0.0
self._current_txn_total_time = 0.0
self._previous_loop_ts = 0.0

# TODO(paul): These can eventually be removed once the metrics code
# is running in mainline, and we have some nice monitoring frontends
Expand Down Expand Up @@ -463,23 +488,23 @@ def new_transaction(
sql_txn_timer.labels(desc).observe(duration)

@defer.inlineCallbacks
def runInteraction(self, desc, func, *args, **kwargs):
def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any):
"""Starts a transaction on the database and runs a given function

Arguments:
desc (str): description of the transaction, for logging and metrics
func (func): callback function, which will be called with a
desc: description of the transaction, for logging and metrics
func: callback function, which will be called with a
database transaction (twisted.enterprise.adbapi.Transaction) as
its first argument, followed by `args` and `kwargs`.

args (list): positional args to pass to `func`
kwargs (dict): named args to pass to `func`
args: positional args to pass to `func`
kwargs: named args to pass to `func`

Returns:
Deferred: The result of func
"""
after_callbacks = []
exception_callbacks = []
after_callbacks = [] # type: List[_CallbackListEntry]
exception_callbacks = [] # type: List[_CallbackListEntry]

if LoggingContext.current_context() == LoggingContext.sentinel:
logger.warning("Starting db txn '%s' from sentinel context", desc)
Expand All @@ -505,15 +530,15 @@ def runInteraction(self, desc, func, *args, **kwargs):
return result

@defer.inlineCallbacks
def runWithConnection(self, func, *args, **kwargs):
def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any):
"""Wraps the .runWithConnection() method on the underlying db_pool.

Arguments:
func (func): callback function, which will be called with a
func: callback function, which will be called with a
database connection (twisted.enterprise.adbapi.Connection) as
its first argument, followed by `args` and `kwargs`.
args (list): positional args to pass to `func`
kwargs (dict): named args to pass to `func`
args: positional args to pass to `func`
kwargs: named args to pass to `func`

Returns:
Deferred: The result of func
Expand Down Expand Up @@ -800,7 +825,7 @@ def _getwhere(key):
return False

# We didn't find any existing rows, so insert a new one
allvalues = {}
allvalues = {} # type: Dict[str, Any]
allvalues.update(keyvalues)
allvalues.update(values)
allvalues.update(insertion_values)
Expand Down Expand Up @@ -829,7 +854,7 @@ def simple_upsert_txn_native_upsert(
Returns:
None
"""
allvalues = {}
allvalues = {} # type: Dict[str, Any]
allvalues.update(keyvalues)
allvalues.update(insertion_values)

Expand Down Expand Up @@ -916,7 +941,7 @@ def simple_upsert_many_txn_native_upsert(
Returns:
None
"""
allnames = []
allnames = [] # type: List[str]
allnames.extend(key_names)
allnames.extend(value_names)

Expand Down Expand Up @@ -1100,7 +1125,7 @@ def simple_select_many_batch(
keyvalues : dict of column names and values to select the rows with
retcols : list of strings giving the names of the columns to return
"""
results = []
results = [] # type: List[Dict[str, Any]]

if not iterable:
return results
Expand Down Expand Up @@ -1439,7 +1464,7 @@ def simple_select_list_paginate_txn(
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")

where_clause = "WHERE " if filters or keyvalues else ""
arg_list = []
arg_list = [] # type: List[Any]
if filters:
where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
arg_list += list(filters.values())
Expand Down
28 changes: 15 additions & 13 deletions synapse/storage/engines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
import platform

from ._base import IncorrectDatabaseSetup
from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
from .postgres import PostgresEngine
from .sqlite import Sqlite3Engine

SUPPORTED_MODULE = {"sqlite3": Sqlite3Engine, "psycopg2": PostgresEngine}


def create_engine(database_config):
def create_engine(database_config) -> BaseDatabaseEngine:
name = database_config["name"]
engine_class = SUPPORTED_MODULE.get(name, None)

if engine_class:
if name == "sqlite3":
import sqlite3

return Sqlite3Engine(sqlite3, database_config)

if name == "psycopg2":
# pypy requires psycopg2cffi rather than psycopg2
if name == "psycopg2" and platform.python_implementation() == "PyPy":
name = "psycopg2cffi"
module = importlib.import_module(name)
return engine_class(module, database_config)
if platform.python_implementation() == "PyPy":
import psycopg2cffi as psycopg2 # type: ignore
else:
import psycopg2 # type: ignore

return PostgresEngine(psycopg2, database_config)

raise RuntimeError("Unsupported database engine '%s'" % (name,))


__all__ = ["create_engine", "IncorrectDatabaseSetup"]
__all__ = ["create_engine", "BaseDatabaseEngine", "IncorrectDatabaseSetup"]
Loading