From 91b730e59f2afb84083da3968cdec22e1a1944f9 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 26 May 2021 15:05:32 -0400 Subject: [PATCH 1/8] Federation servlets take a homeserver instead of a handler. --- synapse/federation/transport/server.py | 208 +++++++++++++++++-------- 1 file changed, 139 insertions(+), 69 deletions(-) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 00ff02c7cbb5..f58f82752c52 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -272,10 +272,17 @@ class BaseFederationServlet: RATELIMIT = True # Whether to rate limit requests or not - def __init__(self, handler, authenticator, ratelimiter, server_name): - self.handler = handler + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + self.hs = hs self.authenticator = authenticator self.ratelimiter = ratelimiter + self.server_name = server_name def _wrap(self, func): authenticator = self.authenticator @@ -372,17 +379,25 @@ def register(self, server): ) -class FederationSendServlet(BaseFederationServlet): +class BaseFederationServerServlet(BaseFederationServlet): + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_federation_server() + + +class FederationSendServlet(BaseFederationServerServlet): PATH = "/send/(?P[^/]*)/?" # We ratelimit manually in the handler as we queue up the requests and we # don't want to fill up the ratelimiter with blocked requests. RATELIMIT = False - def __init__(self, handler, server_name, **kwargs): - super().__init__(handler, server_name=server_name, **kwargs) - self.server_name = server_name - # This is when someone is trying to send us a bunch of data. async def on_PUT(self, origin, content, query, transaction_id): """Called on PUT /send// @@ -431,7 +446,7 @@ async def on_PUT(self, origin, content, query, transaction_id): return code, response -class FederationEventServlet(BaseFederationServlet): +class FederationEventServlet(BaseFederationServerServlet): PATH = "/event/(?P[^/]*)/?" # This is when someone asks for a data item for a given server data_id pair. @@ -439,7 +454,7 @@ async def on_GET(self, origin, content, query, event_id): return await self.handler.on_pdu_request(origin, event_id) -class FederationStateV1Servlet(BaseFederationServlet): +class FederationStateV1Servlet(BaseFederationServerServlet): PATH = "/state/(?P[^/]*)/?" # This is when someone asks for all data for a given room. @@ -451,7 +466,7 @@ async def on_GET(self, origin, content, query, room_id): ) -class FederationStateIdsServlet(BaseFederationServlet): +class FederationStateIdsServlet(BaseFederationServerServlet): PATH = "/state_ids/(?P[^/]*)/?" async def on_GET(self, origin, content, query, room_id): @@ -462,7 +477,7 @@ async def on_GET(self, origin, content, query, room_id): ) -class FederationBackfillServlet(BaseFederationServlet): +class FederationBackfillServlet(BaseFederationServerServlet): PATH = "/backfill/(?P[^/]*)/?" async def on_GET(self, origin, content, query, room_id): @@ -475,7 +490,7 @@ async def on_GET(self, origin, content, query, room_id): return await self.handler.on_backfill_request(origin, room_id, versions, limit) -class FederationQueryServlet(BaseFederationServlet): +class FederationQueryServlet(BaseFederationServerServlet): PATH = "/query/(?P[^/]*)" # This is when we receive a server-server Query @@ -485,7 +500,7 @@ async def on_GET(self, origin, content, query, query_type): return await self.handler.on_query_request(query_type, args) -class FederationMakeJoinServlet(BaseFederationServlet): +class FederationMakeJoinServlet(BaseFederationServerServlet): PATH = "/make_join/(?P[^/]*)/(?P[^/]*)" async def on_GET(self, origin, _content, query, room_id, user_id): @@ -515,7 +530,7 @@ async def on_GET(self, origin, _content, query, room_id, user_id): return 200, content -class FederationMakeLeaveServlet(BaseFederationServlet): +class FederationMakeLeaveServlet(BaseFederationServerServlet): PATH = "/make_leave/(?P[^/]*)/(?P[^/]*)" async def on_GET(self, origin, content, query, room_id, user_id): @@ -523,7 +538,7 @@ async def on_GET(self, origin, content, query, room_id, user_id): return 200, content -class FederationV1SendLeaveServlet(BaseFederationServlet): +class FederationV1SendLeaveServlet(BaseFederationServerServlet): PATH = "/send_leave/(?P[^/]*)/(?P[^/]*)" async def on_PUT(self, origin, content, query, room_id, event_id): @@ -531,7 +546,7 @@ async def on_PUT(self, origin, content, query, room_id, event_id): return 200, (200, content) -class FederationV2SendLeaveServlet(BaseFederationServlet): +class FederationV2SendLeaveServlet(BaseFederationServerServlet): PATH = "/send_leave/(?P[^/]*)/(?P[^/]*)" PREFIX = FEDERATION_V2_PREFIX @@ -541,14 +556,14 @@ async def on_PUT(self, origin, content, query, room_id, event_id): return 200, content -class FederationEventAuthServlet(BaseFederationServlet): +class FederationEventAuthServlet(BaseFederationServerServlet): PATH = "/event_auth/(?P[^/]*)/(?P[^/]*)" async def on_GET(self, origin, content, query, room_id, event_id): return await self.handler.on_event_auth(origin, room_id, event_id) -class FederationV1SendJoinServlet(BaseFederationServlet): +class FederationV1SendJoinServlet(BaseFederationServerServlet): PATH = "/send_join/(?P[^/]*)/(?P[^/]*)" async def on_PUT(self, origin, content, query, room_id, event_id): @@ -558,7 +573,7 @@ async def on_PUT(self, origin, content, query, room_id, event_id): return 200, (200, content) -class FederationV2SendJoinServlet(BaseFederationServlet): +class FederationV2SendJoinServlet(BaseFederationServerServlet): PATH = "/send_join/(?P[^/]*)/(?P[^/]*)" PREFIX = FEDERATION_V2_PREFIX @@ -570,7 +585,7 @@ async def on_PUT(self, origin, content, query, room_id, event_id): return 200, content -class FederationV1InviteServlet(BaseFederationServlet): +class FederationV1InviteServlet(BaseFederationServerServlet): PATH = "/invite/(?P[^/]*)/(?P[^/]*)" async def on_PUT(self, origin, content, query, room_id, event_id): @@ -587,7 +602,7 @@ async def on_PUT(self, origin, content, query, room_id, event_id): return 200, (200, content) -class FederationV2InviteServlet(BaseFederationServlet): +class FederationV2InviteServlet(BaseFederationServerServlet): PATH = "/invite/(?P[^/]*)/(?P[^/]*)" PREFIX = FEDERATION_V2_PREFIX @@ -611,7 +626,7 @@ async def on_PUT(self, origin, content, query, room_id, event_id): return 200, content -class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet): +class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet): PATH = "/exchange_third_party_invite/(?P[^/]*)" async def on_PUT(self, origin, content, query, room_id): @@ -619,21 +634,21 @@ async def on_PUT(self, origin, content, query, room_id): return 200, {} -class FederationClientKeysQueryServlet(BaseFederationServlet): +class FederationClientKeysQueryServlet(BaseFederationServerServlet): PATH = "/user/keys/query" async def on_POST(self, origin, content, query): return await self.handler.on_query_client_keys(origin, content) -class FederationUserDevicesQueryServlet(BaseFederationServlet): +class FederationUserDevicesQueryServlet(BaseFederationServerServlet): PATH = "/user/devices/(?P[^/]*)" async def on_GET(self, origin, content, query, user_id): return await self.handler.on_query_user_devices(origin, user_id) -class FederationClientKeysClaimServlet(BaseFederationServlet): +class FederationClientKeysClaimServlet(BaseFederationServerServlet): PATH = "/user/keys/claim" async def on_POST(self, origin, content, query): @@ -641,7 +656,7 @@ async def on_POST(self, origin, content, query): return 200, response -class FederationGetMissingEventsServlet(BaseFederationServlet): +class FederationGetMissingEventsServlet(BaseFederationServerServlet): # TODO(paul): Why does this path alone end with "/?" optional? PATH = "/get_missing_events/(?P[^/]*)/?" @@ -661,7 +676,7 @@ async def on_POST(self, origin, content, query, room_id): return 200, content -class On3pidBindServlet(BaseFederationServlet): +class On3pidBindServlet(BaseFederationServerServlet): PATH = "/3pid/onbind" REQUIRE_AUTH = False @@ -691,7 +706,7 @@ async def on_POST(self, origin, content, query): return 200, {} -class OpenIdUserInfo(BaseFederationServlet): +class OpenIdUserInfo(BaseFederationServerServlet): """ Exchange a bearer token for information about a user. @@ -767,8 +782,16 @@ class PublicRoomList(BaseFederationServlet): PATH = "/publicRooms" - def __init__(self, handler, authenticator, ratelimiter, server_name, allow_access): - super().__init__(handler, authenticator, ratelimiter, server_name) + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + allow_access: bool, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_room_list_handler() self.allow_access = allow_access async def on_GET(self, origin, content, query): @@ -853,7 +876,19 @@ async def on_GET(self, origin, content, query): ) -class FederationGroupsProfileServlet(BaseFederationServlet): +class BaseGroupsServerServlet(BaseFederationServlet): + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_groups_server_handler() + + +class FederationGroupsProfileServlet(BaseGroupsServerServlet): """Get/set the basic profile of a group on behalf of a user""" PATH = "/groups/(?P[^/]*)/profile" @@ -879,7 +914,7 @@ async def on_POST(self, origin, content, query, group_id): return 200, new_content -class FederationGroupsSummaryServlet(BaseFederationServlet): +class FederationGroupsSummaryServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/summary" async def on_GET(self, origin, content, query, group_id): @@ -892,7 +927,7 @@ async def on_GET(self, origin, content, query, group_id): return 200, new_content -class FederationGroupsRoomsServlet(BaseFederationServlet): +class FederationGroupsRoomsServlet(BaseGroupsServerServlet): """Get the rooms in a group on behalf of a user""" PATH = "/groups/(?P[^/]*)/rooms" @@ -907,7 +942,7 @@ async def on_GET(self, origin, content, query, group_id): return 200, new_content -class FederationGroupsAddRoomsServlet(BaseFederationServlet): +class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet): """Add/remove room from group""" PATH = "/groups/(?P[^/]*)/room/(?P[^/]*)" @@ -935,7 +970,7 @@ async def on_DELETE(self, origin, content, query, group_id, room_id): return 200, new_content -class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet): +class FederationGroupsAddRoomsConfigServlet(BaseGroupsServerServlet): """Update room config in group""" PATH = ( @@ -955,7 +990,7 @@ async def on_POST(self, origin, content, query, group_id, room_id, config_key): return 200, result -class FederationGroupsUsersServlet(BaseFederationServlet): +class FederationGroupsUsersServlet(BaseGroupsServerServlet): """Get the users in a group on behalf of a user""" PATH = "/groups/(?P[^/]*)/users" @@ -970,7 +1005,7 @@ async def on_GET(self, origin, content, query, group_id): return 200, new_content -class FederationGroupsInvitedUsersServlet(BaseFederationServlet): +class FederationGroupsInvitedUsersServlet(BaseGroupsServerServlet): """Get the users that have been invited to a group""" PATH = "/groups/(?P[^/]*)/invited_users" @@ -987,7 +1022,7 @@ async def on_GET(self, origin, content, query, group_id): return 200, new_content -class FederationGroupsInviteServlet(BaseFederationServlet): +class FederationGroupsInviteServlet(BaseGroupsServerServlet): """Ask a group server to invite someone to the group""" PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/invite" @@ -1004,7 +1039,7 @@ async def on_POST(self, origin, content, query, group_id, user_id): return 200, new_content -class FederationGroupsAcceptInviteServlet(BaseFederationServlet): +class FederationGroupsAcceptInviteServlet(BaseGroupsServerServlet): """Accept an invitation from the group server""" PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/accept_invite" @@ -1018,7 +1053,7 @@ async def on_POST(self, origin, content, query, group_id, user_id): return 200, new_content -class FederationGroupsJoinServlet(BaseFederationServlet): +class FederationGroupsJoinServlet(BaseGroupsServerServlet): """Attempt to join a group""" PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/join" @@ -1032,7 +1067,7 @@ async def on_POST(self, origin, content, query, group_id, user_id): return 200, new_content -class FederationGroupsRemoveUserServlet(BaseFederationServlet): +class FederationGroupsRemoveUserServlet(BaseGroupsServerServlet): """Leave or kick a user from the group""" PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/remove" @@ -1049,7 +1084,19 @@ async def on_POST(self, origin, content, query, group_id, user_id): return 200, new_content -class FederationGroupsLocalInviteServlet(BaseFederationServlet): +class BaseGroupsLocalServlet(BaseFederationServlet): + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_groups_local_handler() + + +class FederationGroupsLocalInviteServlet(BaseGroupsLocalServlet): """A group server has invited a local user""" PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/invite" @@ -1063,7 +1110,7 @@ async def on_POST(self, origin, content, query, group_id, user_id): return 200, new_content -class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet): +class FederationGroupsRemoveLocalUserServlet(BaseGroupsLocalServlet): """A group server has removed a local user""" PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/remove" @@ -1084,6 +1131,16 @@ class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): PATH = "/groups/(?P[^/]*)/renew_attestation/(?P[^/]*)" + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_groups_attestation_renewer() + async def on_POST(self, origin, content, query, group_id, user_id): # We don't need to check auth here as we check the attestation signatures @@ -1094,7 +1151,7 @@ async def on_POST(self, origin, content, query, group_id, user_id): return 200, new_content -class FederationGroupsSummaryRoomsServlet(BaseFederationServlet): +class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet): """Add/remove a room from the group summary, with optional category. Matches both: @@ -1151,7 +1208,7 @@ async def on_DELETE(self, origin, content, query, group_id, category_id, room_id return 200, resp -class FederationGroupsCategoriesServlet(BaseFederationServlet): +class FederationGroupsCategoriesServlet(BaseGroupsServerServlet): """Get all categories for a group""" PATH = "/groups/(?P[^/]*)/categories/?" @@ -1166,7 +1223,7 @@ async def on_GET(self, origin, content, query, group_id): return 200, resp -class FederationGroupsCategoryServlet(BaseFederationServlet): +class FederationGroupsCategoryServlet(BaseGroupsServerServlet): """Add/remove/get a category in a group""" PATH = "/groups/(?P[^/]*)/categories/(?P[^/]+)" @@ -1219,7 +1276,7 @@ async def on_DELETE(self, origin, content, query, group_id, category_id): return 200, resp -class FederationGroupsRolesServlet(BaseFederationServlet): +class FederationGroupsRolesServlet(BaseGroupsServerServlet): """Get roles in a group""" PATH = "/groups/(?P[^/]*)/roles/?" @@ -1234,7 +1291,7 @@ async def on_GET(self, origin, content, query, group_id): return 200, resp -class FederationGroupsRoleServlet(BaseFederationServlet): +class FederationGroupsRoleServlet(BaseGroupsServerServlet): """Add/remove/get a role in a group""" PATH = "/groups/(?P[^/]*)/roles/(?P[^/]+)" @@ -1287,7 +1344,7 @@ async def on_DELETE(self, origin, content, query, group_id, role_id): return 200, resp -class FederationGroupsSummaryUsersServlet(BaseFederationServlet): +class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet): """Add/remove a user from the group summary, with optional role. Matches both: @@ -1342,7 +1399,7 @@ async def on_DELETE(self, origin, content, query, group_id, role_id, user_id): return 200, resp -class FederationGroupsBulkPublicisedServlet(BaseFederationServlet): +class FederationGroupsBulkPublicisedServlet(BaseGroupsLocalServlet): """Get roles in a group""" PATH = "/get_groups_publicised" @@ -1355,7 +1412,7 @@ async def on_POST(self, origin, content, query): return 200, resp -class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet): +class FederationGroupsSettingJoinPolicyServlet(BaseGroupsServerServlet): """Sets whether a group is joinable without an invite or knock""" PATH = "/groups/(?P[^/]*)/settings/m.join_policy" @@ -1376,6 +1433,16 @@ class FederationSpaceSummaryServlet(BaseFederationServlet): PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946" PATH = "/spaces/(?P[^/]*)" + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_space_summary_handler() + async def on_GET( self, origin: str, @@ -1441,16 +1508,25 @@ class RoomComplexityServlet(BaseFederationServlet): PATH = "/rooms/(?P[^/]*)/complexity" PREFIX = FEDERATION_UNSTABLE_PREFIX - async def on_GET(self, origin, content, query, room_id): - - store = self.handler.hs.get_datastore() + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self._store = self.hs.get_datastore() - is_public = await store.is_room_world_readable_or_publicly_joinable(room_id) + async def on_GET(self, origin, content, query, room_id): + is_public = await self._store.is_room_world_readable_or_publicly_joinable( + room_id + ) if not is_public: raise SynapseError(404, "Room not found", errcode=Codes.INVALID_PARAM) - complexity = await store.get_room_complexity(room_id) + complexity = await self._store.get_room_complexity(room_id) return 200, complexity @@ -1479,6 +1555,7 @@ async def on_GET(self, origin, content, query, room_id): On3pidBindServlet, FederationVersionServlet, RoomComplexityServlet, + FederationSpaceSummaryServlet, ) # type: Tuple[Type[BaseFederationServlet], ...] OPENID_SERVLET_CLASSES = ( @@ -1556,23 +1633,16 @@ def register_servlets( if "federation" in servlet_groups: for servletclass in FEDERATION_SERVLET_CLASSES: servletclass( - handler=hs.get_federation_server(), + hs=hs, authenticator=authenticator, ratelimiter=ratelimiter, server_name=hs.hostname, ).register(resource) - FederationSpaceSummaryServlet( - handler=hs.get_space_summary_handler(), - authenticator=authenticator, - ratelimiter=ratelimiter, - server_name=hs.hostname, - ).register(resource) - if "openid" in servlet_groups: for servletclass in OPENID_SERVLET_CLASSES: servletclass( - handler=hs.get_federation_server(), + hs=hs, authenticator=authenticator, ratelimiter=ratelimiter, server_name=hs.hostname, @@ -1581,7 +1651,7 @@ def register_servlets( if "room_list" in servlet_groups: for servletclass in ROOM_LIST_CLASSES: servletclass( - handler=hs.get_room_list_handler(), + hs=hs, authenticator=authenticator, ratelimiter=ratelimiter, server_name=hs.hostname, @@ -1591,7 +1661,7 @@ def register_servlets( if "group_server" in servlet_groups: for servletclass in GROUP_SERVER_SERVLET_CLASSES: servletclass( - handler=hs.get_groups_server_handler(), + hs=hs, authenticator=authenticator, ratelimiter=ratelimiter, server_name=hs.hostname, @@ -1600,7 +1670,7 @@ def register_servlets( if "group_local" in servlet_groups: for servletclass in GROUP_LOCAL_SERVLET_CLASSES: servletclass( - handler=hs.get_groups_local_handler(), + hs=hs, authenticator=authenticator, ratelimiter=ratelimiter, server_name=hs.hostname, @@ -1609,7 +1679,7 @@ def register_servlets( if "group_attestation" in servlet_groups: for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES: servletclass( - handler=hs.get_groups_attestation_renewer(), + hs=hs, authenticator=authenticator, ratelimiter=ratelimiter, server_name=hs.hostname, From c13a26e48828c7a49052aa16e2150c55174d4ae3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 26 May 2021 15:10:24 -0400 Subject: [PATCH 2/8] Fix a missing argument to spaces summary. --- synapse/federation/transport/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index f58f82752c52..23cb5d65d38a 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -1465,7 +1465,7 @@ async def on_GET( ) return 200, await self.handler.federation_space_summary( - room_id, suggested_only, max_rooms_per_space, exclude_rooms + origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms ) # TODO When switching to the stable endpoint, remove the POST handler. From 468aaeda888627d490cdd3523d828a88c56b31eb Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 26 May 2021 15:13:47 -0400 Subject: [PATCH 3/8] network tuple can be none. --- synapse/handlers/room_list.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 141c9c044400..0a26088d3215 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -44,7 +44,7 @@ def __init__(self, hs: "HomeServer"): self.enable_room_list_search = hs.config.enable_room_list_search self.response_cache = ResponseCache( hs.get_clock(), "room_list" - ) # type: ResponseCache[Tuple[Optional[int], Optional[str], ThirdPartyInstanceID]] + ) # type: ResponseCache[Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]]] self.remote_response_cache = ResponseCache( hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000 ) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]] @@ -54,7 +54,7 @@ async def get_local_public_room_list( limit: Optional[int] = None, since_token: Optional[str] = None, search_filter: Optional[dict] = None, - network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID, + network_tuple: Optional[ThirdPartyInstanceID] = EMPTY_THIRD_PARTY_ID, from_federation: bool = False, ) -> JsonDict: """Generate a local public room list. @@ -111,7 +111,7 @@ async def _get_public_room_list( limit: Optional[int] = None, since_token: Optional[str] = None, search_filter: Optional[dict] = None, - network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID, + network_tuple: Optional[ThirdPartyInstanceID] = EMPTY_THIRD_PARTY_ID, from_federation: bool = False, ) -> JsonDict: """Generate a public room list. From b07d1a8e33dc0b5f65446c7e6cacd2ce174ffbc4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 26 May 2021 15:19:26 -0400 Subject: [PATCH 4/8] Handle workers for groups. --- synapse/federation/transport/server.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 23cb5d65d38a..9460ee720e40 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -28,6 +28,7 @@ FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX, ) +from synapse.handlers.groups_local import GroupsLocalHandler from synapse.http.server import HttpServer, JsonResource from synapse.http.servlet import ( parse_boolean_from_args, @@ -1105,6 +1106,10 @@ async def on_POST(self, origin, content, query, group_id, user_id): if get_domain_from_id(group_id) != origin: raise SynapseError(403, "group_id doesn't match origin") + assert isinstance( + self.handler, GroupsLocalHandler + ), "Workers cannot handle group invites." + new_content = await self.handler.on_invite(group_id, user_id, content) return 200, new_content @@ -1119,6 +1124,10 @@ async def on_POST(self, origin, content, query, group_id, user_id): if get_domain_from_id(group_id) != origin: raise SynapseError(403, "user_id doesn't match origin") + assert isinstance( + self.handler, GroupsLocalHandler + ), "Workers cannot handle group removals." + new_content = await self.handler.user_removed_from_group( group_id, user_id, content ) From 36c4c37924223af0b1637b83808cb547f99e6f65 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 27 May 2021 08:16:53 -0400 Subject: [PATCH 5/8] Newsfragment --- changelog.d/10080.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/10080.misc diff --git a/changelog.d/10080.misc b/changelog.d/10080.misc new file mode 100644 index 000000000000..9adb0fbd02d3 --- /dev/null +++ b/changelog.d/10080.misc @@ -0,0 +1 @@ +Add type hints to the federation servlets. From 8214e0a87c0e334d6b59506d5dab48ab293c48fe Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 1 Jun 2021 08:02:23 -0400 Subject: [PATCH 6/8] Add docstrings to base classes. --- synapse/federation/transport/server.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 9460ee720e40..0cfd9fff45fa 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -381,6 +381,10 @@ def register(self, server): class BaseFederationServerServlet(BaseFederationServlet): + """Abstract base class for federation servlet classes which provides a federation server handler. + + See BaseFederationServlet for more information. + """ def __init__( self, hs: HomeServer, @@ -878,6 +882,10 @@ async def on_GET(self, origin, content, query): class BaseGroupsServerServlet(BaseFederationServlet): + """Abstract base class for federation servlet classes which provides a groups server handler. + + See BaseFederationServlet for more information. + """ def __init__( self, hs: HomeServer, @@ -1086,6 +1094,10 @@ async def on_POST(self, origin, content, query, group_id, user_id): class BaseGroupsLocalServlet(BaseFederationServlet): + """Abstract base class for federation servlet classes which provides a groups local handler. + + See BaseFederationServlet for more information. + """ def __init__( self, hs: HomeServer, From acc6e45dc008ef4a5f625915e9c27f106200e789 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 1 Jun 2021 08:10:32 -0400 Subject: [PATCH 7/8] Lint --- synapse/federation/transport/server.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 0cfd9fff45fa..602ea82e17f7 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -385,6 +385,7 @@ class BaseFederationServerServlet(BaseFederationServlet): See BaseFederationServlet for more information. """ + def __init__( self, hs: HomeServer, @@ -886,6 +887,7 @@ class BaseGroupsServerServlet(BaseFederationServlet): See BaseFederationServlet for more information. """ + def __init__( self, hs: HomeServer, @@ -1098,6 +1100,7 @@ class BaseGroupsLocalServlet(BaseFederationServlet): See BaseFederationServlet for more information. """ + def __init__( self, hs: HomeServer, From 6deea89dc07ee8601efab7994b4c19c73d3fa95a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 8 Jun 2021 09:11:28 -0400 Subject: [PATCH 8/8] Fix a few more type hints. --- synapse/federation/federation_server.py | 6 +++--- synapse/http/servlet.py | 24 ++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index ace30aa45078..86562cd04f28 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -129,7 +129,7 @@ def __init__(self, hs: "HomeServer"): # come in waves. self._state_resp_cache = ResponseCache( hs.get_clock(), "state_resp", timeout_ms=30000 - ) # type: ResponseCache[Tuple[str, str]] + ) # type: ResponseCache[Tuple[str, Optional[str]]] self._state_ids_resp_cache = ResponseCache( hs.get_clock(), "state_ids_resp", timeout_ms=30000 ) # type: ResponseCache[Tuple[str, str]] @@ -406,7 +406,7 @@ async def _process_edu(edu_dict): ) async def on_room_state_request( - self, origin: str, room_id: str, event_id: str + self, origin: str, room_id: str, event_id: Optional[str] ) -> Tuple[int, Dict[str, Any]]: origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) @@ -463,7 +463,7 @@ async def _on_state_ids_request_compute(self, room_id, event_id): return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} async def _on_context_state_request_compute( - self, room_id: str, event_id: str + self, room_id: str, event_id: Optional[str] ) -> Dict[str, list]: if event_id: pdus = await self.handler.get_state_for_pdu( diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index d61563d39b6c..72e2ec78db41 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -295,6 +295,30 @@ def parse_strings_from_args( return default +@overload +def parse_string_from_args( + args: Dict[bytes, List[bytes]], + name: str, + default: Optional[str] = None, + required: Literal[True] = True, + allowed_values: Optional[Iterable[str]] = None, + encoding: str = "ascii", +) -> str: + ... + + +@overload +def parse_string_from_args( + args: Dict[bytes, List[bytes]], + name: str, + default: Optional[str] = None, + required: bool = False, + allowed_values: Optional[Iterable[str]] = None, + encoding: str = "ascii", +) -> Optional[str]: + ... + + def parse_string_from_args( args: Dict[bytes, List[bytes]], name: str,