Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Convert additional database stores to async/await (#8045)
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep authored Aug 7, 2020
1 parent 1048ed2 commit f3fe696
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 152 deletions.
1 change: 1 addition & 0 deletions changelog.d/8045.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
54 changes: 24 additions & 30 deletions synapse/storage/databases/main/client_ips.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
# limitations under the License.

import logging

from twisted.internet import defer
from typing import Dict, Optional, Tuple

from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
Expand Down Expand Up @@ -82,21 +81,19 @@ def __init__(self, database: DatabasePool, db_conn, hs):
"devices_last_seen", self._devices_last_seen_update
)

@defer.inlineCallbacks
def _remove_user_ip_nonunique(self, progress, batch_size):
async def _remove_user_ip_nonunique(self, progress, batch_size):
def f(conn):
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS user_ips_user_ip")
txn.close()

yield self.db_pool.runWithConnection(f)
yield self.db_pool.updates._end_background_update(
await self.db_pool.runWithConnection(f)
await self.db_pool.updates._end_background_update(
"user_ips_drop_nonunique_index"
)
return 1

@defer.inlineCallbacks
def _analyze_user_ip(self, progress, batch_size):
async def _analyze_user_ip(self, progress, batch_size):
# Background update to analyze user_ips table before we run the
# deduplication background update. The table may not have been analyzed
# for ages due to the table locks.
Expand All @@ -106,14 +103,13 @@ def _analyze_user_ip(self, progress, batch_size):
def user_ips_analyze(txn):
txn.execute("ANALYZE user_ips")

yield self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze)
await self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze)

yield self.db_pool.updates._end_background_update("user_ips_analyze")
await self.db_pool.updates._end_background_update("user_ips_analyze")

return 1

@defer.inlineCallbacks
def _remove_user_ip_dupes(self, progress, batch_size):
async def _remove_user_ip_dupes(self, progress, batch_size):
# This works function works by scanning the user_ips table in batches
# based on `last_seen`. For each row in a batch it searches the rest of
# the table to see if there are any duplicates, if there are then they
Expand All @@ -140,7 +136,7 @@ def get_last_seen(txn):
return None

# Get a last seen that has roughly `batch_size` since `begin_last_seen`
end_last_seen = yield self.db_pool.runInteraction(
end_last_seen = await self.db_pool.runInteraction(
"user_ips_dups_get_last_seen", get_last_seen
)

Expand Down Expand Up @@ -275,15 +271,14 @@ def remove(txn):
txn, "user_ips_remove_dupes", {"last_seen": end_last_seen}
)

yield self.db_pool.runInteraction("user_ips_dups_remove", remove)
await self.db_pool.runInteraction("user_ips_dups_remove", remove)

if last:
yield self.db_pool.updates._end_background_update("user_ips_remove_dupes")
await self.db_pool.updates._end_background_update("user_ips_remove_dupes")

return batch_size

@defer.inlineCallbacks
def _devices_last_seen_update(self, progress, batch_size):
async def _devices_last_seen_update(self, progress, batch_size):
"""Background update to insert last seen info into devices table
"""

Expand Down Expand Up @@ -346,12 +341,12 @@ def _devices_last_seen_update_txn(txn):

return len(rows)

