Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert additional database stores to async/await #8045

Merged
merged 7 commits into from
Aug 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8045.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
54 changes: 24 additions & 30 deletions synapse/storage/databases/main/client_ips.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
# limitations under the License.

import logging

from twisted.internet import defer
from typing import Dict, Optional, Tuple

from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
Expand Down Expand Up @@ -82,21 +81,19 @@ def __init__(self, database: DatabasePool, db_conn, hs):
"devices_last_seen", self._devices_last_seen_update
)

@defer.inlineCallbacks
def _remove_user_ip_nonunique(self, progress, batch_size):
async def _remove_user_ip_nonunique(self, progress, batch_size):
def f(conn):
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS user_ips_user_ip")
txn.close()

yield self.db_pool.runWithConnection(f)
yield self.db_pool.updates._end_background_update(
await self.db_pool.runWithConnection(f)
await self.db_pool.updates._end_background_update(
"user_ips_drop_nonunique_index"
)
return 1

@defer.inlineCallbacks
def _analyze_user_ip(self, progress, batch_size):
async def _analyze_user_ip(self, progress, batch_size):
# Background update to analyze user_ips table before we run the
# deduplication background update. The table may not have been analyzed
# for ages due to the table locks.
Expand All @@ -106,14 +103,13 @@ def _analyze_user_ip(self, progress, batch_size):
def user_ips_analyze(txn):
txn.execute("ANALYZE user_ips")

yield self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze)
await self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze)

yield self.db_pool.updates._end_background_update("user_ips_analyze")
await self.db_pool.updates._end_background_update("user_ips_analyze")

return 1

@defer.inlineCallbacks
def _remove_user_ip_dupes(self, progress, batch_size):
async def _remove_user_ip_dupes(self, progress, batch_size):
# This works function works by scanning the user_ips table in batches
# based on `last_seen`. For each row in a batch it searches the rest of
# the table to see if there are any duplicates, if there are then they
Expand All @@ -140,7 +136,7 @@ def get_last_seen(txn):
return None

# Get a last seen that has roughly `batch_size` since `begin_last_seen`
end_last_seen = yield self.db_pool.runInteraction(
end_last_seen = await self.db_pool.runInteraction(
"user_ips_dups_get_last_seen", get_last_seen
)

Expand Down Expand Up @@ -275,15 +271,14 @@ def remove(txn):
txn, "user_ips_remove_dupes", {"last_seen": end_last_seen}
)

yield self.db_pool.runInteraction("user_ips_dups_remove", remove)
await self.db_pool.runInteraction("user_ips_dups_remove", remove)

if last:
yield self.db_pool.updates._end_background_update("user_ips_remove_dupes")
await self.db_pool.updates._end_background_update("user_ips_remove_dupes")

return batch_size

@defer.inlineCallbacks
def _devices_last_seen_update(self, progress, batch_size):
async def _devices_last_seen_update(self, progress, batch_size):
"""Background update to insert last seen info into devices table
"""

Expand Down Expand Up @@ -346,12 +341,12 @@ def _devices_last_seen_update_txn(txn):

return len(rows)

