Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Refactor storage methods to retrieve to-device messages
Browse files Browse the repository at this point in the history
This commit refactors the previously rather duplicated 'get_new_messages_for_device' and
'get_new_messages' methods into one new private method with combined logic, and two small
public methods. The public methods expose the correct interface for querying to-device
messages for either a single device (where a limit can be used) and multiple devices
(where using a limit is infeasible).
  • Loading branch information
anoadragon453 committed Jan 25, 2022
1 parent 0ac079b commit 822e92a
Showing 1 changed file with 194 additions and 87 deletions.
281 changes: 194 additions & 87 deletions synapse/storage/databases/main/deviceinbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple, cast
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple, cast

from synapse.logging import issue9533_logger
from synapse.logging.opentracing import log_kv, set_tag, trace
Expand Down Expand Up @@ -137,134 +137,241 @@ def process_replication_rows(self, stream_name, instance_name, token, rows):
def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token()

async def get_new_messages(
async def get_messages_for_user_devices(
self,
user_ids: Collection[str],
from_stream_id: int,
to_stream_id: int,
) -> Dict[Tuple[str, str], List[JsonDict]]:
"""
Retrieve to-device messages for a given set of user IDs.
Retrieve to-device messages for a given set of users.
Only to-device messages with stream ids between the given boundaries
(from < X <= to) are returned.
Note that a stream ID can be shared by multiple copies of the same message with
different recipient devices. Each (device, message_content) tuple has their own
row in the device_inbox table.
Args:
user_ids: The users to retrieve to-device messages for.
from_stream_id: The lower boundary of stream id to filter with (exclusive).
to_stream_id: The upper boundary of stream id to filter with (inclusive).
Returns:
A list of to-device messages.
A dictionary of (user id, device id) -> list of to-device messages.
"""
# Bail out if none of these users have any messages
for user_id in user_ids:
if self._device_inbox_stream_cache.has_entity_changed(
user_id, from_stream_id
):
break
else:
return {}

def get_new_messages_txn(txn: LoggingTransaction):
# Build a query to select messages from any of the given users that are between
# the given stream id bounds
# We expect the stream ID returned by _get_new_device_messages to always
# return to_stream_id. So, no need to return it from this function.
user_id_device_id_to_messages, _ = await self._get_device_messages(
user_ids=user_ids,
from_stream_id=from_stream_id,
to_stream_id=to_stream_id,
)

# Scope to only the given users. We need to use this method as doing so is
# different across database engines.
many_clause_sql, many_clause_args = make_in_list_sql_clause(
self.database_engine, "user_id", user_ids
)
return user_id_device_id_to_messages

sql = f"""
SELECT user_id, device_id, message_json FROM device_inbox
WHERE {many_clause_sql}
AND ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
"""
async def get_messages_for_device(
self,
user_id: str,
device_id: str,
from_stream_id: int,
to_stream_id: int,
limit: int = 100,
) -> Tuple[List[JsonDict], int]:
"""
Retrieve to-device messages for a single user device.
txn.execute(sql, (*many_clause_args, from_stream_id, to_stream_id))
Only to-device messages with stream ids between the given boundaries
(from < X <= to) are returned.
# Create a dictionary of (user ID, device ID) -> list of messages that
# that device is meant to receive.
recipient_device_to_messages: Dict[Tuple[str, str], List[JsonDict]] = {}
Args:
user_id: The ID of the user to retrieve messages for.
device_id: The ID of the device to retrieve to-device messages for.
from_stream_id: The lower boundary of stream id to filter with (exclusive).
to_stream_id: The upper boundary of stream id to filter with (inclusive).
limit: A limit on the number of to-device messages returned.
for row in txn:
recipient_user_id = row[0]
recipient_device_id = row[1]
message_dict = db_to_json(row[2])
Returns:
A tuple containing:
* A dictionary of (user id, device id) -> list of to-device messages.
* The last-processed stream ID. Subsequent calls of this function with the
same device should pass this value as 'from_stream_id'.
"""
(
user_id_device_id_to_messages,
last_processed_stream_id,
) = await self._get_device_messages(
user_ids=[user_id],
device_ids=[device_id],
from_stream_id=from_stream_id,
to_stream_id=to_stream_id,
limit=limit,
)

recipient_device_to_messages.setdefault(
(recipient_user_id, recipient_device_id), []
).append(message_dict)
if not user_id_device_id_to_messages:
# There were no messages!
return [], to_stream_id

return recipient_device_to_messages
# Extract the messages, no need to return the user and device ID again
to_device_messages = list(user_id_device_id_to_messages.values())[0]

return await self.db_pool.runInteraction(
"get_new_messages", get_new_messages_txn
)
return to_device_messages, last_processed_stream_id

async def get_new_messages_for_device(
async def _get_device_messages(
self,
user_id: str,
device_id: Optional[str],
last_stream_id: int,
current_stream_id: int,
limit: int = 100,
) -> Tuple[List[dict], int]:
user_ids: Collection[str],
from_stream_id: int,
to_stream_id: int,
device_ids: Optional[Collection[str]] = None,
limit: Optional[int] = None,
) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
"""
Retrieve pending to-device messages for a collection of user devices.
Only to-device messages with stream ids between the given boundaries
(from < X <= to) are returned.
Note that a stream ID can be shared by multiple copies of the same message with
different recipient devices. Stream IDs are only unique in the context of a single
user ID / device ID pair Thus, applying a limit (of messages to return) when working
with a sliding window of stream IDs is only possible when querying messages of a
single user device.
Finally, note that device IDs are not unique across users.
Args:
user_id: The recipient user_id.
device_id: The recipient device_id.
last_stream_id: The last stream ID checked.
current_stream_id: The current position of the to device
message stream.
limit: The maximum number of messages to retrieve.
user_ids: The user IDs to filter device messages by.
from_stream_id: The lower boundary of stream id to filter with (exclusive).
to_stream_id: The upper boundary of stream id to filter with (inclusive).
device_ids: If provided, only messages destined for these device IDs will be returned.
If not provided, all device IDs for the given user IDs will be used.
limit: The maximum number of to-device messages to return. Can only be used when
passing a single user ID / device ID tuple.
Returns:
A tuple containing:
* A list of messages for the device.
* The max stream token of these messages. There may be more to retrieve
if the given limit was reached.
* A dict of (user_id, device_id) -> list of to-device messages
* The last-processed stream ID. If this is less than `to_stream_id`, then
there may be more messages to retrieve. If `limit` is not set, then this
is always equal to 'to_stream_id'.
"""
has_changed = self._device_inbox_stream_cache.has_entity_changed(
user_id, last_stream_id
)
if not has_changed:
return [], current_stream_id
# A limit can only be applied when querying for a single user ID / device ID tuple.
if limit:
if not device_ids:
raise AssertionError(
"Programming error: _get_new_device_messages was passed 'limit' "
"but not device_ids. This could lead to querying multiple user ID "
"/ device ID pairs, which is not compatible with 'limit'"
)

def get_new_messages_for_device_txn(txn):
sql = (
"SELECT stream_id, message_json FROM device_inbox"
" WHERE user_id = ? AND device_id = ?"
" AND ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC"
" LIMIT ?"
if len(user_ids) > 1 or len(device_ids) > 1:
raise AssertionError(
"Programming error: _get_new_device_messages was passed 'limit' "
"with >1 user id/device id pair"
)

user_ids_to_query: Set[str] = set()
device_ids_to_query: Set[str] = set()

if device_ids is not None:
# If a collection of device IDs were passed, use them to filter results.
# Otherwise, device IDs will be derived from the given collection of user IDs.
device_ids_to_query.update(device_ids)

# Determine which users have devices with pending messages
for user_id in user_ids:
if self._device_inbox_stream_cache.has_entity_changed(
user_id, from_stream_id
):
# This user has new messages sent to them. Query messages for them
user_ids_to_query.add(user_id)

def get_new_device_messages_txn(txn: LoggingTransaction):
# Build a query to select messages from any of the given devices that
# are between the given stream id bounds.

# If a list of device IDs was not provided, retrieve all devices IDs
# for the given users. We explicitly do not query hidden devices, as
# hidden devices should not receive to-device messages.
if not device_ids:
user_device_dicts = self.db_pool.simple_select_many_txn(
txn,
table="devices",
column="user_id",
iterable=user_ids_to_query,
keyvalues={"user_id": user_id, "hidden": False},
retcols=("device_id",),
)

device_ids_to_query.update(
{row["device_id"] for row in user_device_dicts}
)

if not user_ids_to_query or not device_ids_to_query:
# We've ended up with no devices to query.
return {}, to_stream_id

# We include both user IDs and device IDs in this query, as we have an index
# (device_inbox_user_stream_id) for them.
user_id_many_clause_sql, user_id_many_clause_args = make_in_list_sql_clause(
self.database_engine, "user_id", user_ids_to_query
)
txn.execute(
sql, (user_id, device_id, last_stream_id, current_stream_id, limit)
(
device_id_many_clause_sql,
device_id_many_clause_args,
) = make_in_list_sql_clause(
self.database_engine, "device_id", device_ids_to_query
)

messages = []
stream_pos = current_stream_id
sql = f"""
SELECT stream_id, user_id, device_id, message_json FROM device_inbox
WHERE {user_id_many_clause_sql}
AND {device_id_many_clause_sql}
AND ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
"""
sql_args = (
*user_id_many_clause_args,
*device_id_many_clause_args,
from_stream_id,
to_stream_id,
)

for row in txn:
stream_pos = row[0]
messages.append(db_to_json(row[1]))
# If a limit was provided, limit the data retrieved from the database
if limit:
sql += "LIMIT ?"
sql_args += (limit,)

# If the limit was not reached we know that there's no more data for this
# user/device pair up to current_stream_id.
if len(messages) < limit:
stream_pos = current_stream_id
txn.execute(sql, sql_args)

return messages, stream_pos
# Create and fill a dictionary of (user ID, device ID) -> list of messages
# intended for each device.
recipient_device_to_messages: Dict[Tuple[str, str], List[JsonDict]] = {}
for message_count, row in enumerate(txn, start=1):
last_processed_stream_pos = row[0]
recipient_user_id = row[1]
recipient_device_id = row[2]
message_dict = db_to_json(row[3])

# Store the device details
recipient_device_to_messages.setdefault(
(recipient_user_id, recipient_device_id), []
).append(message_dict)

if limit and message_count == limit:
# We ended up hitting the message limit. There may be more messages to retrieve.
# Return what we have, as well as the last stream position that was processed.
#
# The caller is expected to set this as the lower (exclusive) bound
# for the next query of this device.
return recipient_device_to_messages, last_processed_stream_pos

# The limit was not reached, thus we know that recipient_device_to_messages
# contains all to-device messages for the given device and stream id range.
#
# We return to_stream_id, which the caller should then provide as the lower
# (exclusive) bound on the next query of this device.
return recipient_device_to_messages, to_stream_id

return await self.db_pool.runInteraction(
"get_new_messages_for_device", get_new_messages_for_device_txn
"get_new_device_messages", get_new_device_messages_txn
)

@trace
Expand Down

0 comments on commit 822e92a

Please sign in to comment.