Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Merge commit 'aec708517' into anoa/dinsic_release_1_21_x
Browse files Browse the repository at this point in the history
* commit 'aec708517':
  Convert state and stream stores and related code to async (#8194)
  Ensure that the OpenID Connect remote ID is a string. (#8190)
  • Loading branch information
anoadragon453 committed Oct 20, 2020
2 parents 41ac123 + aec7085 commit b7672ff
Show file tree
Hide file tree
Showing 10 changed files with 94 additions and 47 deletions.
1 change: 1 addition & 0 deletions changelog.d/8190.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix logging in via OpenID Connect with a provider that uses integer user IDs.
1 change: 1 addition & 0 deletions changelog.d/8194.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
3 changes: 3 additions & 0 deletions synapse/handlers/oidc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,9 @@ async def _map_userinfo_to_user(
raise MappingException(
"Failed to extract subject from OIDC response: %s" % (e,)
)
# Some OIDC providers use integer IDs, but Synapse expects external IDs
# to be strings.
remote_user_id = str(remote_user_id)

logger.info(
"Looking for existing mapping for user %s:%s",
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ async def clone_existing_room(
old_room_member_state_events = await self.store.get_events(
old_room_member_state_ids.values()
)
for k, old_event in old_room_member_state_events.items():
for old_event in old_room_member_state_events.values():
# Only transfer ban events
if (
"membership" in old_event.content
Expand Down
19 changes: 10 additions & 9 deletions synapse/storage/databases/main/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateFilter
from synapse.types import StateMap
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList

Expand Down Expand Up @@ -163,15 +164,15 @@ async def get_create_event_for_room(self, room_id: str) -> EventBase:
return create_event

@cached(max_entries=100000, iterable=True)
def get_current_state_ids(self, room_id):
async def get_current_state_ids(self, room_id: str) -> StateMap[str]:
"""Get the current state event ids for a room based on the
current_state_events table.
Args:
room_id (str)
room_id: The room to get the state IDs of.
Returns:
deferred: dict of (type, state_key) -> event_id
The current state of the room.
"""

def _get_current_state_ids_txn(txn):
Expand All @@ -184,14 +185,14 @@ def _get_current_state_ids_txn(txn):

return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}

return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_current_state_ids", _get_current_state_ids_txn
)

# FIXME: how should this be cached?
def get_filtered_current_state_ids(
async def get_filtered_current_state_ids(
self, room_id: str, state_filter: StateFilter = StateFilter.all()
):
) -> StateMap[str]:
"""Get the current state event of a given type for a room based on the
current_state_events table. This may not be as up-to-date as the result
of doing a fresh state resolution as per state_handler.get_current_state
Expand All @@ -202,14 +203,14 @@ def get_filtered_current_state_ids(
from the database.
Returns:
defer.Deferred[StateMap[str]]: Map from type/state_key to event ID.
Map from type/state_key to event ID.
"""

where_clause, where_args = state_filter.make_sql_filter_clause()

if not where_clause:
# We delegate to the cached version
return self.get_current_state_ids(room_id)
return await self.get_current_state_ids(room_id)

def _get_filtered_current_state_ids_txn(txn):
results = {}
Expand All @@ -231,7 +232,7 @@ def _get_filtered_current_state_ids_txn(txn):

return results

return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
)

Expand Down
21 changes: 11 additions & 10 deletions synapse/storage/databases/main/state_deltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,17 @@
# limitations under the License.

import logging

from twisted.internet import defer
from typing import Any, Dict, List, Tuple

from synapse.storage._base import SQLBaseStore

logger = logging.getLogger(__name__)


class StateDeltasStore(SQLBaseStore):
def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int):
async def get_current_state_deltas(
self, prev_stream_id: int, max_stream_id: int
) -> Tuple[int, List[Dict[str, Any]]]:
"""Fetch a list of room state changes since the given stream id
Each entry in the result contains the following fields:
Expand All @@ -37,12 +38,12 @@ def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int):
if it's new state.
Args:
prev_stream_id (int): point to get changes since (exclusive)
max_stream_id (int): the point that we know has been correctly persisted
prev_stream_id: point to get changes since (exclusive)
max_stream_id: the point that we know has been correctly persisted
- ie, an upper limit to return changes from.
Returns:
Deferred[tuple[int, list[dict]]: A tuple consisting of:
A tuple consisting of:
- the stream id which these results go up to
- list of current_state_delta_stream rows. If it is empty, we are
up to date.
Expand All @@ -58,7 +59,7 @@ def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int):
# if the CSDs haven't changed between prev_stream_id and now, we
# know for certain that they haven't changed between prev_stream_id and
# max_stream_id.
return defer.succeed((max_stream_id, []))
return (max_stream_id, [])

def get_current_state_deltas_txn(txn):
# First we calculate the max stream id that will give us less than
Expand Down Expand Up @@ -102,7 +103,7 @@ def get_current_state_deltas_txn(txn):
txn.execute(sql, (prev_stream_id, clipped_stream_id))
return clipped_stream_id, self.db_pool.cursor_to_dict(txn)

return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_current_state_deltas", get_current_state_deltas_txn
)

Expand All @@ -114,8 +115,8 @@ def _get_max_stream_id_in_current_state_deltas_txn(self, txn):
retcol="COALESCE(MAX(stream_id), -1)",
)

def get_max_stream_id_in_current_state_deltas(self):
return self.db_pool.runInteraction(
async def get_max_stream_id_in_current_state_deltas(self):
return await self.db_pool.runInteraction(
"get_max_stream_id_in_current_state_deltas",
self._get_max_stream_id_in_current_state_deltas_txn,
)
11 changes: 7 additions & 4 deletions synapse/storage/databases/main/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,16 +539,17 @@ async def get_recent_event_ids_for_room(

return rows, token

def get_room_event_before_stream_ordering(self, room_id: str, stream_ordering: int):
async def get_room_event_before_stream_ordering(
self, room_id: str, stream_ordering: int
) -> Tuple[int, int, str]:
"""Gets details of the first event in a room at or before a stream ordering
Args:
room_id:
stream_ordering:
Returns:
Deferred[(int, int, str)]:
(stream ordering, topological ordering, event_id)
A tuple of (stream ordering, topological ordering, event_id)
"""

def _f(txn):
Expand All @@ -563,7 +564,9 @@ def _f(txn):
txn.execute(sql, (room_id, stream_ordering))
return txn.fetchone()

return self.db_pool.runInteraction("get_room_event_before_stream_ordering", _f)
return await self.db_pool.runInteraction(
"get_room_event_before_stream_ordering", _f
)

async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str:
"""Returns the current token for rooms stream.
Expand Down
26 changes: 13 additions & 13 deletions synapse/storage/databases/state/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
from collections import namedtuple
from typing import Dict, Iterable, List, Set, Tuple

from twisted.internet import defer

from synapse.api.constants import EventTypes
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
Expand Down Expand Up @@ -103,7 +101,7 @@ def get_max_state_group_txn(txn: Cursor):
)

@cached(max_entries=10000, iterable=True)
def get_state_group_delta(self, state_group):
async def get_state_group_delta(self, state_group):
"""Given a state group try to return a previous group and a delta between
the old and the new.
Expand Down Expand Up @@ -135,7 +133,7 @@ def _get_state_group_delta_txn(txn):
{(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
)

return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_state_group_delta", _get_state_group_delta_txn
)

Expand Down Expand Up @@ -367,9 +365,9 @@ def _insert_into_cache(
fetched_keys=non_member_types,
)

def store_state_group(
async def store_state_group(
self, event_id, room_id, prev_group, delta_ids, current_state_ids
):
) -> int:
"""Store a new set of state, returning a newly assigned state group.
Args:
Expand All @@ -383,7 +381,7 @@ def store_state_group(
to event_id.
Returns:
Deferred[int]: The state group ID
The state group ID
"""

def _store_state_group_txn(txn):
Expand Down Expand Up @@ -484,11 +482,13 @@ def _store_state_group_txn(txn):

return state_group

return self.db_pool.runInteraction("store_state_group", _store_state_group_txn)
return await self.db_pool.runInteraction(
"store_state_group", _store_state_group_txn
)

def purge_unreferenced_state_groups(
async def purge_unreferenced_state_groups(
self, room_id: str, state_groups_to_delete
) -> defer.Deferred:
) -> None:
"""Deletes no longer referenced state groups and de-deltas any state
groups that reference them.
Expand All @@ -499,7 +499,7 @@ def purge_unreferenced_state_groups(
to delete.
"""

return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"purge_unreferenced_state_groups",
self._purge_unreferenced_state_groups,
room_id,
Expand Down Expand Up @@ -594,15 +594,15 @@ async def get_previous_state_groups(

return {row["state_group"]: row["prev_state_group"] for row in rows}

def purge_room_state(self, room_id, state_groups_to_delete):
async def purge_room_state(self, room_id, state_groups_to_delete):
"""Deletes all record of a room from state tables
Args:
room_id (str):
state_groups_to_delete (list[int]): State groups to delete
"""

return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"purge_room_state",
self._purge_room_state_txn,
room_id,
Expand Down
16 changes: 8 additions & 8 deletions synapse/storage/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,19 +333,19 @@ class StateGroupStorage(object):
def __init__(self, hs, stores):
self.stores = stores

def get_state_group_delta(self, state_group: int):
async def get_state_group_delta(self, state_group: int):
"""Given a state group try to return a previous group and a delta between
the old and the new.
Args:
state_group: The state group used to retrieve state deltas.
Returns:
Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
Tuple[Optional[int], Optional[StateMap[str]]]:
(prev_group, delta_ids)
"""

return self.stores.state.get_state_group_delta(state_group)
return await self.stores.state.get_state_group_delta(state_group)

async def get_state_groups_ids(
self, _room_id: str, event_ids: Iterable[str]
Expand Down Expand Up @@ -525,7 +525,7 @@ async def get_state_ids_for_event(
state_filter: The state filter used to fetch state from the database.
Returns:
A deferred dict from (type, state_key) -> state_event
A dict from (type, state_key) -> state_event
"""
state_map = await self.get_state_ids_for_events([event_id], state_filter)
return state_map[event_id]
Expand All @@ -546,14 +546,14 @@ def _get_state_for_groups(
"""
return self.stores.state._get_state_for_groups(groups, state_filter)

def store_state_group(
async def store_state_group(
self,
event_id: str,
room_id: str,
prev_group: Optional[int],
delta_ids: Optional[dict],
current_state_ids: dict,
):
) -> int:
"""Store a new set of state, returning a newly assigned state group.
Args:
Expand All @@ -567,8 +567,8 @@ def store_state_group(
to event_id.
Returns:
Deferred[int]: The state group ID
The state group ID
"""
return self.stores.state.store_state_group(
return await self.stores.state.store_state_group(
event_id, room_id, prev_group, delta_ids, current_state_ids
)
Loading

0 comments on commit b7672ff

Please sign in to comment.