updated = yield self.db_pool.runInteraction(
updated = await self.db_pool.runInteraction(
"_devices_last_seen_update", _devices_last_seen_update_txn
)

if not updated:
yield self.db_pool.updates._end_background_update("devices_last_seen")
await self.db_pool.updates._end_background_update("devices_last_seen")

return updated

Expand Down Expand Up @@ -460,25 +455,25 @@ def _update_client_ips_batch_txn(self, txn, to_update):
# Failed to upsert, log and continue
logger.error("Failed to insert client IP %r: %r", entry, e)

@defer.inlineCallbacks
def get_last_client_ip_by_device(self, user_id, device_id):
async def get_last_client_ip_by_device(
self, user_id: str, device_id: Optional[str]
) -> Dict[Tuple[str, str], dict]:
"""For each device_id listed, give the user_ip it was last seen on

Args:
user_id (str)
device_id (str): If None fetches all devices for the user
user_id: The user to fetch devices for.
device_id: If None fetches all devices for the user

Returns:
defer.Deferred: resolves to a dict, where the keys
are (user_id, device_id) tuples. The values are also dicts, with
keys giving the column names
A dictionary mapping a tuple of (user_id, device_id) to dicts, with
keys giving the column names from the devices table.
"""

keyvalues = {"user_id": user_id}
if device_id is not None:
keyvalues["device_id"] = device_id

res = yield self.db_pool.simple_select_list(
res = await self.db_pool.simple_select_list(
table="devices",
keyvalues=keyvalues,
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
Expand All @@ -500,8 +495,7 @@ def get_last_client_ip_by_device(self, user_id, device_id):
}
return ret

@defer.inlineCallbacks
def get_user_ip_and_agents(self, user):
async def get_user_ip_and_agents(self, user):
user_id = user.to_string()
results = {}

Expand All @@ -511,7 +505,7 @@ def get_user_ip_and_agents(self, user):
user_agent, _, last_seen = self._batch_row_update[key]
results[(access_token, ip)] = (user_agent, last_seen)

rows = yield self.db_pool.simple_select_list(
rows = await self.db_pool.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "last_seen"],
Expand Down
69 changes: 35 additions & 34 deletions synapse/storage/databases/main/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
import logging
import re
from collections import namedtuple

from twisted.internet import defer
from typing import List, Optional

from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
Expand Down Expand Up @@ -114,8 +113,7 @@ def __init__(self, database: DatabasePool, db_conn, hs):
self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search
)

@defer.inlineCallbacks
def _background_reindex_search(self, progress, batch_size):
async def _background_reindex_search(self, progress, batch_size):
# we work through the events table from highest stream id to lowest
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
Expand Down Expand Up @@ -206,19 +204,18 @@ def reindex_search_txn(txn):

return len(event_search_rows)

result = yield self.db_pool.runInteraction(
result = await self.db_pool.runInteraction(
self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn
)

if not result:
yield self.db_pool.updates._end_background_update(
await self.db_pool.updates._end_background_update(
self.EVENT_SEARCH_UPDATE_NAME
)

return result

@defer.inlineCallbacks
def _background_reindex_gin_search(self, progress, batch_size):
async def _background_reindex_gin_search(self, progress, batch_size):
"""This handles old synapses which used GIST indexes, if any;
converting them back to be GIN as per the actual schema.
"""
Expand Down Expand Up @@ -255,15 +252,14 @@ def create_index(conn):
conn.set_session(autocommit=False)

if isinstance(self.database_engine, PostgresEngine):
yield self.db_pool.runWithConnection(create_index)
await self.db_pool.runWithConnection(create_index)

yield self.db_pool.updates._end_background_update(
await self.db_pool.updates._end_background_update(
self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME
)
return 1

@defer.inlineCallbacks
def _background_reindex_search_order(self, progress, batch_size):
async def _background_reindex_search_order(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
Expand All @@ -288,12 +284,12 @@ def create_index(conn):
)
conn.set_session(autocommit=False)

yield self.db_pool.runWithConnection(create_index)
await self.db_pool.runWithConnection(create_index)

pg = dict(progress)
pg["have_added_indexes"] = True

yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
self.db_pool.updates._background_update_progress_txn,
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
Expand Down Expand Up @@ -331,12 +327,12 @@ def reindex_search_txn(txn):

return len(rows), True

num_rows, finished = yield self.db_pool.runInteraction(
num_rows, finished = await self.db_pool.runInteraction(
self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn
)

if not finished:
yield self.db_pool.updates._end_background_update(
await self.db_pool.updates._end_background_update(
self.EVENT_SEARCH_ORDER_UPDATE_NAME
)

Expand All @@ -347,8 +343,7 @@ class SearchStore(SearchBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super(SearchStore, self).__init__(database, db_conn, hs)

@defer.inlineCallbacks
def search_msgs(self, room_ids, search_term, keys):
async def search_msgs(self, room_ids, search_term, keys):
"""Performs a full text search over events with given keys.

Args:
Expand Down Expand Up @@ -425,15 +420,15 @@ def search_msgs(self, room_ids, search_term, keys):
# entire table from the database.
sql += " ORDER BY rank DESC LIMIT 500"

results = yield self.db_pool.execute(
results = await self.db_pool.execute(
"search_msgs", self.db_pool.cursor_to_dict, sql, *args
)

results = list(filter(lambda row: row["room_id"] in room_ids, results))

# We set redact_behaviour to BLOCK here to prevent redacted events being returned in
# search results (which is a data leak)
events = yield self.get_events_as_list(
events = await self.get_events_as_list(
[r["event_id"] for r in results],
redact_behaviour=EventRedactBehaviour.BLOCK,
)
Expand All @@ -442,11 +437,11 @@ def search_msgs(self, room_ids, search_term, keys):

highlights = None
if isinstance(self.database_engine, PostgresEngine):
highlights = yield self._find_highlights_in_postgres(search_query, events)
highlights = await self._find_highlights_in_postgres(search_query, events)

count_sql += " GROUP BY room_id"

count_results = yield self.db_pool.execute(
count_results = await self.db_pool.execute(
"search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
)

Expand All @@ -462,19 +457,25 @@ def search_msgs(self, room_ids, search_term, keys):
"count": count,
}

@defer.inlineCallbacks
def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None):
async def search_rooms(
self,
room_ids: List[str],
search_term: str,
keys: List[str],
limit,
pagination_token: Optional[str] = None,
) -> List[dict]:
"""Performs a full text search over events with given keys.

Args:
room_id (list): The room_ids to search in
search_term (str): Search term to search for
keys (list): List of keys to search in, currently supports
"content.body", "content.name", "content.topic"
pagination_token (str): A pagination token previously returned
room_ids: The room_ids to search in
search_term: Search term to search for
keys: List of keys to search in, currently supports "content.body",
"content.name", "content.topic"
pagination_token: A pagination token previously returned

Returns:
list of dicts
Each match as a dictionary.
"""
clauses = []

Expand Down Expand Up @@ -577,15 +578,15 @@ def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None

args.append(limit)

results = yield self.db_pool.execute(
results = await self.db_pool.execute(
"search_rooms", self.db_pool.cursor_to_dict, sql, *args
)

results = list(filter(lambda row: row["room_id"] in room_ids, results))

# We set redact_behaviour to BLOCK here to prevent redacted events being returned in
# search results (which is a data leak)
events = yield self.get_events_as_list(
events = await self.get_events_as_list(
[r["event_id"] for r in results],
redact_behaviour=EventRedactBehaviour.BLOCK,
)
Expand All @@ -594,11 +595,11 @@ def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None

highlights = None
if isinstance(self.database_engine, PostgresEngine):
highlights = yield self._find_highlights_in_postgres(search_query, events)
highlights = await self._find_highlights_in_postgres(search_query, events)

count_sql += " GROUP BY room_id"

count_results = yield self.db_pool.execute(
count_results = await self.db_pool.execute(
"search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
)

Expand Down
7 changes: 2 additions & 5 deletions synapse/storage/databases/main/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

from unpaddedbase64 import encode_base64

from twisted.internet import defer

from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList

Expand All @@ -40,9 +38,8 @@ def f(txn):

return self.db_pool.runInteraction("get_event_reference_hashes", f)

@defer.inlineCallbacks
def add_event_hashes(self, event_ids):
hashes = yield self.get_event_reference_hashes(event_ids)
async def add_event_hashes(self, event_ids):
hashes = await self.get_event_reference_hashes(event_ids)
hashes = {
e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"}
for e_id, h in hashes.items()
Expand Down
Loading