updated = yield self.db_pool.runInteraction(
updated = await self.db_pool.runInteraction(
"_devices_last_seen_update", _devices_last_seen_update_txn
)

if not updated:
yield self.db_pool.updates._end_background_update("devices_last_seen")
await self.db_pool.updates._end_background_update("devices_last_seen")

return updated

Expand Down Expand Up @@ -460,25 +455,25 @@ def _update_client_ips_batch_txn(self, txn, to_update):
# Failed to upsert, log and continue
logger.error("Failed to insert client IP %r: %r", entry, e)

@defer.inlineCallbacks
def get_last_client_ip_by_device(self, user_id, device_id):
async def get_last_client_ip_by_device(
self, user_id: str, device_id: Optional[str]
) -> Dict[Tuple[str, str], dict]:
"""For each device_id listed, give the user_ip it was last seen on
Args:
user_id (str)
device_id (str): If None fetches all devices for the user
user_id: The user to fetch devices for.
device_id: If None fetches all devices for the user
Returns:
defer.Deferred: resolves to a dict, where the keys
are (user_id, device_id) tuples. The values are also dicts, with
keys giving the column names
A dictionary mapping a tuple of (user_id, device_id) to dicts, with
keys giving the column names from the devices table.
"""

keyvalues = {"user_id": user_id}
if device_id is not None:
keyvalues["device_id"] = device_id

res = yield self.db_pool.simple_select_list(
res = await self.db_pool.simple_select_list(
table="devices",
keyvalues=keyvalues,
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
Expand All @@ -500,8 +495,7 @@ def get_last_client_ip_by_device(self, user_id, device_id):
}
return ret

@defer.inlineCallbacks
def get_user_ip_and_agents(self, user):
async def get_user_ip_and_agents(self, user):
user_id = user.to_string()
results = {}

Expand All @@ -511,7 +505,7 @@ def get_user_ip_and_agents(self, user):
user_agent, _, last_seen = self._batch_row_update[key]
results[(access_token, ip)] = (user_agent, last_seen)

rows = yield self.db_pool.simple_select_list(
rows = await self.db_pool.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "last_seen"],
Expand Down
69 changes: 35 additions & 34 deletions synapse/storage/databases/main/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
import logging
import re
from collections import namedtuple

from twisted.internet import defer
from typing import List, Optional

from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
Expand Down Expand Up @@ -114,8 +113,7 @@ def __init__(self, database: DatabasePool, db_conn, hs):
self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search
)

@defer.inlineCallbacks
def _background_reindex_search(self, progress, batch_size):
async def _background_reindex_search(self, progress, batch_size):
# we work through the events table from highest stream id to lowest
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
Expand Down Expand Up @@ -206,19 +204,18 @@ def reindex_search_txn(txn):

return len(event_search_rows)

result = yield self.db_pool.runInteraction(
result = await self.db_pool.runInteraction(
self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn
)

if not result:
yield self.db_pool.updates._end_background_update(
await self.db_pool.updates._end_background_update(
self.EVENT_SEARCH_UPDATE_NAME
)

return result

@defer.inlineCallbacks
def _background_reindex_gin_search(self, progress, batch_size):
async def _background_reindex_gin_search(self, progress, batch_size):
"""This handles old synapses which used GIST indexes, if any;
converting them back to be GIN as per the actual schema.
"""
Expand Down Expand Up @@ -255,15 +252,14 @@ def create_index(conn):
conn.set_session(autocommit=False)

if isinstance(self.database_engine, PostgresEngine):
yield self.db_pool.runWithConnection(create_index)
await self.db_pool.runWithConnection(create_index)

yield self.db_pool.updates._end_background_update(
await self.db_pool.updates._end_background_update(
self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME
)
return 1

@defer.inlineCallbacks
def _background_reindex_search_order(self, progress, batch_size):
async def _background_reindex_search_order(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
Expand All @@ -288,12 +284,12 @@ def create_index(conn):
)
conn.set_session(autocommit=False)

yield self.db_pool.runWithConnection(create_index)
await self.db_pool.runWithConnection(create_index)

pg = dict(progress)
pg["have_added_indexes"] = True

yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
self.db_pool.updates._background_update_progress_txn,
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
Expand Down Expand Up @@ -331,12 +327,12 @@ def reindex_search_txn(txn):

return len(rows), True

num_rows, finished = yield self.db_pool.runInteraction(
num_rows, finished = await self.db_pool.runInteraction(
self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn
)

if not finished:
yield self.db_pool.updates._end_background_update(
await self.db_pool.updates._end_background_update(
self.EVENT_SEARCH_ORDER_UPDATE_NAME
)

Expand All @@ -347,8 +343,7 @@ class SearchStore(SearchBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super(SearchStore, self).__init__(database, db_conn, hs)

@defer.inlineCallbacks
def search_msgs(self, room_ids, search_term, keys):
async def search_msgs(self, room_ids, search_term, keys):
"""Performs a full text search over events with given keys.
Args:
Expand Down Expand Up @@ -425,15 +420,15 @@ def search_msgs(self, room_ids, search_term, keys):
# entire table from the database.
sql += " ORDER BY rank DESC LIMIT 500"

results = yield self.db_pool.execute(
results = await self.db_pool.execute(
"search_msgs", self.db_pool.cursor_to_dict, sql, *args
)

results = list(filter(lambda row: row["room_id"] in room_ids, results))

# We set redact_behaviour to BLOCK here to prevent redacted events being returned in
# search results (which is a data leak)
events = yield self.get_events_as_list(
events = await self.get_events_as_list(
[r["event_id"] for r in results],
redact_behaviour=EventRedactBehaviour.BLOCK,
)
Expand All @@ -442,11 +437,11 @@ def search_msgs(self, room_ids, search_term, keys):

highlights = None
if isinstance(self.database_engine, PostgresEngine):
highlights = yield self._find_highlights_in_postgres(search_query, events)
highlights = await self._find_highlights_in_postgres(search_query, events)

count_sql += " GROUP BY room_id"

count_results = yield self.db_pool.execute(
count_results = await self.db_pool.execute(
"search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
)

Expand All @@ -462,19 +457,25 @@ def search_msgs(self, room_ids, search_term, keys):
"count": count,
}

@defer.inlineCallbacks
def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None):
async def search_rooms(
self,
room_ids: List[str],
search_term: str,
keys: List[str],
limit,
pagination_token: Optional[str] = None,
) -> List[dict]:
"""Performs a full text search over events with given keys.
Args:
room_id (list): The room_ids to search in
search_term (str): Search term to search for
keys (list): List of keys to search in, currently supports
"content.body", "content.name", "content.topic"
pagination_token (str): A pagination token previously returned
room_ids: The room_ids to search in
search_term: Search term to search for
keys: List of keys to search in, currently supports "content.body",
"content.name", "content.topic"
pagination_token: A pagination token previously returned
Returns:
list of dicts
Each match as a dictionary.
"""
clauses = []

Expand Down Expand Up @@ -577,15 +578,15 @@ def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None

args.append(limit)

results = yield self.db_pool.execute(
results = await self.db_pool.execute(
"search_rooms", self.db_pool.cursor_to_dict, sql, *args
)

results = list(filter(lambda row: row["room_id"] in room_ids, results))

# We set redact_behaviour to BLOCK here to prevent redacted events being returned in
# search results (which is a data leak)
events = yield self.get_events_as_list(
events = await self.get_events_as_list(
[r["event_id"] for r in results],
redact_behaviour=EventRedactBehaviour.BLOCK,
)
Expand All @@ -594,11 +595,11 @@ def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None

highlights = None
if isinstance(self.database_engine, PostgresEngine):
highlights = yield self._find_highlights_in_postgres(search_query, events)
highlights = await self._find_highlights_in_postgres(search_query, events)

count_sql += " GROUP BY room_id"

count_results = yield self.db_pool.execute(
count_results = await self.db_pool.execute(
"search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
)

Expand Down
7 changes: 2 additions & 5 deletions synapse/storage/databases/main/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

from unpaddedbase64 import encode_base64

from twisted.internet import defer

from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList

Expand All @@ -40,9 +38,8 @@ def f(txn):

return self.db_pool.runInteraction("get_event_reference_hashes", f)

@defer.inlineCallbacks
def add_event_hashes(self, event_ids):
hashes = yield self.get_event_reference_hashes(event_ids)
async def add_event_hashes(self, event_ids):
hashes = await self.get_event_reference_hashes(event_ids)
hashes = {
e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"}
for e_id, h in hashes.items()
Expand Down
Loading

0 comments on commit f3fe696

Please sign in to comment.