Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add a type hint for get_device_handler() and fix incorrect types. (#…
Browse files Browse the repository at this point in the history
…14055)

This was the last untyped handler from the HomeServer object. Since
it was being treated as Any (and thus unchecked) it was being used
incorrectly in a few places.
  • Loading branch information
clokep authored and H-Shay committed Dec 13, 2022
1 parent 56f8fea commit c69d78f
Show file tree
Hide file tree
Showing 16 changed files with 185 additions and 77 deletions.
1 change: 1 addition & 0 deletions changelog.d/14055.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to `HomeServer`.
4 changes: 4 additions & 0 deletions synapse/handlers/deactivate_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import TYPE_CHECKING, Optional

from synapse.api.errors import SynapseError
from synapse.handlers.device import DeviceHandler
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import Codes, Requester, UserID, create_requester

Expand Down Expand Up @@ -76,6 +77,9 @@ async def deactivate_account(
True if identity server supports removing threepids, otherwise False.
"""

# This can only be called on the main process.
assert isinstance(self._device_handler, DeviceHandler)

# Check if this user can be deactivated
if not await self._third_party_rules.check_can_deactivate_user(
user_id, by_admin
Expand Down
65 changes: 50 additions & 15 deletions synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@


class DeviceWorkerHandler:
device_list_updater: "DeviceListWorkerUpdater"

def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.hs = hs
Expand All @@ -76,6 +78,8 @@ def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
self._msc3852_enabled = hs.config.experimental.msc3852_enabled

self.device_list_updater = DeviceListWorkerUpdater(hs)

@trace
async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
"""
Expand All @@ -99,6 +103,19 @@ async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
log_kv(device_map)
return devices

async def get_dehydrated_device(
self, user_id: str
) -> Optional[Tuple[str, JsonDict]]:
"""Retrieve the information for a dehydrated device.
Args:
user_id: the user whose dehydrated device we are looking for
Returns:
a tuple whose first item is the device ID, and the second item is
the dehydrated device information
"""
return await self.store.get_dehydrated_device(user_id)

@trace
async def get_device(self, user_id: str, device_id: str) -> JsonDict:
"""Retrieve the given device
Expand Down Expand Up @@ -127,7 +144,7 @@ async def get_device(self, user_id: str, device_id: str) -> JsonDict:
@cancellable
async def get_device_changes_in_shared_rooms(
self, user_id: str, room_ids: Collection[str], from_token: StreamToken
) -> Collection[str]:
) -> Set[str]:
"""Get the set of users whose devices have changed who share a room with
the given user.
"""
Expand Down Expand Up @@ -320,6 +337,8 @@ async def handle_room_un_partial_stated(self, room_id: str) -> None:


class DeviceHandler(DeviceWorkerHandler):
device_list_updater: "DeviceListUpdater"

def __init__(self, hs: "HomeServer"):
super().__init__(hs)

Expand Down Expand Up @@ -606,19 +625,6 @@ async def store_dehydrated_device(
await self.delete_devices(user_id, [old_device_id])
return device_id

async def get_dehydrated_device(
self, user_id: str
) -> Optional[Tuple[str, JsonDict]]:
"""Retrieve the information for a dehydrated device.
Args:
user_id: the user whose dehydrated device we are looking for
Returns:
a tuple whose first item is the device ID, and the second item is
the dehydrated device information
"""
return await self.store.get_dehydrated_device(user_id)

async def rehydrate_device(
self, user_id: str, access_token: str, device_id: str
) -> dict:
Expand Down Expand Up @@ -882,7 +888,36 @@ def _update_device_from_client_ips(
)


class DeviceListUpdater:
class DeviceListWorkerUpdater:
"Handles incoming device list updates from federation and contacts the main process over replication"

def __init__(self, hs: "HomeServer"):
from synapse.replication.http.devices import (
ReplicationUserDevicesResyncRestServlet,
)

self._user_device_resync_client = (
ReplicationUserDevicesResyncRestServlet.make_client(hs)
)

async def user_device_resync(
self, user_id: str, mark_failed_as_stale: bool = True
) -> Optional[JsonDict]:
"""Fetches all devices for a user and updates the device cache with them.
Args:
user_id: The user's id whose device_list will be updated.
mark_failed_as_stale: Whether to mark the user's device list as stale
if the attempt to resync failed.
Returns:
A dict with device info as under the "devices" in the result of this
request:
https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
"""
return await self._user_device_resync_client(user_id=user_id)


class DeviceListUpdater(DeviceListWorkerUpdater):
"Handles incoming device list updates from federation and updates the DB"

def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):
Expand Down
61 changes: 32 additions & 29 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@

from synapse.api.constants import EduTypes
from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError
from synapse.handlers.device import DeviceHandler
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.types import (
JsonDict,
UserID,
Expand All @@ -56,27 +56,23 @@ def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine
self.clock = hs.get_clock()

self._edu_updater = SigningKeyEduUpdater(hs, self)

federation_registry = hs.get_federation_registry()

self._is_master = hs.config.worker.worker_app is None
if not self._is_master:
self._user_device_resync_client = (
ReplicationUserDevicesResyncRestServlet.make_client(hs)
)
else:
is_master = hs.config.worker.worker_app is None
if is_master:
edu_updater = SigningKeyEduUpdater(hs)

# Only register this edu handler on master as it requires writing
# device updates to the db
federation_registry.register_edu_handler(
EduTypes.SIGNING_KEY_UPDATE,
self._edu_updater.incoming_signing_key_update,
edu_updater.incoming_signing_key_update,
)
# also handle the unstable version
# FIXME: remove this when enough servers have upgraded
federation_registry.register_edu_handler(
EduTypes.UNSTABLE_SIGNING_KEY_UPDATE,
self._edu_updater.incoming_signing_key_update,
edu_updater.incoming_signing_key_update,
)

# doesn't really work as part of the generic query API, because the
Expand Down Expand Up @@ -319,14 +315,13 @@ async def _query_devices_for_destination(
# probably be tracking their device lists. However, we haven't
# done an initial sync on the device list so we do it now.
try:
if self._is_master:
resync_results = await self.device_handler.device_list_updater.user_device_resync(
resync_results = (
await self.device_handler.device_list_updater.user_device_resync(
user_id
)
else:
resync_results = await self._user_device_resync_client(
user_id=user_id
)
)
if resync_results is None:
raise ValueError("Device resync failed")

# Add the device keys to the results.
user_devices = resync_results["devices"]
Expand Down Expand Up @@ -605,6 +600,8 @@ async def claim_client_keys(destination: str) -> None:
async def upload_keys_for_user(
self, user_id: str, device_id: str, keys: JsonDict
) -> JsonDict:
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)

time_now = self.clock.time_msec()

Expand Down Expand Up @@ -732,6 +729,8 @@ async def upload_signing_keys_for_user(
user_id: the user uploading the keys
keys: the signing keys
"""
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)

