Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add type hints for state. (#8140)
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep committed Aug 24, 2020
1 parent cbd8d83 commit 5758dcf
Show file tree
Hide file tree
Showing 10 changed files with 420 additions and 203 deletions.
1 change: 1 addition & 0 deletions changelog.d/8140.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to `synapse.state`.
47 changes: 47 additions & 0 deletions stubs/frozendict.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Stub for frozendict.

from typing import (
Any,
Hashable,
Iterable,
Iterator,
Mapping,
overload,
Tuple,
TypeVar,
)

_KT = TypeVar("_KT", bound=Hashable) # Key type.
_VT = TypeVar("_VT") # Value type.

class frozendict(Mapping[_KT, _VT]):
@overload
def __init__(self, **kwargs: _VT) -> None: ...
@overload
def __init__(self, __map: Mapping[_KT, _VT], **kwargs: _VT) -> None: ...
@overload
def __init__(
self, __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT
) -> None: ...
def __getitem__(self, key: _KT) -> _VT: ...
def __contains__(self, key: Any) -> bool: ...
def copy(self, **add_or_replace: Any) -> frozendict: ...
def __iter__(self) -> Iterator[_KT]: ...
def __len__(self) -> int: ...
def __repr__(self) -> str: ...
def __hash__(self) -> int: ...
4 changes: 2 additions & 2 deletions synapse/federation/sender/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,10 +329,10 @@ async def send_read_receipt(self, receipt: ReadReceipt) -> None:
room_id = receipt.room_id

# Work out which remote servers should be poked and poke them.
domains = await self.state.get_current_hosts_in_room(room_id)
domains_set = await self.state.get_current_hosts_in_room(room_id)
domains = [
d
for d in domains
for d in domains_set
if d != self.server_name
and self._federation_shard_config.should_handle(self._instance_name, d)
]
Expand Down
10 changes: 6 additions & 4 deletions synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2134,10 +2134,10 @@ async def _check_for_soft_fail(
)
state_sets = list(state_sets.values())
state_sets.append(state)
current_state_ids = await self.state_handler.resolve_events(
current_states = await self.state_handler.resolve_events(
room_version, state_sets, event
)
current_state_ids = {k: e.event_id for k, e in current_state_ids.items()}
current_state_ids = {k: e.event_id for k, e in current_states.items()}
else:
current_state_ids = await self.state_handler.get_current_state_ids(
event.room_id, latest_event_ids=extrem_ids
Expand All @@ -2149,9 +2149,11 @@ async def _check_for_soft_fail(

# Now check if event pass auth against said current state
auth_types = auth_types_for_event(event)
current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
current_state_ids_list = [
e for k, e in current_state_ids.items() if k in auth_types
]

auth_events_map = await self.store.get_events(current_state_ids)
auth_events_map = await self.store.get_events(current_state_ids_list)
current_auth_events = {
(e.type, e.state_key): e for e in auth_events_map.values()
}
Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.state import StateHandler
from synapse.storage.databases.main import DataStore
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.types import Collection, JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import cached
from synapse.util.metrics import Measure
Expand Down Expand Up @@ -1318,7 +1318,7 @@ async def get_interested_parties(

async def get_interested_remotes(
store: DataStore, states: List[UserPresenceState], state_handler: StateHandler
) -> List[Tuple[List[str], List[UserPresenceState]]]:
) -> List[Tuple[Collection[str], List[UserPresenceState]]]:
"""Given a list of presence states figure out which remote servers
should be sent which.
Expand All @@ -1334,7 +1334,7 @@ async def get_interested_remotes(
each tuple the list of UserPresenceState should be sent to each
destination
"""
hosts_and_states = []
hosts_and_states = [] # type: List[Tuple[Collection[str], List[UserPresenceState]]]

# First we look up the rooms each user is in (as well as any explicit
# subscriptions), then for each distinct room we look up the remote
Expand Down
20 changes: 12 additions & 8 deletions synapse/handlers/room_member.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import logging
import random
from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union

from unpaddedbase64 import encode_base64

Expand All @@ -38,7 +38,15 @@
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.storage.roommember import RoomsForUser
from synapse.types import Collection, JsonDict, Requester, RoomAlias, RoomID, UserID
from synapse.types import (
Collection,
JsonDict,
Requester,
RoomAlias,
RoomID,
StateMap,
UserID,
)
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room

Expand Down Expand Up @@ -738,9 +746,7 @@ async def send_membership_event(
if prev_member_event.membership == Membership.JOIN:
await self._user_left_room(target_user, room_id)

async def _can_guest_join(
self, current_state_ids: Dict[Tuple[str, str], str]
) -> bool:
async def _can_guest_join(self, current_state_ids: StateMap[str]) -> bool:
"""
Returns whether a guest can join a room based on its current state.
"""
Expand Down Expand Up @@ -969,9 +975,7 @@ async def _make_and_store_3pid_invite(
)
return stream_id

async def _is_host_in_room(
self, current_state_ids: Dict[Tuple[str, str], str]
) -> bool:
async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool:
# Have we just created the room, and is this about to be the very
# first member event?
create_event_id = current_state_ids.get(("m.room.create", ""))
Expand Down
Loading

0 comments on commit 5758dcf

Please sign in to comment.