Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix import cycle #11965

Merged
merged 2 commits into from
Feb 11, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11965.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix an import cycle in `synapse.event_auth`.
54 changes: 31 additions & 23 deletions synapse/event_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import logging
import typing
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union

from canonicaljson import encode_canonical_json
Expand All @@ -34,15 +35,18 @@
EventFormatVersions,
RoomVersion,
)
from synapse.events import EventBase
from synapse.events.builder import EventBuilder
from synapse.types import StateMap, UserID, get_domain_from_id

if typing.TYPE_CHECKING:
# conditional imports to avoid import cycle
from synapse.events import EventBase
from synapse.events.builder import EventBuilder

logger = logging.getLogger(__name__)


def validate_event_for_room_version(
room_version_obj: RoomVersion, event: EventBase
room_version_obj: RoomVersion, event: "EventBase"
) -> None:
"""Ensure that the event complies with the limits, and has the right signatures

Expand Down Expand Up @@ -113,7 +117,9 @@ def validate_event_for_room_version(


def check_auth_rules_for_event(
room_version_obj: RoomVersion, event: EventBase, auth_events: Iterable[EventBase]
room_version_obj: RoomVersion,
event: "EventBase",
auth_events: Iterable["EventBase"],
) -> None:
"""Check that an event complies with the auth rules

Expand Down Expand Up @@ -256,7 +262,7 @@ def check_auth_rules_for_event(
logger.debug("Allowing! %s", event)


def _check_size_limits(event: EventBase) -> None:
def _check_size_limits(event: "EventBase") -> None:
if len(event.user_id) > 255:
raise EventSizeError("'user_id' too large")
if len(event.room_id) > 255:
Expand All @@ -271,7 +277,7 @@ def _check_size_limits(event: EventBase) -> None:
raise EventSizeError("event too large")


def _can_federate(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
def _can_federate(event: "EventBase", auth_events: StateMap["EventBase"]) -> bool:
creation_event = auth_events.get((EventTypes.Create, ""))
# There should always be a creation event, but if not don't federate.
if not creation_event:
Expand All @@ -281,7 +287,7 @@ def _can_federate(event: EventBase, auth_events: StateMap[EventBase]) -> bool:


def _is_membership_change_allowed(
room_version: RoomVersion, event: EventBase, auth_events: StateMap[EventBase]
room_version: RoomVersion, event: "EventBase", auth_events: StateMap["EventBase"]
) -> None:
"""
Confirms that the event which changes membership is an allowed change.
Expand Down Expand Up @@ -471,23 +477,25 @@ def _is_membership_change_allowed(


def _check_event_sender_in_room(
event: EventBase, auth_events: StateMap[EventBase]
event: "EventBase", auth_events: StateMap["EventBase"]
) -> None:
key = (EventTypes.Member, event.user_id)
member_event = auth_events.get(key)

_check_joined_room(member_event, event.user_id, event.room_id)


def _check_joined_room(member: Optional[EventBase], user_id: str, room_id: str) -> None:
def _check_joined_room(
member: Optional["EventBase"], user_id: str, room_id: str
) -> None:
if not member or member.membership != Membership.JOIN:
raise AuthError(
403, "User %s not in room %s (%s)" % (user_id, room_id, repr(member))
)


def get_send_level(
etype: str, state_key: Optional[str], power_levels_event: Optional[EventBase]
etype: str, state_key: Optional[str], power_levels_event: Optional["EventBase"]
) -> int:
"""Get the power level required to send an event of a given type

Expand Down Expand Up @@ -523,7 +531,7 @@ def get_send_level(
return int(send_level)


def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
def _can_send_event(event: "EventBase", auth_events: StateMap["EventBase"]) -> bool:
power_levels_event = get_power_level_event(auth_events)

send_level = get_send_level(event.type, event.get("state_key"), power_levels_event)
Expand All @@ -547,8 +555,8 @@ def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool:

def check_redaction(
room_version_obj: RoomVersion,
event: EventBase,
auth_events: StateMap[EventBase],
event: "EventBase",
auth_events: StateMap["EventBase"],
) -> bool:
"""Check whether the event sender is allowed to redact the target event.

Expand Down Expand Up @@ -585,8 +593,8 @@ def check_redaction(

def check_historical(
room_version_obj: RoomVersion,
event: EventBase,
auth_events: StateMap[EventBase],
event: "EventBase",
auth_events: StateMap["EventBase"],
) -> None:
"""Check whether the event sender is allowed to send historical related
events like "insertion", "batch", and "marker".
Expand Down Expand Up @@ -616,8 +624,8 @@ def check_historical(

def _check_power_levels(
room_version_obj: RoomVersion,
event: EventBase,
auth_events: StateMap[EventBase],
event: "EventBase",
auth_events: StateMap["EventBase"],
) -> None:
user_list = event.content.get("users", {})
# Validate users
Expand Down Expand Up @@ -710,11 +718,11 @@ def _check_power_levels(
)


def get_power_level_event(auth_events: StateMap[EventBase]) -> Optional[EventBase]:
def get_power_level_event(auth_events: StateMap["EventBase"]) -> Optional["EventBase"]:
return auth_events.get((EventTypes.PowerLevels, ""))


def get_user_power_level(user_id: str, auth_events: StateMap[EventBase]) -> int:
def get_user_power_level(user_id: str, auth_events: StateMap["EventBase"]) -> int:
"""Get a user's power level

Args:
Expand Down Expand Up @@ -750,7 +758,7 @@ def get_user_power_level(user_id: str, auth_events: StateMap[EventBase]) -> int:
return 0


def get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -> int:
def get_named_level(auth_events: StateMap["EventBase"], name: str, default: int) -> int:
power_level_event = get_power_level_event(auth_events)

if not power_level_event:
Expand All @@ -763,7 +771,7 @@ def get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -
return default


def _verify_third_party_invite(event: EventBase, auth_events: StateMap[EventBase]):
def _verify_third_party_invite(event: "EventBase", auth_events: StateMap["EventBase"]):
"""
Validates that the invite event is authorized by a previous third-party invite.

Expand Down Expand Up @@ -827,7 +835,7 @@ def _verify_third_party_invite(event: EventBase, auth_events: StateMap[EventBase
return False


def get_public_keys(invite_event: EventBase) -> List[Dict[str, Any]]:
def get_public_keys(invite_event: "EventBase") -> List[Dict[str, Any]]:
public_keys = []
if "public_key" in invite_event.content:
o = {"public_key": invite_event.content["public_key"]}
Expand All @@ -839,7 +847,7 @@ def get_public_keys(invite_event: EventBase) -> List[Dict[str, Any]]:


def auth_types_for_event(
room_version: RoomVersion, event: Union[EventBase, EventBuilder]
room_version: RoomVersion, event: Union["EventBase", "EventBuilder"]
) -> Set[Tuple[str, str]]:
"""Given an event, return a list of (EventType, StateKey) that may be
needed to auth the event. The returned list may be a superset of what
Expand Down