# if a master key is uploaded, then check it. Otherwise, load the
# stored master key, to check signatures on other keys
Expand Down Expand Up @@ -823,6 +822,9 @@ async def upload_signatures_for_device_keys(
Raises:
SynapseError: if the signatures dict is not valid.
"""
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)

failures = {}

# signatures to be stored. Each item will be a SignatureListItem
Expand Down Expand Up @@ -1200,6 +1202,9 @@ async def _retrieve_cross_signing_keys_for_remote_user(
A tuple of the retrieved key content, the key's ID and the matching VerifyKey.
If the key cannot be retrieved, all values in the tuple will instead be None.
"""
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)

try:
remote_result = await self.federation.query_user_devices(
user.domain, user.to_string()
Expand Down Expand Up @@ -1396,11 +1401,14 @@ class SignatureListItem:
class SigningKeyEduUpdater:
"""Handles incoming signing key updates from federation and updates the DB"""

def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.federation = hs.get_federation_client()
self.clock = hs.get_clock()
self.e2e_keys_handler = e2e_keys_handler

device_handler = hs.get_device_handler()
assert isinstance(device_handler, DeviceHandler)
self._device_handler = device_handler

self._remote_edu_linearizer = Linearizer(name="remote_signing_key")

Expand Down Expand Up @@ -1445,9 +1453,6 @@ async def _handle_signing_key_updates(self, user_id: str) -> None:
user_id: the user whose updates we are processing
"""

device_handler = self.e2e_keys_handler.device_handler
device_list_updater = device_handler.device_list_updater

async with self._remote_edu_linearizer.queue(user_id):
pending_updates = self._pending_updates.pop(user_id, [])
if not pending_updates:
Expand All @@ -1459,13 +1464,11 @@ async def _handle_signing_key_updates(self, user_id: str) -> None:
logger.info("pending updates: %r", pending_updates)

for master_key, self_signing_key in pending_updates:
new_device_ids = (
await device_list_updater.process_cross_signing_key_update(
user_id,
master_key,
self_signing_key,
)
new_device_ids = await self._device_handler.device_list_updater.process_cross_signing_key_update(
user_id,
master_key,
self_signing_key,
)
device_ids = device_ids + new_device_ids

await device_handler.notify_device_update(user_id, device_ids)
await self._device_handler.notify_device_update(user_id, device_ids)
4 changes: 4 additions & 0 deletions synapse/handlers/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from synapse.appservice import ApplicationService
from synapse.config.server import is_threepid_reserved
from synapse.handlers.device import DeviceHandler
from synapse.http.servlet import assert_params_in_dict
from synapse.replication.http.login import RegisterDeviceReplicationServlet
from synapse.replication.http.register import (
Expand Down Expand Up @@ -841,6 +842,9 @@ class and RegisterDeviceReplicationServlet.
refresh_token = None
refresh_token_id = None

# This can only run on the main process.
assert isinstance(self.device_handler, DeviceHandler)

registered_device_id = await self.device_handler.check_device_registered(
user_id,
device_id,
Expand Down
6 changes: 5 additions & 1 deletion synapse/handlers/set_password.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import TYPE_CHECKING, Optional

from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.handlers.device import DeviceHandler
from synapse.types import Requester

if TYPE_CHECKING:
Expand All @@ -29,7 +30,10 @@ class SetPasswordHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
# This can only be instantiated on the main process.
device_handler = hs.get_device_handler()
assert isinstance(device_handler, DeviceHandler)
self._device_handler = device_handler

async def set_password(
self,
Expand Down
9 changes: 9 additions & 0 deletions synapse/handlers/sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
from synapse.config.sso import SsoAttributeRequirement
from synapse.handlers.device import DeviceHandler
from synapse.handlers.register import init_counters_for_auth_provider
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http import get_request_user_agent
Expand Down Expand Up @@ -1035,13 +1036,21 @@ async def revoke_sessions_for_provider_session_id(
) -> None:
"""Revoke any devices and in-flight logins tied to a provider session.
Can only be called from the main process.
Args:
auth_provider_id: A unique identifier for this SSO provider, e.g.
"oidc" or "saml".
auth_provider_session_id: The session ID from the provider to logout
expected_user_id: The user we're expecting to logout. If set, it will ignore
sessions belonging to other users and log an error.
"""

# It is expected that this is the main process.
assert isinstance(
self._device_handler, DeviceHandler
), "revoking SSO sessions can only be called on the main process"

# Invalidate any running user-mapping sessions
to_delete = []
for session_id, session in self._username_mapping_sessions.items():
Expand Down
10 changes: 9 additions & 1 deletion synapse/module_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
ON_LOGGED_OUT_CALLBACK,
AuthHandler,
)
from synapse.handlers.device import DeviceHandler
from synapse.handlers.push_rules import RuleSpec, check_actions
from synapse.http.client import SimpleHttpClient
from synapse.http.server import (
Expand Down Expand Up @@ -207,6 +208,7 @@ def __init__(self, hs: "HomeServer", auth_handler: AuthHandler) -> None:
self._registration_handler = hs.get_registration_handler()
self._send_email_handler = hs.get_send_email_handler()
self._push_rules_handler = hs.get_push_rules_handler()
self._device_handler = hs.get_device_handler()
self.custom_template_dir = hs.config.server.custom_template_directory

try:
Expand Down Expand Up @@ -784,6 +786,8 @@ def invalidate_access_token(
) -> Generator["defer.Deferred[Any]", Any, None]:
"""Invalidate an access token for a user
Can only be called from the main process.
Added in Synapse v0.25.0.
Args:
Expand All @@ -796,6 +800,10 @@ def invalidate_access_token(
Raises:
synapse.api.errors.AuthError: the access token is invalid
"""
assert isinstance(
self._device_handler, DeviceHandler
), "invalidate_access_token can only be called on the main process"

# see if the access token corresponds to a device
user_info = yield defer.ensureDeferred(
self._auth.get_user_by_access_token(access_token)
Expand All @@ -805,7 +813,7 @@ def invalidate_access_token(
if device_id:
# delete the device, which will also delete its access tokens
yield defer.ensureDeferred(
self._hs.get_device_handler().delete_devices(user_id, [device_id])
self._device_handler.delete_devices(user_id, [device_id])
)
else:
# no associated device. Just delete the access token.
Expand Down
Loading

0 comments on commit c69d78f

Please sign in to comment.